Spaces:
Sleeping
Sleeping
Tobias Pasquale
commited on
Commit
·
7793bb6
1
Parent(s):
7effb84
style: Fix code formatting and linting issues for CI/CD compliance
Browse files- Fix black code formatting across all Python files
- Fix isort import ordering
- Remove unused imports (pytest, pathlib, numpy, Union)
- Fix line length issues (split long lines)
- Remove unused variables in tests
- Fix whitespace and end-of-file issues
- Address flake8 linting requirements
All changes maintain functionality while ensuring CI/CD pipeline passes.
- CHANGELOG.md +19 -1
- app.py +16 -18
- src/__init__.py +1 -1
- src/config.py +3 -3
- src/embedding/__init__.py +1 -1
- src/embedding/embedding_service.py +55 -48
- src/ingestion/__init__.py +1 -1
- src/ingestion/document_chunker.py +43 -36
- src/ingestion/document_parser.py +19 -21
- src/ingestion/ingestion_pipeline.py +26 -20
- src/vector_store/__init__.py +1 -1
- src/vector_store/vector_db.py +57 -49
- tests/test_app.py +2 -1
- tests/test_embedding/__init__.py +1 -1
- tests/test_embedding/test_embedding_service.py +57 -46
- tests/test_ingestion/__init__.py +1 -1
- tests/test_ingestion/test_document_chunker.py +66 -58
- tests/test_ingestion/test_document_parser.py +39 -30
- tests/test_ingestion/test_ingestion_pipeline.py +70 -58
- tests/test_integration.py +66 -47
- tests/test_vector_store/__init__.py +1 -1
- tests/test_vector_store/test_vector_db.py +51 -41
CHANGELOG.md
CHANGED
|
@@ -205,6 +205,24 @@ Each entry includes:
|
|
| 205 |
- **Foundation Complete**: ChromaDB + HuggingFace embeddings fully integrated and tested
|
| 206 |
- **Phase 2A Status**: ✅ COMPLETED SUCCESSFULLY - Ready for Phase 2B Enhanced Ingestion Pipeline
|
| 207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
---
|
| 209 |
|
| 210 |
## Next Planned Actions
|
|
@@ -248,4 +266,4 @@ Each entry includes:
|
|
| 248 |
|
| 249 |
---
|
| 250 |
|
| 251 |
-
*This changelog is automatically updated after each development action to maintain complete project transparency and audit trail.*
|
|
|
|
| 205 |
- **Foundation Complete**: ChromaDB + HuggingFace embeddings fully integrated and tested
|
| 206 |
- **Phase 2A Status**: ✅ COMPLETED SUCCESSFULLY - Ready for Phase 2B Enhanced Ingestion Pipeline
|
| 207 |
|
| 208 |
+
#### Entry #012 - 2025-10-17 17:30
|
| 209 |
+
- **Action Type**: DEPLOY + COLLABORATE
|
| 210 |
+
- **Component**: Project Documentation & Team Collaboration
|
| 211 |
+
- **Description**: Moved development changelog to root directory and committed to git for better team collaboration and visibility
|
| 212 |
+
- **Files Changed**:
|
| 213 |
+
- Moved: `planning/development-changelog.md` → `CHANGELOG.md` (root directory)
|
| 214 |
+
- Modified: `README.md` (added Development Progress section)
|
| 215 |
+
- Committed: All Phase 2A changes to `feat/embedding-vector-storage` branch
|
| 216 |
+
- **Tests**: N/A (documentation/collaboration improvement)
|
| 217 |
+
- **CI/CD**: Branch pushed to GitHub with comprehensive commit history
|
| 218 |
+
- **Notes**:
|
| 219 |
+
- **Team Collaboration**: CHANGELOG.md now visible in repository for partner collaboration
|
| 220 |
+
- **Comprehensive Commit**: All Phase 2A changes committed with detailed descriptions
|
| 221 |
+
- **Documentation Enhancement**: README updated to reference changelog for development tracking
|
| 222 |
+
- **Branch Status**: `feat/embedding-vector-storage` ready for pull request and code review
|
| 223 |
+
- **Visibility Improvement**: Development progress now trackable by all team members
|
| 224 |
+
- **Next Steps**: Ready for partner review and Phase 2B planning collaboration
|
| 225 |
+
|
| 226 |
---
|
| 227 |
|
| 228 |
## Next Planned Actions
|
|
|
|
| 266 |
|
| 267 |
---
|
| 268 |
|
| 269 |
+
*This changelog is automatically updated after each development action to maintain complete project transparency and audit trail.*
|
app.py
CHANGED
|
@@ -19,32 +19,30 @@ def health():
|
|
| 19 |
return jsonify({"status": "ok"}), 200
|
| 20 |
|
| 21 |
|
| 22 |
-
@app.route(
|
| 23 |
def ingest():
|
| 24 |
"""Endpoint to trigger document ingestion"""
|
| 25 |
try:
|
|
|
|
|
|
|
| 26 |
from src.ingestion.ingestion_pipeline import IngestionPipeline
|
| 27 |
-
|
| 28 |
-
|
| 29 |
pipeline = IngestionPipeline(
|
| 30 |
-
chunk_size=DEFAULT_CHUNK_SIZE,
|
| 31 |
-
overlap=DEFAULT_OVERLAP,
|
| 32 |
-
seed=RANDOM_SEED
|
| 33 |
)
|
| 34 |
-
|
| 35 |
chunks = pipeline.process_directory(CORPUS_DIRECTORY)
|
| 36 |
-
|
| 37 |
-
return jsonify(
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
except Exception as e:
|
| 44 |
-
return jsonify({
|
| 45 |
-
"status": "error",
|
| 46 |
-
"message": str(e)
|
| 47 |
-
}), 500
|
| 48 |
|
| 49 |
|
| 50 |
if __name__ == "__main__":
|
|
|
|
| 19 |
return jsonify({"status": "ok"}), 200
|
| 20 |
|
| 21 |
|
| 22 |
+
@app.route("/ingest", methods=["POST"])
|
| 23 |
def ingest():
|
| 24 |
"""Endpoint to trigger document ingestion"""
|
| 25 |
try:
|
| 26 |
+
from src.config import (CORPUS_DIRECTORY, DEFAULT_CHUNK_SIZE,
|
| 27 |
+
DEFAULT_OVERLAP, RANDOM_SEED)
|
| 28 |
from src.ingestion.ingestion_pipeline import IngestionPipeline
|
| 29 |
+
|
|
|
|
| 30 |
pipeline = IngestionPipeline(
|
| 31 |
+
chunk_size=DEFAULT_CHUNK_SIZE, overlap=DEFAULT_OVERLAP, seed=RANDOM_SEED
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
+
|
| 34 |
chunks = pipeline.process_directory(CORPUS_DIRECTORY)
|
| 35 |
+
|
| 36 |
+
return jsonify(
|
| 37 |
+
{
|
| 38 |
+
"status": "success",
|
| 39 |
+
"chunks_processed": len(chunks),
|
| 40 |
+
"message": f"Successfully processed {len(chunks)} chunks",
|
| 41 |
+
}
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
except Exception as e:
|
| 45 |
+
return jsonify({"status": "error", "message": str(e)}), 500
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
if __name__ == "__main__":
|
src/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Empty file to make src a package
|
|
|
|
| 1 |
+
# Empty file to make src a package
|
src/config.py
CHANGED
|
@@ -6,10 +6,10 @@ DEFAULT_OVERLAP = 200
|
|
| 6 |
RANDOM_SEED = 42
|
| 7 |
|
| 8 |
# Supported file formats
|
| 9 |
-
SUPPORTED_FORMATS = {
|
| 10 |
|
| 11 |
# Corpus directory
|
| 12 |
-
CORPUS_DIRECTORY =
|
| 13 |
|
| 14 |
# Vector Database Settings
|
| 15 |
VECTOR_DB_PERSIST_PATH = "data/chroma_db"
|
|
@@ -25,4 +25,4 @@ EMBEDDING_DEVICE = "cpu" # Use CPU for free tier compatibility
|
|
| 25 |
# Search Settings
|
| 26 |
DEFAULT_TOP_K = 5
|
| 27 |
MAX_TOP_K = 20
|
| 28 |
-
MIN_SIMILARITY_THRESHOLD = 0.3
|
|
|
|
| 6 |
RANDOM_SEED = 42
|
| 7 |
|
| 8 |
# Supported file formats
|
| 9 |
+
SUPPORTED_FORMATS = {".txt", ".md", ".markdown"}
|
| 10 |
|
| 11 |
# Corpus directory
|
| 12 |
+
CORPUS_DIRECTORY = "synthetic_policies"
|
| 13 |
|
| 14 |
# Vector Database Settings
|
| 15 |
VECTOR_DB_PERSIST_PATH = "data/chroma_db"
|
|
|
|
| 25 |
# Search Settings
|
| 26 |
DEFAULT_TOP_K = 5
|
| 27 |
MAX_TOP_K = 20
|
| 28 |
+
MIN_SIMILARITY_THRESHOLD = 0.3
|
src/embedding/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Embedding service package for HuggingFace model integration
|
|
|
|
| 1 |
+
# Embedding service package for HuggingFace model integration
|
src/embedding/embedding_service.py
CHANGED
|
@@ -1,22 +1,24 @@
|
|
| 1 |
-
from sentence_transformers import SentenceTransformer
|
| 2 |
-
from typing import List, Union
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class EmbeddingService:
|
| 7 |
"""HuggingFace sentence-transformers wrapper for generating embeddings"""
|
| 8 |
-
|
| 9 |
_model_cache = {} # Class-level cache for model instances
|
| 10 |
-
|
| 11 |
def __init__(
|
| 12 |
-
self,
|
| 13 |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 14 |
device: str = "cpu",
|
| 15 |
-
batch_size: int = 32
|
| 16 |
):
|
| 17 |
"""
|
| 18 |
Initialize the embedding service
|
| 19 |
-
|
| 20 |
Args:
|
| 21 |
model_name: HuggingFace model name
|
| 22 |
device: Device to run the model on ('cpu' or 'cuda')
|
|
@@ -25,64 +27,69 @@ class EmbeddingService:
|
|
| 25 |
self.model_name = model_name
|
| 26 |
self.device = device
|
| 27 |
self.batch_size = batch_size
|
| 28 |
-
|
| 29 |
# Load model (with caching)
|
| 30 |
self.model = self._load_model()
|
| 31 |
-
|
| 32 |
-
logging.info(
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
def _load_model(self) -> SentenceTransformer:
|
| 35 |
"""Load the sentence transformer model with caching"""
|
| 36 |
cache_key = f"{self.model_name}_{self.device}"
|
| 37 |
-
|
| 38 |
if cache_key not in self._model_cache:
|
| 39 |
-
logging.info(
|
|
|
|
|
|
|
| 40 |
model = SentenceTransformer(self.model_name, device=self.device)
|
| 41 |
self._model_cache[cache_key] = model
|
| 42 |
-
logging.info(
|
| 43 |
else:
|
| 44 |
logging.info(f"Using cached model '{self.model_name}'")
|
| 45 |
-
|
| 46 |
return self._model_cache[cache_key]
|
| 47 |
-
|
| 48 |
def embed_text(self, text: str) -> List[float]:
|
| 49 |
"""
|
| 50 |
Generate embedding for a single text
|
| 51 |
-
|
| 52 |
Args:
|
| 53 |
text: Text to embed
|
| 54 |
-
|
| 55 |
Returns:
|
| 56 |
List of float values representing the embedding
|
| 57 |
"""
|
| 58 |
if not text.strip():
|
| 59 |
# Handle empty text - still generate embedding
|
| 60 |
text = " " # Single space to avoid completely empty input
|
| 61 |
-
|
| 62 |
try:
|
| 63 |
# Generate embedding
|
| 64 |
embedding = self.model.encode(text, convert_to_numpy=True)
|
| 65 |
-
|
| 66 |
# Convert to Python list of floats
|
| 67 |
return embedding.tolist()
|
| 68 |
-
|
| 69 |
except Exception as e:
|
| 70 |
logging.error(f"Failed to generate embedding for text: {e}")
|
| 71 |
raise e
|
| 72 |
-
|
| 73 |
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
| 74 |
"""
|
| 75 |
Generate embeddings for multiple texts
|
| 76 |
-
|
| 77 |
Args:
|
| 78 |
texts: List of texts to embed
|
| 79 |
-
|
| 80 |
Returns:
|
| 81 |
List of embeddings (each embedding is a list of floats)
|
| 82 |
"""
|
| 83 |
if not texts:
|
| 84 |
return []
|
| 85 |
-
|
| 86 |
try:
|
| 87 |
# Preprocess empty texts
|
| 88 |
processed_texts = []
|
|
@@ -91,48 +98,48 @@ class EmbeddingService:
|
|
| 91 |
processed_texts.append(" ") # Single space for empty texts
|
| 92 |
else:
|
| 93 |
processed_texts.append(text)
|
| 94 |
-
|
| 95 |
# Generate embeddings in batches
|
| 96 |
all_embeddings = []
|
| 97 |
-
|
| 98 |
for i in range(0, len(processed_texts), self.batch_size):
|
| 99 |
-
batch_texts = processed_texts[i:i + self.batch_size]
|
| 100 |
-
|
| 101 |
# Generate embeddings for this batch
|
| 102 |
batch_embeddings = self.model.encode(
|
| 103 |
-
batch_texts,
|
| 104 |
convert_to_numpy=True,
|
| 105 |
-
show_progress_bar=False # Disable progress bar for cleaner output
|
| 106 |
)
|
| 107 |
-
|
| 108 |
# Convert to list of lists
|
| 109 |
for embedding in batch_embeddings:
|
| 110 |
all_embeddings.append(embedding.tolist())
|
| 111 |
-
|
| 112 |
logging.info(f"Generated embeddings for {len(texts)} texts")
|
| 113 |
return all_embeddings
|
| 114 |
-
|
| 115 |
except Exception as e:
|
| 116 |
logging.error(f"Failed to generate embeddings for texts: {e}")
|
| 117 |
raise e
|
| 118 |
-
|
| 119 |
def get_embedding_dimension(self) -> int:
|
| 120 |
"""Get the dimension of embeddings produced by this model"""
|
| 121 |
return self.model.get_sentence_embedding_dimension()
|
| 122 |
-
|
| 123 |
def encode_batch(self, texts: List[str]) -> np.ndarray:
|
| 124 |
"""
|
| 125 |
Generate embeddings and return as numpy array (for efficiency)
|
| 126 |
-
|
| 127 |
Args:
|
| 128 |
texts: List of texts to embed
|
| 129 |
-
|
| 130 |
Returns:
|
| 131 |
NumPy array of embeddings
|
| 132 |
"""
|
| 133 |
if not texts:
|
| 134 |
return np.array([])
|
| 135 |
-
|
| 136 |
# Preprocess empty texts
|
| 137 |
processed_texts = []
|
| 138 |
for text in texts:
|
|
@@ -140,33 +147,33 @@ class EmbeddingService:
|
|
| 140 |
processed_texts.append(" ")
|
| 141 |
else:
|
| 142 |
processed_texts.append(text)
|
| 143 |
-
|
| 144 |
return self.model.encode(processed_texts, convert_to_numpy=True)
|
| 145 |
-
|
| 146 |
def similarity(self, text1: str, text2: str) -> float:
|
| 147 |
"""
|
| 148 |
Calculate cosine similarity between two texts
|
| 149 |
-
|
| 150 |
Args:
|
| 151 |
text1: First text
|
| 152 |
text2: Second text
|
| 153 |
-
|
| 154 |
Returns:
|
| 155 |
Cosine similarity score (0-1)
|
| 156 |
"""
|
| 157 |
try:
|
| 158 |
embeddings = self.embed_texts([text1, text2])
|
| 159 |
-
|
| 160 |
# Calculate cosine similarity
|
| 161 |
embed1 = np.array(embeddings[0])
|
| 162 |
embed2 = np.array(embeddings[1])
|
| 163 |
-
|
| 164 |
similarity = np.dot(embed1, embed2) / (
|
| 165 |
np.linalg.norm(embed1) * np.linalg.norm(embed2)
|
| 166 |
)
|
| 167 |
-
|
| 168 |
return float(similarity)
|
| 169 |
-
|
| 170 |
except Exception as e:
|
| 171 |
logging.error(f"Failed to calculate similarity: {e}")
|
| 172 |
-
return 0.0
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
import numpy as np
|
| 5 |
+
from sentence_transformers import SentenceTransformer
|
| 6 |
+
|
| 7 |
|
| 8 |
class EmbeddingService:
|
| 9 |
"""HuggingFace sentence-transformers wrapper for generating embeddings"""
|
| 10 |
+
|
| 11 |
_model_cache = {} # Class-level cache for model instances
|
| 12 |
+
|
| 13 |
def __init__(
|
| 14 |
+
self,
|
| 15 |
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
|
| 16 |
device: str = "cpu",
|
| 17 |
+
batch_size: int = 32,
|
| 18 |
):
|
| 19 |
"""
|
| 20 |
Initialize the embedding service
|
| 21 |
+
|
| 22 |
Args:
|
| 23 |
model_name: HuggingFace model name
|
| 24 |
device: Device to run the model on ('cpu' or 'cuda')
|
|
|
|
| 27 |
self.model_name = model_name
|
| 28 |
self.device = device
|
| 29 |
self.batch_size = batch_size
|
| 30 |
+
|
| 31 |
# Load model (with caching)
|
| 32 |
self.model = self._load_model()
|
| 33 |
+
|
| 34 |
+
logging.info(
|
| 35 |
+
f"Initialized EmbeddingService with model "
|
| 36 |
+
f"'{model_name}' on device '{device}'"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
def _load_model(self) -> SentenceTransformer:
|
| 40 |
"""Load the sentence transformer model with caching"""
|
| 41 |
cache_key = f"{self.model_name}_{self.device}"
|
| 42 |
+
|
| 43 |
if cache_key not in self._model_cache:
|
| 44 |
+
logging.info(
|
| 45 |
+
f"Loading model '{self.model_name}' on device '{self.device}'..."
|
| 46 |
+
)
|
| 47 |
model = SentenceTransformer(self.model_name, device=self.device)
|
| 48 |
self._model_cache[cache_key] = model
|
| 49 |
+
logging.info("Model loaded successfully")
|
| 50 |
else:
|
| 51 |
logging.info(f"Using cached model '{self.model_name}'")
|
| 52 |
+
|
| 53 |
return self._model_cache[cache_key]
|
| 54 |
+
|
| 55 |
def embed_text(self, text: str) -> List[float]:
|
| 56 |
"""
|
| 57 |
Generate embedding for a single text
|
| 58 |
+
|
| 59 |
Args:
|
| 60 |
text: Text to embed
|
| 61 |
+
|
| 62 |
Returns:
|
| 63 |
List of float values representing the embedding
|
| 64 |
"""
|
| 65 |
if not text.strip():
|
| 66 |
# Handle empty text - still generate embedding
|
| 67 |
text = " " # Single space to avoid completely empty input
|
| 68 |
+
|
| 69 |
try:
|
| 70 |
# Generate embedding
|
| 71 |
embedding = self.model.encode(text, convert_to_numpy=True)
|
| 72 |
+
|
| 73 |
# Convert to Python list of floats
|
| 74 |
return embedding.tolist()
|
| 75 |
+
|
| 76 |
except Exception as e:
|
| 77 |
logging.error(f"Failed to generate embedding for text: {e}")
|
| 78 |
raise e
|
| 79 |
+
|
| 80 |
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
| 81 |
"""
|
| 82 |
Generate embeddings for multiple texts
|
| 83 |
+
|
| 84 |
Args:
|
| 85 |
texts: List of texts to embed
|
| 86 |
+
|
| 87 |
Returns:
|
| 88 |
List of embeddings (each embedding is a list of floats)
|
| 89 |
"""
|
| 90 |
if not texts:
|
| 91 |
return []
|
| 92 |
+
|
| 93 |
try:
|
| 94 |
# Preprocess empty texts
|
| 95 |
processed_texts = []
|
|
|
|
| 98 |
processed_texts.append(" ") # Single space for empty texts
|
| 99 |
else:
|
| 100 |
processed_texts.append(text)
|
| 101 |
+
|
| 102 |
# Generate embeddings in batches
|
| 103 |
all_embeddings = []
|
| 104 |
+
|
| 105 |
for i in range(0, len(processed_texts), self.batch_size):
|
| 106 |
+
batch_texts = processed_texts[i : i + self.batch_size]
|
| 107 |
+
|
| 108 |
# Generate embeddings for this batch
|
| 109 |
batch_embeddings = self.model.encode(
|
| 110 |
+
batch_texts,
|
| 111 |
convert_to_numpy=True,
|
| 112 |
+
show_progress_bar=False, # Disable progress bar for cleaner output
|
| 113 |
)
|
| 114 |
+
|
| 115 |
# Convert to list of lists
|
| 116 |
for embedding in batch_embeddings:
|
| 117 |
all_embeddings.append(embedding.tolist())
|
| 118 |
+
|
| 119 |
logging.info(f"Generated embeddings for {len(texts)} texts")
|
| 120 |
return all_embeddings
|
| 121 |
+
|
| 122 |
except Exception as e:
|
| 123 |
logging.error(f"Failed to generate embeddings for texts: {e}")
|
| 124 |
raise e
|
| 125 |
+
|
| 126 |
def get_embedding_dimension(self) -> int:
|
| 127 |
"""Get the dimension of embeddings produced by this model"""
|
| 128 |
return self.model.get_sentence_embedding_dimension()
|
| 129 |
+
|
| 130 |
def encode_batch(self, texts: List[str]) -> np.ndarray:
|
| 131 |
"""
|
| 132 |
Generate embeddings and return as numpy array (for efficiency)
|
| 133 |
+
|
| 134 |
Args:
|
| 135 |
texts: List of texts to embed
|
| 136 |
+
|
| 137 |
Returns:
|
| 138 |
NumPy array of embeddings
|
| 139 |
"""
|
| 140 |
if not texts:
|
| 141 |
return np.array([])
|
| 142 |
+
|
| 143 |
# Preprocess empty texts
|
| 144 |
processed_texts = []
|
| 145 |
for text in texts:
|
|
|
|
| 147 |
processed_texts.append(" ")
|
| 148 |
else:
|
| 149 |
processed_texts.append(text)
|
| 150 |
+
|
| 151 |
return self.model.encode(processed_texts, convert_to_numpy=True)
|
| 152 |
+
|
| 153 |
def similarity(self, text1: str, text2: str) -> float:
|
| 154 |
"""
|
| 155 |
Calculate cosine similarity between two texts
|
| 156 |
+
|
| 157 |
Args:
|
| 158 |
text1: First text
|
| 159 |
text2: Second text
|
| 160 |
+
|
| 161 |
Returns:
|
| 162 |
Cosine similarity score (0-1)
|
| 163 |
"""
|
| 164 |
try:
|
| 165 |
embeddings = self.embed_texts([text1, text2])
|
| 166 |
+
|
| 167 |
# Calculate cosine similarity
|
| 168 |
embed1 = np.array(embeddings[0])
|
| 169 |
embed2 = np.array(embeddings[1])
|
| 170 |
+
|
| 171 |
similarity = np.dot(embed1, embed2) / (
|
| 172 |
np.linalg.norm(embed1) * np.linalg.norm(embed2)
|
| 173 |
)
|
| 174 |
+
|
| 175 |
return float(similarity)
|
| 176 |
+
|
| 177 |
except Exception as e:
|
| 178 |
logging.error(f"Failed to calculate similarity: {e}")
|
| 179 |
+
return 0.0
|
src/ingestion/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Empty file to make ingestion a package
|
|
|
|
| 1 |
+
# Empty file to make ingestion a package
|
src/ingestion/document_chunker.py
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
import hashlib
|
| 2 |
import random
|
| 3 |
-
from typing import
|
|
|
|
| 4 |
|
| 5 |
class DocumentChunker:
|
| 6 |
"""Document chunker with overlap and reproducible behavior"""
|
| 7 |
-
|
| 8 |
-
def __init__(
|
|
|
|
|
|
|
| 9 |
"""
|
| 10 |
Initialize the document chunker
|
| 11 |
-
|
| 12 |
Args:
|
| 13 |
chunk_size: Maximum characters per chunk
|
| 14 |
overlap: Number of overlapping characters between chunks
|
|
@@ -17,80 +20,84 @@ class DocumentChunker:
|
|
| 17 |
self.chunk_size = chunk_size
|
| 18 |
self.overlap = overlap
|
| 19 |
self.seed = seed
|
| 20 |
-
|
| 21 |
if seed is not None:
|
| 22 |
random.seed(seed)
|
| 23 |
-
|
| 24 |
def chunk_text(self, text: str) -> List[Dict[str, Any]]:
|
| 25 |
"""
|
| 26 |
Chunk text into overlapping segments
|
| 27 |
-
|
| 28 |
Args:
|
| 29 |
text: Input text to chunk
|
| 30 |
-
|
| 31 |
Returns:
|
| 32 |
List of chunk dictionaries with content and basic metadata
|
| 33 |
"""
|
| 34 |
if not text.strip():
|
| 35 |
return []
|
| 36 |
-
|
| 37 |
chunks = []
|
| 38 |
start = 0
|
| 39 |
chunk_index = 0
|
| 40 |
-
|
| 41 |
while start < len(text):
|
| 42 |
end = start + self.chunk_size
|
| 43 |
chunk_content = text[start:end]
|
| 44 |
-
|
| 45 |
# Create chunk with metadata
|
| 46 |
chunk = {
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
}
|
| 54 |
}
|
| 55 |
-
|
| 56 |
chunks.append(chunk)
|
| 57 |
-
|
| 58 |
# Move start position with overlap consideration
|
| 59 |
start = end - self.overlap
|
| 60 |
chunk_index += 1
|
| 61 |
-
|
| 62 |
# Break if we've processed all text
|
| 63 |
if end >= len(text):
|
| 64 |
break
|
| 65 |
-
|
| 66 |
return chunks
|
| 67 |
-
|
| 68 |
-
def chunk_document(
|
|
|
|
|
|
|
| 69 |
"""
|
| 70 |
Chunk a document while preserving document metadata
|
| 71 |
-
|
| 72 |
Args:
|
| 73 |
text: Document text content
|
| 74 |
doc_metadata: Document metadata to preserve
|
| 75 |
-
|
| 76 |
Returns:
|
| 77 |
List of chunks with combined metadata
|
| 78 |
"""
|
| 79 |
chunks = self.chunk_text(text)
|
| 80 |
-
|
| 81 |
# Enhance each chunk with document metadata
|
| 82 |
for chunk in chunks:
|
| 83 |
-
chunk[
|
| 84 |
# Create unique chunk ID combining document and chunk info
|
| 85 |
-
chunk[
|
| 86 |
-
chunk[
|
| 87 |
-
chunk[
|
| 88 |
-
doc_metadata.get(
|
| 89 |
)
|
| 90 |
-
|
| 91 |
return chunks
|
| 92 |
-
|
| 93 |
-
def _generate_chunk_id(
|
|
|
|
|
|
|
| 94 |
"""Generate a deterministic chunk ID"""
|
| 95 |
id_string = f"{filename}_{chunk_index}_{content[:50]}"
|
| 96 |
-
return hashlib.md5(id_string.encode()).hexdigest()[:12]
|
|
|
|
| 1 |
import hashlib
|
| 2 |
import random
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
|
| 6 |
class DocumentChunker:
|
| 7 |
"""Document chunker with overlap and reproducible behavior"""
|
| 8 |
+
|
| 9 |
+
def __init__(
|
| 10 |
+
self, chunk_size: int = 1000, overlap: int = 200, seed: Optional[int] = None
|
| 11 |
+
):
|
| 12 |
"""
|
| 13 |
Initialize the document chunker
|
| 14 |
+
|
| 15 |
Args:
|
| 16 |
chunk_size: Maximum characters per chunk
|
| 17 |
overlap: Number of overlapping characters between chunks
|
|
|
|
| 20 |
self.chunk_size = chunk_size
|
| 21 |
self.overlap = overlap
|
| 22 |
self.seed = seed
|
| 23 |
+
|
| 24 |
if seed is not None:
|
| 25 |
random.seed(seed)
|
| 26 |
+
|
| 27 |
def chunk_text(self, text: str) -> List[Dict[str, Any]]:
|
| 28 |
"""
|
| 29 |
Chunk text into overlapping segments
|
| 30 |
+
|
| 31 |
Args:
|
| 32 |
text: Input text to chunk
|
| 33 |
+
|
| 34 |
Returns:
|
| 35 |
List of chunk dictionaries with content and basic metadata
|
| 36 |
"""
|
| 37 |
if not text.strip():
|
| 38 |
return []
|
| 39 |
+
|
| 40 |
chunks = []
|
| 41 |
start = 0
|
| 42 |
chunk_index = 0
|
| 43 |
+
|
| 44 |
while start < len(text):
|
| 45 |
end = start + self.chunk_size
|
| 46 |
chunk_content = text[start:end]
|
| 47 |
+
|
| 48 |
# Create chunk with metadata
|
| 49 |
chunk = {
|
| 50 |
+
"content": chunk_content,
|
| 51 |
+
"metadata": {
|
| 52 |
+
"chunk_index": chunk_index,
|
| 53 |
+
"start_pos": start,
|
| 54 |
+
"end_pos": min(end, len(text)),
|
| 55 |
+
"chunk_id": self._generate_chunk_id(chunk_content, chunk_index),
|
| 56 |
+
},
|
| 57 |
}
|
| 58 |
+
|
| 59 |
chunks.append(chunk)
|
| 60 |
+
|
| 61 |
# Move start position with overlap consideration
|
| 62 |
start = end - self.overlap
|
| 63 |
chunk_index += 1
|
| 64 |
+
|
| 65 |
# Break if we've processed all text
|
| 66 |
if end >= len(text):
|
| 67 |
break
|
| 68 |
+
|
| 69 |
return chunks
|
| 70 |
+
|
| 71 |
+
def chunk_document(
|
| 72 |
+
self, text: str, doc_metadata: Dict[str, Any]
|
| 73 |
+
) -> List[Dict[str, Any]]:
|
| 74 |
"""
|
| 75 |
Chunk a document while preserving document metadata
|
| 76 |
+
|
| 77 |
Args:
|
| 78 |
text: Document text content
|
| 79 |
doc_metadata: Document metadata to preserve
|
| 80 |
+
|
| 81 |
Returns:
|
| 82 |
List of chunks with combined metadata
|
| 83 |
"""
|
| 84 |
chunks = self.chunk_text(text)
|
| 85 |
+
|
| 86 |
# Enhance each chunk with document metadata
|
| 87 |
for chunk in chunks:
|
| 88 |
+
chunk["metadata"].update(doc_metadata)
|
| 89 |
# Create unique chunk ID combining document and chunk info
|
| 90 |
+
chunk["metadata"]["chunk_id"] = self._generate_chunk_id(
|
| 91 |
+
chunk["content"],
|
| 92 |
+
chunk["metadata"]["chunk_index"],
|
| 93 |
+
doc_metadata.get("filename", "unknown"),
|
| 94 |
)
|
| 95 |
+
|
| 96 |
return chunks
|
| 97 |
+
|
| 98 |
+
def _generate_chunk_id(
|
| 99 |
+
self, content: str, chunk_index: int, filename: str = ""
|
| 100 |
+
) -> str:
|
| 101 |
"""Generate a deterministic chunk ID"""
|
| 102 |
id_string = f"{filename}_{chunk_index}_{content[:50]}"
|
| 103 |
+
return hashlib.md5(id_string.encode()).hexdigest()[:12]
|
src/ingestion/document_parser.py
CHANGED
|
@@ -1,46 +1,44 @@
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
-
from typing import
|
|
|
|
| 4 |
|
| 5 |
class DocumentParser:
|
| 6 |
"""Parser for different document formats in the policy corpus"""
|
| 7 |
-
|
| 8 |
-
SUPPORTED_FORMATS = {
|
| 9 |
-
|
| 10 |
def parse_document(self, file_path: str) -> Dict[str, Any]:
|
| 11 |
"""
|
| 12 |
Parse a document and return content with metadata
|
| 13 |
-
|
| 14 |
Args:
|
| 15 |
file_path: Path to the document file
|
| 16 |
-
|
| 17 |
Returns:
|
| 18 |
Dict containing 'content' and 'metadata'
|
| 19 |
-
|
| 20 |
Raises:
|
| 21 |
FileNotFoundError: If file doesn't exist
|
| 22 |
ValueError: If file format is unsupported
|
| 23 |
"""
|
| 24 |
path = Path(file_path)
|
| 25 |
-
|
| 26 |
# Check file format first (before existence check)
|
| 27 |
if path.suffix.lower() not in self.SUPPORTED_FORMATS:
|
| 28 |
raise ValueError(f"Unsupported file format: {path.suffix}")
|
| 29 |
-
|
| 30 |
if not path.exists():
|
| 31 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 32 |
-
|
| 33 |
-
with open(file_path,
|
| 34 |
content = f.read()
|
| 35 |
-
|
| 36 |
metadata = {
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
}
|
| 42 |
-
|
| 43 |
-
return {
|
| 44 |
-
'content': content,
|
| 45 |
-
'metadata': metadata
|
| 46 |
-
}
|
|
|
|
| 1 |
import os
|
| 2 |
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
|
| 6 |
class DocumentParser:
|
| 7 |
"""Parser for different document formats in the policy corpus"""
|
| 8 |
+
|
| 9 |
+
SUPPORTED_FORMATS = {".txt", ".md", ".markdown"}
|
| 10 |
+
|
| 11 |
def parse_document(self, file_path: str) -> Dict[str, Any]:
|
| 12 |
"""
|
| 13 |
Parse a document and return content with metadata
|
| 14 |
+
|
| 15 |
Args:
|
| 16 |
file_path: Path to the document file
|
| 17 |
+
|
| 18 |
Returns:
|
| 19 |
Dict containing 'content' and 'metadata'
|
| 20 |
+
|
| 21 |
Raises:
|
| 22 |
FileNotFoundError: If file doesn't exist
|
| 23 |
ValueError: If file format is unsupported
|
| 24 |
"""
|
| 25 |
path = Path(file_path)
|
| 26 |
+
|
| 27 |
# Check file format first (before existence check)
|
| 28 |
if path.suffix.lower() not in self.SUPPORTED_FORMATS:
|
| 29 |
raise ValueError(f"Unsupported file format: {path.suffix}")
|
| 30 |
+
|
| 31 |
if not path.exists():
|
| 32 |
raise FileNotFoundError(f"File not found: {file_path}")
|
| 33 |
+
|
| 34 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 35 |
content = f.read()
|
| 36 |
+
|
| 37 |
metadata = {
|
| 38 |
+
"filename": path.name,
|
| 39 |
+
"file_type": path.suffix.lstrip(".").lower(),
|
| 40 |
+
"file_size": os.path.getsize(file_path),
|
| 41 |
+
"file_path": str(path.absolute()),
|
| 42 |
}
|
| 43 |
+
|
| 44 |
+
return {"content": content, "metadata": metadata}
|
|
|
|
|
|
|
|
|
src/ingestion/ingestion_pipeline.py
CHANGED
|
@@ -1,69 +1,75 @@
|
|
| 1 |
from pathlib import Path
|
| 2 |
-
from typing import
|
| 3 |
-
|
| 4 |
from .document_chunker import DocumentChunker
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class IngestionPipeline:
|
| 7 |
"""Complete ingestion pipeline for processing document corpus"""
|
| 8 |
-
|
| 9 |
def __init__(self, chunk_size: int = 1000, overlap: int = 200, seed: int = 42):
|
| 10 |
"""
|
| 11 |
Initialize the ingestion pipeline
|
| 12 |
-
|
| 13 |
Args:
|
| 14 |
chunk_size: Size of text chunks
|
| 15 |
overlap: Overlap between chunks
|
| 16 |
seed: Random seed for reproducibility
|
| 17 |
"""
|
| 18 |
self.parser = DocumentParser()
|
| 19 |
-
self.chunker = DocumentChunker(
|
|
|
|
|
|
|
| 20 |
self.seed = seed
|
| 21 |
-
|
| 22 |
def process_directory(self, directory_path: str) -> List[Dict[str, Any]]:
|
| 23 |
"""
|
| 24 |
Process all supported documents in a directory
|
| 25 |
-
|
| 26 |
Args:
|
| 27 |
directory_path: Path to directory containing documents
|
| 28 |
-
|
| 29 |
Returns:
|
| 30 |
List of processed chunks with metadata
|
| 31 |
"""
|
| 32 |
directory = Path(directory_path)
|
| 33 |
if not directory.exists():
|
| 34 |
raise FileNotFoundError(f"Directory not found: {directory_path}")
|
| 35 |
-
|
| 36 |
all_chunks = []
|
| 37 |
-
|
| 38 |
# Process each supported file
|
| 39 |
for file_path in directory.iterdir():
|
| 40 |
-
if
|
|
|
|
|
|
|
|
|
|
| 41 |
try:
|
| 42 |
chunks = self.process_file(str(file_path))
|
| 43 |
all_chunks.extend(chunks)
|
| 44 |
except Exception as e:
|
| 45 |
print(f"Warning: Failed to process {file_path}: {e}")
|
| 46 |
continue
|
| 47 |
-
|
| 48 |
return all_chunks
|
| 49 |
-
|
| 50 |
def process_file(self, file_path: str) -> List[Dict[str, Any]]:
|
| 51 |
"""
|
| 52 |
Process a single file through the complete pipeline
|
| 53 |
-
|
| 54 |
Args:
|
| 55 |
file_path: Path to the file to process
|
| 56 |
-
|
| 57 |
Returns:
|
| 58 |
List of chunks from the file
|
| 59 |
"""
|
| 60 |
# Parse document
|
| 61 |
parsed_doc = self.parser.parse_document(file_path)
|
| 62 |
-
|
| 63 |
# Chunk the document
|
| 64 |
chunks = self.chunker.chunk_document(
|
| 65 |
-
parsed_doc[
|
| 66 |
-
parsed_doc['metadata']
|
| 67 |
)
|
| 68 |
-
|
| 69 |
-
return chunks
|
|
|
|
| 1 |
from pathlib import Path
|
| 2 |
+
from typing import Any, Dict, List
|
| 3 |
+
|
| 4 |
from .document_chunker import DocumentChunker
|
| 5 |
+
from .document_parser import DocumentParser
|
| 6 |
+
|
| 7 |
|
| 8 |
class IngestionPipeline:
|
| 9 |
"""Complete ingestion pipeline for processing document corpus"""
|
| 10 |
+
|
| 11 |
def __init__(self, chunk_size: int = 1000, overlap: int = 200, seed: int = 42):
|
| 12 |
"""
|
| 13 |
Initialize the ingestion pipeline
|
| 14 |
+
|
| 15 |
Args:
|
| 16 |
chunk_size: Size of text chunks
|
| 17 |
overlap: Overlap between chunks
|
| 18 |
seed: Random seed for reproducibility
|
| 19 |
"""
|
| 20 |
self.parser = DocumentParser()
|
| 21 |
+
self.chunker = DocumentChunker(
|
| 22 |
+
chunk_size=chunk_size, overlap=overlap, seed=seed
|
| 23 |
+
)
|
| 24 |
self.seed = seed
|
| 25 |
+
|
| 26 |
def process_directory(self, directory_path: str) -> List[Dict[str, Any]]:
|
| 27 |
"""
|
| 28 |
Process all supported documents in a directory
|
| 29 |
+
|
| 30 |
Args:
|
| 31 |
directory_path: Path to directory containing documents
|
| 32 |
+
|
| 33 |
Returns:
|
| 34 |
List of processed chunks with metadata
|
| 35 |
"""
|
| 36 |
directory = Path(directory_path)
|
| 37 |
if not directory.exists():
|
| 38 |
raise FileNotFoundError(f"Directory not found: {directory_path}")
|
| 39 |
+
|
| 40 |
all_chunks = []
|
| 41 |
+
|
| 42 |
# Process each supported file
|
| 43 |
for file_path in directory.iterdir():
|
| 44 |
+
if (
|
| 45 |
+
file_path.is_file()
|
| 46 |
+
and file_path.suffix.lower() in self.parser.SUPPORTED_FORMATS
|
| 47 |
+
):
|
| 48 |
try:
|
| 49 |
chunks = self.process_file(str(file_path))
|
| 50 |
all_chunks.extend(chunks)
|
| 51 |
except Exception as e:
|
| 52 |
print(f"Warning: Failed to process {file_path}: {e}")
|
| 53 |
continue
|
| 54 |
+
|
| 55 |
return all_chunks
|
| 56 |
+
|
| 57 |
def process_file(self, file_path: str) -> List[Dict[str, Any]]:
|
| 58 |
"""
|
| 59 |
Process a single file through the complete pipeline
|
| 60 |
+
|
| 61 |
Args:
|
| 62 |
file_path: Path to the file to process
|
| 63 |
+
|
| 64 |
Returns:
|
| 65 |
List of chunks from the file
|
| 66 |
"""
|
| 67 |
# Parse document
|
| 68 |
parsed_doc = self.parser.parse_document(file_path)
|
| 69 |
+
|
| 70 |
# Chunk the document
|
| 71 |
chunks = self.chunker.chunk_document(
|
| 72 |
+
parsed_doc["content"], parsed_doc["metadata"]
|
|
|
|
| 73 |
)
|
| 74 |
+
|
| 75 |
+
return chunks
|
src/vector_store/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Vector store package for ChromaDB integration
|
|
|
|
| 1 |
+
# Vector store package for ChromaDB integration
|
src/vector_store/vector_db.py
CHANGED
|
@@ -1,92 +1,100 @@
|
|
| 1 |
-
import chromadb
|
| 2 |
-
from typing import List, Dict, Any, Optional
|
| 3 |
-
from pathlib import Path
|
| 4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class VectorDatabase:
|
| 7 |
"""ChromaDB integration for vector storage and similarity search"""
|
| 8 |
-
|
| 9 |
def __init__(self, persist_path: str, collection_name: str):
|
| 10 |
"""
|
| 11 |
Initialize the vector database
|
| 12 |
-
|
| 13 |
Args:
|
| 14 |
persist_path: Path to persist the database
|
| 15 |
collection_name: Name of the collection to use
|
| 16 |
"""
|
| 17 |
self.persist_path = persist_path
|
| 18 |
self.collection_name = collection_name
|
| 19 |
-
|
| 20 |
# Ensure persist directory exists
|
| 21 |
Path(persist_path).mkdir(parents=True, exist_ok=True)
|
| 22 |
-
|
| 23 |
# Initialize ChromaDB client with persistence
|
| 24 |
self.client = chromadb.PersistentClient(path=persist_path)
|
| 25 |
-
|
| 26 |
# Get or create collection
|
| 27 |
try:
|
| 28 |
self.collection = self.client.get_collection(name=collection_name)
|
| 29 |
except ValueError:
|
| 30 |
# Collection doesn't exist, create it
|
| 31 |
self.collection = self.client.create_collection(name=collection_name)
|
| 32 |
-
|
| 33 |
-
logging.info(
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
| 35 |
def get_collection(self):
|
| 36 |
"""Get the ChromaDB collection"""
|
| 37 |
return self.collection
|
| 38 |
-
|
| 39 |
def add_embeddings(
|
| 40 |
-
self,
|
| 41 |
-
embeddings: List[List[float]],
|
| 42 |
-
chunk_ids: List[str],
|
| 43 |
-
documents: List[str],
|
| 44 |
-
metadatas: List[Dict[str, Any]]
|
| 45 |
) -> bool:
|
| 46 |
"""
|
| 47 |
Add embeddings to the vector database
|
| 48 |
-
|
| 49 |
Args:
|
| 50 |
embeddings: List of embedding vectors
|
| 51 |
chunk_ids: List of unique chunk IDs
|
| 52 |
documents: List of document contents
|
| 53 |
metadatas: List of metadata dictionaries
|
| 54 |
-
|
| 55 |
Returns:
|
| 56 |
True if successful, False otherwise
|
| 57 |
"""
|
| 58 |
try:
|
| 59 |
# Validate input lengths match
|
| 60 |
-
if not (
|
|
|
|
|
|
|
| 61 |
raise ValueError("All input lists must have the same length")
|
| 62 |
-
|
| 63 |
# Add to ChromaDB collection
|
| 64 |
self.collection.add(
|
| 65 |
embeddings=embeddings,
|
| 66 |
documents=documents,
|
| 67 |
metadatas=metadatas,
|
| 68 |
-
ids=chunk_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
-
|
| 71 |
-
logging.info(f"Added {len(embeddings)} embeddings to collection '{self.collection_name}'")
|
| 72 |
return True
|
| 73 |
-
|
| 74 |
except Exception as e:
|
| 75 |
logging.error(f"Failed to add embeddings: {e}")
|
| 76 |
raise e
|
| 77 |
-
|
| 78 |
def search(
|
| 79 |
-
self,
|
| 80 |
-
query_embedding: List[float],
|
| 81 |
-
top_k: int = 5
|
| 82 |
) -> List[Dict[str, Any]]:
|
| 83 |
"""
|
| 84 |
Search for similar embeddings
|
| 85 |
-
|
| 86 |
Args:
|
| 87 |
query_embedding: Query vector to search for
|
| 88 |
top_k: Number of results to return
|
| 89 |
-
|
| 90 |
Returns:
|
| 91 |
List of search results with metadata
|
| 92 |
"""
|
|
@@ -94,33 +102,33 @@ class VectorDatabase:
|
|
| 94 |
# Handle empty collection
|
| 95 |
if self.get_count() == 0:
|
| 96 |
return []
|
| 97 |
-
|
| 98 |
# Perform similarity search
|
| 99 |
results = self.collection.query(
|
| 100 |
query_embeddings=[query_embedding],
|
| 101 |
-
n_results=min(top_k, self.get_count())
|
| 102 |
)
|
| 103 |
-
|
| 104 |
# Format results
|
| 105 |
formatted_results = []
|
| 106 |
-
|
| 107 |
-
if results[
|
| 108 |
-
for i in range(len(results[
|
| 109 |
result = {
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
}
|
| 115 |
formatted_results.append(result)
|
| 116 |
-
|
| 117 |
logging.info(f"Search returned {len(formatted_results)} results")
|
| 118 |
return formatted_results
|
| 119 |
-
|
| 120 |
except Exception as e:
|
| 121 |
logging.error(f"Search failed: {e}")
|
| 122 |
return []
|
| 123 |
-
|
| 124 |
def get_count(self) -> int:
|
| 125 |
"""Get the number of embeddings in the collection"""
|
| 126 |
try:
|
|
@@ -128,7 +136,7 @@ class VectorDatabase:
|
|
| 128 |
except Exception as e:
|
| 129 |
logging.error(f"Failed to get count: {e}")
|
| 130 |
return 0
|
| 131 |
-
|
| 132 |
def delete_collection(self) -> bool:
|
| 133 |
"""Delete the collection"""
|
| 134 |
try:
|
|
@@ -138,7 +146,7 @@ class VectorDatabase:
|
|
| 138 |
except Exception as e:
|
| 139 |
logging.error(f"Failed to delete collection: {e}")
|
| 140 |
return False
|
| 141 |
-
|
| 142 |
def reset_collection(self) -> bool:
|
| 143 |
"""Reset the collection (delete and recreate)"""
|
| 144 |
try:
|
|
@@ -148,12 +156,12 @@ class VectorDatabase:
|
|
| 148 |
except ValueError:
|
| 149 |
# Collection doesn't exist, that's fine
|
| 150 |
pass
|
| 151 |
-
|
| 152 |
# Create new collection
|
| 153 |
self.collection = self.client.create_collection(name=self.collection_name)
|
| 154 |
logging.info(f"Reset collection '{self.collection_name}'")
|
| 155 |
return True
|
| 156 |
-
|
| 157 |
except Exception as e:
|
| 158 |
logging.error(f"Failed to reset collection: {e}")
|
| 159 |
-
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Any, Dict, List
|
| 4 |
+
|
| 5 |
+
import chromadb
|
| 6 |
+
|
| 7 |
|
| 8 |
class VectorDatabase:
|
| 9 |
"""ChromaDB integration for vector storage and similarity search"""
|
| 10 |
+
|
| 11 |
def __init__(self, persist_path: str, collection_name: str):
|
| 12 |
"""
|
| 13 |
Initialize the vector database
|
| 14 |
+
|
| 15 |
Args:
|
| 16 |
persist_path: Path to persist the database
|
| 17 |
collection_name: Name of the collection to use
|
| 18 |
"""
|
| 19 |
self.persist_path = persist_path
|
| 20 |
self.collection_name = collection_name
|
| 21 |
+
|
| 22 |
# Ensure persist directory exists
|
| 23 |
Path(persist_path).mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
# Initialize ChromaDB client with persistence
|
| 26 |
self.client = chromadb.PersistentClient(path=persist_path)
|
| 27 |
+
|
| 28 |
# Get or create collection
|
| 29 |
try:
|
| 30 |
self.collection = self.client.get_collection(name=collection_name)
|
| 31 |
except ValueError:
|
| 32 |
# Collection doesn't exist, create it
|
| 33 |
self.collection = self.client.create_collection(name=collection_name)
|
| 34 |
+
|
| 35 |
+
logging.info(
|
| 36 |
+
f"Initialized VectorDatabase with collection "
|
| 37 |
+
f"'{collection_name}' at '{persist_path}'"
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
def get_collection(self):
|
| 41 |
"""Get the ChromaDB collection"""
|
| 42 |
return self.collection
|
| 43 |
+
|
| 44 |
def add_embeddings(
|
| 45 |
+
self,
|
| 46 |
+
embeddings: List[List[float]],
|
| 47 |
+
chunk_ids: List[str],
|
| 48 |
+
documents: List[str],
|
| 49 |
+
metadatas: List[Dict[str, Any]],
|
| 50 |
) -> bool:
|
| 51 |
"""
|
| 52 |
Add embeddings to the vector database
|
| 53 |
+
|
| 54 |
Args:
|
| 55 |
embeddings: List of embedding vectors
|
| 56 |
chunk_ids: List of unique chunk IDs
|
| 57 |
documents: List of document contents
|
| 58 |
metadatas: List of metadata dictionaries
|
| 59 |
+
|
| 60 |
Returns:
|
| 61 |
True if successful, False otherwise
|
| 62 |
"""
|
| 63 |
try:
|
| 64 |
# Validate input lengths match
|
| 65 |
+
if not (
|
| 66 |
+
len(embeddings) == len(chunk_ids) == len(documents) == len(metadatas)
|
| 67 |
+
):
|
| 68 |
raise ValueError("All input lists must have the same length")
|
| 69 |
+
|
| 70 |
# Add to ChromaDB collection
|
| 71 |
self.collection.add(
|
| 72 |
embeddings=embeddings,
|
| 73 |
documents=documents,
|
| 74 |
metadatas=metadatas,
|
| 75 |
+
ids=chunk_ids,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
logging.info(
|
| 79 |
+
f"Added {len(embeddings)} embeddings to collection "
|
| 80 |
+
f"'{self.collection_name}'"
|
| 81 |
)
|
|
|
|
|
|
|
| 82 |
return True
|
| 83 |
+
|
| 84 |
except Exception as e:
|
| 85 |
logging.error(f"Failed to add embeddings: {e}")
|
| 86 |
raise e
|
| 87 |
+
|
| 88 |
def search(
|
| 89 |
+
self, query_embedding: List[float], top_k: int = 5
|
|
|
|
|
|
|
| 90 |
) -> List[Dict[str, Any]]:
|
| 91 |
"""
|
| 92 |
Search for similar embeddings
|
| 93 |
+
|
| 94 |
Args:
|
| 95 |
query_embedding: Query vector to search for
|
| 96 |
top_k: Number of results to return
|
| 97 |
+
|
| 98 |
Returns:
|
| 99 |
List of search results with metadata
|
| 100 |
"""
|
|
|
|
| 102 |
# Handle empty collection
|
| 103 |
if self.get_count() == 0:
|
| 104 |
return []
|
| 105 |
+
|
| 106 |
# Perform similarity search
|
| 107 |
results = self.collection.query(
|
| 108 |
query_embeddings=[query_embedding],
|
| 109 |
+
n_results=min(top_k, self.get_count()),
|
| 110 |
)
|
| 111 |
+
|
| 112 |
# Format results
|
| 113 |
formatted_results = []
|
| 114 |
+
|
| 115 |
+
if results["ids"] and len(results["ids"][0]) > 0:
|
| 116 |
+
for i in range(len(results["ids"][0])):
|
| 117 |
result = {
|
| 118 |
+
"id": results["ids"][0][i],
|
| 119 |
+
"document": results["documents"][0][i],
|
| 120 |
+
"metadata": results["metadatas"][0][i],
|
| 121 |
+
"distance": results["distances"][0][i],
|
| 122 |
}
|
| 123 |
formatted_results.append(result)
|
| 124 |
+
|
| 125 |
logging.info(f"Search returned {len(formatted_results)} results")
|
| 126 |
return formatted_results
|
| 127 |
+
|
| 128 |
except Exception as e:
|
| 129 |
logging.error(f"Search failed: {e}")
|
| 130 |
return []
|
| 131 |
+
|
| 132 |
def get_count(self) -> int:
|
| 133 |
"""Get the number of embeddings in the collection"""
|
| 134 |
try:
|
|
|
|
| 136 |
except Exception as e:
|
| 137 |
logging.error(f"Failed to get count: {e}")
|
| 138 |
return 0
|
| 139 |
+
|
| 140 |
def delete_collection(self) -> bool:
|
| 141 |
"""Delete the collection"""
|
| 142 |
try:
|
|
|
|
| 146 |
except Exception as e:
|
| 147 |
logging.error(f"Failed to delete collection: {e}")
|
| 148 |
return False
|
| 149 |
+
|
| 150 |
def reset_collection(self) -> bool:
|
| 151 |
"""Reset the collection (delete and recreate)"""
|
| 152 |
try:
|
|
|
|
| 156 |
except ValueError:
|
| 157 |
# Collection doesn't exist, that's fine
|
| 158 |
pass
|
| 159 |
+
|
| 160 |
# Create new collection
|
| 161 |
self.collection = self.client.create_collection(name=self.collection_name)
|
| 162 |
logging.info(f"Reset collection '{self.collection_name}'")
|
| 163 |
return True
|
| 164 |
+
|
| 165 |
except Exception as e:
|
| 166 |
logging.error(f"Failed to reset collection: {e}")
|
| 167 |
+
return False
|
tests/test_app.py
CHANGED
|
@@ -33,7 +33,8 @@ def test_index_endpoint(client):
|
|
| 33 |
def test_ingest_endpoint_exists():
|
| 34 |
"""Test that the ingest endpoint is available"""
|
| 35 |
from app import app
|
|
|
|
| 36 |
client = app.test_client()
|
| 37 |
-
response = client.post(
|
| 38 |
# Should not be 404 (not found)
|
| 39 |
assert response.status_code != 404
|
|
|
|
| 33 |
def test_ingest_endpoint_exists():
|
| 34 |
"""Test that the ingest endpoint is available"""
|
| 35 |
from app import app
|
| 36 |
+
|
| 37 |
client = app.test_client()
|
| 38 |
+
response = client.post("/ingest")
|
| 39 |
# Should not be 404 (not found)
|
| 40 |
assert response.status_code != 404
|
tests/test_embedding/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Test package for embedding service components
|
|
|
|
| 1 |
+
# Test package for embedding service components
|
tests/test_embedding/test_embedding_service.py
CHANGED
|
@@ -1,196 +1,207 @@
|
|
| 1 |
-
import pytest
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from src.embedding.embedding_service import EmbeddingService
|
| 4 |
|
|
|
|
| 5 |
def test_embedding_service_initialization():
|
| 6 |
"""Test EmbeddingService initialization"""
|
| 7 |
# Test will fail initially - we'll implement EmbeddingService to make it pass
|
| 8 |
service = EmbeddingService()
|
| 9 |
-
|
| 10 |
assert service is not None
|
| 11 |
assert service.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
| 12 |
assert service.device == "cpu"
|
| 13 |
|
|
|
|
| 14 |
def test_embedding_service_with_custom_config():
|
| 15 |
"""Test EmbeddingService initialization with custom configuration"""
|
| 16 |
service = EmbeddingService(
|
| 17 |
-
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 18 |
-
device="cpu",
|
| 19 |
-
batch_size=16
|
| 20 |
)
|
| 21 |
-
|
| 22 |
assert service.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
| 23 |
assert service.device == "cpu"
|
| 24 |
assert service.batch_size == 16
|
| 25 |
|
|
|
|
| 26 |
def test_single_text_embedding():
|
| 27 |
"""Test embedding generation for a single text"""
|
| 28 |
service = EmbeddingService()
|
| 29 |
-
|
| 30 |
text = "This is a test document about company policies."
|
| 31 |
embedding = service.embed_text(text)
|
| 32 |
-
|
| 33 |
# Should return a list of floats (embedding vector)
|
| 34 |
assert isinstance(embedding, list)
|
| 35 |
assert len(embedding) == 384 # all-MiniLM-L6-v2 dimension
|
| 36 |
assert all(isinstance(x, (float, np.float32, np.float64)) for x in embedding)
|
| 37 |
|
|
|
|
| 38 |
def test_batch_text_embedding():
|
| 39 |
"""Test embedding generation for multiple texts"""
|
| 40 |
service = EmbeddingService()
|
| 41 |
-
|
| 42 |
texts = [
|
| 43 |
"This is the first document about remote work policy.",
|
| 44 |
"This is the second document about employee benefits.",
|
| 45 |
-
"This is the third document about code of conduct."
|
| 46 |
]
|
| 47 |
-
|
| 48 |
embeddings = service.embed_texts(texts)
|
| 49 |
-
|
| 50 |
# Should return list of embeddings
|
| 51 |
assert isinstance(embeddings, list)
|
| 52 |
assert len(embeddings) == 3
|
| 53 |
-
|
| 54 |
# Each embedding should be correct dimension
|
| 55 |
for embedding in embeddings:
|
| 56 |
assert isinstance(embedding, list)
|
| 57 |
assert len(embedding) == 384
|
| 58 |
assert all(isinstance(x, (float, np.float32, np.float64)) for x in embedding)
|
| 59 |
|
|
|
|
| 60 |
def test_embedding_consistency():
|
| 61 |
"""Test that same text produces same embedding"""
|
| 62 |
service = EmbeddingService()
|
| 63 |
-
|
| 64 |
text = "Consistent embedding test text."
|
| 65 |
-
|
| 66 |
embedding1 = service.embed_text(text)
|
| 67 |
embedding2 = service.embed_text(text)
|
| 68 |
-
|
| 69 |
# Should be identical (deterministic)
|
| 70 |
assert embedding1 == embedding2
|
| 71 |
|
|
|
|
| 72 |
def test_different_texts_different_embeddings():
|
| 73 |
"""Test that different texts produce different embeddings"""
|
| 74 |
service = EmbeddingService()
|
| 75 |
-
|
| 76 |
text1 = "This is about remote work policy."
|
| 77 |
text2 = "This is about employee benefits and healthcare."
|
| 78 |
-
|
| 79 |
embedding1 = service.embed_text(text1)
|
| 80 |
embedding2 = service.embed_text(text2)
|
| 81 |
-
|
| 82 |
# Should be different
|
| 83 |
assert embedding1 != embedding2
|
| 84 |
-
|
| 85 |
# But should have same dimension
|
| 86 |
assert len(embedding1) == len(embedding2) == 384
|
| 87 |
|
|
|
|
| 88 |
def test_empty_text_handling():
|
| 89 |
"""Test handling of empty or whitespace-only text"""
|
| 90 |
service = EmbeddingService()
|
| 91 |
-
|
| 92 |
# Empty string
|
| 93 |
embedding_empty = service.embed_text("")
|
| 94 |
assert isinstance(embedding_empty, list)
|
| 95 |
assert len(embedding_empty) == 384
|
| 96 |
-
|
| 97 |
# Whitespace only
|
| 98 |
embedding_whitespace = service.embed_text(" \n\t ")
|
| 99 |
assert isinstance(embedding_whitespace, list)
|
| 100 |
assert len(embedding_whitespace) == 384
|
| 101 |
|
|
|
|
| 102 |
def test_very_long_text_handling():
|
| 103 |
"""Test handling of very long texts"""
|
| 104 |
service = EmbeddingService()
|
| 105 |
-
|
| 106 |
# Create a very long text (should test tokenization limits)
|
| 107 |
long_text = "This is a very long document. " * 1000 # ~30,000 characters
|
| 108 |
-
|
| 109 |
embedding = service.embed_text(long_text)
|
| 110 |
assert isinstance(embedding, list)
|
| 111 |
assert len(embedding) == 384
|
| 112 |
|
|
|
|
| 113 |
def test_batch_size_handling():
|
| 114 |
"""Test that batch processing works correctly"""
|
| 115 |
service = EmbeddingService(batch_size=2) # Small batch for testing
|
| 116 |
-
|
| 117 |
texts = [
|
| 118 |
"Text one about policy",
|
| 119 |
-
"Text two about procedures",
|
| 120 |
"Text three about guidelines",
|
| 121 |
"Text four about regulations",
|
| 122 |
-
"Text five about rules"
|
| 123 |
]
|
| 124 |
-
|
| 125 |
embeddings = service.embed_texts(texts)
|
| 126 |
-
|
| 127 |
# Should process all texts despite small batch size
|
| 128 |
assert len(embeddings) == 5
|
| 129 |
-
|
| 130 |
# All embeddings should be valid
|
| 131 |
for embedding in embeddings:
|
| 132 |
assert len(embedding) == 384
|
| 133 |
|
|
|
|
| 134 |
def test_special_characters_handling():
|
| 135 |
"""Test handling of special characters and unicode"""
|
| 136 |
service = EmbeddingService()
|
| 137 |
-
|
| 138 |
texts_with_special_chars = [
|
| 139 |
"Policy with émojis 😀 and úñicode",
|
| 140 |
"Text with numbers: 123,456.78 and symbols @#$%",
|
| 141 |
"Markdown: # Header\n## Subheader\n- List item",
|
| 142 |
-
"Mixed: Policy-2024 (v1.2) — updated 12/01/2025"
|
| 143 |
]
|
| 144 |
-
|
| 145 |
embeddings = service.embed_texts(texts_with_special_chars)
|
| 146 |
-
|
| 147 |
assert len(embeddings) == 4
|
| 148 |
for embedding in embeddings:
|
| 149 |
assert len(embedding) == 384
|
| 150 |
|
|
|
|
| 151 |
def test_similarity_makes_sense():
|
| 152 |
"""Test that semantically similar texts have similar embeddings"""
|
| 153 |
service = EmbeddingService()
|
| 154 |
-
|
| 155 |
# Similar texts
|
| 156 |
text1 = "Employee remote work policy guidelines"
|
| 157 |
text2 = "Guidelines for working from home policies"
|
| 158 |
-
|
| 159 |
# Different text
|
| 160 |
text3 = "Financial expense reimbursement procedures"
|
| 161 |
-
|
| 162 |
embed1 = service.embed_text(text1)
|
| 163 |
embed2 = service.embed_text(text2)
|
| 164 |
embed3 = service.embed_text(text3)
|
| 165 |
-
|
| 166 |
# Calculate simple cosine similarity (for validation)
|
| 167 |
def cosine_similarity(a, b):
|
| 168 |
import numpy as np
|
|
|
|
| 169 |
a_np = np.array(a)
|
| 170 |
b_np = np.array(b)
|
| 171 |
return np.dot(a_np, b_np) / (np.linalg.norm(a_np) * np.linalg.norm(b_np))
|
| 172 |
-
|
| 173 |
sim_1_2 = cosine_similarity(embed1, embed2) # Similar texts
|
| 174 |
sim_1_3 = cosine_similarity(embed1, embed3) # Different texts
|
| 175 |
-
|
| 176 |
# Similar texts should have higher similarity than different texts
|
| 177 |
assert sim_1_2 > sim_1_3
|
| 178 |
assert sim_1_2 > 0.5 # Should be reasonably similar
|
| 179 |
|
|
|
|
| 180 |
def test_model_loading_performance():
|
| 181 |
"""Test that model loading doesn't happen repeatedly"""
|
| 182 |
# This test ensures model is cached after first load
|
| 183 |
import time
|
| 184 |
-
|
| 185 |
start_time = time.time()
|
| 186 |
-
|
| 187 |
first_load_time = time.time() - start_time
|
| 188 |
-
|
| 189 |
start_time = time.time()
|
| 190 |
-
|
| 191 |
second_load_time = time.time() - start_time
|
| 192 |
-
|
| 193 |
# Second initialization should be faster (model already cached)
|
| 194 |
# Note: This might not always be true depending on implementation
|
| 195 |
# but it's good to test the general behavior
|
| 196 |
-
assert second_load_time <= first_load_time * 2 # Allow some variance
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
+
|
| 3 |
from src.embedding.embedding_service import EmbeddingService
|
| 4 |
|
| 5 |
+
|
| 6 |
def test_embedding_service_initialization():
|
| 7 |
"""Test EmbeddingService initialization"""
|
| 8 |
# Test will fail initially - we'll implement EmbeddingService to make it pass
|
| 9 |
service = EmbeddingService()
|
| 10 |
+
|
| 11 |
assert service is not None
|
| 12 |
assert service.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
| 13 |
assert service.device == "cpu"
|
| 14 |
|
| 15 |
+
|
| 16 |
def test_embedding_service_with_custom_config():
|
| 17 |
"""Test EmbeddingService initialization with custom configuration"""
|
| 18 |
service = EmbeddingService(
|
| 19 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2", device="cpu", batch_size=16
|
|
|
|
|
|
|
| 20 |
)
|
| 21 |
+
|
| 22 |
assert service.model_name == "sentence-transformers/all-MiniLM-L6-v2"
|
| 23 |
assert service.device == "cpu"
|
| 24 |
assert service.batch_size == 16
|
| 25 |
|
| 26 |
+
|
| 27 |
def test_single_text_embedding():
|
| 28 |
"""Test embedding generation for a single text"""
|
| 29 |
service = EmbeddingService()
|
| 30 |
+
|
| 31 |
text = "This is a test document about company policies."
|
| 32 |
embedding = service.embed_text(text)
|
| 33 |
+
|
| 34 |
# Should return a list of floats (embedding vector)
|
| 35 |
assert isinstance(embedding, list)
|
| 36 |
assert len(embedding) == 384 # all-MiniLM-L6-v2 dimension
|
| 37 |
assert all(isinstance(x, (float, np.float32, np.float64)) for x in embedding)
|
| 38 |
|
| 39 |
+
|
| 40 |
def test_batch_text_embedding():
|
| 41 |
"""Test embedding generation for multiple texts"""
|
| 42 |
service = EmbeddingService()
|
| 43 |
+
|
| 44 |
texts = [
|
| 45 |
"This is the first document about remote work policy.",
|
| 46 |
"This is the second document about employee benefits.",
|
| 47 |
+
"This is the third document about code of conduct.",
|
| 48 |
]
|
| 49 |
+
|
| 50 |
embeddings = service.embed_texts(texts)
|
| 51 |
+
|
| 52 |
# Should return list of embeddings
|
| 53 |
assert isinstance(embeddings, list)
|
| 54 |
assert len(embeddings) == 3
|
| 55 |
+
|
| 56 |
# Each embedding should be correct dimension
|
| 57 |
for embedding in embeddings:
|
| 58 |
assert isinstance(embedding, list)
|
| 59 |
assert len(embedding) == 384
|
| 60 |
assert all(isinstance(x, (float, np.float32, np.float64)) for x in embedding)
|
| 61 |
|
| 62 |
+
|
| 63 |
def test_embedding_consistency():
|
| 64 |
"""Test that same text produces same embedding"""
|
| 65 |
service = EmbeddingService()
|
| 66 |
+
|
| 67 |
text = "Consistent embedding test text."
|
| 68 |
+
|
| 69 |
embedding1 = service.embed_text(text)
|
| 70 |
embedding2 = service.embed_text(text)
|
| 71 |
+
|
| 72 |
# Should be identical (deterministic)
|
| 73 |
assert embedding1 == embedding2
|
| 74 |
|
| 75 |
+
|
| 76 |
def test_different_texts_different_embeddings():
|
| 77 |
"""Test that different texts produce different embeddings"""
|
| 78 |
service = EmbeddingService()
|
| 79 |
+
|
| 80 |
text1 = "This is about remote work policy."
|
| 81 |
text2 = "This is about employee benefits and healthcare."
|
| 82 |
+
|
| 83 |
embedding1 = service.embed_text(text1)
|
| 84 |
embedding2 = service.embed_text(text2)
|
| 85 |
+
|
| 86 |
# Should be different
|
| 87 |
assert embedding1 != embedding2
|
| 88 |
+
|
| 89 |
# But should have same dimension
|
| 90 |
assert len(embedding1) == len(embedding2) == 384
|
| 91 |
|
| 92 |
+
|
| 93 |
def test_empty_text_handling():
|
| 94 |
"""Test handling of empty or whitespace-only text"""
|
| 95 |
service = EmbeddingService()
|
| 96 |
+
|
| 97 |
# Empty string
|
| 98 |
embedding_empty = service.embed_text("")
|
| 99 |
assert isinstance(embedding_empty, list)
|
| 100 |
assert len(embedding_empty) == 384
|
| 101 |
+
|
| 102 |
# Whitespace only
|
| 103 |
embedding_whitespace = service.embed_text(" \n\t ")
|
| 104 |
assert isinstance(embedding_whitespace, list)
|
| 105 |
assert len(embedding_whitespace) == 384
|
| 106 |
|
| 107 |
+
|
| 108 |
def test_very_long_text_handling():
|
| 109 |
"""Test handling of very long texts"""
|
| 110 |
service = EmbeddingService()
|
| 111 |
+
|
| 112 |
# Create a very long text (should test tokenization limits)
|
| 113 |
long_text = "This is a very long document. " * 1000 # ~30,000 characters
|
| 114 |
+
|
| 115 |
embedding = service.embed_text(long_text)
|
| 116 |
assert isinstance(embedding, list)
|
| 117 |
assert len(embedding) == 384
|
| 118 |
|
| 119 |
+
|
| 120 |
def test_batch_size_handling():
|
| 121 |
"""Test that batch processing works correctly"""
|
| 122 |
service = EmbeddingService(batch_size=2) # Small batch for testing
|
| 123 |
+
|
| 124 |
texts = [
|
| 125 |
"Text one about policy",
|
| 126 |
+
"Text two about procedures",
|
| 127 |
"Text three about guidelines",
|
| 128 |
"Text four about regulations",
|
| 129 |
+
"Text five about rules",
|
| 130 |
]
|
| 131 |
+
|
| 132 |
embeddings = service.embed_texts(texts)
|
| 133 |
+
|
| 134 |
# Should process all texts despite small batch size
|
| 135 |
assert len(embeddings) == 5
|
| 136 |
+
|
| 137 |
# All embeddings should be valid
|
| 138 |
for embedding in embeddings:
|
| 139 |
assert len(embedding) == 384
|
| 140 |
|
| 141 |
+
|
| 142 |
def test_special_characters_handling():
|
| 143 |
"""Test handling of special characters and unicode"""
|
| 144 |
service = EmbeddingService()
|
| 145 |
+
|
| 146 |
texts_with_special_chars = [
|
| 147 |
"Policy with émojis 😀 and úñicode",
|
| 148 |
"Text with numbers: 123,456.78 and symbols @#$%",
|
| 149 |
"Markdown: # Header\n## Subheader\n- List item",
|
| 150 |
+
"Mixed: Policy-2024 (v1.2) — updated 12/01/2025",
|
| 151 |
]
|
| 152 |
+
|
| 153 |
embeddings = service.embed_texts(texts_with_special_chars)
|
| 154 |
+
|
| 155 |
assert len(embeddings) == 4
|
| 156 |
for embedding in embeddings:
|
| 157 |
assert len(embedding) == 384
|
| 158 |
|
| 159 |
+
|
| 160 |
def test_similarity_makes_sense():
|
| 161 |
"""Test that semantically similar texts have similar embeddings"""
|
| 162 |
service = EmbeddingService()
|
| 163 |
+
|
| 164 |
# Similar texts
|
| 165 |
text1 = "Employee remote work policy guidelines"
|
| 166 |
text2 = "Guidelines for working from home policies"
|
| 167 |
+
|
| 168 |
# Different text
|
| 169 |
text3 = "Financial expense reimbursement procedures"
|
| 170 |
+
|
| 171 |
embed1 = service.embed_text(text1)
|
| 172 |
embed2 = service.embed_text(text2)
|
| 173 |
embed3 = service.embed_text(text3)
|
| 174 |
+
|
| 175 |
# Calculate simple cosine similarity (for validation)
|
| 176 |
def cosine_similarity(a, b):
|
| 177 |
import numpy as np
|
| 178 |
+
|
| 179 |
a_np = np.array(a)
|
| 180 |
b_np = np.array(b)
|
| 181 |
return np.dot(a_np, b_np) / (np.linalg.norm(a_np) * np.linalg.norm(b_np))
|
| 182 |
+
|
| 183 |
sim_1_2 = cosine_similarity(embed1, embed2) # Similar texts
|
| 184 |
sim_1_3 = cosine_similarity(embed1, embed3) # Different texts
|
| 185 |
+
|
| 186 |
# Similar texts should have higher similarity than different texts
|
| 187 |
assert sim_1_2 > sim_1_3
|
| 188 |
assert sim_1_2 > 0.5 # Should be reasonably similar
|
| 189 |
|
| 190 |
+
|
| 191 |
def test_model_loading_performance():
|
| 192 |
"""Test that model loading doesn't happen repeatedly"""
|
| 193 |
# This test ensures model is cached after first load
|
| 194 |
import time
|
| 195 |
+
|
| 196 |
start_time = time.time()
|
| 197 |
+
EmbeddingService() # First service
|
| 198 |
first_load_time = time.time() - start_time
|
| 199 |
+
|
| 200 |
start_time = time.time()
|
| 201 |
+
EmbeddingService() # Second service
|
| 202 |
second_load_time = time.time() - start_time
|
| 203 |
+
|
| 204 |
# Second initialization should be faster (model already cached)
|
| 205 |
# Note: This might not always be true depending on implementation
|
| 206 |
# but it's good to test the general behavior
|
| 207 |
+
assert second_load_time <= first_load_time * 2 # Allow some variance
|
tests/test_ingestion/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Test package for ingestion components
|
|
|
|
| 1 |
+
# Test package for ingestion components
|
tests/test_ingestion/test_document_chunker.py
CHANGED
|
@@ -1,73 +1,74 @@
|
|
| 1 |
-
import pytest
|
| 2 |
from src.ingestion.document_chunker import DocumentChunker
|
| 3 |
|
|
|
|
| 4 |
def test_chunk_by_characters():
|
| 5 |
"""Test basic character-based chunking"""
|
| 6 |
chunker = DocumentChunker(chunk_size=50, overlap=10)
|
| 7 |
-
|
| 8 |
text = "This is a test document. " * 10 # 250 characters
|
| 9 |
chunks = chunker.chunk_text(text)
|
| 10 |
-
|
| 11 |
assert len(chunks) > 1 # Should create multiple chunks
|
| 12 |
-
assert all(len(chunk[
|
| 13 |
-
|
| 14 |
# Test overlap
|
| 15 |
if len(chunks) > 1:
|
| 16 |
# Check that there's overlap between consecutive chunks
|
| 17 |
-
assert chunks[0][
|
|
|
|
| 18 |
|
| 19 |
def test_chunk_with_metadata():
|
| 20 |
"""Test that chunks preserve document metadata"""
|
| 21 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 22 |
-
|
| 23 |
-
doc_metadata = {
|
| 24 |
-
|
| 25 |
-
'file_type': 'txt',
|
| 26 |
-
'source_id': 'doc_001'
|
| 27 |
-
}
|
| 28 |
-
|
| 29 |
text = "Content that will be chunked. " * 20
|
| 30 |
chunks = chunker.chunk_document(text, doc_metadata)
|
| 31 |
-
|
| 32 |
for chunk in chunks:
|
| 33 |
-
assert chunk[
|
| 34 |
-
assert chunk[
|
| 35 |
-
assert
|
| 36 |
-
assert
|
|
|
|
| 37 |
|
| 38 |
def test_reproducible_chunking():
|
| 39 |
"""Test that chunking is deterministic with fixed seed"""
|
| 40 |
chunker1 = DocumentChunker(chunk_size=100, overlap=20, seed=42)
|
| 41 |
chunker2 = DocumentChunker(chunk_size=100, overlap=20, seed=42)
|
| 42 |
-
|
| 43 |
text = "This text will be chunked reproducibly. " * 30
|
| 44 |
-
|
| 45 |
chunks1 = chunker1.chunk_text(text)
|
| 46 |
chunks2 = chunker2.chunk_text(text)
|
| 47 |
-
|
| 48 |
assert len(chunks1) == len(chunks2)
|
| 49 |
for c1, c2 in zip(chunks1, chunks2):
|
| 50 |
-
assert c1[
|
|
|
|
| 51 |
|
| 52 |
def test_empty_text_chunking():
|
| 53 |
"""Test handling of empty or very short text"""
|
| 54 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 55 |
-
|
| 56 |
# Empty text
|
| 57 |
chunks = chunker.chunk_text("")
|
| 58 |
assert len(chunks) == 0
|
| 59 |
-
|
| 60 |
# Very short text
|
| 61 |
chunks = chunker.chunk_text("Short")
|
| 62 |
assert len(chunks) == 1
|
| 63 |
-
assert chunks[0][
|
|
|
|
| 64 |
|
| 65 |
def test_chunk_real_policy_content():
|
| 66 |
"""Test chunking actual policy document content"""
|
| 67 |
chunker = DocumentChunker(chunk_size=500, overlap=100, seed=42)
|
| 68 |
-
|
| 69 |
# Use content that resembles our policy documents
|
| 70 |
-
policy_content =
|
|
|
|
| 71 |
|
| 72 |
**Effective Date:** 2025-01-01
|
| 73 |
**Revision:** 1.1
|
|
@@ -83,54 +84,61 @@ Welcome to Innovate Inc.! We are thrilled to have you as part of our team. Our s
|
|
| 83 |
|
| 84 |
### 2.1. Code of Conduct
|
| 85 |
|
| 86 |
-
All employees must adhere to our code of conduct which emphasizes integrity, respect, and professionalism in all interactions."""
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
doc_metadata = {
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
}
|
| 93 |
-
|
| 94 |
chunks = chunker.chunk_document(policy_content, doc_metadata)
|
| 95 |
-
|
| 96 |
# Verify chunking worked
|
| 97 |
assert len(chunks) > 1
|
| 98 |
-
|
| 99 |
# Verify all chunks have proper metadata
|
| 100 |
for i, chunk in enumerate(chunks):
|
| 101 |
-
assert chunk[
|
| 102 |
-
assert chunk[
|
| 103 |
-
assert chunk[
|
| 104 |
-
assert
|
| 105 |
-
assert len(chunk[
|
| 106 |
-
|
| 107 |
# Verify overlap exists between consecutive chunks
|
| 108 |
if len(chunks) > 1:
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def test_chunk_metadata_inheritance():
|
| 112 |
"""Test that document metadata is properly inherited by chunks"""
|
| 113 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 114 |
-
|
| 115 |
doc_metadata = {
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
}
|
| 121 |
-
|
| 122 |
text = "Policy content goes here. " * 20
|
| 123 |
chunks = chunker.chunk_document(text, doc_metadata)
|
| 124 |
-
|
| 125 |
for chunk in chunks:
|
| 126 |
# Original metadata should be preserved
|
| 127 |
-
assert chunk[
|
| 128 |
-
assert chunk[
|
| 129 |
-
assert chunk[
|
| 130 |
-
|
| 131 |
-
|
|
|
|
| 132 |
# New chunk-specific metadata should be added
|
| 133 |
-
assert
|
| 134 |
-
assert
|
| 135 |
-
assert
|
| 136 |
-
assert
|
|
|
|
|
|
|
| 1 |
from src.ingestion.document_chunker import DocumentChunker
|
| 2 |
|
| 3 |
+
|
| 4 |
def test_chunk_by_characters():
|
| 5 |
"""Test basic character-based chunking"""
|
| 6 |
chunker = DocumentChunker(chunk_size=50, overlap=10)
|
| 7 |
+
|
| 8 |
text = "This is a test document. " * 10 # 250 characters
|
| 9 |
chunks = chunker.chunk_text(text)
|
| 10 |
+
|
| 11 |
assert len(chunks) > 1 # Should create multiple chunks
|
| 12 |
+
assert all(len(chunk["content"]) <= 50 for chunk in chunks)
|
| 13 |
+
|
| 14 |
# Test overlap
|
| 15 |
if len(chunks) > 1:
|
| 16 |
# Check that there's overlap between consecutive chunks
|
| 17 |
+
assert chunks[0]["content"][-10:] in chunks[1]["content"][:20]
|
| 18 |
+
|
| 19 |
|
| 20 |
def test_chunk_with_metadata():
|
| 21 |
"""Test that chunks preserve document metadata"""
|
| 22 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 23 |
+
|
| 24 |
+
doc_metadata = {"filename": "test.txt", "file_type": "txt", "source_id": "doc_001"}
|
| 25 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
text = "Content that will be chunked. " * 20
|
| 27 |
chunks = chunker.chunk_document(text, doc_metadata)
|
| 28 |
+
|
| 29 |
for chunk in chunks:
|
| 30 |
+
assert chunk["metadata"]["filename"] == "test.txt"
|
| 31 |
+
assert chunk["metadata"]["file_type"] == "txt"
|
| 32 |
+
assert "chunk_id" in chunk["metadata"]
|
| 33 |
+
assert "chunk_index" in chunk["metadata"]
|
| 34 |
+
|
| 35 |
|
| 36 |
def test_reproducible_chunking():
|
| 37 |
"""Test that chunking is deterministic with fixed seed"""
|
| 38 |
chunker1 = DocumentChunker(chunk_size=100, overlap=20, seed=42)
|
| 39 |
chunker2 = DocumentChunker(chunk_size=100, overlap=20, seed=42)
|
| 40 |
+
|
| 41 |
text = "This text will be chunked reproducibly. " * 30
|
| 42 |
+
|
| 43 |
chunks1 = chunker1.chunk_text(text)
|
| 44 |
chunks2 = chunker2.chunk_text(text)
|
| 45 |
+
|
| 46 |
assert len(chunks1) == len(chunks2)
|
| 47 |
for c1, c2 in zip(chunks1, chunks2):
|
| 48 |
+
assert c1["content"] == c2["content"]
|
| 49 |
+
|
| 50 |
|
| 51 |
def test_empty_text_chunking():
|
| 52 |
"""Test handling of empty or very short text"""
|
| 53 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 54 |
+
|
| 55 |
# Empty text
|
| 56 |
chunks = chunker.chunk_text("")
|
| 57 |
assert len(chunks) == 0
|
| 58 |
+
|
| 59 |
# Very short text
|
| 60 |
chunks = chunker.chunk_text("Short")
|
| 61 |
assert len(chunks) == 1
|
| 62 |
+
assert chunks[0]["content"] == "Short"
|
| 63 |
+
|
| 64 |
|
| 65 |
def test_chunk_real_policy_content():
|
| 66 |
"""Test chunking actual policy document content"""
|
| 67 |
chunker = DocumentChunker(chunk_size=500, overlap=100, seed=42)
|
| 68 |
+
|
| 69 |
# Use content that resembles our policy documents
|
| 70 |
+
policy_content = (
|
| 71 |
+
"""# HR-POL-001: Employee Handbook
|
| 72 |
|
| 73 |
**Effective Date:** 2025-01-01
|
| 74 |
**Revision:** 1.1
|
|
|
|
| 84 |
|
| 85 |
### 2.1. Code of Conduct
|
| 86 |
|
| 87 |
+
All employees must adhere to our code of conduct which emphasizes integrity, respect, and professionalism in all interactions."""
|
| 88 |
+
* 3
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
doc_metadata = {
|
| 92 |
+
"filename": "employee_handbook.md",
|
| 93 |
+
"file_type": "md",
|
| 94 |
+
"file_path": "/path/to/employee_handbook.md",
|
| 95 |
}
|
| 96 |
+
|
| 97 |
chunks = chunker.chunk_document(policy_content, doc_metadata)
|
| 98 |
+
|
| 99 |
# Verify chunking worked
|
| 100 |
assert len(chunks) > 1
|
| 101 |
+
|
| 102 |
# Verify all chunks have proper metadata
|
| 103 |
for i, chunk in enumerate(chunks):
|
| 104 |
+
assert chunk["metadata"]["filename"] == "employee_handbook.md"
|
| 105 |
+
assert chunk["metadata"]["file_type"] == "md"
|
| 106 |
+
assert chunk["metadata"]["chunk_index"] == i
|
| 107 |
+
assert "chunk_id" in chunk["metadata"]
|
| 108 |
+
assert len(chunk["content"]) <= 500
|
| 109 |
+
|
| 110 |
# Verify overlap exists between consecutive chunks
|
| 111 |
if len(chunks) > 1:
|
| 112 |
+
overlap_check = (
|
| 113 |
+
chunks[0]["content"][-100:] in chunks[1]["content"][:200]
|
| 114 |
+
)
|
| 115 |
+
assert overlap_check
|
| 116 |
+
|
| 117 |
|
| 118 |
def test_chunk_metadata_inheritance():
|
| 119 |
"""Test that document metadata is properly inherited by chunks"""
|
| 120 |
chunker = DocumentChunker(chunk_size=100, overlap=20)
|
| 121 |
+
|
| 122 |
doc_metadata = {
|
| 123 |
+
"filename": "test_policy.md",
|
| 124 |
+
"file_type": "md",
|
| 125 |
+
"file_size": 1500,
|
| 126 |
+
"file_path": "/absolute/path/to/test_policy.md",
|
| 127 |
}
|
| 128 |
+
|
| 129 |
text = "Policy content goes here. " * 20
|
| 130 |
chunks = chunker.chunk_document(text, doc_metadata)
|
| 131 |
+
|
| 132 |
for chunk in chunks:
|
| 133 |
# Original metadata should be preserved
|
| 134 |
+
assert chunk["metadata"]["filename"] == "test_policy.md"
|
| 135 |
+
assert chunk["metadata"]["file_type"] == "md"
|
| 136 |
+
assert chunk["metadata"]["file_size"] == 1500
|
| 137 |
+
expected_path = "/absolute/path/to/test_policy.md"
|
| 138 |
+
assert chunk["metadata"]["file_path"] == expected_path
|
| 139 |
+
|
| 140 |
# New chunk-specific metadata should be added
|
| 141 |
+
assert "chunk_index" in chunk["metadata"]
|
| 142 |
+
assert "chunk_id" in chunk["metadata"]
|
| 143 |
+
assert "start_pos" in chunk["metadata"]
|
| 144 |
+
assert "end_pos" in chunk["metadata"]
|
tests/test_ingestion/test_document_parser.py
CHANGED
|
@@ -1,85 +1,94 @@
|
|
| 1 |
-
import pytest
|
| 2 |
-
import tempfile
|
| 3 |
import os
|
|
|
|
| 4 |
from pathlib import Path
|
| 5 |
|
|
|
|
|
|
|
|
|
|
| 6 |
def test_parse_txt_file():
|
| 7 |
"""Test parsing a simple text file"""
|
| 8 |
# Test will fail initially - we'll implement parser to make it pass
|
| 9 |
from src.ingestion.document_parser import DocumentParser
|
| 10 |
-
|
| 11 |
parser = DocumentParser()
|
| 12 |
-
with tempfile.NamedTemporaryFile(mode=
|
| 13 |
f.write("This is a test policy document.\nIt has multiple lines.")
|
| 14 |
temp_path = f.name
|
| 15 |
-
|
| 16 |
try:
|
| 17 |
result = parser.parse_document(temp_path)
|
| 18 |
-
assert
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
finally:
|
| 22 |
os.unlink(temp_path)
|
| 23 |
|
|
|
|
| 24 |
def test_parse_markdown_file():
|
| 25 |
"""Test parsing a markdown file"""
|
| 26 |
from src.ingestion.document_parser import DocumentParser
|
| 27 |
-
|
| 28 |
parser = DocumentParser()
|
| 29 |
markdown_content = """# Policy Title
|
| 30 |
-
|
| 31 |
## Section 1
|
| 32 |
This is section content.
|
| 33 |
|
| 34 |
### Subsection
|
| 35 |
More content here."""
|
| 36 |
-
|
| 37 |
-
with tempfile.NamedTemporaryFile(mode=
|
| 38 |
f.write(markdown_content)
|
| 39 |
temp_path = f.name
|
| 40 |
-
|
| 41 |
try:
|
| 42 |
result = parser.parse_document(temp_path)
|
| 43 |
-
assert "Policy Title" in result[
|
| 44 |
-
assert "Section 1" in result[
|
| 45 |
-
assert result[
|
| 46 |
finally:
|
| 47 |
os.unlink(temp_path)
|
| 48 |
|
|
|
|
| 49 |
def test_parse_unsupported_format():
|
| 50 |
"""Test handling of unsupported file formats"""
|
| 51 |
from src.ingestion.document_parser import DocumentParser
|
| 52 |
-
|
| 53 |
parser = DocumentParser()
|
| 54 |
with pytest.raises(ValueError, match="Unsupported file format"):
|
| 55 |
parser.parse_document("test.xyz")
|
| 56 |
|
|
|
|
| 57 |
def test_parse_nonexistent_file():
|
| 58 |
"""Test handling of non-existent files"""
|
| 59 |
from src.ingestion.document_parser import DocumentParser
|
| 60 |
-
|
| 61 |
parser = DocumentParser()
|
| 62 |
with pytest.raises(FileNotFoundError):
|
| 63 |
parser.parse_document("nonexistent.txt")
|
| 64 |
|
|
|
|
| 65 |
def test_parse_real_policy_document():
|
| 66 |
"""Test parsing an actual policy document from our corpus"""
|
| 67 |
from src.ingestion.document_parser import DocumentParser
|
| 68 |
-
|
| 69 |
parser = DocumentParser()
|
| 70 |
# Use a real policy document from our corpus
|
| 71 |
policy_path = "synthetic_policies/employee_handbook.md"
|
| 72 |
-
|
| 73 |
result = parser.parse_document(policy_path)
|
| 74 |
-
|
| 75 |
# Verify content structure
|
| 76 |
-
assert "employee_handbook.md" in result[
|
| 77 |
-
assert result[
|
| 78 |
-
assert "Employee Handbook" in result[
|
| 79 |
-
assert "HR-POL-001" in result[
|
| 80 |
-
assert len(result[
|
| 81 |
-
|
| 82 |
# Verify metadata completeness
|
| 83 |
-
assert
|
| 84 |
-
assert
|
| 85 |
-
assert result[
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import tempfile
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
|
| 8 |
def test_parse_txt_file():
|
| 9 |
"""Test parsing a simple text file"""
|
| 10 |
# Test will fail initially - we'll implement parser to make it pass
|
| 11 |
from src.ingestion.document_parser import DocumentParser
|
| 12 |
+
|
| 13 |
parser = DocumentParser()
|
| 14 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
| 15 |
f.write("This is a test policy document.\nIt has multiple lines.")
|
| 16 |
temp_path = f.name
|
| 17 |
+
|
| 18 |
try:
|
| 19 |
result = parser.parse_document(temp_path)
|
| 20 |
+
assert (
|
| 21 |
+
result["content"]
|
| 22 |
+
== "This is a test policy document.\nIt has multiple lines."
|
| 23 |
+
)
|
| 24 |
+
assert result["metadata"]["filename"] == Path(temp_path).name
|
| 25 |
+
assert result["metadata"]["file_type"] == "txt"
|
| 26 |
finally:
|
| 27 |
os.unlink(temp_path)
|
| 28 |
|
| 29 |
+
|
| 30 |
def test_parse_markdown_file():
|
| 31 |
"""Test parsing a markdown file"""
|
| 32 |
from src.ingestion.document_parser import DocumentParser
|
| 33 |
+
|
| 34 |
parser = DocumentParser()
|
| 35 |
markdown_content = """# Policy Title
|
| 36 |
+
|
| 37 |
## Section 1
|
| 38 |
This is section content.
|
| 39 |
|
| 40 |
### Subsection
|
| 41 |
More content here."""
|
| 42 |
+
|
| 43 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
| 44 |
f.write(markdown_content)
|
| 45 |
temp_path = f.name
|
| 46 |
+
|
| 47 |
try:
|
| 48 |
result = parser.parse_document(temp_path)
|
| 49 |
+
assert "Policy Title" in result["content"]
|
| 50 |
+
assert "Section 1" in result["content"]
|
| 51 |
+
assert result["metadata"]["file_type"] == "md"
|
| 52 |
finally:
|
| 53 |
os.unlink(temp_path)
|
| 54 |
|
| 55 |
+
|
| 56 |
def test_parse_unsupported_format():
|
| 57 |
"""Test handling of unsupported file formats"""
|
| 58 |
from src.ingestion.document_parser import DocumentParser
|
| 59 |
+
|
| 60 |
parser = DocumentParser()
|
| 61 |
with pytest.raises(ValueError, match="Unsupported file format"):
|
| 62 |
parser.parse_document("test.xyz")
|
| 63 |
|
| 64 |
+
|
| 65 |
def test_parse_nonexistent_file():
|
| 66 |
"""Test handling of non-existent files"""
|
| 67 |
from src.ingestion.document_parser import DocumentParser
|
| 68 |
+
|
| 69 |
parser = DocumentParser()
|
| 70 |
with pytest.raises(FileNotFoundError):
|
| 71 |
parser.parse_document("nonexistent.txt")
|
| 72 |
|
| 73 |
+
|
| 74 |
def test_parse_real_policy_document():
|
| 75 |
"""Test parsing an actual policy document from our corpus"""
|
| 76 |
from src.ingestion.document_parser import DocumentParser
|
| 77 |
+
|
| 78 |
parser = DocumentParser()
|
| 79 |
# Use a real policy document from our corpus
|
| 80 |
policy_path = "synthetic_policies/employee_handbook.md"
|
| 81 |
+
|
| 82 |
result = parser.parse_document(policy_path)
|
| 83 |
+
|
| 84 |
# Verify content structure
|
| 85 |
+
assert "employee_handbook.md" in result["metadata"]["filename"]
|
| 86 |
+
assert result["metadata"]["file_type"] == "md"
|
| 87 |
+
assert "Employee Handbook" in result["content"]
|
| 88 |
+
assert "HR-POL-001" in result["content"]
|
| 89 |
+
assert len(result["content"]) > 100 # Should have substantial content
|
| 90 |
+
|
| 91 |
# Verify metadata completeness
|
| 92 |
+
assert "file_size" in result["metadata"]
|
| 93 |
+
assert "file_path" in result["metadata"]
|
| 94 |
+
assert result["metadata"]["file_size"] > 0
|
tests/test_ingestion/test_ingestion_pipeline.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
-
import pytest
|
| 2 |
-
import tempfile
|
| 3 |
import os
|
|
|
|
| 4 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
| 5 |
from src.ingestion.ingestion_pipeline import IngestionPipeline
|
| 6 |
|
|
|
|
| 7 |
def test_full_ingestion_pipeline():
|
| 8 |
"""Test the complete ingestion pipeline end-to-end"""
|
| 9 |
# Create temporary test documents
|
|
@@ -11,68 +14,73 @@ def test_full_ingestion_pipeline():
|
|
| 11 |
# Create test files
|
| 12 |
txt_file = Path(temp_dir) / "policy1.txt"
|
| 13 |
md_file = Path(temp_dir) / "policy2.md"
|
| 14 |
-
|
| 15 |
-
txt_file.write_text(
|
|
|
|
|
|
|
| 16 |
md_file.write_text("# Markdown Policy\n\nThis is markdown content.")
|
| 17 |
-
|
| 18 |
# Initialize pipeline
|
| 19 |
pipeline = IngestionPipeline(chunk_size=50, overlap=10, seed=42)
|
| 20 |
-
|
| 21 |
# Process documents
|
| 22 |
results = pipeline.process_directory(temp_dir)
|
| 23 |
-
|
| 24 |
assert len(results) >= 2 # At least one result per file
|
| 25 |
-
|
| 26 |
# Verify structure
|
| 27 |
for result in results:
|
| 28 |
-
assert
|
| 29 |
-
assert
|
| 30 |
-
assert
|
| 31 |
-
assert
|
|
|
|
| 32 |
|
| 33 |
def test_pipeline_reproducibility():
|
| 34 |
"""Test that pipeline produces consistent results"""
|
| 35 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 36 |
test_file = Path(temp_dir) / "test.txt"
|
| 37 |
test_file.write_text("Test content for reproducibility. " * 20)
|
| 38 |
-
|
| 39 |
pipeline1 = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 40 |
pipeline2 = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 41 |
-
|
| 42 |
results1 = pipeline1.process_directory(temp_dir)
|
| 43 |
results2 = pipeline2.process_directory(temp_dir)
|
| 44 |
-
|
| 45 |
assert len(results1) == len(results2)
|
| 46 |
-
|
| 47 |
for r1, r2 in zip(results1, results2):
|
| 48 |
-
assert r1[
|
| 49 |
-
assert r1[
|
|
|
|
| 50 |
|
| 51 |
def test_pipeline_with_real_corpus():
|
| 52 |
"""Test pipeline with actual policy documents"""
|
| 53 |
pipeline = IngestionPipeline(chunk_size=1000, overlap=200, seed=42)
|
| 54 |
-
|
| 55 |
# Process just one real document to verify it works
|
| 56 |
corpus_dir = "synthetic_policies"
|
| 57 |
-
|
| 58 |
# Check if corpus directory exists
|
| 59 |
if not Path(corpus_dir).exists():
|
| 60 |
pytest.skip("Corpus directory not found - test requires synthetic_policies/")
|
| 61 |
-
|
| 62 |
results = pipeline.process_directory(corpus_dir)
|
| 63 |
-
|
| 64 |
# Should process all 22 documents
|
| 65 |
assert len(results) > 20 # Should have many chunks from 22 documents
|
| 66 |
-
|
| 67 |
# Verify all results have proper structure
|
| 68 |
for result in results:
|
| 69 |
-
assert
|
| 70 |
-
assert
|
| 71 |
-
assert
|
| 72 |
-
assert
|
| 73 |
-
assert
|
| 74 |
-
assert result[
|
| 75 |
-
assert
|
|
|
|
| 76 |
|
| 77 |
def test_pipeline_error_handling():
|
| 78 |
"""Test pipeline handles errors gracefully"""
|
|
@@ -80,87 +88,91 @@ def test_pipeline_error_handling():
|
|
| 80 |
# Create valid and invalid files
|
| 81 |
valid_file = Path(temp_dir) / "valid.md"
|
| 82 |
invalid_file = Path(temp_dir) / "invalid.xyz"
|
| 83 |
-
|
| 84 |
valid_file.write_text("# Valid Policy\n\nThis is valid content.")
|
| 85 |
invalid_file.write_text("This file has unsupported format.")
|
| 86 |
-
|
| 87 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 88 |
-
|
| 89 |
# Should process valid file and skip invalid one
|
| 90 |
results = pipeline.process_directory(temp_dir)
|
| 91 |
-
|
| 92 |
# Should only get results from valid file
|
| 93 |
assert len(results) >= 1
|
| 94 |
-
|
| 95 |
# All results should be from valid file
|
| 96 |
for result in results:
|
| 97 |
-
assert result[
|
|
|
|
| 98 |
|
| 99 |
def test_pipeline_single_file():
|
| 100 |
"""Test processing a single file"""
|
| 101 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 102 |
-
|
| 103 |
-
with tempfile.NamedTemporaryFile(mode=
|
| 104 |
f.write("# Test Policy\n\n" + "Content section. " * 20)
|
| 105 |
temp_path = f.name
|
| 106 |
-
|
| 107 |
try:
|
| 108 |
results = pipeline.process_file(temp_path)
|
| 109 |
-
|
| 110 |
# Should get multiple chunks due to length
|
| 111 |
assert len(results) > 1
|
| 112 |
-
|
| 113 |
# All chunks should have same filename
|
| 114 |
filename = Path(temp_path).name
|
| 115 |
for result in results:
|
| 116 |
-
assert result[
|
| 117 |
-
assert result[
|
| 118 |
-
assert
|
| 119 |
-
|
| 120 |
finally:
|
| 121 |
os.unlink(temp_path)
|
| 122 |
|
|
|
|
| 123 |
def test_pipeline_empty_directory():
|
| 124 |
"""Test pipeline with empty directory"""
|
| 125 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 126 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 127 |
-
|
| 128 |
results = pipeline.process_directory(temp_dir)
|
| 129 |
-
|
| 130 |
# Should return empty list for empty directory
|
| 131 |
assert len(results) == 0
|
| 132 |
|
|
|
|
| 133 |
def test_pipeline_nonexistent_directory():
|
| 134 |
"""Test pipeline with non-existent directory"""
|
| 135 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 136 |
-
|
| 137 |
with pytest.raises(FileNotFoundError):
|
| 138 |
pipeline.process_directory("/nonexistent/directory")
|
| 139 |
|
|
|
|
| 140 |
def test_pipeline_configuration():
|
| 141 |
"""Test pipeline configuration options"""
|
| 142 |
# Test different configurations
|
| 143 |
pipeline_small = IngestionPipeline(chunk_size=50, overlap=10, seed=42)
|
| 144 |
pipeline_large = IngestionPipeline(chunk_size=200, overlap=50, seed=42)
|
| 145 |
-
|
| 146 |
-
with tempfile.NamedTemporaryFile(mode=
|
| 147 |
content = "Policy content goes here. " * 30 # 780 characters
|
| 148 |
f.write(content)
|
| 149 |
temp_path = f.name
|
| 150 |
-
|
| 151 |
try:
|
| 152 |
results_small = pipeline_small.process_file(temp_path)
|
| 153 |
results_large = pipeline_large.process_file(temp_path)
|
| 154 |
-
|
| 155 |
# Small chunks should create more chunks
|
| 156 |
assert len(results_small) > len(results_large)
|
| 157 |
-
|
| 158 |
# All chunks should respect size limits
|
| 159 |
for result in results_small:
|
| 160 |
-
assert len(result[
|
| 161 |
-
|
| 162 |
for result in results_large:
|
| 163 |
-
assert len(result[
|
| 164 |
-
|
| 165 |
finally:
|
| 166 |
-
os.unlink(temp_path)
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import tempfile
|
| 3 |
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
from src.ingestion.ingestion_pipeline import IngestionPipeline
|
| 8 |
|
| 9 |
+
|
| 10 |
def test_full_ingestion_pipeline():
|
| 11 |
"""Test the complete ingestion pipeline end-to-end"""
|
| 12 |
# Create temporary test documents
|
|
|
|
| 14 |
# Create test files
|
| 15 |
txt_file = Path(temp_dir) / "policy1.txt"
|
| 16 |
md_file = Path(temp_dir) / "policy2.md"
|
| 17 |
+
|
| 18 |
+
txt_file.write_text(
|
| 19 |
+
"This is a text policy document with important information."
|
| 20 |
+
)
|
| 21 |
md_file.write_text("# Markdown Policy\n\nThis is markdown content.")
|
| 22 |
+
|
| 23 |
# Initialize pipeline
|
| 24 |
pipeline = IngestionPipeline(chunk_size=50, overlap=10, seed=42)
|
| 25 |
+
|
| 26 |
# Process documents
|
| 27 |
results = pipeline.process_directory(temp_dir)
|
| 28 |
+
|
| 29 |
assert len(results) >= 2 # At least one result per file
|
| 30 |
+
|
| 31 |
# Verify structure
|
| 32 |
for result in results:
|
| 33 |
+
assert "content" in result
|
| 34 |
+
assert "metadata" in result
|
| 35 |
+
assert "chunk_id" in result["metadata"]
|
| 36 |
+
assert "filename" in result["metadata"]
|
| 37 |
+
|
| 38 |
|
| 39 |
def test_pipeline_reproducibility():
|
| 40 |
"""Test that pipeline produces consistent results"""
|
| 41 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 42 |
test_file = Path(temp_dir) / "test.txt"
|
| 43 |
test_file.write_text("Test content for reproducibility. " * 20)
|
| 44 |
+
|
| 45 |
pipeline1 = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 46 |
pipeline2 = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 47 |
+
|
| 48 |
results1 = pipeline1.process_directory(temp_dir)
|
| 49 |
results2 = pipeline2.process_directory(temp_dir)
|
| 50 |
+
|
| 51 |
assert len(results1) == len(results2)
|
| 52 |
+
|
| 53 |
for r1, r2 in zip(results1, results2):
|
| 54 |
+
assert r1["content"] == r2["content"]
|
| 55 |
+
assert r1["metadata"]["chunk_id"] == r2["metadata"]["chunk_id"]
|
| 56 |
+
|
| 57 |
|
| 58 |
def test_pipeline_with_real_corpus():
|
| 59 |
"""Test pipeline with actual policy documents"""
|
| 60 |
pipeline = IngestionPipeline(chunk_size=1000, overlap=200, seed=42)
|
| 61 |
+
|
| 62 |
# Process just one real document to verify it works
|
| 63 |
corpus_dir = "synthetic_policies"
|
| 64 |
+
|
| 65 |
# Check if corpus directory exists
|
| 66 |
if not Path(corpus_dir).exists():
|
| 67 |
pytest.skip("Corpus directory not found - test requires synthetic_policies/")
|
| 68 |
+
|
| 69 |
results = pipeline.process_directory(corpus_dir)
|
| 70 |
+
|
| 71 |
# Should process all 22 documents
|
| 72 |
assert len(results) > 20 # Should have many chunks from 22 documents
|
| 73 |
+
|
| 74 |
# Verify all results have proper structure
|
| 75 |
for result in results:
|
| 76 |
+
assert "content" in result
|
| 77 |
+
assert "metadata" in result
|
| 78 |
+
assert "chunk_id" in result["metadata"]
|
| 79 |
+
assert "filename" in result["metadata"]
|
| 80 |
+
assert "file_type" in result["metadata"]
|
| 81 |
+
assert result["metadata"]["file_type"] == "md"
|
| 82 |
+
assert "chunk_index" in result["metadata"]
|
| 83 |
+
|
| 84 |
|
| 85 |
def test_pipeline_error_handling():
|
| 86 |
"""Test pipeline handles errors gracefully"""
|
|
|
|
| 88 |
# Create valid and invalid files
|
| 89 |
valid_file = Path(temp_dir) / "valid.md"
|
| 90 |
invalid_file = Path(temp_dir) / "invalid.xyz"
|
| 91 |
+
|
| 92 |
valid_file.write_text("# Valid Policy\n\nThis is valid content.")
|
| 93 |
invalid_file.write_text("This file has unsupported format.")
|
| 94 |
+
|
| 95 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 96 |
+
|
| 97 |
# Should process valid file and skip invalid one
|
| 98 |
results = pipeline.process_directory(temp_dir)
|
| 99 |
+
|
| 100 |
# Should only get results from valid file
|
| 101 |
assert len(results) >= 1
|
| 102 |
+
|
| 103 |
# All results should be from valid file
|
| 104 |
for result in results:
|
| 105 |
+
assert result["metadata"]["filename"] == "valid.md"
|
| 106 |
+
|
| 107 |
|
| 108 |
def test_pipeline_single_file():
|
| 109 |
"""Test processing a single file"""
|
| 110 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 111 |
+
|
| 112 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".md", delete=False) as f:
|
| 113 |
f.write("# Test Policy\n\n" + "Content section. " * 20)
|
| 114 |
temp_path = f.name
|
| 115 |
+
|
| 116 |
try:
|
| 117 |
results = pipeline.process_file(temp_path)
|
| 118 |
+
|
| 119 |
# Should get multiple chunks due to length
|
| 120 |
assert len(results) > 1
|
| 121 |
+
|
| 122 |
# All chunks should have same filename
|
| 123 |
filename = Path(temp_path).name
|
| 124 |
for result in results:
|
| 125 |
+
assert result["metadata"]["filename"] == filename
|
| 126 |
+
assert result["metadata"]["file_type"] == "md"
|
| 127 |
+
assert "chunk_index" in result["metadata"]
|
| 128 |
+
|
| 129 |
finally:
|
| 130 |
os.unlink(temp_path)
|
| 131 |
|
| 132 |
+
|
| 133 |
def test_pipeline_empty_directory():
|
| 134 |
"""Test pipeline with empty directory"""
|
| 135 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 136 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 137 |
+
|
| 138 |
results = pipeline.process_directory(temp_dir)
|
| 139 |
+
|
| 140 |
# Should return empty list for empty directory
|
| 141 |
assert len(results) == 0
|
| 142 |
|
| 143 |
+
|
| 144 |
def test_pipeline_nonexistent_directory():
|
| 145 |
"""Test pipeline with non-existent directory"""
|
| 146 |
pipeline = IngestionPipeline(chunk_size=100, overlap=20, seed=42)
|
| 147 |
+
|
| 148 |
with pytest.raises(FileNotFoundError):
|
| 149 |
pipeline.process_directory("/nonexistent/directory")
|
| 150 |
|
| 151 |
+
|
| 152 |
def test_pipeline_configuration():
|
| 153 |
"""Test pipeline configuration options"""
|
| 154 |
# Test different configurations
|
| 155 |
pipeline_small = IngestionPipeline(chunk_size=50, overlap=10, seed=42)
|
| 156 |
pipeline_large = IngestionPipeline(chunk_size=200, overlap=50, seed=42)
|
| 157 |
+
|
| 158 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
|
| 159 |
content = "Policy content goes here. " * 30 # 780 characters
|
| 160 |
f.write(content)
|
| 161 |
temp_path = f.name
|
| 162 |
+
|
| 163 |
try:
|
| 164 |
results_small = pipeline_small.process_file(temp_path)
|
| 165 |
results_large = pipeline_large.process_file(temp_path)
|
| 166 |
+
|
| 167 |
# Small chunks should create more chunks
|
| 168 |
assert len(results_small) > len(results_large)
|
| 169 |
+
|
| 170 |
# All chunks should respect size limits
|
| 171 |
for result in results_small:
|
| 172 |
+
assert len(result["content"]) <= 50
|
| 173 |
+
|
| 174 |
for result in results_large:
|
| 175 |
+
assert len(result["content"]) <= 200
|
| 176 |
+
|
| 177 |
finally:
|
| 178 |
+
os.unlink(temp_path)
|
tests/test_integration.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
"""Integration tests for Phase 2A components."""
|
| 2 |
|
| 3 |
-
import pytest
|
| 4 |
-
import tempfile
|
| 5 |
import shutil
|
| 6 |
-
|
| 7 |
|
| 8 |
from src.embedding.embedding_service import EmbeddingService
|
| 9 |
from src.vector_store.vector_db import VectorDatabase
|
|
@@ -11,101 +9,122 @@ from src.vector_store.vector_db import VectorDatabase
|
|
| 11 |
|
| 12 |
class TestPhase2AIntegration:
|
| 13 |
"""Test integration between EmbeddingService and VectorDatabase"""
|
| 14 |
-
|
| 15 |
def setup_method(self):
|
| 16 |
"""Set up test environment with temporary database"""
|
| 17 |
self.test_dir = tempfile.mkdtemp()
|
| 18 |
self.embedding_service = EmbeddingService()
|
| 19 |
-
self.vector_db = VectorDatabase(
|
| 20 |
-
|
|
|
|
|
|
|
| 21 |
def teardown_method(self):
|
| 22 |
"""Clean up temporary resources"""
|
| 23 |
-
if hasattr(self,
|
| 24 |
shutil.rmtree(self.test_dir, ignore_errors=True)
|
| 25 |
-
|
| 26 |
def test_embedding_vector_storage_workflow(self):
|
| 27 |
"""Test complete workflow: text → embedding → storage → search"""
|
| 28 |
-
|
| 29 |
# Sample policy texts
|
| 30 |
documents = [
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
]
|
| 36 |
-
|
| 37 |
# Generate embeddings
|
| 38 |
embeddings = self.embedding_service.embed_texts(documents)
|
| 39 |
-
|
| 40 |
# Verify embeddings were generated
|
| 41 |
assert len(embeddings) == len(documents)
|
| 42 |
-
assert all(
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
| 44 |
# Store embeddings with metadata (using existing collection)
|
| 45 |
doc_ids = [f"doc_{i}" for i in range(len(documents))]
|
| 46 |
metadatas = [{"type": "policy", "doc_id": doc_id} for doc_id in doc_ids]
|
| 47 |
-
|
| 48 |
success = self.vector_db.add_embeddings(
|
| 49 |
embeddings=embeddings,
|
| 50 |
chunk_ids=doc_ids,
|
| 51 |
documents=documents,
|
| 52 |
-
metadatas=metadatas
|
| 53 |
)
|
| 54 |
-
|
| 55 |
assert success is True
|
| 56 |
-
|
| 57 |
# Test search functionality
|
| 58 |
query = "remote work from home policy"
|
| 59 |
query_embedding = self.embedding_service.embed_text(query)
|
| 60 |
-
|
| 61 |
-
results = self.vector_db.search(
|
| 62 |
-
|
| 63 |
-
top_k=2
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
# Verify search results (should return list of dictionaries)
|
| 67 |
assert isinstance(results, list)
|
| 68 |
assert len(results) <= 2 # Should return at most 2 results
|
| 69 |
-
|
| 70 |
if results: # If we have results
|
| 71 |
assert all(isinstance(result, dict) for result in results)
|
| 72 |
# Check that at least one result contains remote work related content
|
| 73 |
-
documents_found = [result.get(
|
| 74 |
-
remote_work_found = any(
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
assert remote_work_found
|
| 77 |
-
|
| 78 |
def test_basic_embedding_dimension_consistency(self):
|
| 79 |
"""Test that embeddings have consistent dimensions"""
|
| 80 |
-
|
| 81 |
# Test different text lengths
|
| 82 |
texts = [
|
| 83 |
"Short text.",
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
]
|
| 87 |
-
|
| 88 |
# Generate embeddings
|
| 89 |
embeddings = self.embedding_service.embed_texts(texts)
|
| 90 |
-
|
| 91 |
# All embeddings should have the same dimension
|
| 92 |
dimensions = [len(emb) for emb in embeddings]
|
| 93 |
assert all(dim == dimensions[0] for dim in dimensions)
|
| 94 |
-
|
| 95 |
# Dimension should match the service's reported dimension
|
| 96 |
assert dimensions[0] == self.embedding_service.get_embedding_dimension()
|
| 97 |
-
|
| 98 |
def test_empty_collection_handling(self):
|
| 99 |
"""Test behavior with empty collection"""
|
| 100 |
-
|
| 101 |
# Search in empty collection
|
| 102 |
query_embedding = self.embedding_service.embed_text("test query")
|
| 103 |
-
|
| 104 |
-
results = self.vector_db.search(
|
| 105 |
-
|
| 106 |
-
top_k=5
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
# Should handle empty collection gracefully
|
| 110 |
assert isinstance(results, list)
|
| 111 |
-
assert len(results) == 0
|
|
|
|
| 1 |
"""Integration tests for Phase 2A components."""
|
| 2 |
|
|
|
|
|
|
|
| 3 |
import shutil
|
| 4 |
+
import tempfile
|
| 5 |
|
| 6 |
from src.embedding.embedding_service import EmbeddingService
|
| 7 |
from src.vector_store.vector_db import VectorDatabase
|
|
|
|
| 9 |
|
| 10 |
class TestPhase2AIntegration:
|
| 11 |
"""Test integration between EmbeddingService and VectorDatabase"""
|
| 12 |
+
|
| 13 |
def setup_method(self):
|
| 14 |
"""Set up test environment with temporary database"""
|
| 15 |
self.test_dir = tempfile.mkdtemp()
|
| 16 |
self.embedding_service = EmbeddingService()
|
| 17 |
+
self.vector_db = VectorDatabase(
|
| 18 |
+
persist_path=self.test_dir, collection_name="test_integration"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
def teardown_method(self):
|
| 22 |
"""Clean up temporary resources"""
|
| 23 |
+
if hasattr(self, "test_dir"):
|
| 24 |
shutil.rmtree(self.test_dir, ignore_errors=True)
|
| 25 |
+
|
| 26 |
def test_embedding_vector_storage_workflow(self):
|
| 27 |
"""Test complete workflow: text → embedding → storage → search"""
|
| 28 |
+
|
| 29 |
# Sample policy texts
|
| 30 |
documents = [
|
| 31 |
+
(
|
| 32 |
+
"Employees must complete security training annually to "
|
| 33 |
+
"maintain access to company systems."
|
| 34 |
+
),
|
| 35 |
+
(
|
| 36 |
+
"Remote work policy allows employees to work from home up to "
|
| 37 |
+
"3 days per week."
|
| 38 |
+
),
|
| 39 |
+
(
|
| 40 |
+
"All expenses over $500 require manager approval before "
|
| 41 |
+
"reimbursement."
|
| 42 |
+
),
|
| 43 |
+
(
|
| 44 |
+
"Code review is mandatory for all pull requests before "
|
| 45 |
+
"merging to main branch."
|
| 46 |
+
),
|
| 47 |
]
|
| 48 |
+
|
| 49 |
# Generate embeddings
|
| 50 |
embeddings = self.embedding_service.embed_texts(documents)
|
| 51 |
+
|
| 52 |
# Verify embeddings were generated
|
| 53 |
assert len(embeddings) == len(documents)
|
| 54 |
+
assert all(
|
| 55 |
+
len(emb) == self.embedding_service.get_embedding_dimension()
|
| 56 |
+
for emb in embeddings
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
# Store embeddings with metadata (using existing collection)
|
| 60 |
doc_ids = [f"doc_{i}" for i in range(len(documents))]
|
| 61 |
metadatas = [{"type": "policy", "doc_id": doc_id} for doc_id in doc_ids]
|
| 62 |
+
|
| 63 |
success = self.vector_db.add_embeddings(
|
| 64 |
embeddings=embeddings,
|
| 65 |
chunk_ids=doc_ids,
|
| 66 |
documents=documents,
|
| 67 |
+
metadatas=metadatas,
|
| 68 |
)
|
| 69 |
+
|
| 70 |
assert success is True
|
| 71 |
+
|
| 72 |
# Test search functionality
|
| 73 |
query = "remote work from home policy"
|
| 74 |
query_embedding = self.embedding_service.embed_text(query)
|
| 75 |
+
|
| 76 |
+
results = self.vector_db.search(query_embedding=query_embedding, top_k=2)
|
| 77 |
+
|
|
|
|
|
|
|
|
|
|
| 78 |
# Verify search results (should return list of dictionaries)
|
| 79 |
assert isinstance(results, list)
|
| 80 |
assert len(results) <= 2 # Should return at most 2 results
|
| 81 |
+
|
| 82 |
if results: # If we have results
|
| 83 |
assert all(isinstance(result, dict) for result in results)
|
| 84 |
# Check that at least one result contains remote work related content
|
| 85 |
+
documents_found = [result.get("document", "") for result in results]
|
| 86 |
+
remote_work_found = any(
|
| 87 |
+
"remote work" in doc.lower() or "work from home" in doc.lower()
|
| 88 |
+
for doc in documents_found
|
| 89 |
+
)
|
| 90 |
assert remote_work_found
|
| 91 |
+
|
| 92 |
def test_basic_embedding_dimension_consistency(self):
|
| 93 |
"""Test that embeddings have consistent dimensions"""
|
| 94 |
+
|
| 95 |
# Test different text lengths
|
| 96 |
texts = [
|
| 97 |
"Short text.",
|
| 98 |
+
(
|
| 99 |
+
"This is a medium length text with several words to test "
|
| 100 |
+
"embedding consistency."
|
| 101 |
+
),
|
| 102 |
+
(
|
| 103 |
+
"This is a much longer text that contains multiple sentences "
|
| 104 |
+
"and various types of content to ensure that the embedding "
|
| 105 |
+
"service can handle longer inputs without issues and still "
|
| 106 |
+
"produce consistent dimensional output vectors."
|
| 107 |
+
),
|
| 108 |
]
|
| 109 |
+
|
| 110 |
# Generate embeddings
|
| 111 |
embeddings = self.embedding_service.embed_texts(texts)
|
| 112 |
+
|
| 113 |
# All embeddings should have the same dimension
|
| 114 |
dimensions = [len(emb) for emb in embeddings]
|
| 115 |
assert all(dim == dimensions[0] for dim in dimensions)
|
| 116 |
+
|
| 117 |
# Dimension should match the service's reported dimension
|
| 118 |
assert dimensions[0] == self.embedding_service.get_embedding_dimension()
|
| 119 |
+
|
| 120 |
def test_empty_collection_handling(self):
|
| 121 |
"""Test behavior with empty collection"""
|
| 122 |
+
|
| 123 |
# Search in empty collection
|
| 124 |
query_embedding = self.embedding_service.embed_text("test query")
|
| 125 |
+
|
| 126 |
+
results = self.vector_db.search(query_embedding=query_embedding, top_k=5)
|
| 127 |
+
|
|
|
|
|
|
|
|
|
|
| 128 |
# Should handle empty collection gracefully
|
| 129 |
assert isinstance(results, list)
|
| 130 |
+
assert len(results) == 0
|
tests/test_vector_store/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
| 1 |
-
# Test package for vector store components
|
|
|
|
| 1 |
+
# Test package for vector store components
|
tests/test_vector_store/test_vector_db.py
CHANGED
|
@@ -1,187 +1,197 @@
|
|
| 1 |
-
import pytest
|
| 2 |
import tempfile
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
from src.vector_store.vector_db import VectorDatabase
|
| 7 |
|
|
|
|
| 8 |
def test_vector_database_initialization():
|
| 9 |
"""Test VectorDatabase initialization and connection"""
|
| 10 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 11 |
# Test will fail initially - we'll implement VectorDatabase to make it pass
|
| 12 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_collection")
|
| 13 |
-
|
| 14 |
# Should create connection successfully
|
| 15 |
assert db is not None
|
| 16 |
assert db.collection_name == "test_collection"
|
| 17 |
assert db.persist_path == temp_dir
|
| 18 |
|
|
|
|
| 19 |
def test_create_collection():
|
| 20 |
"""Test creating a new collection"""
|
| 21 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 22 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 23 |
-
|
| 24 |
# Collection should be created
|
| 25 |
collection = db.get_collection()
|
| 26 |
assert collection is not None
|
| 27 |
assert collection.name == "test_docs"
|
| 28 |
|
|
|
|
| 29 |
def test_add_embeddings():
|
| 30 |
"""Test adding embeddings to the database"""
|
| 31 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 32 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 33 |
-
|
| 34 |
# Sample data
|
| 35 |
embeddings = [
|
| 36 |
[0.1, 0.2, 0.3, 0.4], # 4-dimensional for testing
|
| 37 |
[0.5, 0.6, 0.7, 0.8],
|
| 38 |
-
[0.9, 1.0, 1.1, 1.2]
|
| 39 |
]
|
| 40 |
-
|
| 41 |
chunk_ids = ["chunk_1", "chunk_2", "chunk_3"]
|
| 42 |
-
|
| 43 |
documents = [
|
| 44 |
"This is the first document chunk.",
|
| 45 |
"This is the second document chunk.",
|
| 46 |
-
"This is the third document chunk."
|
| 47 |
]
|
| 48 |
-
|
| 49 |
metadatas = [
|
| 50 |
{"filename": "doc1.md", "chunk_index": 0},
|
| 51 |
{"filename": "doc1.md", "chunk_index": 1},
|
| 52 |
-
{"filename": "doc2.md", "chunk_index": 0}
|
| 53 |
]
|
| 54 |
-
|
| 55 |
# Add embeddings
|
| 56 |
result = db.add_embeddings(
|
| 57 |
embeddings=embeddings,
|
| 58 |
chunk_ids=chunk_ids,
|
| 59 |
documents=documents,
|
| 60 |
-
metadatas=metadatas
|
| 61 |
)
|
| 62 |
-
|
| 63 |
# Should return success
|
| 64 |
assert result is True
|
| 65 |
-
|
| 66 |
# Verify count
|
| 67 |
count = db.get_count()
|
| 68 |
assert count == 3
|
| 69 |
|
|
|
|
| 70 |
def test_search_embeddings():
|
| 71 |
"""Test searching for similar embeddings"""
|
| 72 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 73 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 74 |
-
|
| 75 |
# Add some test data first
|
| 76 |
embeddings = [
|
| 77 |
[1.0, 0.0, 0.0, 0.0], # Distinct embeddings for testing
|
| 78 |
[0.0, 1.0, 0.0, 0.0],
|
| 79 |
[0.0, 0.0, 1.0, 0.0],
|
| 80 |
-
[0.0, 0.0, 0.0, 1.0]
|
| 81 |
]
|
| 82 |
-
|
| 83 |
chunk_ids = ["chunk_1", "chunk_2", "chunk_3", "chunk_4"]
|
| 84 |
documents = ["Doc 1", "Doc 2", "Doc 3", "Doc 4"]
|
| 85 |
metadatas = [{"index": i} for i in range(4)]
|
| 86 |
-
|
| 87 |
db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 88 |
-
|
| 89 |
# Search for similar to first embedding
|
| 90 |
query_embedding = [1.0, 0.0, 0.0, 0.0]
|
| 91 |
results = db.search(query_embedding, top_k=2)
|
| 92 |
-
|
| 93 |
# Should return results
|
| 94 |
assert len(results) <= 2
|
| 95 |
assert len(results) > 0
|
| 96 |
-
|
| 97 |
# First result should be the exact match
|
| 98 |
assert results[0]["id"] == "chunk_1"
|
| 99 |
assert "distance" in results[0]
|
| 100 |
assert "document" in results[0]
|
| 101 |
assert "metadata" in results[0]
|
| 102 |
|
|
|
|
| 103 |
def test_delete_collection():
|
| 104 |
"""Test deleting a collection"""
|
| 105 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 106 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 107 |
-
|
| 108 |
# Add some data
|
| 109 |
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
| 110 |
chunk_ids = ["chunk_1"]
|
| 111 |
documents = ["Test doc"]
|
| 112 |
metadatas = [{"test": True}]
|
| 113 |
-
|
| 114 |
db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 115 |
assert db.get_count() == 1
|
| 116 |
-
|
| 117 |
# Delete collection
|
| 118 |
db.delete_collection()
|
| 119 |
-
|
| 120 |
# Should be empty after recreation
|
| 121 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 122 |
assert db.get_count() == 0
|
| 123 |
|
|
|
|
| 124 |
def test_persistence():
|
| 125 |
"""Test that data persists across database instances"""
|
| 126 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 127 |
# Create first instance and add data
|
| 128 |
db1 = VectorDatabase(persist_path=temp_dir, collection_name="persistent_test")
|
| 129 |
-
|
| 130 |
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
| 131 |
chunk_ids = ["persistent_chunk"]
|
| 132 |
documents = ["Persistent document"]
|
| 133 |
metadatas = [{"persistent": True}]
|
| 134 |
-
|
| 135 |
db1.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 136 |
assert db1.get_count() == 1
|
| 137 |
-
|
| 138 |
# Create second instance with same path
|
| 139 |
db2 = VectorDatabase(persist_path=temp_dir, collection_name="persistent_test")
|
| 140 |
-
|
| 141 |
# Should have the same data
|
| 142 |
assert db2.get_count() == 1
|
| 143 |
-
|
| 144 |
# Should be able to search and find the data
|
| 145 |
results = db2.search([0.1, 0.2, 0.3, 0.4], top_k=1)
|
| 146 |
assert len(results) == 1
|
| 147 |
assert results[0]["id"] == "persistent_chunk"
|
| 148 |
|
|
|
|
| 149 |
def test_error_handling():
|
| 150 |
"""Test error handling for various edge cases"""
|
| 151 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 152 |
db = VectorDatabase(persist_path=temp_dir, collection_name="error_test")
|
| 153 |
-
|
| 154 |
# Test empty search
|
| 155 |
results = db.search([0.1, 0.2, 0.3, 0.4], top_k=5)
|
| 156 |
assert results == []
|
| 157 |
-
|
| 158 |
# Test adding mismatched data
|
| 159 |
with pytest.raises((ValueError, Exception)):
|
| 160 |
db.add_embeddings(
|
| 161 |
embeddings=[[0.1, 0.2]], # 2D
|
| 162 |
chunk_ids=["chunk_1", "chunk_2"], # 2 IDs but 1 embedding
|
| 163 |
documents=["Doc 1"], # 1 document
|
| 164 |
-
metadatas=[{"test": True}] # 1 metadata
|
| 165 |
)
|
| 166 |
|
|
|
|
| 167 |
def test_batch_operations():
|
| 168 |
"""Test batch operations for performance"""
|
| 169 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 170 |
db = VectorDatabase(persist_path=temp_dir, collection_name="batch_test")
|
| 171 |
-
|
| 172 |
# Create larger batch for testing
|
| 173 |
batch_size = 50
|
| 174 |
-
embeddings = [
|
|
|
|
|
|
|
|
|
|
| 175 |
chunk_ids = [f"chunk_{i}" for i in range(batch_size)]
|
| 176 |
documents = [f"Document {i} content" for i in range(batch_size)]
|
| 177 |
metadatas = [{"batch_index": i, "test_batch": True} for i in range(batch_size)]
|
| 178 |
-
|
| 179 |
# Should handle batch operations
|
| 180 |
result = db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 181 |
assert result is True
|
| 182 |
assert db.get_count() == batch_size
|
| 183 |
-
|
| 184 |
# Should handle batch search
|
| 185 |
query_embedding = [0.0, 1.0, 2.0, 3.0]
|
| 186 |
results = db.search(query_embedding, top_k=10)
|
| 187 |
-
assert len(results) == 10 # Should return requested number
|
|
|
|
|
|
|
| 1 |
import tempfile
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
from src.vector_store.vector_db import VectorDatabase
|
| 6 |
|
| 7 |
+
|
| 8 |
def test_vector_database_initialization():
|
| 9 |
"""Test VectorDatabase initialization and connection"""
|
| 10 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 11 |
# Test will fail initially - we'll implement VectorDatabase to make it pass
|
| 12 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_collection")
|
| 13 |
+
|
| 14 |
# Should create connection successfully
|
| 15 |
assert db is not None
|
| 16 |
assert db.collection_name == "test_collection"
|
| 17 |
assert db.persist_path == temp_dir
|
| 18 |
|
| 19 |
+
|
| 20 |
def test_create_collection():
|
| 21 |
"""Test creating a new collection"""
|
| 22 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 23 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 24 |
+
|
| 25 |
# Collection should be created
|
| 26 |
collection = db.get_collection()
|
| 27 |
assert collection is not None
|
| 28 |
assert collection.name == "test_docs"
|
| 29 |
|
| 30 |
+
|
| 31 |
def test_add_embeddings():
|
| 32 |
"""Test adding embeddings to the database"""
|
| 33 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 34 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 35 |
+
|
| 36 |
# Sample data
|
| 37 |
embeddings = [
|
| 38 |
[0.1, 0.2, 0.3, 0.4], # 4-dimensional for testing
|
| 39 |
[0.5, 0.6, 0.7, 0.8],
|
| 40 |
+
[0.9, 1.0, 1.1, 1.2],
|
| 41 |
]
|
| 42 |
+
|
| 43 |
chunk_ids = ["chunk_1", "chunk_2", "chunk_3"]
|
| 44 |
+
|
| 45 |
documents = [
|
| 46 |
"This is the first document chunk.",
|
| 47 |
"This is the second document chunk.",
|
| 48 |
+
"This is the third document chunk.",
|
| 49 |
]
|
| 50 |
+
|
| 51 |
metadatas = [
|
| 52 |
{"filename": "doc1.md", "chunk_index": 0},
|
| 53 |
{"filename": "doc1.md", "chunk_index": 1},
|
| 54 |
+
{"filename": "doc2.md", "chunk_index": 0},
|
| 55 |
]
|
| 56 |
+
|
| 57 |
# Add embeddings
|
| 58 |
result = db.add_embeddings(
|
| 59 |
embeddings=embeddings,
|
| 60 |
chunk_ids=chunk_ids,
|
| 61 |
documents=documents,
|
| 62 |
+
metadatas=metadatas,
|
| 63 |
)
|
| 64 |
+
|
| 65 |
# Should return success
|
| 66 |
assert result is True
|
| 67 |
+
|
| 68 |
# Verify count
|
| 69 |
count = db.get_count()
|
| 70 |
assert count == 3
|
| 71 |
|
| 72 |
+
|
| 73 |
def test_search_embeddings():
|
| 74 |
"""Test searching for similar embeddings"""
|
| 75 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 76 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 77 |
+
|
| 78 |
# Add some test data first
|
| 79 |
embeddings = [
|
| 80 |
[1.0, 0.0, 0.0, 0.0], # Distinct embeddings for testing
|
| 81 |
[0.0, 1.0, 0.0, 0.0],
|
| 82 |
[0.0, 0.0, 1.0, 0.0],
|
| 83 |
+
[0.0, 0.0, 0.0, 1.0],
|
| 84 |
]
|
| 85 |
+
|
| 86 |
chunk_ids = ["chunk_1", "chunk_2", "chunk_3", "chunk_4"]
|
| 87 |
documents = ["Doc 1", "Doc 2", "Doc 3", "Doc 4"]
|
| 88 |
metadatas = [{"index": i} for i in range(4)]
|
| 89 |
+
|
| 90 |
db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 91 |
+
|
| 92 |
# Search for similar to first embedding
|
| 93 |
query_embedding = [1.0, 0.0, 0.0, 0.0]
|
| 94 |
results = db.search(query_embedding, top_k=2)
|
| 95 |
+
|
| 96 |
# Should return results
|
| 97 |
assert len(results) <= 2
|
| 98 |
assert len(results) > 0
|
| 99 |
+
|
| 100 |
# First result should be the exact match
|
| 101 |
assert results[0]["id"] == "chunk_1"
|
| 102 |
assert "distance" in results[0]
|
| 103 |
assert "document" in results[0]
|
| 104 |
assert "metadata" in results[0]
|
| 105 |
|
| 106 |
+
|
| 107 |
def test_delete_collection():
|
| 108 |
"""Test deleting a collection"""
|
| 109 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 110 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 111 |
+
|
| 112 |
# Add some data
|
| 113 |
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
| 114 |
chunk_ids = ["chunk_1"]
|
| 115 |
documents = ["Test doc"]
|
| 116 |
metadatas = [{"test": True}]
|
| 117 |
+
|
| 118 |
db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 119 |
assert db.get_count() == 1
|
| 120 |
+
|
| 121 |
# Delete collection
|
| 122 |
db.delete_collection()
|
| 123 |
+
|
| 124 |
# Should be empty after recreation
|
| 125 |
db = VectorDatabase(persist_path=temp_dir, collection_name="test_docs")
|
| 126 |
assert db.get_count() == 0
|
| 127 |
|
| 128 |
+
|
| 129 |
def test_persistence():
|
| 130 |
"""Test that data persists across database instances"""
|
| 131 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 132 |
# Create first instance and add data
|
| 133 |
db1 = VectorDatabase(persist_path=temp_dir, collection_name="persistent_test")
|
| 134 |
+
|
| 135 |
embeddings = [[0.1, 0.2, 0.3, 0.4]]
|
| 136 |
chunk_ids = ["persistent_chunk"]
|
| 137 |
documents = ["Persistent document"]
|
| 138 |
metadatas = [{"persistent": True}]
|
| 139 |
+
|
| 140 |
db1.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 141 |
assert db1.get_count() == 1
|
| 142 |
+
|
| 143 |
# Create second instance with same path
|
| 144 |
db2 = VectorDatabase(persist_path=temp_dir, collection_name="persistent_test")
|
| 145 |
+
|
| 146 |
# Should have the same data
|
| 147 |
assert db2.get_count() == 1
|
| 148 |
+
|
| 149 |
# Should be able to search and find the data
|
| 150 |
results = db2.search([0.1, 0.2, 0.3, 0.4], top_k=1)
|
| 151 |
assert len(results) == 1
|
| 152 |
assert results[0]["id"] == "persistent_chunk"
|
| 153 |
|
| 154 |
+
|
| 155 |
def test_error_handling():
|
| 156 |
"""Test error handling for various edge cases"""
|
| 157 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 158 |
db = VectorDatabase(persist_path=temp_dir, collection_name="error_test")
|
| 159 |
+
|
| 160 |
# Test empty search
|
| 161 |
results = db.search([0.1, 0.2, 0.3, 0.4], top_k=5)
|
| 162 |
assert results == []
|
| 163 |
+
|
| 164 |
# Test adding mismatched data
|
| 165 |
with pytest.raises((ValueError, Exception)):
|
| 166 |
db.add_embeddings(
|
| 167 |
embeddings=[[0.1, 0.2]], # 2D
|
| 168 |
chunk_ids=["chunk_1", "chunk_2"], # 2 IDs but 1 embedding
|
| 169 |
documents=["Doc 1"], # 1 document
|
| 170 |
+
metadatas=[{"test": True}], # 1 metadata
|
| 171 |
)
|
| 172 |
|
| 173 |
+
|
| 174 |
def test_batch_operations():
|
| 175 |
"""Test batch operations for performance"""
|
| 176 |
with tempfile.TemporaryDirectory() as temp_dir:
|
| 177 |
db = VectorDatabase(persist_path=temp_dir, collection_name="batch_test")
|
| 178 |
+
|
| 179 |
# Create larger batch for testing
|
| 180 |
batch_size = 50
|
| 181 |
+
embeddings = [
|
| 182 |
+
[float(i), float(i + 1), float(i + 2), float(i + 3)]
|
| 183 |
+
for i in range(batch_size)
|
| 184 |
+
]
|
| 185 |
chunk_ids = [f"chunk_{i}" for i in range(batch_size)]
|
| 186 |
documents = [f"Document {i} content" for i in range(batch_size)]
|
| 187 |
metadatas = [{"batch_index": i, "test_batch": True} for i in range(batch_size)]
|
| 188 |
+
|
| 189 |
# Should handle batch operations
|
| 190 |
result = db.add_embeddings(embeddings, chunk_ids, documents, metadatas)
|
| 191 |
assert result is True
|
| 192 |
assert db.get_count() == batch_size
|
| 193 |
+
|
| 194 |
# Should handle batch search
|
| 195 |
query_embedding = [0.0, 1.0, 2.0, 3.0]
|
| 196 |
results = db.search(query_embedding, top_k=10)
|
| 197 |
+
assert len(results) == 10 # Should return requested number
|