File size: 11,863 Bytes
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a52e676
 
 
 
135f0d6
a52e676
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
 
159faf0
135f0d6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
"""
Enhanced RAG Pipeline with Guardrails Integration

This module extends the existing RAG pipeline with comprehensive
guardrails for response quality and safety validation.
"""

import logging
import time
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from ..guardrails import GuardrailsResult, GuardrailsSystem
from .rag_pipeline import RAGConfig, RAGPipeline, RAGResponse

logger = logging.getLogger(__name__)


@dataclass
class EnhancedRAGResponse(RAGResponse):
    """Enhanced RAG response with guardrails metadata."""

    guardrails_approved: bool = True
    guardrails_confidence: float = 1.0
    safety_passed: bool = True
    quality_score: float = 1.0
    guardrails_warnings: Optional[List[str]] = None
    guardrails_fallbacks: Optional[List[str]] = None

    def __post_init__(self):
        if self.guardrails_warnings is None:
            self.guardrails_warnings = []
        if self.guardrails_fallbacks is None:
            self.guardrails_fallbacks = []


class EnhancedRAGPipeline:
    """
    Enhanced RAG pipeline with integrated guardrails system.

    Extends the base RAG pipeline with:
    - Comprehensive response validation
    - Content safety filtering
    - Quality scoring and metrics
    - Source attribution and citations
    - Error handling and fallbacks
    """

    def __init__(
        self,
        base_pipeline: RAGPipeline,
        guardrails_config: Optional[Dict[str, Any]] = None,
    ):
        """
        Initialize enhanced RAG pipeline.

        Args:
            base_pipeline: Base RAG pipeline instance
            guardrails_config: Configuration for guardrails system
        """
        self.base_pipeline = base_pipeline
        self.guardrails = GuardrailsSystem(guardrails_config)

        logger.info("EnhancedRAGPipeline initialized with guardrails")

    def generate_answer(self, question: str) -> EnhancedRAGResponse:
        """
        Generate answer with comprehensive guardrails validation.

        Args:
            question: User's question about corporate policies

        Returns:
            EnhancedRAGResponse with validation and safety checks
        """
        start_time = time.time()

        try:
            # Step 1: Generate initial response using base pipeline
            base_response = self.base_pipeline.generate_answer(question)

            if not base_response.success:
                return self._create_enhanced_response_from_base(base_response)

            # Step 2: Apply comprehensive guardrails validation
            guardrails_result = self.guardrails.validate_response(
                response=base_response.answer,
                query=question,
                sources=base_response.sources,
                context=None,  # Could be enhanced with additional context
            )

            # Step 3: Create enhanced response based on guardrails result
            if guardrails_result.is_approved:
                # Use enhanced response with improved citations
                enhanced_answer = guardrails_result.enhanced_response

                # Update confidence based on guardrails assessment
                enhanced_confidence = (base_response.confidence + guardrails_result.confidence_score) / 2

                return EnhancedRAGResponse(
                    answer=enhanced_answer,
                    sources=base_response.sources,
                    confidence=enhanced_confidence,
                    processing_time=time.time() - start_time,
                    llm_provider=base_response.llm_provider,
                    llm_model=base_response.llm_model,
                    context_length=base_response.context_length,
                    search_results_count=base_response.search_results_count,
                    success=True,
                    error_message=None,
                    # Guardrails metadata
                    guardrails_approved=True,
                    guardrails_confidence=guardrails_result.confidence_score,
                    safety_passed=guardrails_result.safety_result.is_safe,
                    quality_score=guardrails_result.quality_score.overall_score,
                    guardrails_warnings=guardrails_result.warnings,
                    guardrails_fallbacks=guardrails_result.fallbacks_applied,
                )
            else:
                # Response was rejected by guardrails
                rejection_reason = self._format_rejection_reason(guardrails_result)

                return EnhancedRAGResponse(
                    answer=rejection_reason,
                    sources=[],
                    confidence=0.0,
                    processing_time=time.time() - start_time,
                    llm_provider=base_response.llm_provider,
                    llm_model=base_response.llm_model,
                    context_length=0,
                    search_results_count=0,
                    success=False,
                    error_message="Response rejected by guardrails",
                    # Guardrails metadata
                    guardrails_approved=False,
                    guardrails_confidence=guardrails_result.confidence_score,
                    safety_passed=guardrails_result.safety_result.is_safe,
                    quality_score=guardrails_result.quality_score.overall_score,
                    guardrails_warnings=guardrails_result.warnings + [f"Rejected: {rejection_reason}"],
                    guardrails_fallbacks=guardrails_result.fallbacks_applied,
                )

        except Exception as e:
            logger.error(f"Enhanced RAG pipeline error: {e}")

            # Fallback to base pipeline response if available
            try:
                base_response = self.base_pipeline.generate_answer(question)
                if base_response.success:
                    # Create enhanced response with error warning
                    enhanced = self._create_enhanced_response_from_base(base_response)
                    enhanced.error_message = f"Guardrails validation failed: {str(e)}"
                    if enhanced.guardrails_warnings is not None:
                        enhanced.guardrails_warnings.append("Guardrails validation failed")
                    return enhanced
            except Exception:
                pass

            # Final fallback
            return EnhancedRAGResponse(
                answer=(
                    "I apologize, but I encountered an error processing your question. "
                    "Please try again or contact support if the issue persists."
                ),
                sources=[],
                confidence=0.0,
                processing_time=time.time() - start_time,
                llm_provider="error",
                llm_model="error",
                context_length=0,
                search_results_count=0,
                success=False,
                error_message=f"Enhanced pipeline error: {str(e)}",
                guardrails_approved=False,
                guardrails_confidence=0.0,
                safety_passed=False,
                quality_score=0.0,
                guardrails_warnings=[f"Pipeline error: {str(e)}"],
            )

    def _create_enhanced_response_from_base(self, base_response: RAGResponse) -> EnhancedRAGResponse:
        """Create enhanced response from base response."""
        return EnhancedRAGResponse(
            answer=base_response.answer,
            sources=base_response.sources,
            confidence=base_response.confidence,
            processing_time=base_response.processing_time,
            llm_provider=base_response.llm_provider,
            llm_model=base_response.llm_model,
            context_length=base_response.context_length,
            search_results_count=base_response.search_results_count,
            success=base_response.success,
            error_message=base_response.error_message,
            # Default guardrails values (bypassed)
            guardrails_approved=True,
            guardrails_confidence=0.5,
            safety_passed=True,
            quality_score=0.5,
            guardrails_warnings=["Guardrails bypassed due to base pipeline issue"],
            guardrails_fallbacks=["base_pipeline_fallback"],
        )

    def _format_rejection_reason(self, guardrails_result: GuardrailsResult) -> str:
        """Format user-friendly rejection reason."""
        if not guardrails_result.safety_result.is_safe:
            return (
                "I cannot provide this response due to safety concerns. "
                "Please rephrase your question or contact HR for assistance."
            )

        if guardrails_result.quality_score.overall_score < 0.5:
            low_quality_msg = (
                "I couldn't generate a sufficiently detailed response to your "
                "question. Please try rephrasing your question or contact HR "
                "for more specific guidance."
            )
            return low_quality_msg

        if not guardrails_result.citations:
            return (
                "I couldn't find adequate source documentation to support a response. "
                "Please contact HR or check our policy documentation directly."
            )

        return (
            "I couldn't provide a complete response to your question. "
            "Please contact HR for assistance or try rephrasing your question."
        )

    def get_health_status(self) -> Dict[str, Any]:
        """Get health status of enhanced pipeline."""
        base_health = {
            "base_pipeline": "healthy",  # Assume healthy for now
            "llm_service": "healthy",
            "search_service": "healthy",
        }

        guardrails_health = self.guardrails.get_system_health()

        overall_status = "healthy" if guardrails_health["status"] == "healthy" else "degraded"

        return {
            "status": overall_status,
            "base_pipeline": base_health,
            "guardrails": guardrails_health,
        }

    @property
    def config(self) -> RAGConfig:
        """Access base pipeline configuration."""
        return self.base_pipeline.config

    def validate_response_only(self, response: str, query: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        Validate a response using only guardrails (without generating).

        Useful for testing and external validation.
        """
        guardrails_result = self.guardrails.validate_response(response=response, query=query, sources=sources)

        return {
            "approved": guardrails_result.is_approved,
            "confidence": guardrails_result.confidence_score,
            "safety_result": {
                "is_safe": guardrails_result.safety_result.is_safe,
                "risk_level": guardrails_result.safety_result.risk_level,
                "issues": guardrails_result.safety_result.issues_found,
            },
            "quality_score": {
                "overall": guardrails_result.quality_score.overall_score,
                "relevance": guardrails_result.quality_score.relevance_score,
                "completeness": guardrails_result.quality_score.completeness_score,
                "coherence": guardrails_result.quality_score.coherence_score,
                "source_fidelity": (guardrails_result.quality_score.source_fidelity_score),
            },
            "citations": [
                {
                    "document": citation.document,
                    "confidence": citation.confidence,
                    "excerpt": citation.excerpt,
                }
                for citation in guardrails_result.citations
            ],
            "recommendations": guardrails_result.recommendations,
            "warnings": guardrails_result.warnings,
            "processing_time": guardrails_result.processing_time,
        }