""" Context Manager for RAG Pipeline This module handles context retrieval, formatting, and management for the RAG pipeline, ensuring optimal context window utilization. """ import logging from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @dataclass class ContextConfig: """Configuration for context management.""" max_context_length: int = 3000 # Maximum characters in context max_results: int = 5 # Maximum search results to include min_similarity: float = 0.1 # Minimum similarity threshold overlap_penalty: float = 0.1 # Penalty for overlapping content class ContextManager: """ Manages context retrieval and optimization for RAG pipeline. Handles: - Context length management - Relevance filtering - Duplicate content removal - Source prioritization """ def __init__(self, config: Optional[ContextConfig] = None): """ Initialize ContextManager with configuration. Args: config: Context configuration, uses defaults if None """ self.config = config or ContextConfig() logger.info("ContextManager initialized") def prepare_context(self, search_results: List[Dict[str, Any]], query: str) -> Tuple[str, List[Dict[str, Any]]]: """ Prepare optimized context from search results. Args: search_results: Results from SearchService query: Original user query for context optimization Returns: Tuple of (formatted_context, filtered_results) """ if not search_results: return "No relevant information found.", [] # Filter and rank results filtered_results = self._filter_results(search_results) # Remove duplicates and optimize for context window optimized_results = self._optimize_context(filtered_results) # Format for prompt formatted_context = self._format_context(optimized_results) logger.debug( f"Prepared context from {len(search_results)} results, " f"filtered to {len(optimized_results)} results, " f"{len(formatted_context)} characters" ) return formatted_context, optimized_results def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Filter search results by relevance and quality. Args: results: Raw search results Returns: Filtered and sorted results """ filtered = [] for result in results: similarity = result.get("similarity_score", 0.0) content = result.get("content", "").strip() # Apply filters if similarity >= self.config.min_similarity and content and len(content) > 20: # Minimum content length filtered.append(result) # Sort by similarity score (descending) filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True) # Limit to max results return filtered[: self.config.max_results] def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Optimize context to fit within token limits while maximizing relevance. Args: results: Filtered search results Returns: Optimized results list """ if not results: return [] optimized = [] current_length = 0 seen_content = set() for result in results: content = result.get("content", "").strip() content_length = len(content) # Check if adding this result would exceed limit estimated_formatted_length = current_length + content_length + 100 # Buffer if estimated_formatted_length > self.config.max_context_length: # Try to truncate content remaining_space = self.config.max_context_length - current_length - 100 if remaining_space > 200: # Minimum useful content truncated_content = content[:remaining_space] + "..." result_copy = result.copy() result_copy["content"] = truncated_content optimized.append(result_copy) break # Check for duplicate or highly similar content content_lower = content.lower() is_duplicate = False for seen in seen_content: # Simple similarity check for duplicates if ( len(set(content_lower.split()) & set(seen.split())) / max(len(content_lower.split()), len(seen.split())) > 0.8 ): is_duplicate = True break if not is_duplicate: optimized.append(result) seen_content.add(content_lower) current_length += content_length return optimized def _format_context(self, results: List[Dict[str, Any]]) -> str: """ Format optimized results into context string. Args: results: Optimized search results Returns: Formatted context string """ if not results: return "No relevant information found in corporate policies." context_parts = [] for i, result in enumerate(results, 1): metadata = result.get("metadata", {}) filename = metadata.get("filename", f"document_{i}") content = result.get("content", "").strip() # Format with document info context_parts.append(f"Document: {filename}\n" f"Content: {content}") return "\n\n---\n\n".join(context_parts) def validate_context_quality(self, context: str, query: str, min_quality_score: float = 0.3) -> Dict[str, Any]: """ Validate the quality of prepared context for a given query. Args: context: Formatted context string query: Original user query min_quality_score: Minimum acceptable quality score Returns: Dictionary with quality metrics and validation result """ # Simple quality checks quality_metrics = { "length": len(context), "has_content": bool(context.strip()), "estimated_relevance": 0.0, "passes_validation": False, } if not context.strip(): quality_metrics["passes_validation"] = False return quality_metrics # Estimate relevance based on query term overlap query_terms = set(query.lower().split()) context_terms = set(context.lower().split()) if query_terms and context_terms: overlap = len(query_terms & context_terms) relevance = overlap / len(query_terms) quality_metrics["estimated_relevance"] = relevance quality_metrics["passes_validation"] = relevance >= min_quality_score else: quality_metrics["passes_validation"] = False return quality_metrics def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: """ Generate summary of sources used in context. Args: results: Search results used for context Returns: Summary of sources and their contribution """ sources = {} total_content_length = 0 for result in results: metadata = result.get("metadata", {}) filename = metadata.get("filename", "unknown") content_length = len(result.get("content", "")) similarity = result.get("similarity_score", 0.0) if filename not in sources: sources[filename] = { "chunks": 0, "total_content_length": 0, "max_similarity": 0.0, "avg_similarity": 0.0, } sources[filename]["chunks"] += 1 sources[filename]["total_content_length"] += content_length sources[filename]["max_similarity"] = max(sources[filename]["max_similarity"], similarity) total_content_length += content_length # Calculate averages and percentages for source_info in sources.values(): source_info["content_percentage"] = source_info["total_content_length"] / max(total_content_length, 1) * 100 return { "total_sources": len(sources), "total_chunks": len(results), "total_content_length": total_content_length, "sources": sources, }