Spaces:
Sleeping
Sleeping
| import os | |
| import pickle | |
| import logging | |
| from typing import Dict, List, Any | |
| from app.config import Config | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue | |
| class SessionStorage: | |
| """ | |
| Manages session persistence using a hybrid storage approach: | |
| - Stores session metadata in local pickle files | |
| - Stores vector data in Qdrant collections | |
| - Maintains connection between the two | |
| """ | |
| def __init__(self): | |
| """ | |
| Initialize the session storage system. | |
| Sets up Qdrant client connection and ensures storage directory exists. | |
| """ | |
| try: | |
| Config.create_storage_dir() | |
| self.logger = logging.getLogger(__name__) | |
| # Initialize Qdrant client with configuration from Config | |
| self.qdrant_client = QdrantClient( | |
| host=Config.QDRANT_HOST, | |
| port=Config.QDRANT_PORT, | |
| prefer_grpc=True # Use gRPC for better performance | |
| ) | |
| self.logger.info("Qdrant client initialized") | |
| except Exception as e: | |
| self.logger.error(f"Storage initialization error: {str(e)}") | |
| raise RuntimeError("Storage initialization failed") from e | |
| def get_session_path(self, session_id: str) -> str: | |
| """ | |
| Get the filesystem path for a session's pickle file. | |
| Args: | |
| session_id: Unique session identifier | |
| Returns: | |
| str: Full path to session file | |
| """ | |
| return os.path.join(Config.STORAGE_DIR, f"{session_id}.pkl") | |
| def save_session(self, session_id: str, data: Dict): | |
| """ | |
| Persist session data to disk (excluding Qdrant references). | |
| Args: | |
| session_id: Session identifier | |
| data: Session data dictionary | |
| """ | |
| session_path = self.get_session_path(session_id) | |
| # Remove Qdrant collection reference before saving to avoid serialization issues | |
| data = data.copy() | |
| if 'qdrant_collection' in data: | |
| del data['qdrant_collection'] | |
| with open(session_path, 'wb') as f: | |
| pickle.dump(data, f) | |
| def load_session(self, session_id: str) -> Dict: | |
| """ | |
| Load session data from disk and reconnect to Qdrant collection. | |
| Args: | |
| session_id: Session identifier | |
| Returns: | |
| Dict: Session data with restored Qdrant collection reference | |
| """ | |
| session_path = self.get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return None | |
| with open(session_path, 'rb') as f: | |
| data = pickle.load(f) | |
| # Restore Qdrant collection reference | |
| collection_name = f"session_{session_id}" | |
| data['qdrant_collection'] = collection_name | |
| # Ensure collection exists in Qdrant (create if missing) | |
| if not self.qdrant_client.collection_exists(collection_name): | |
| self.logger.warning(f"Qdrant collection {collection_name} missing, creating new") | |
| self.qdrant_client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams( | |
| size=Config.EMBEDDING_SIZE, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| return data | |
| def delete_session(self, session_id: str): | |
| """ | |
| Completely remove a session (both disk and Qdrant storage). | |
| Args: | |
| session_id: Session identifier to delete | |
| """ | |
| session_path = self.get_session_path(session_id) | |
| # Delete Qdrant collection first | |
| collection_name = f"session_{session_id}" | |
| try: | |
| self.qdrant_client.delete_collection(collection_name) | |
| self.logger.info(f"Deleted Qdrant collection: {collection_name}") | |
| except Exception as e: | |
| self.logger.error(f"Error deleting Qdrant collection: {str(e)}") | |
| # Delete session file | |
| if os.path.exists(session_path): | |
| os.remove(session_path) | |
| class QdrantStorage: | |
| """ | |
| Manages vector storage operations using Qdrant. | |
| Handles collection management and vector operations. | |
| """ | |
| def __init__(self, collection_name: str, vector_size: int, | |
| host: str = Config.QDRANT_HOST, port: int = Config.QDRANT_PORT): | |
| """ | |
| Initialize Qdrant storage for a specific collection. | |
| Args: | |
| collection_name: Name of the Qdrant collection | |
| vector_size: Dimensionality of vectors to store | |
| host: Qdrant server host (default from Config) | |
| port: Qdrant server port (default from Config) | |
| """ | |
| self.logger = logging.getLogger(__name__) | |
| self.collection_name = collection_name | |
| self.vector_size = vector_size | |
| # Initialize Qdrant client with gRPC preference | |
| self.qdrant = QdrantClient(host=host, port=port, prefer_grpc=True) | |
| self._ensure_collection() | |
| def _ensure_collection(self): | |
| """ | |
| Ensure the collection exists in Qdrant. | |
| Creates it if missing, otherwise verifies configuration. | |
| """ | |
| try: | |
| collection_info = self.qdrant.get_collection(self.collection_name) | |
| if collection_info.vectors_count > 0: | |
| self.logger.info(f"Using existing Qdrant collection: {self.collection_name}") | |
| except Exception: | |
| self.logger.info(f"Creating Qdrant collection: {self.collection_name}") | |
| self.qdrant.recreate_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=self.vector_size, | |
| distance=Distance.COSINE # Using cosine similarity | |
| ) | |
| ) | |
| def add_vectors(self, vectors: List[List[float]], payloads: List[Dict[str, Any]], offset: int = 0): | |
| """ | |
| Add vectors and associated metadata to the collection. | |
| Args: | |
| vectors: List of vector embeddings | |
| payloads: List of metadata dictionaries | |
| offset: Starting ID for new points (default 0) | |
| """ | |
| points = [ | |
| PointStruct( | |
| id=offset + idx, # Sequential IDs with optional offset | |
| vector=vector, | |
| payload=payload | |
| ) | |
| for idx, (vector, payload) in enumerate(zip(vectors, payloads)) | |
| ] | |
| self.qdrant.upsert( | |
| collection_name=self.collection_name, | |
| points=points, | |
| wait=True # Ensure immediate persistence | |
| ) | |
| self.logger.info(f"Added {len(points)} vectors to Qdrant collection '{self.collection_name}'") | |
| def search(self, query_vector: List[float], session_id: str, limit: int = 5): | |
| """ | |
| Search the collection for similar vectors, filtered by session. | |
| Args: | |
| query_vector: The vector to compare against | |
| session_id: Session identifier to filter results | |
| limit: Maximum number of results to return | |
| Returns: | |
| List[Dict]: Search results with scores and metadata | |
| """ | |
| # Add session filter to ensure only current session results | |
| results = self.qdrant.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_vector, | |
| query_filter=Filter( | |
| must=[ | |
| FieldCondition( | |
| key="session_id", | |
| match=MatchValue(value=session_id) | |
| ) | |
| ] | |
| ), | |
| limit=limit | |
| ) | |
| return [ | |
| { | |
| "id": hit.id, | |
| "score": hit.score, | |
| "payload": hit.payload | |
| } | |
| for hit in results | |
| ] |