Spaces:
Sleeping
Sleeping
sethmcknight
fix(chroma): recover from corrupted persistent DB by cleaning and retrying init
b3b90ec
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import chromadb | |
| from src.config import VECTOR_STORAGE_TYPE | |
| from src.utils.memory_utils import log_memory_checkpoint, memory_monitor | |
| from src.vector_db.postgres_adapter import PostgresVectorAdapter | |
| def create_vector_database(persist_path: Optional[str] = None, collection_name: Optional[str] = None): | |
| """ | |
| Factory function to create the appropriate vector database implementation. | |
| Args: | |
| persist_path: Path for persistence (used by ChromaDB) | |
| collection_name: Name of the collection | |
| Returns: | |
| Vector database implementation | |
| """ | |
| # Allow runtime override via environment variable to make tests and | |
| # deploy-time configuration consistent. Prefer explicit env var when set. | |
| storage_type = os.getenv("VECTOR_STORAGE_TYPE") or VECTOR_STORAGE_TYPE | |
| if storage_type == "postgres": | |
| return PostgresVectorAdapter(table_name=collection_name or "document_embeddings") | |
| else: | |
| # Default to ChromaDB | |
| from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH | |
| return VectorDatabase( | |
| persist_path=persist_path or VECTOR_DB_PERSIST_PATH, | |
| collection_name=collection_name or COLLECTION_NAME, | |
| ) | |
| class VectorDatabase: | |
| """ChromaDB integration for vector storage and similarity search""" | |
| def __init__( | |
| self, | |
| persist_path: str, | |
| collection_name: str, | |
| ): | |
| """ | |
| Initialize the vector database | |
| Args: | |
| persist_path: Path to persist the database | |
| collection_name: Name of the collection to use | |
| """ | |
| self.persist_path = persist_path | |
| self.collection_name = collection_name | |
| # Ensure persist directory exists | |
| Path(persist_path).mkdir(parents=True, exist_ok=True) | |
| # Get chroma settings from config for memory optimization | |
| from chromadb.config import Settings | |
| from src.config import CHROMA_SETTINGS | |
| # Convert CHROMA_SETTINGS dict to Settings object | |
| chroma_settings = Settings(**CHROMA_SETTINGS) | |
| # Initialize ChromaDB client with persistence and memory optimization | |
| log_memory_checkpoint("vector_db_before_client_init") | |
| try: | |
| self.client = chromadb.PersistentClient(path=persist_path, settings=chroma_settings) | |
| except Exception as e: | |
| # Detect common sqlite corrupt/partial-init state where Chroma's sysdb | |
| # tables (like `tenants`) are missing. Attempt a safe one-time cleanup | |
| # of the persistence directory and retry initialization. This helps | |
| # recover when a previous failed startup left an inconsistent DB. | |
| import glob | |
| import shutil | |
| import sqlite3 | |
| logging.warning( | |
| "ChromaDB persistent client init failed: %s; attempting cleanup and retry", | |
| e, | |
| ) | |
| # Only perform aggressive cleanup for sqlite OperationalError or | |
| # Chroma UniqueConstraint/Operational style issues. | |
| if isinstance(e, sqlite3.OperationalError) or "no such table" in str(e).lower(): | |
| try: | |
| # Remove sqlite files and chroma DB folders under persist_path | |
| pattern = os.path.join(persist_path, "*") | |
| for p in glob.glob(pattern): | |
| try: | |
| if os.path.isdir(p): | |
| shutil.rmtree(p) | |
| else: | |
| os.remove(p) | |
| except Exception: | |
| # Best-effort cleanup; continue | |
| logging.debug("Failed to remove %s during cleanup", p) | |
| # Recreate the directory and retry | |
| Path(persist_path).mkdir(parents=True, exist_ok=True) | |
| self.client = chromadb.PersistentClient(path=persist_path, settings=chroma_settings) | |
| logging.info("ChromaDB persistence cleaned and client reinitialized") | |
| except Exception as e2: | |
| logging.error("ChromaDB recovery attempt failed: %s", e2) | |
| # Re-raise original exception to let caller handle failure | |
| raise | |
| else: | |
| # If it's an unexpected error, re-raise to be handled upstream | |
| raise | |
| log_memory_checkpoint("vector_db_after_client_init") | |
| # Get or create collection | |
| self.collection = self.client.get_or_create_collection(name=collection_name) | |
| logging.info(f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'") | |
| def get_collection(self): | |
| """Get the ChromaDB collection""" | |
| return self.collection | |
| def add_embeddings_batch( | |
| self, | |
| batch_embeddings: List[List[List[float]]], | |
| batch_chunk_ids: List[List[str]], | |
| batch_documents: List[List[str]], | |
| batch_metadatas: List[List[Dict[str, Any]]], | |
| ) -> int: | |
| """ | |
| Add embeddings in batches to prevent memory issues with large datasets | |
| Args: | |
| batch_embeddings: List of embedding batches | |
| batch_chunk_ids: List of chunk ID batches | |
| batch_documents: List of document batches | |
| batch_metadatas: List of metadata batches | |
| Returns: | |
| Number of embeddings added | |
| """ | |
| total_added = 0 | |
| for i, (embeddings, chunk_ids, documents, metadatas) in enumerate( | |
| zip( | |
| batch_embeddings, | |
| batch_chunk_ids, | |
| batch_documents, | |
| batch_metadatas, | |
| ) | |
| ): | |
| log_memory_checkpoint(f"before_add_batch_{i}") | |
| # add_embeddings may return True on success (or raise on failure) | |
| added = self.add_embeddings( | |
| embeddings=embeddings, | |
| chunk_ids=chunk_ids, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ) | |
| # If add_embeddings returns True, treat as all embeddings added | |
| if isinstance(added, bool) and added: | |
| added_count = len(embeddings) | |
| elif isinstance(added, int): | |
| added_count = int(added) | |
| else: | |
| added_count = 0 | |
| total_added += added_count | |
| logging.info(f"Added batch {i+1}/{len(batch_embeddings)}") | |
| # Force cleanup after each batch | |
| import gc | |
| gc.collect() | |
| log_memory_checkpoint(f"after_add_batch_{i}") | |
| return total_added | |
| def add_embeddings( | |
| self, | |
| embeddings: List[List[float]], | |
| chunk_ids: List[str], | |
| documents: List[str], | |
| metadatas: List[Dict[str, Any]], | |
| ) -> int: | |
| """ | |
| Add embeddings to the collection | |
| Args: | |
| embeddings: List of embedding vectors | |
| chunk_ids: List of chunk IDs | |
| documents: List of document texts | |
| metadatas: List of metadata dictionaries | |
| Returns: | |
| Number of embeddings added | |
| """ | |
| # Validate input lengths | |
| n = len(embeddings) | |
| if not (len(chunk_ids) == n and len(documents) == n and len(metadatas) == n): | |
| raise ValueError(f"Number of embeddings {n} must match number of ids {len(chunk_ids)}") | |
| log_memory_checkpoint("before_add_embeddings") | |
| try: | |
| self.collection.add( | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=chunk_ids, | |
| ) | |
| log_memory_checkpoint("after_add_embeddings") | |
| logging.info(f"Added {n} embeddings to collection") | |
| # Return boolean True for API compatibility tests | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to add embeddings: {e}") | |
| # Re-raise to allow callers/tests to handle failures explicitly | |
| raise | |
| def search(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]: | |
| """ | |
| Search for similar embeddings | |
| Args: | |
| query_embedding: Query vector to search for | |
| top_k: Number of results to return | |
| Returns: | |
| List of search results with metadata | |
| """ | |
| try: | |
| # Handle empty collection | |
| if self.get_count() == 0: | |
| return [] | |
| # Perform similarity search | |
| log_memory_checkpoint("vector_db_before_query") | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=min(top_k, self.get_count()), | |
| ) | |
| log_memory_checkpoint("vector_db_after_query") | |
| # Format results | |
| formatted_results = [] | |
| if results["ids"] and len(results["ids"][0]) > 0: | |
| for i in range(len(results["ids"][0])): | |
| result = { | |
| "id": results["ids"][0][i], | |
| "document": results["documents"][0][i], | |
| "metadata": results["metadatas"][0][i], | |
| "distance": results["distances"][0][i], | |
| } | |
| formatted_results.append(result) | |
| logging.info(f"Search returned {len(formatted_results)} results") | |
| return formatted_results | |
| except Exception as e: | |
| logging.error(f"Search failed: {e}") | |
| return [] | |
| def get_count(self) -> int: | |
| """Get the number of embeddings in the collection""" | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| logging.error(f"Failed to get count: {e}") | |
| return 0 | |
| def delete_collection(self) -> bool: | |
| """Delete the collection""" | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| logging.info(f"Deleted collection '{self.collection_name}'") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to delete collection: {e}") | |
| return False | |
| def reset_collection(self) -> bool: | |
| """Reset the collection (delete and recreate)""" | |
| try: | |
| # Delete existing collection | |
| try: | |
| self.client.delete_collection(name=self.collection_name) | |
| except ValueError: | |
| # Collection doesn't exist, that's fine | |
| pass | |
| # Create new collection | |
| self.collection = self.client.create_collection(name=self.collection_name) | |
| logging.info(f"Reset collection '{self.collection_name}'") | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to reset collection: {e}") | |
| return False | |
| def get_embedding_dimension(self) -> int: | |
| """ | |
| Get the embedding dimension from existing data in the collection. | |
| Returns 0 if collection is empty or has no embeddings. | |
| """ | |
| try: | |
| count = self.get_count() | |
| if count == 0: | |
| return 0 | |
| # Retrieve one record to check its embedding dimension | |
| record = self.collection.get( | |
| ids=None, # None returns all records, but we only need one | |
| include=["embeddings"], | |
| limit=1, | |
| ) | |
| if record and "embeddings" in record and record["embeddings"]: | |
| return len(record["embeddings"][0]) | |
| return 0 | |
| except Exception as e: | |
| logging.error(f"Failed to get embedding dimension: {e}") | |
| return 0 | |
| def has_valid_embeddings(self, expected_dimension: int) -> bool: | |
| """ | |
| Check if the collection has embeddings with the expected dimension. | |
| Args: | |
| expected_dimension: The expected embedding dimension | |
| Returns: | |
| True if collection has embeddings with correct dimension, False otherwise | |
| """ | |
| try: | |
| actual_dimension = self.get_embedding_dimension() | |
| return actual_dimension == expected_dimension and actual_dimension > 0 | |
| except Exception as e: | |
| logging.error(f"Failed to validate embeddings: {e}") | |
| return False | |