Spaces:
Sleeping
Sleeping
File size: 8,837 Bytes
c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 159faf0 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 159faf0 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 159faf0 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 508a7e5 c280a92 159faf0 508a7e5 c280a92 159faf0 c280a92 508a7e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
"""
Context Manager for RAG Pipeline
This module handles context retrieval, formatting, and management
for the RAG pipeline, ensuring optimal context window utilization.
"""
import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class ContextConfig:
"""Configuration for context management."""
max_context_length: int = 3000 # Maximum characters in context
max_results: int = 5 # Maximum search results to include
min_similarity: float = 0.1 # Minimum similarity threshold
overlap_penalty: float = 0.1 # Penalty for overlapping content
class ContextManager:
"""
Manages context retrieval and optimization for RAG pipeline.
Handles:
- Context length management
- Relevance filtering
- Duplicate content removal
- Source prioritization
"""
def __init__(self, config: Optional[ContextConfig] = None):
"""
Initialize ContextManager with configuration.
Args:
config: Context configuration, uses defaults if None
"""
self.config = config or ContextConfig()
logger.info("ContextManager initialized")
def prepare_context(self, search_results: List[Dict[str, Any]], query: str) -> Tuple[str, List[Dict[str, Any]]]:
"""
Prepare optimized context from search results.
Args:
search_results: Results from SearchService
query: Original user query for context optimization
Returns:
Tuple of (formatted_context, filtered_results)
"""
if not search_results:
return "No relevant information found.", []
# Filter and rank results
filtered_results = self._filter_results(search_results)
# Remove duplicates and optimize for context window
optimized_results = self._optimize_context(filtered_results)
# Format for prompt
formatted_context = self._format_context(optimized_results)
logger.debug(
f"Prepared context from {len(search_results)} results, "
f"filtered to {len(optimized_results)} results, "
f"{len(formatted_context)} characters"
)
return formatted_context, optimized_results
def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Filter search results by relevance and quality.
Args:
results: Raw search results
Returns:
Filtered and sorted results
"""
filtered = []
for result in results:
similarity = result.get("similarity_score", 0.0)
content = result.get("content", "").strip()
# Apply filters
if similarity >= self.config.min_similarity and content and len(content) > 20: # Minimum content length
filtered.append(result)
# Sort by similarity score (descending)
filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
# Limit to max results
return filtered[: self.config.max_results]
def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Optimize context to fit within token limits while maximizing relevance.
Args:
results: Filtered search results
Returns:
Optimized results list
"""
if not results:
return []
optimized = []
current_length = 0
seen_content = set()
for result in results:
content = result.get("content", "").strip()
content_length = len(content)
# Check if adding this result would exceed limit
estimated_formatted_length = current_length + content_length + 100 # Buffer
if estimated_formatted_length > self.config.max_context_length:
# Try to truncate content
remaining_space = self.config.max_context_length - current_length - 100
if remaining_space > 200: # Minimum useful content
truncated_content = content[:remaining_space] + "..."
result_copy = result.copy()
result_copy["content"] = truncated_content
optimized.append(result_copy)
break
# Check for duplicate or highly similar content
content_lower = content.lower()
is_duplicate = False
for seen in seen_content:
# Simple similarity check for duplicates
if (
len(set(content_lower.split()) & set(seen.split()))
/ max(len(content_lower.split()), len(seen.split()))
> 0.8
):
is_duplicate = True
break
if not is_duplicate:
optimized.append(result)
seen_content.add(content_lower)
current_length += content_length
return optimized
def _format_context(self, results: List[Dict[str, Any]]) -> str:
"""
Format optimized results into context string.
Args:
results: Optimized search results
Returns:
Formatted context string
"""
if not results:
return "No relevant information found in corporate policies."
context_parts = []
for i, result in enumerate(results, 1):
metadata = result.get("metadata", {})
filename = metadata.get("filename", f"document_{i}")
content = result.get("content", "").strip()
# Format with document info
context_parts.append(f"Document: {filename}\n" f"Content: {content}")
return "\n\n---\n\n".join(context_parts)
def validate_context_quality(self, context: str, query: str, min_quality_score: float = 0.3) -> Dict[str, Any]:
"""
Validate the quality of prepared context for a given query.
Args:
context: Formatted context string
query: Original user query
min_quality_score: Minimum acceptable quality score
Returns:
Dictionary with quality metrics and validation result
"""
# Simple quality checks
quality_metrics = {
"length": len(context),
"has_content": bool(context.strip()),
"estimated_relevance": 0.0,
"passes_validation": False,
}
if not context.strip():
quality_metrics["passes_validation"] = False
return quality_metrics
# Estimate relevance based on query term overlap
query_terms = set(query.lower().split())
context_terms = set(context.lower().split())
if query_terms and context_terms:
overlap = len(query_terms & context_terms)
relevance = overlap / len(query_terms)
quality_metrics["estimated_relevance"] = relevance
quality_metrics["passes_validation"] = relevance >= min_quality_score
else:
quality_metrics["passes_validation"] = False
return quality_metrics
def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Generate summary of sources used in context.
Args:
results: Search results used for context
Returns:
Summary of sources and their contribution
"""
sources = {}
total_content_length = 0
for result in results:
metadata = result.get("metadata", {})
filename = metadata.get("filename", "unknown")
content_length = len(result.get("content", ""))
similarity = result.get("similarity_score", 0.0)
if filename not in sources:
sources[filename] = {
"chunks": 0,
"total_content_length": 0,
"max_similarity": 0.0,
"avg_similarity": 0.0,
}
sources[filename]["chunks"] += 1
sources[filename]["total_content_length"] += content_length
sources[filename]["max_similarity"] = max(sources[filename]["max_similarity"], similarity)
total_content_length += content_length
# Calculate averages and percentages
for source_info in sources.values():
source_info["content_percentage"] = source_info["total_content_length"] / max(total_content_length, 1) * 100
return {
"total_sources": len(sources),
"total_chunks": len(results),
"total_content_length": total_content_length,
"sources": sources,
}
|