ConvoBot / src /handlers.py
ashish-ninehertz
changes
e272f4f
import json
import logging
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from typing import Optional, Dict
from app.rag import RAGSystem
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class ConnectionManager:
"""
Manages active WebSocket connections.
Tracks connected clients and handles connection/disconnection events.
"""
def __init__(self):
# Dictionary to store active WebSocket connections by session_id
self.active_connections: Dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, session_id: str):
"""
Accept a new WebSocket connection and track it.
Args:
websocket: The WebSocket connection object
session_id: Unique identifier for the session
"""
await websocket.accept()
self.active_connections[session_id] = websocket
logger.info(f"WebSocket connected: {session_id}")
def disconnect(self, session_id: str):
"""
Remove a WebSocket connection from active connections.
Args:
session_id: The session ID to disconnect
"""
self.active_connections.pop(session_id, None)
logger.info(f"WebSocket disconnected: {session_id}")
# Initialize system components
manager = ConnectionManager() # Manages WebSocket connections
rag = RAGSystem() # The RAG processing system
class ResponseFormatter:
"""
Formats responses before sending to clients.
Can be extended to standardize response structure.
"""
def __init__(self):
pass
def format_response(self, response: dict) -> dict:
"""
Format a response dictionary.
Args:
response: Raw response dictionary
Returns:
dict: Formatted response
"""
return response # Currently passes through unchanged
formatter = ResponseFormatter() # Create formatter instance
class ChatMessage(BaseModel):
"""
Pydantic model for validating incoming chat messages.
Ensures proper message structure.
"""
text: str # The message text/content
url: Optional[str] = None # Optional URL for context
async def websocket_endpoint(websocket: WebSocket, session_id: str):
"""
WebSocket endpoint for handling real-time chat interactions.
Args:
websocket: The WebSocket connection
session_id: Unique identifier for the chat session
Handles:
- Connection management
- Message processing
- Error handling
"""
# Register the new connection
await manager.connect(websocket, session_id)
try:
while True:
# Wait for and receive incoming message
data = await websocket.receive_json()
logger.info(f"Received message: {data}")
# Parse and validate message using Pydantic model
message = ChatMessage(**data)
# Process the message through RAG system
response = await rag.process_query(message.text, session_id)
# Log and send the response
logger.info(f"Sending response: {response}")
await websocket.send_json(response)
except WebSocketDisconnect:
# Handle graceful disconnection
manager.disconnect(session_id)
logger.info(f"Client disconnected: {session_id}")
except Exception as e:
# Handle other errors and notify client
logger.error(f"Error in websocket: {str(e)}")
await websocket.send_json({
"status": "error",
"message": str(e)
})