import json import logging from fastapi import WebSocket, WebSocketDisconnect from pydantic import BaseModel from typing import Optional, Dict from app.rag import RAGSystem # Configure logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) class ConnectionManager: """ Manages active WebSocket connections. Tracks connected clients and handles connection/disconnection events. """ def __init__(self): # Dictionary to store active WebSocket connections by session_id self.active_connections: Dict[str, WebSocket] = {} async def connect(self, websocket: WebSocket, session_id: str): """ Accept a new WebSocket connection and track it. Args: websocket: The WebSocket connection object session_id: Unique identifier for the session """ await websocket.accept() self.active_connections[session_id] = websocket logger.info(f"WebSocket connected: {session_id}") def disconnect(self, session_id: str): """ Remove a WebSocket connection from active connections. Args: session_id: The session ID to disconnect """ self.active_connections.pop(session_id, None) logger.info(f"WebSocket disconnected: {session_id}") # Initialize system components manager = ConnectionManager() # Manages WebSocket connections rag = RAGSystem() # The RAG processing system class ResponseFormatter: """ Formats responses before sending to clients. Can be extended to standardize response structure. """ def __init__(self): pass def format_response(self, response: dict) -> dict: """ Format a response dictionary. Args: response: Raw response dictionary Returns: dict: Formatted response """ return response # Currently passes through unchanged formatter = ResponseFormatter() # Create formatter instance class ChatMessage(BaseModel): """ Pydantic model for validating incoming chat messages. Ensures proper message structure. """ text: str # The message text/content url: Optional[str] = None # Optional URL for context async def websocket_endpoint(websocket: WebSocket, session_id: str): """ WebSocket endpoint for handling real-time chat interactions. Args: websocket: The WebSocket connection session_id: Unique identifier for the chat session Handles: - Connection management - Message processing - Error handling """ # Register the new connection await manager.connect(websocket, session_id) try: while True: # Wait for and receive incoming message data = await websocket.receive_json() logger.info(f"Received message: {data}") # Parse and validate message using Pydantic model message = ChatMessage(**data) # Process the message through RAG system response = await rag.process_query(message.text, session_id) # Log and send the response logger.info(f"Sending response: {response}") await websocket.send_json(response) except WebSocketDisconnect: # Handle graceful disconnection manager.disconnect(session_id) logger.info(f"Client disconnected: {session_id}") except Exception as e: # Handle other errors and notify client logger.error(f"Error in websocket: {str(e)}") await websocket.send_json({ "status": "error", "message": str(e) })