Spaces:
Sleeping
Sleeping
| """ | |
| Context Manager for RAG Pipeline | |
| This module handles context retrieval, formatting, and management | |
| for the RAG pipeline, ensuring optimal context window utilization. | |
| """ | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Tuple | |
| logger = logging.getLogger(__name__) | |
| class ContextConfig: | |
| """Configuration for context management.""" | |
| max_context_length: int = 3000 # Maximum characters in context | |
| max_results: int = 5 # Maximum search results to include | |
| min_similarity: float = 0.1 # Minimum similarity threshold | |
| overlap_penalty: float = 0.1 # Penalty for overlapping content | |
| class ContextManager: | |
| """ | |
| Manages context retrieval and optimization for RAG pipeline. | |
| Handles: | |
| - Context length management | |
| - Relevance filtering | |
| - Duplicate content removal | |
| - Source prioritization | |
| """ | |
| def __init__(self, config: Optional[ContextConfig] = None): | |
| """ | |
| Initialize ContextManager with configuration. | |
| Args: | |
| config: Context configuration, uses defaults if None | |
| """ | |
| self.config = config or ContextConfig() | |
| logger.info("ContextManager initialized") | |
| def prepare_context(self, search_results: List[Dict[str, Any]], query: str) -> Tuple[str, List[Dict[str, Any]]]: | |
| """ | |
| Prepare optimized context from search results. | |
| Args: | |
| search_results: Results from SearchService | |
| query: Original user query for context optimization | |
| Returns: | |
| Tuple of (formatted_context, filtered_results) | |
| """ | |
| if not search_results: | |
| return "No relevant information found.", [] | |
| # Filter and rank results | |
| filtered_results = self._filter_results(search_results) | |
| # Remove duplicates and optimize for context window | |
| optimized_results = self._optimize_context(filtered_results) | |
| # Format for prompt | |
| formatted_context = self._format_context(optimized_results) | |
| logger.debug( | |
| f"Prepared context from {len(search_results)} results, " | |
| f"filtered to {len(optimized_results)} results, " | |
| f"{len(formatted_context)} characters" | |
| ) | |
| return formatted_context, optimized_results | |
| def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Filter search results by relevance and quality. | |
| Args: | |
| results: Raw search results | |
| Returns: | |
| Filtered and sorted results | |
| """ | |
| filtered = [] | |
| for result in results: | |
| similarity = result.get("similarity_score", 0.0) | |
| content = result.get("content", "").strip() | |
| # Apply filters | |
| if similarity >= self.config.min_similarity and content and len(content) > 20: # Minimum content length | |
| filtered.append(result) | |
| # Sort by similarity score (descending) | |
| filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True) | |
| # Limit to max results | |
| return filtered[: self.config.max_results] | |
| def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| """ | |
| Optimize context to fit within token limits while maximizing relevance. | |
| Args: | |
| results: Filtered search results | |
| Returns: | |
| Optimized results list | |
| """ | |
| if not results: | |
| return [] | |
| optimized = [] | |
| current_length = 0 | |
| seen_content = set() | |
| for result in results: | |
| content = result.get("content", "").strip() | |
| content_length = len(content) | |
| # Check if adding this result would exceed limit | |
| estimated_formatted_length = current_length + content_length + 100 # Buffer | |
| if estimated_formatted_length > self.config.max_context_length: | |
| # Try to truncate content | |
| remaining_space = self.config.max_context_length - current_length - 100 | |
| if remaining_space > 200: # Minimum useful content | |
| truncated_content = content[:remaining_space] + "..." | |
| result_copy = result.copy() | |
| result_copy["content"] = truncated_content | |
| optimized.append(result_copy) | |
| break | |
| # Check for duplicate or highly similar content | |
| content_lower = content.lower() | |
| is_duplicate = False | |
| for seen in seen_content: | |
| # Simple similarity check for duplicates | |
| if ( | |
| len(set(content_lower.split()) & set(seen.split())) | |
| / max(len(content_lower.split()), len(seen.split())) | |
| > 0.8 | |
| ): | |
| is_duplicate = True | |
| break | |
| if not is_duplicate: | |
| optimized.append(result) | |
| seen_content.add(content_lower) | |
| current_length += content_length | |
| return optimized | |
| def _format_context(self, results: List[Dict[str, Any]]) -> str: | |
| """ | |
| Format optimized results into context string. | |
| Args: | |
| results: Optimized search results | |
| Returns: | |
| Formatted context string | |
| """ | |
| if not results: | |
| return "No relevant information found in corporate policies." | |
| context_parts = [] | |
| for i, result in enumerate(results, 1): | |
| metadata = result.get("metadata", {}) | |
| filename = metadata.get("filename", f"document_{i}") | |
| content = result.get("content", "").strip() | |
| # Format with document info | |
| context_parts.append(f"Document: {filename}\n" f"Content: {content}") | |
| return "\n\n---\n\n".join(context_parts) | |
| def validate_context_quality(self, context: str, query: str, min_quality_score: float = 0.3) -> Dict[str, Any]: | |
| """ | |
| Validate the quality of prepared context for a given query. | |
| Args: | |
| context: Formatted context string | |
| query: Original user query | |
| min_quality_score: Minimum acceptable quality score | |
| Returns: | |
| Dictionary with quality metrics and validation result | |
| """ | |
| # Simple quality checks | |
| quality_metrics = { | |
| "length": len(context), | |
| "has_content": bool(context.strip()), | |
| "estimated_relevance": 0.0, | |
| "passes_validation": False, | |
| } | |
| if not context.strip(): | |
| quality_metrics["passes_validation"] = False | |
| return quality_metrics | |
| # Estimate relevance based on query term overlap | |
| query_terms = set(query.lower().split()) | |
| context_terms = set(context.lower().split()) | |
| if query_terms and context_terms: | |
| overlap = len(query_terms & context_terms) | |
| relevance = overlap / len(query_terms) | |
| quality_metrics["estimated_relevance"] = relevance | |
| quality_metrics["passes_validation"] = relevance >= min_quality_score | |
| else: | |
| quality_metrics["passes_validation"] = False | |
| return quality_metrics | |
| def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Generate summary of sources used in context. | |
| Args: | |
| results: Search results used for context | |
| Returns: | |
| Summary of sources and their contribution | |
| """ | |
| sources = {} | |
| total_content_length = 0 | |
| for result in results: | |
| metadata = result.get("metadata", {}) | |
| filename = metadata.get("filename", "unknown") | |
| content_length = len(result.get("content", "")) | |
| similarity = result.get("similarity_score", 0.0) | |
| if filename not in sources: | |
| sources[filename] = { | |
| "chunks": 0, | |
| "total_content_length": 0, | |
| "max_similarity": 0.0, | |
| "avg_similarity": 0.0, | |
| } | |
| sources[filename]["chunks"] += 1 | |
| sources[filename]["total_content_length"] += content_length | |
| sources[filename]["max_similarity"] = max(sources[filename]["max_similarity"], similarity) | |
| total_content_length += content_length | |
| # Calculate averages and percentages | |
| for source_info in sources.values(): | |
| source_info["content_percentage"] = source_info["total_content_length"] / max(total_content_length, 1) * 100 | |
| return { | |
| "total_sources": len(sources), | |
| "total_chunks": len(results), | |
| "total_content_length": total_content_length, | |
| "sources": sources, | |
| } | |