Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| from fastapi import WebSocket, WebSocketDisconnect | |
| from pydantic import BaseModel | |
| from typing import Optional, Dict | |
| from app.rag import RAGSystem | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| class ConnectionManager: | |
| """ | |
| Manages active WebSocket connections. | |
| Tracks connected clients and handles connection/disconnection events. | |
| """ | |
| def __init__(self): | |
| # Dictionary to store active WebSocket connections by session_id | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, session_id: str): | |
| """ | |
| Accept a new WebSocket connection and track it. | |
| Args: | |
| websocket: The WebSocket connection object | |
| session_id: Unique identifier for the session | |
| """ | |
| await websocket.accept() | |
| self.active_connections[session_id] = websocket | |
| logger.info(f"WebSocket connected: {session_id}") | |
| def disconnect(self, session_id: str): | |
| """ | |
| Remove a WebSocket connection from active connections. | |
| Args: | |
| session_id: The session ID to disconnect | |
| """ | |
| self.active_connections.pop(session_id, None) | |
| logger.info(f"WebSocket disconnected: {session_id}") | |
| # Initialize system components | |
| manager = ConnectionManager() # Manages WebSocket connections | |
| rag = RAGSystem() # The RAG processing system | |
| class ResponseFormatter: | |
| """ | |
| Formats responses before sending to clients. | |
| Can be extended to standardize response structure. | |
| """ | |
| def __init__(self): | |
| pass | |
| def format_response(self, response: dict) -> dict: | |
| """ | |
| Format a response dictionary. | |
| Args: | |
| response: Raw response dictionary | |
| Returns: | |
| dict: Formatted response | |
| """ | |
| return response # Currently passes through unchanged | |
| formatter = ResponseFormatter() # Create formatter instance | |
| class ChatMessage(BaseModel): | |
| """ | |
| Pydantic model for validating incoming chat messages. | |
| Ensures proper message structure. | |
| """ | |
| text: str # The message text/content | |
| url: Optional[str] = None # Optional URL for context | |
| async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
| """ | |
| WebSocket endpoint for handling real-time chat interactions. | |
| Args: | |
| websocket: The WebSocket connection | |
| session_id: Unique identifier for the chat session | |
| Handles: | |
| - Connection management | |
| - Message processing | |
| - Error handling | |
| """ | |
| # Register the new connection | |
| await manager.connect(websocket, session_id) | |
| try: | |
| while True: | |
| # Wait for and receive incoming message | |
| data = await websocket.receive_json() | |
| logger.info(f"Received message: {data}") | |
| # Parse and validate message using Pydantic model | |
| message = ChatMessage(**data) | |
| # Process the message through RAG system | |
| response = await rag.process_query(message.text, session_id) | |
| # Log and send the response | |
| logger.info(f"Sending response: {response}") | |
| await websocket.send_json(response) | |
| except WebSocketDisconnect: | |
| # Handle graceful disconnection | |
| manager.disconnect(session_id) | |
| logger.info(f"Client disconnected: {session_id}") | |
| except Exception as e: | |
| # Handle other errors and notify client | |
| logger.error(f"Error in websocket: {str(e)}") | |
| await websocket.send_json({ | |
| "status": "error", | |
| "message": str(e) | |
| }) |