File size: 3,764 Bytes
e272f4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import logging
from fastapi import WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from typing import Optional, Dict
from app.rag import RAGSystem

# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

class ConnectionManager:
    """
    Manages active WebSocket connections.
    Tracks connected clients and handles connection/disconnection events.
    """
    
    def __init__(self):
        # Dictionary to store active WebSocket connections by session_id
        self.active_connections: Dict[str, WebSocket] = {}

    async def connect(self, websocket: WebSocket, session_id: str):
        """
        Accept a new WebSocket connection and track it.
        
        Args:
            websocket: The WebSocket connection object
            session_id: Unique identifier for the session
        """
        await websocket.accept()
        self.active_connections[session_id] = websocket
        logger.info(f"WebSocket connected: {session_id}")

    def disconnect(self, session_id: str):
        """
        Remove a WebSocket connection from active connections.
        
        Args:
            session_id: The session ID to disconnect
        """
        self.active_connections.pop(session_id, None)
        logger.info(f"WebSocket disconnected: {session_id}")

# Initialize system components
manager = ConnectionManager()  # Manages WebSocket connections
rag = RAGSystem()  # The RAG processing system

class ResponseFormatter:
    """
    Formats responses before sending to clients.
    Can be extended to standardize response structure.
    """
    
    def __init__(self):
        pass

    def format_response(self, response: dict) -> dict:
        """
        Format a response dictionary.
        
        Args:
            response: Raw response dictionary
            
        Returns:
            dict: Formatted response
        """
        return response  # Currently passes through unchanged

formatter = ResponseFormatter()  # Create formatter instance

class ChatMessage(BaseModel):
    """
    Pydantic model for validating incoming chat messages.
    Ensures proper message structure.
    """
    text: str  # The message text/content
    url: Optional[str] = None  # Optional URL for context

async def websocket_endpoint(websocket: WebSocket, session_id: str):
    """
    WebSocket endpoint for handling real-time chat interactions.
    
    Args:
        websocket: The WebSocket connection
        session_id: Unique identifier for the chat session
        
    Handles:
        - Connection management
        - Message processing
        - Error handling
    """
    # Register the new connection
    await manager.connect(websocket, session_id)
    
    try:
        while True:
            # Wait for and receive incoming message
            data = await websocket.receive_json()
            logger.info(f"Received message: {data}")
            
            # Parse and validate message using Pydantic model
            message = ChatMessage(**data)
            
            # Process the message through RAG system
            response = await rag.process_query(message.text, session_id)
            
            # Log and send the response
            logger.info(f"Sending response: {response}")
            await websocket.send_json(response)

    except WebSocketDisconnect:
        # Handle graceful disconnection
        manager.disconnect(session_id)
        logger.info(f"Client disconnected: {session_id}")
    except Exception as e:
        # Handle other errors and notify client
        logger.error(f"Error in websocket: {str(e)}")
        await websocket.send_json({
            "status": "error",
            "message": str(e)
        })