sethmcknight commited on
Commit
15f6c83
·
1 Parent(s): 2f59223

fix(chroma): use get_or_create_collection to prevent race conditions

Browse files
src/config.py CHANGED
@@ -14,7 +14,7 @@ SUPPORTED_FORMATS = {".txt", ".md", ".markdown"}
14
  CORPUS_DIRECTORY = "synthetic_policies"
15
 
16
  # Vector Database Settings
17
- VECTOR_STORAGE_TYPE = os.getenv("VECTOR_STORAGE_TYPE", "postgres") # "chroma" or "postgres"
18
  VECTOR_DB_PERSIST_PATH = "data/chroma_db" # Used for ChromaDB
19
  DATABASE_URL = os.getenv("DATABASE_URL") # Used for PostgreSQL
20
  COLLECTION_NAME = "policy_documents"
 
14
  CORPUS_DIRECTORY = "synthetic_policies"
15
 
16
  # Vector Database Settings
17
+ VECTOR_STORAGE_TYPE = os.getenv("VECTOR_STORAGE_TYPE", "chroma") # "chroma" or "postgres"
18
  VECTOR_DB_PERSIST_PATH = "data/chroma_db" # Used for ChromaDB
19
  DATABASE_URL = os.getenv("DATABASE_URL") # Used for PostgreSQL
20
  COLLECTION_NAME = "policy_documents"
src/embedding/embedding_service.py CHANGED
@@ -151,9 +151,11 @@ class EmbeddingService:
151
  file_name,
152
  e,
153
  )
 
154
  model = ORTModelForFeatureExtraction.from_pretrained(
155
  self.model_name,
156
  provider=provider,
 
157
  session_options=so,
158
  )
159
  logging.info(
 
151
  file_name,
152
  e,
153
  )
154
+ # The key change: we now pass the file_name to the fallback as well
155
  model = ORTModelForFeatureExtraction.from_pretrained(
156
  self.model_name,
157
  provider=provider,
158
+ file_name=file_name, # Added this line
159
  session_options=so,
160
  )
161
  logging.info(
src/vector_store/vector_db.py CHANGED
@@ -72,11 +72,7 @@ class VectorDatabase:
72
  log_memory_checkpoint("vector_db_after_client_init")
73
 
74
  # Get or create collection
75
- try:
76
- self.collection = self.client.get_collection(name=collection_name)
77
- except ValueError:
78
- # Collection doesn't exist, create it
79
- self.collection = self.client.create_collection(name=collection_name)
80
 
81
  logging.info(f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'")
82
 
 
72
  log_memory_checkpoint("vector_db_after_client_init")
73
 
74
  # Get or create collection
75
+ self.collection = self.client.get_or_create_collection(name=collection_name)
 
 
 
 
76
 
77
  logging.info(f"Initialized VectorDatabase with collection " f"'{collection_name}' at '{persist_path}'")
78