File size: 8,837 Bytes
c280a92
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
159faf0
c280a92
 
508a7e5
c280a92
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
508a7e5
c280a92
 
508a7e5
c280a92
 
508a7e5
c280a92
 
 
 
 
508a7e5
c280a92
 
 
 
 
508a7e5
c280a92
 
508a7e5
c280a92
 
 
 
508a7e5
c280a92
 
 
508a7e5
c280a92
159faf0
c280a92
508a7e5
c280a92
 
508a7e5
c280a92
508a7e5
c280a92
 
 
 
508a7e5
c280a92
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
508a7e5
c280a92
 
508a7e5
 
 
 
 
c280a92
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
508a7e5
c280a92
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
508a7e5
c280a92
508a7e5
c280a92
 
 
159faf0
c280a92
 
508a7e5
c280a92
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
 
 
 
 
 
 
 
 
 
 
 
508a7e5
c280a92
 
508a7e5
c280a92
 
 
 
 
508a7e5
c280a92
 
 
 
 
508a7e5
c280a92
 
 
 
 
508a7e5
c280a92
508a7e5
c280a92
 
159faf0
508a7e5
c280a92
 
 
 
159faf0
c280a92
 
 
 
 
508a7e5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""
Context Manager for RAG Pipeline

This module handles context retrieval, formatting, and management
for the RAG pipeline, ensuring optimal context window utilization.
"""

import logging
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)


@dataclass
class ContextConfig:
    """Configuration for context management."""

    max_context_length: int = 3000  # Maximum characters in context
    max_results: int = 5  # Maximum search results to include
    min_similarity: float = 0.1  # Minimum similarity threshold
    overlap_penalty: float = 0.1  # Penalty for overlapping content


class ContextManager:
    """
    Manages context retrieval and optimization for RAG pipeline.

    Handles:
    - Context length management
    - Relevance filtering
    - Duplicate content removal
    - Source prioritization
    """

    def __init__(self, config: Optional[ContextConfig] = None):
        """
        Initialize ContextManager with configuration.

        Args:
            config: Context configuration, uses defaults if None
        """
        self.config = config or ContextConfig()
        logger.info("ContextManager initialized")

    def prepare_context(self, search_results: List[Dict[str, Any]], query: str) -> Tuple[str, List[Dict[str, Any]]]:
        """
        Prepare optimized context from search results.

        Args:
            search_results: Results from SearchService
            query: Original user query for context optimization

        Returns:
            Tuple of (formatted_context, filtered_results)
        """
        if not search_results:
            return "No relevant information found.", []

        # Filter and rank results
        filtered_results = self._filter_results(search_results)

        # Remove duplicates and optimize for context window
        optimized_results = self._optimize_context(filtered_results)

        # Format for prompt
        formatted_context = self._format_context(optimized_results)

        logger.debug(
            f"Prepared context from {len(search_results)} results, "
            f"filtered to {len(optimized_results)} results, "
            f"{len(formatted_context)} characters"
        )

        return formatted_context, optimized_results

    def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Filter search results by relevance and quality.

        Args:
            results: Raw search results

        Returns:
            Filtered and sorted results
        """
        filtered = []

        for result in results:
            similarity = result.get("similarity_score", 0.0)
            content = result.get("content", "").strip()

            # Apply filters
            if similarity >= self.config.min_similarity and content and len(content) > 20:  # Minimum content length
                filtered.append(result)

        # Sort by similarity score (descending)
        filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)

        # Limit to max results
        return filtered[: self.config.max_results]

    def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Optimize context to fit within token limits while maximizing relevance.

        Args:
            results: Filtered search results

        Returns:
            Optimized results list
        """
        if not results:
            return []

        optimized = []
        current_length = 0
        seen_content = set()

        for result in results:
            content = result.get("content", "").strip()
            content_length = len(content)

            # Check if adding this result would exceed limit
            estimated_formatted_length = current_length + content_length + 100  # Buffer
            if estimated_formatted_length > self.config.max_context_length:
                # Try to truncate content
                remaining_space = self.config.max_context_length - current_length - 100
                if remaining_space > 200:  # Minimum useful content
                    truncated_content = content[:remaining_space] + "..."
                    result_copy = result.copy()
                    result_copy["content"] = truncated_content
                    optimized.append(result_copy)
                break

            # Check for duplicate or highly similar content
            content_lower = content.lower()
            is_duplicate = False

            for seen in seen_content:
                # Simple similarity check for duplicates
                if (
                    len(set(content_lower.split()) & set(seen.split()))
                    / max(len(content_lower.split()), len(seen.split()))
                    > 0.8
                ):
                    is_duplicate = True
                    break

            if not is_duplicate:
                optimized.append(result)
                seen_content.add(content_lower)
                current_length += content_length

        return optimized

    def _format_context(self, results: List[Dict[str, Any]]) -> str:
        """
        Format optimized results into context string.

        Args:
            results: Optimized search results

        Returns:
            Formatted context string
        """
        if not results:
            return "No relevant information found in corporate policies."

        context_parts = []

        for i, result in enumerate(results, 1):
            metadata = result.get("metadata", {})
            filename = metadata.get("filename", f"document_{i}")
            content = result.get("content", "").strip()

            # Format with document info
            context_parts.append(f"Document: {filename}\n" f"Content: {content}")

        return "\n\n---\n\n".join(context_parts)

    def validate_context_quality(self, context: str, query: str, min_quality_score: float = 0.3) -> Dict[str, Any]:
        """
        Validate the quality of prepared context for a given query.

        Args:
            context: Formatted context string
            query: Original user query
            min_quality_score: Minimum acceptable quality score

        Returns:
            Dictionary with quality metrics and validation result
        """
        # Simple quality checks
        quality_metrics = {
            "length": len(context),
            "has_content": bool(context.strip()),
            "estimated_relevance": 0.0,
            "passes_validation": False,
        }

        if not context.strip():
            quality_metrics["passes_validation"] = False
            return quality_metrics

        # Estimate relevance based on query term overlap
        query_terms = set(query.lower().split())
        context_terms = set(context.lower().split())

        if query_terms and context_terms:
            overlap = len(query_terms & context_terms)
            relevance = overlap / len(query_terms)
            quality_metrics["estimated_relevance"] = relevance
            quality_metrics["passes_validation"] = relevance >= min_quality_score
        else:
            quality_metrics["passes_validation"] = False

        return quality_metrics

    def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Generate summary of sources used in context.

        Args:
            results: Search results used for context

        Returns:
            Summary of sources and their contribution
        """
        sources = {}
        total_content_length = 0

        for result in results:
            metadata = result.get("metadata", {})
            filename = metadata.get("filename", "unknown")
            content_length = len(result.get("content", ""))
            similarity = result.get("similarity_score", 0.0)

            if filename not in sources:
                sources[filename] = {
                    "chunks": 0,
                    "total_content_length": 0,
                    "max_similarity": 0.0,
                    "avg_similarity": 0.0,
                }

            sources[filename]["chunks"] += 1
            sources[filename]["total_content_length"] += content_length
            sources[filename]["max_similarity"] = max(sources[filename]["max_similarity"], similarity)

            total_content_length += content_length

        # Calculate averages and percentages
        for source_info in sources.values():
            source_info["content_percentage"] = source_info["total_content_length"] / max(total_content_length, 1) * 100

        return {
            "total_sources": len(sources),
            "total_chunks": len(results),
            "total_content_length": total_content_length,
            "sources": sources,
        }