File size: 12,568 Bytes
afecdc5
9988b25
7793bb6
9988b25
7793bb6
 
 
dca679b
0a7f9b4
9988b25
0a7f9b4
afecdc5
159faf0
dca679b
 
 
 
 
 
 
 
 
 
9988b25
 
 
dca679b
9988b25
159faf0
dca679b
 
 
 
 
 
 
 
 
 
afecdc5
 
7793bb6
0a7f9b4
 
 
 
 
afecdc5
 
7793bb6
afecdc5
 
 
 
 
 
7793bb6
afecdc5
 
7793bb6
0a7f9b4
 
 
 
 
 
 
 
 
 
b3b90ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a7f9b4
7793bb6
afecdc5
15f6c83
7793bb6
159faf0
7793bb6
afecdc5
 
 
7793bb6
0a7f9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afecdc5
7793bb6
 
 
 
 
0a7f9b4
afecdc5
0a7f9b4
7793bb6
afecdc5
 
0a7f9b4
 
afecdc5
7793bb6
afecdc5
0a7f9b4
afecdc5
0a7f9b4
 
 
159faf0
0a7f9b4
 
afecdc5
 
0a7f9b4
 
 
 
7793bb6
 
0a7f9b4
 
 
afecdc5
7793bb6
afecdc5
 
0a7f9b4
 
7793bb6
0a7f9b4
159faf0
afecdc5
 
7793bb6
afecdc5
 
 
7793bb6
afecdc5
 
 
 
 
 
 
7793bb6
afecdc5
0a7f9b4
afecdc5
 
7793bb6
afecdc5
0a7f9b4
7793bb6
afecdc5
 
7793bb6
 
 
afecdc5
7793bb6
 
 
 
afecdc5
 
7793bb6
afecdc5
 
7793bb6
afecdc5
 
 
7793bb6
afecdc5
 
 
 
 
 
 
7793bb6
afecdc5
 
 
 
 
 
 
 
 
7793bb6
afecdc5
 
 
 
 
 
 
 
 
7793bb6
afecdc5
 
 
 
7793bb6
afecdc5
 
7793bb6
f88b1d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import logging
import os
from pathlib import Path
from typing import Any, Dict, List, Optional

import chromadb

from src.config import VECTOR_STORAGE_TYPE
from src.utils.memory_utils import log_memory_checkpoint, memory_monitor
from src.vector_db.postgres_adapter import PostgresVectorAdapter


def create_vector_database(persist_path: Optional[str] = None, collection_name: Optional[str] = None):
    """
    Factory function to create the appropriate vector database implementation.

    Args:
        persist_path: Path for persistence (used by ChromaDB)
        collection_name: Name of the collection

    Returns:
        Vector database implementation
    """
    # Allow runtime override via environment variable to make tests and
    # deploy-time configuration consistent. Prefer explicit env var when set.
    storage_type = os.getenv("VECTOR_STORAGE_TYPE") or VECTOR_STORAGE_TYPE

    if storage_type == "postgres":
        return PostgresVectorAdapter(table_name=collection_name or "document_embeddings")
    else:
        # Default to ChromaDB
        from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH

        return VectorDatabase(
            persist_path=persist_path or VECTOR_DB_PERSIST_PATH,
            collection_name=collection_name or COLLECTION_NAME,
        )


