Spaces:
Sleeping
Sleeping
| """ | |
| PostgreSQL vector database service using pgvector extension. | |
| This service provides persistent vector storage with efficient similarity search. | |
| """ | |
| import logging | |
| import os | |
| from contextlib import contextmanager | |
| from typing import Any, Dict, List, Optional | |
| import psycopg2 | |
| import psycopg2.extras | |
| from psycopg2 import sql | |
| logger = logging.getLogger(__name__) | |
| class PostgresVectorService: | |
| """Vector database service using PostgreSQL with pgvector extension.""" | |
| def __init__( | |
| self, | |
| connection_string: Optional[str] = None, | |
| table_name: str = "document_embeddings", | |
| ): | |
| """ | |
| Initialize PostgreSQL vector service. | |
| Args: | |
| connection_string: PostgreSQL connection string. | |
| If None, uses DATABASE_URL env var. | |
| table_name: Name of the table to store embeddings. | |
| """ | |
| self.connection_string = connection_string or os.getenv("DATABASE_URL") | |
| if not self.connection_string: | |
| raise ValueError("DATABASE_URL environment variable is required") | |
| self.table_name = table_name | |
| self.dimension = None # Will be set based on first embedding | |
| # Test connection and create table | |
| self._initialize_database() | |
| def _get_connection(self): | |
| """Context manager for database connections.""" | |
| conn = None | |
| try: | |
| conn = psycopg2.connect(self.connection_string) | |
| yield conn | |
| except Exception as e: | |
| if conn: | |
| conn.rollback() | |
| logger.error(f"Database connection error: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def _initialize_database(self): | |
| """Initialize database with required extensions and tables.""" | |
| conn = None | |
| try: | |
| conn = psycopg2.connect(self.connection_string) | |
| # Use context-managed cursor so test mocks that set __enter__ work correctly | |
| with conn.cursor() as cur: | |
| # Enable pgvector extension | |
| cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") | |
| # Create table with initial structure (dimension will be added later) | |
| cur.execute( | |
| sql.SQL( | |
| """ | |
| CREATE TABLE IF NOT EXISTS {} ( | |
| id SERIAL PRIMARY KEY, | |
| content TEXT NOT NULL, | |
| embedding vector, | |
| metadata JSONB DEFAULT '{{}}', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ); | |
| """ | |
| ).format(sql.Identifier(self.table_name)) | |
| ) | |
| # Create index for text search | |
| cur.execute( | |
| sql.SQL( | |
| "CREATE INDEX IF NOT EXISTS {} " "ON {} USING gin(to_tsvector('english', content));" | |
| ).format( | |
| sql.Identifier(f"idx_{self.table_name}_content"), | |
| sql.Identifier(self.table_name), | |
| ) | |
| ) | |
| conn.commit() | |
| logger.info("Database initialized with table: %s", self.table_name) | |
| except Exception as e: | |
| # Any initialization errors should be logged and re-raised to surface issues | |
| logger.error(f"Database initialization error: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| conn.close() | |
| def _ensure_embedding_dimension(self, dimension: int): | |
| """Ensure the embedding column has the correct dimension.""" | |
| if self.dimension == dimension: | |
| return | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| # Check if we need to alter the table | |
| cur.execute( | |
| """ | |
| SELECT column_name, data_type, character_maximum_length | |
| FROM information_schema.columns | |
| WHERE table_name = %s AND column_name = 'embedding'; | |
| """, | |
| (self.table_name,), | |
| ) | |
| result = cur.fetchone() | |
| if result and ("vector(%s)" % dimension) not in str(result): | |
| # Drop existing index if it exists | |
| cur.execute( | |
| sql.SQL("DROP INDEX IF EXISTS {}; ").format( | |
| sql.Identifier(f"idx_{self.table_name}_embedding_cosine") | |
| ) | |
| ) | |
| # Alter column to correct dimension | |
| cur.execute( | |
| sql.SQL("ALTER TABLE {} ALTER COLUMN embedding TYPE vector({});").format( | |
| sql.Identifier(self.table_name), sql.Literal(dimension) | |
| ) | |
| ) | |
| # Create optimized index for similarity search | |
| cur.execute( | |
| sql.SQL( | |
| "CREATE INDEX IF NOT EXISTS {} ON {} " | |
| "USING ivfflat (embedding vector_cosine_ops) " | |
| "WITH (lists = 100);" | |
| ).format( | |
| sql.Identifier(f"idx_{self.table_name}_embedding_cosine"), | |
| sql.Identifier(self.table_name), | |
| ) | |
| ) | |
| conn.commit() | |
| logger.info("Updated embedding dimension to %s", dimension) | |
| self.dimension = dimension | |
| def add_documents( | |
| self, | |
| texts: List[str], | |
| embeddings: List[List[float]], | |
| metadatas: Optional[List[Dict[str, Any]]] = None, | |
| ) -> List[str]: | |
| """ | |
| Add documents with their embeddings to the database. | |
| Args: | |
| texts: List of document texts | |
| embeddings: List of embedding vectors | |
| metadatas: Optional list of metadata dictionaries | |
| Returns: | |
| List of document IDs | |
| """ | |
| if not texts or not embeddings: | |
| return [] | |
| if len(texts) != len(embeddings): | |
| raise ValueError("Number of texts must match number of embeddings") | |
| if metadatas and len(metadatas) != len(texts): | |
| raise ValueError("Number of metadatas must match number of texts") | |
| # Ensure embedding dimension is set | |
| if embeddings: | |
| self._ensure_embedding_dimension(len(embeddings[0])) | |
| # Default empty metadata if not provided | |
| if metadatas is None: | |
| metadatas = [{}] * len(texts) | |
| document_ids = [] | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| for text, embedding, metadata in zip(texts, embeddings, metadatas): | |
| # Insert document and get ID (table name composed safely) | |
| cur.execute( | |
| sql.SQL( | |
| "INSERT INTO {} (content, embedding, metadata) " "VALUES (%s, %s, %s) RETURNING id;" | |
| ).format(sql.Identifier(self.table_name)), | |
| (text, embedding, psycopg2.extras.Json(metadata)), | |
| ) | |
| doc_id = cur.fetchone()[0] | |
| document_ids.append(str(doc_id)) | |
| conn.commit() | |
| logger.info("Added %d documents to database", len(document_ids)) | |
| return document_ids | |
| def similarity_search( | |
| self, | |
| query_embedding: List[float], | |
| k: int = 5, | |
| filter_metadata: Optional[Dict[str, Any]] = None, | |
| ) -> List[Dict]: | |
| """ | |
| Perform similarity search using cosine distance. | |
| Args: | |
| query_embedding: Query embedding vector | |
| k: Number of results to return | |
| filter_metadata: Optional metadata filters | |
| Returns: | |
| List of documents with similarity scores | |
| """ | |
| if not query_embedding: | |
| return [] | |
| # Build WHERE clause for metadata filtering | |
| where_clause = "" | |
| params = [query_embedding, query_embedding, k] | |
| if filter_metadata: | |
| conditions = [] | |
| for key, value in filter_metadata.items(): | |
| if isinstance(value, str): | |
| conditions.append("metadata->>%s = %s") | |
| params.insert(-1, key) | |
| params.insert(-1, value) | |
| elif isinstance(value, (int, float)): | |
| conditions.append("(metadata->>%s)::numeric = %s") | |
| params.insert(-1, key) | |
| params.insert(-1, value) | |
| if conditions: | |
| where_clause = "WHERE " + " AND ".join(conditions) | |
| # Compose query safely with identifier for table name. where_clause | |
| # contains only parameter placeholders (%s) and logical operators. | |
| query = sql.SQL( | |
| """ | |
| SELECT id, content, metadata, | |
| 1 - (embedding <=> %s) as similarity_score | |
| FROM {} | |
| {} | |
| ORDER BY embedding <=> %s | |
| LIMIT %s; | |
| """ | |
| ).format(sql.Identifier(self.table_name), sql.SQL(where_clause)) | |
| with self._get_connection() as conn: | |
| with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: | |
| cur.execute(query, params) | |
| results = cur.fetchall() | |
| return [ | |
| { | |
| "id": str(row["id"]), | |
| "content": row["content"], | |
| "metadata": row["metadata"] or {}, | |
| "similarity_score": float(row["similarity_score"]), | |
| } | |
| for row in results | |
| ] | |
| def get_collection_info(self) -> Dict[str, Any]: | |
| """Get information about the vector collection.""" | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| # Get document count | |
| cur.execute(sql.SQL("SELECT COUNT(*) FROM {};").format(sql.Identifier(self.table_name))) | |
| doc_count = cur.fetchone()[0] | |
| # Get table size | |
| cur.execute( | |
| sql.SQL("SELECT pg_size_pretty(pg_total_relation_size({})) as size;").format( | |
| sql.Identifier(self.table_name) | |
| ) | |
| ) | |
| table_size = cur.fetchone()[0] | |
| # Get dimension info | |
| cur.execute( | |
| """ | |
| SELECT column_name, data_type | |
| FROM information_schema.columns | |
| WHERE table_name = %s AND column_name = 'embedding'; | |
| """, | |
| (self.table_name,), | |
| ) | |
| embedding_info = cur.fetchone() | |
| return { | |
| "document_count": doc_count, | |
| "table_size": table_size, | |
| "embedding_dimension": self.dimension, | |
| "table_name": self.table_name, | |
| "embedding_column_type": (embedding_info[1] if embedding_info else None), | |
| } | |
| def delete_documents(self, document_ids: List[str]) -> int: | |
| """ | |
| Delete documents by their IDs. | |
| Args: | |
| document_ids: List of document IDs to delete | |
| Returns: | |
| Number of documents deleted | |
| """ | |
| if not document_ids: | |
| return 0 | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| # Convert string IDs to integers | |
| int_ids = [int(doc_id) for doc_id in document_ids] | |
| cur.execute( | |
| sql.SQL("DELETE FROM {} WHERE id = ANY(%s);").format(sql.Identifier(self.table_name)), | |
| (int_ids,), | |
| ) | |
| deleted_count = cur.rowcount | |
| conn.commit() | |
| logger.info("Deleted %d documents", deleted_count) | |
| return deleted_count | |
| def delete_all_documents(self) -> int: | |
| """ | |
| Delete all documents from the collection. | |
| Returns: | |
| Number of documents deleted | |
| """ | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(sql.SQL("SELECT COUNT(*) FROM {};").format(sql.Identifier(self.table_name))) | |
| count_before = cur.fetchone()[0] | |
| cur.execute(sql.SQL("DELETE FROM {};").format(sql.Identifier(self.table_name))) | |
| # Reset the sequence | |
| cur.execute( | |
| sql.SQL("ALTER SEQUENCE {} RESTART WITH 1;").format(sql.Identifier(f"{self.table_name}_id_seq")) | |
| ) | |
| conn.commit() | |
| logger.info("Deleted all %d documents", count_before) | |
| return count_before | |
| def update_document( | |
| self, | |
| document_id: str, | |
| content: Optional[str] = None, | |
| embedding: Optional[List[float]] = None, | |
| metadata: Optional[Dict[str, Any]] = None, | |
| ) -> bool: | |
| """ | |
| Update a document's content, embedding, or metadata. | |
| Args: | |
| document_id: ID of document to update | |
| content: New content (optional) | |
| embedding: New embedding (optional) | |
| metadata: New metadata (optional) | |
| Returns: | |
| True if document was updated, False if not found | |
| """ | |
| if not any([content, embedding, metadata]): | |
| return False | |
| updates = [] | |
| params = [] | |
| if content is not None: | |
| updates.append("content = %s") | |
| params.append(content) | |
| if embedding is not None: | |
| updates.append("embedding = %s") | |
| params.append(embedding) | |
| if metadata is not None: | |
| updates.append("metadata = %s") | |
| params.append(psycopg2.extras.Json(metadata)) | |
| updates.append("updated_at = CURRENT_TIMESTAMP") | |
| params.append(int(document_id)) | |
| # Compose update query with safe identifier for the table name. | |
| query = sql.SQL("UPDATE {} SET " + ", ".join(updates) + " WHERE id = %s").format( | |
| sql.Identifier(self.table_name) | |
| ) | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(query, params) | |
| updated = cur.rowcount > 0 | |
| conn.commit() | |
| if updated: | |
| logger.info("Updated document %s", document_id) | |
| else: | |
| logger.warning("Document %s not found for update", document_id) | |
| return updated | |
| def get_document(self, document_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get a single document by ID. | |
| Args: | |
| document_id: ID of document to retrieve | |
| Returns: | |
| Document dictionary or None if not found | |
| """ | |
| with self._get_connection() as conn: | |
| with conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: | |
| cur.execute( | |
| sql.SQL("SELECT id, content, metadata, created_at, " "updated_at FROM {} WHERE id = %s;").format( | |
| sql.Identifier(self.table_name) | |
| ), | |
| (int(document_id),), | |
| ) | |
| row = cur.fetchone() | |
| if row: | |
| return { | |
| "id": str(row["id"]), | |
| "content": row["content"], | |
| "metadata": row["metadata"] or {}, | |
| "created_at": (row["created_at"].isoformat() if row["created_at"] else None), | |
| "updated_at": (row["updated_at"].isoformat() if row["updated_at"] else None), | |
| } | |
| return None | |
| def health_check(self) -> Dict[str, Any]: | |
| """ | |
| Check the health of the vector database service. | |
| Returns: | |
| Health status dictionary | |
| """ | |
| try: | |
| with self._get_connection() as conn: | |
| with conn.cursor() as cur: | |
| # Test basic connectivity | |
| cur.execute("SELECT 1") | |
| # consume the result to align with mocked fetchone side_effect | |
| # ordering | |
| try: | |
| _ = cur.fetchone() | |
| except Exception: | |
| pass | |
| # Check if pgvector extension is installed | |
| cur.execute("SELECT EXISTS(SELECT 1 FROM pg_extension " "WHERE extname = 'vector')") | |
| result = cur.fetchone() | |
| pgvector_installed = bool(result[0]) if result else False | |
| # Get basic stats | |
| info = self.get_collection_info() | |
| return { | |
| "status": "healthy", | |
| "pgvector_installed": pgvector_installed, | |
| "connection": "ok", | |
| "collection_info": info, | |
| } | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return {"status": "unhealthy", "error": str(e), "connection": "failed"} | |