class VectorDatabase:
    """ChromaDB integration for vector storage and similarity search"""

    def __init__(
        self,
        persist_path: str,
        collection_name: str,
    ):
        """
        Initialize the vector database

        Args:
            persist_path: Path to persist the database
            collection_name: Name of the collection to use
        """
        self.persist_path = persist_path
        self.collection_name = collection_name

        # Ensure persist directory exists
        Path(persist_path).mkdir(parents=True, exist_ok=True)

        # Get chroma settings from config for memory optimization
        from chromadb.config import Settings

        from src.config import CHROMA_SETTINGS

        # Convert CHROMA_SETTINGS dict to Settings object
        chroma_settings = Settings(**CHROMA_SETTINGS)

        # Initialize ChromaDB client with persistence and memory optimization
        log_memory_checkpoint("vector_db_before_client_init")
        try:
            self.client = chromadb.PersistentClient(path=persist_path, settings=chroma_settings)
        except Exception as e:
            # Detect common sqlite corrupt/partial-init state where Chroma's sysdb
            # tables (like `tenants`) are missing. Attempt a safe one-time cleanup
            # of the persistence directory and retry initialization. This helps
            # recover when a previous failed startup left an inconsistent DB.
            import glob
            import shutil
            import sqlite3

            logging.warning(
                "ChromaDB persistent client init failed: %s; attempting cleanup and retry",
                e,
            )

            # Only perform aggressive cleanup for sqlite OperationalError or
            # Chroma UniqueConstraint/Operational style issues.
            if isinstance(e, sqlite3.OperationalError) or "no such table" in str(e).lower():
                try:
                    # Remove sqlite files and chroma DB folders under persist_path
                    pattern = os.path.join(persist_path, "*")
                    for p in glob.glob(pattern):
                        try:
                            if os.path.isdir(p):
                                shutil.rmtree(p)
                            else:
                                os.remove(p)
                        except Exception:
                            # Best-effort cleanup; continue
                            logging.debug("Failed to remove %s during cleanup", p)

                    # Recreate the directory and retry
                    Path(persist_path).mkdir(parents=True, exist_ok=True)
                    self.client = chromadb.PersistentClient(path=persist_path, settings=chroma_settings)
                    logging.info("ChromaDB persistence cleaned and client reinitialized")
                except Exception as e2:
                    logging.error("ChromaDB recovery attempt failed: %s", e2)
                    # Re-raise original exception to let caller handle failure
                    raise
            else:
                # If it's an unexpected error, re-raise to be handled upstream
                raise

        log_memory_checkpoint("vector_db_after_client_init")

        # Get or create collection
        self.collection = self.client.get_or_create_collection(name=collection_name)

        logging.info(f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'")

    def get_collection(self):
        """Get the ChromaDB collection"""
        return self.collection

    @memory_monitor
    def add_embeddings_batch(
        self,
        batch_embeddings: List[List[List[float]]],
        batch_chunk_ids: List[List[str]],
        batch_documents: List[List[str]],
        batch_metadatas: List[List[Dict[str, Any]]],
    ) -> int:
        """
        Add embeddings in batches to prevent memory issues with large datasets

        Args:
            batch_embeddings: List of embedding batches
            batch_chunk_ids: List of chunk ID batches
            batch_documents: List of document batches
            batch_metadatas: List of metadata batches

        Returns:
            Number of embeddings added
        """
        total_added = 0

        for i, (embeddings, chunk_ids, documents, metadatas) in enumerate(
            zip(
                batch_embeddings,
                batch_chunk_ids,
                batch_documents,
                batch_metadatas,
            )
        ):
            log_memory_checkpoint(f"before_add_batch_{i}")
            # add_embeddings may return True on success (or raise on failure)
            added = self.add_embeddings(
                embeddings=embeddings,
                chunk_ids=chunk_ids,
                documents=documents,
                metadatas=metadatas,
            )
            # If add_embeddings returns True, treat as all embeddings added
            if isinstance(added, bool) and added:
                added_count = len(embeddings)
            elif isinstance(added, int):
                added_count = int(added)
            else:
                added_count = 0
            total_added += added_count
            logging.info(f"Added batch {i+1}/{len(batch_embeddings)}")

            # Force cleanup after each batch
            import gc

            gc.collect()
            log_memory_checkpoint(f"after_add_batch_{i}")

        return total_added

    @memory_monitor
    def add_embeddings(
        self,
        embeddings: List[List[float]],
        chunk_ids: List[str],
        documents: List[str],
        metadatas: List[Dict[str, Any]],
    ) -> int:
        """
        Add embeddings to the collection

        Args:
            embeddings: List of embedding vectors
            chunk_ids: List of chunk IDs
            documents: List of document texts
            metadatas: List of metadata dictionaries

        Returns:
            Number of embeddings added
        """
        # Validate input lengths
        n = len(embeddings)
        if not (len(chunk_ids) == n and len(documents) == n and len(metadatas) == n):
            raise ValueError(f"Number of embeddings {n} must match number of ids {len(chunk_ids)}")

        log_memory_checkpoint("before_add_embeddings")
        try:
            self.collection.add(
                embeddings=embeddings,
                documents=documents,
                metadatas=metadatas,
                ids=chunk_ids,
            )

            log_memory_checkpoint("after_add_embeddings")
            logging.info(f"Added {n} embeddings to collection")
            # Return boolean True for API compatibility tests
            return True

        except Exception as e:
            logging.error(f"Failed to add embeddings: {e}")
            # Re-raise to allow callers/tests to handle failures explicitly
            raise

    @memory_monitor
    def search(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
        """
        Search for similar embeddings

        Args:
            query_embedding: Query vector to search for
            top_k: Number of results to return

        Returns:
            List of search results with metadata
        """
        try:
            # Handle empty collection
            if self.get_count() == 0:
                return []

            # Perform similarity search
            log_memory_checkpoint("vector_db_before_query")
            results = self.collection.query(
                query_embeddings=[query_embedding],
                n_results=min(top_k, self.get_count()),
            )
            log_memory_checkpoint("vector_db_after_query")

            # Format results
            formatted_results = []

            if results["ids"] and len(results["ids"][0]) > 0:
                for i in range(len(results["ids"][0])):
                    result = {
                        "id": results["ids"][0][i],
                        "document": results["documents"][0][i],
                        "metadata": results["metadatas"][0][i],
                        "distance": results["distances"][0][i],
                    }
                    formatted_results.append(result)

            logging.info(f"Search returned {len(formatted_results)} results")
            return formatted_results

        except Exception as e:
            logging.error(f"Search failed: {e}")
            return []

    def get_count(self) -> int:
        """Get the number of embeddings in the collection"""
        try:
            return self.collection.count()
        except Exception as e:
            logging.error(f"Failed to get count: {e}")
            return 0

    def delete_collection(self) -> bool:
        """Delete the collection"""
        try:
            self.client.delete_collection(name=self.collection_name)
            logging.info(f"Deleted collection '{self.collection_name}'")
            return True
        except Exception as e:
            logging.error(f"Failed to delete collection: {e}")
            return False

    def reset_collection(self) -> bool:
        """Reset the collection (delete and recreate)"""
        try:
            # Delete existing collection
            try:
                self.client.delete_collection(name=self.collection_name)
            except ValueError:
                # Collection doesn't exist, that's fine
                pass

            # Create new collection
            self.collection = self.client.create_collection(name=self.collection_name)
            logging.info(f"Reset collection '{self.collection_name}'")
            return True

        except Exception as e:
            logging.error(f"Failed to reset collection: {e}")
            return False

    def get_embedding_dimension(self) -> int:
        """
        Get the embedding dimension from existing data in the collection.
        Returns 0 if collection is empty or has no embeddings.
        """
        try:
            count = self.get_count()
            if count == 0:
                return 0

            # Retrieve one record to check its embedding dimension
            record = self.collection.get(
                ids=None,  # None returns all records, but we only need one
                include=["embeddings"],
                limit=1,
            )

            if record and "embeddings" in record and record["embeddings"]:
                return len(record["embeddings"][0])

            return 0

        except Exception as e:
            logging.error(f"Failed to get embedding dimension: {e}")
            return 0

    def has_valid_embeddings(self, expected_dimension: int) -> bool:
        """
        Check if the collection has embeddings with the expected dimension.

        Args:
            expected_dimension: The expected embedding dimension

        Returns:
            True if collection has embeddings with correct dimension, False otherwise
        """
        try:
            actual_dimension = self.get_embedding_dimension()
            return actual_dimension == expected_dimension and actual_dimension > 0

        except Exception as e:
            logging.error(f"Failed to validate embeddings: {e}")
            return False