Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced RAG Pipeline with Guardrails Integration | |
| This module extends the existing RAG pipeline with comprehensive | |
| guardrails for response quality and safety validation. | |
| """ | |
| import logging | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional | |
| from ..guardrails import GuardrailsResult, GuardrailsSystem | |
| from .rag_pipeline import RAGConfig, RAGPipeline, RAGResponse | |
| logger = logging.getLogger(__name__) | |
| class EnhancedRAGResponse(RAGResponse): | |
| """Enhanced RAG response with guardrails metadata.""" | |
| guardrails_approved: bool = True | |
| guardrails_confidence: float = 1.0 | |
| safety_passed: bool = True | |
| quality_score: float = 1.0 | |
| guardrails_warnings: Optional[List[str]] = None | |
| guardrails_fallbacks: Optional[List[str]] = None | |
| def __post_init__(self): | |
| if self.guardrails_warnings is None: | |
| self.guardrails_warnings = [] | |
| if self.guardrails_fallbacks is None: | |
| self.guardrails_fallbacks = [] | |
| class EnhancedRAGPipeline: | |
| """ | |
| Enhanced RAG pipeline with integrated guardrails system. | |
| Extends the base RAG pipeline with: | |
| - Comprehensive response validation | |
| - Content safety filtering | |
| - Quality scoring and metrics | |
| - Source attribution and citations | |
| - Error handling and fallbacks | |
| """ | |
| def __init__( | |
| self, | |
| base_pipeline: RAGPipeline, | |
| guardrails_config: Optional[Dict[str, Any]] = None, | |
| ): | |
| """ | |
| Initialize enhanced RAG pipeline. | |
| Args: | |
| base_pipeline: Base RAG pipeline instance | |
| guardrails_config: Configuration for guardrails system | |
| """ | |
| self.base_pipeline = base_pipeline | |
| self.guardrails = GuardrailsSystem(guardrails_config) | |
| logger.info("EnhancedRAGPipeline initialized with guardrails") | |
| def generate_answer(self, question: str) -> EnhancedRAGResponse: | |
| """ | |
| Generate answer with comprehensive guardrails validation. | |
| Args: | |
| question: User's question about corporate policies | |
| Returns: | |
| EnhancedRAGResponse with validation and safety checks | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Step 1: Generate initial response using base pipeline | |
| base_response = self.base_pipeline.generate_answer(question) | |
| if not base_response.success: | |
| return self._create_enhanced_response_from_base(base_response) | |
| # Step 2: Apply comprehensive guardrails validation | |
| guardrails_result = self.guardrails.validate_response( | |
| response=base_response.answer, | |
| query=question, | |
| sources=base_response.sources, | |
| context=None, # Could be enhanced with additional context | |
| ) | |
| # Step 3: Create enhanced response based on guardrails result | |
| if guardrails_result.is_approved: | |
| # Use enhanced response with improved citations | |
| enhanced_answer = guardrails_result.enhanced_response | |
| # Update confidence based on guardrails assessment | |
| enhanced_confidence = (base_response.confidence + guardrails_result.confidence_score) / 2 | |
| return EnhancedRAGResponse( | |
| answer=enhanced_answer, | |
| sources=base_response.sources, | |
| confidence=enhanced_confidence, | |
| processing_time=time.time() - start_time, | |
| llm_provider=base_response.llm_provider, | |
| llm_model=base_response.llm_model, | |
| context_length=base_response.context_length, | |
| search_results_count=base_response.search_results_count, | |
| success=True, | |
| error_message=None, | |
| # Guardrails metadata | |
| guardrails_approved=True, | |
| guardrails_confidence=guardrails_result.confidence_score, | |
| safety_passed=guardrails_result.safety_result.is_safe, | |
| quality_score=guardrails_result.quality_score.overall_score, | |
| guardrails_warnings=guardrails_result.warnings, | |
| guardrails_fallbacks=guardrails_result.fallbacks_applied, | |
| ) | |
| else: | |
| # Response was rejected by guardrails | |
| rejection_reason = self._format_rejection_reason(guardrails_result) | |
| return EnhancedRAGResponse( | |
| answer=rejection_reason, | |
| sources=[], | |
| confidence=0.0, | |
| processing_time=time.time() - start_time, | |
| llm_provider=base_response.llm_provider, | |
| llm_model=base_response.llm_model, | |
| context_length=0, | |
| search_results_count=0, | |
| success=False, | |
| error_message="Response rejected by guardrails", | |
| # Guardrails metadata | |
| guardrails_approved=False, | |
| guardrails_confidence=guardrails_result.confidence_score, | |
| safety_passed=guardrails_result.safety_result.is_safe, | |
| quality_score=guardrails_result.quality_score.overall_score, | |
| guardrails_warnings=guardrails_result.warnings + [f"Rejected: {rejection_reason}"], | |
| guardrails_fallbacks=guardrails_result.fallbacks_applied, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Enhanced RAG pipeline error: {e}") | |
| # Fallback to base pipeline response if available | |
| try: | |
| base_response = self.base_pipeline.generate_answer(question) | |
| if base_response.success: | |
| # Create enhanced response with error warning | |
| enhanced = self._create_enhanced_response_from_base(base_response) | |
| enhanced.error_message = f"Guardrails validation failed: {str(e)}" | |
| if enhanced.guardrails_warnings is not None: | |
| enhanced.guardrails_warnings.append("Guardrails validation failed") | |
| return enhanced | |
| except Exception: | |
| pass | |
| # Final fallback | |
| return EnhancedRAGResponse( | |
| answer=( | |
| "I apologize, but I encountered an error processing your question. " | |
| "Please try again or contact support if the issue persists." | |
| ), | |
| sources=[], | |
| confidence=0.0, | |
| processing_time=time.time() - start_time, | |
| llm_provider="error", | |
| llm_model="error", | |
| context_length=0, | |
| search_results_count=0, | |
| success=False, | |
| error_message=f"Enhanced pipeline error: {str(e)}", | |
| guardrails_approved=False, | |
| guardrails_confidence=0.0, | |
| safety_passed=False, | |
| quality_score=0.0, | |
| guardrails_warnings=[f"Pipeline error: {str(e)}"], | |
| ) | |
| def _create_enhanced_response_from_base(self, base_response: RAGResponse) -> EnhancedRAGResponse: | |
| """Create enhanced response from base response.""" | |
| return EnhancedRAGResponse( | |
| answer=base_response.answer, | |
| sources=base_response.sources, | |
| confidence=base_response.confidence, | |
| processing_time=base_response.processing_time, | |
| llm_provider=base_response.llm_provider, | |
| llm_model=base_response.llm_model, | |
| context_length=base_response.context_length, | |
| search_results_count=base_response.search_results_count, | |
| success=base_response.success, | |
| error_message=base_response.error_message, | |
| # Default guardrails values (bypassed) | |
| guardrails_approved=True, | |
| guardrails_confidence=0.5, | |
| safety_passed=True, | |
| quality_score=0.5, | |
| guardrails_warnings=["Guardrails bypassed due to base pipeline issue"], | |
| guardrails_fallbacks=["base_pipeline_fallback"], | |
| ) | |
| def _format_rejection_reason(self, guardrails_result: GuardrailsResult) -> str: | |
| """Format user-friendly rejection reason.""" | |
| if not guardrails_result.safety_result.is_safe: | |
| return ( | |
| "I cannot provide this response due to safety concerns. " | |
| "Please rephrase your question or contact HR for assistance." | |
| ) | |
| if guardrails_result.quality_score.overall_score < 0.5: | |
| low_quality_msg = ( | |
| "I couldn't generate a sufficiently detailed response to your " | |
| "question. Please try rephrasing your question or contact HR " | |
| "for more specific guidance." | |
| ) | |
| return low_quality_msg | |
| if not guardrails_result.citations: | |
| return ( | |
| "I couldn't find adequate source documentation to support a response. " | |
| "Please contact HR or check our policy documentation directly." | |
| ) | |
| return ( | |
| "I couldn't provide a complete response to your question. " | |
| "Please contact HR for assistance or try rephrasing your question." | |
| ) | |
| def get_health_status(self) -> Dict[str, Any]: | |
| """Get health status of enhanced pipeline.""" | |
| base_health = { | |
| "base_pipeline": "healthy", # Assume healthy for now | |
| "llm_service": "healthy", | |
| "search_service": "healthy", | |
| } | |
| guardrails_health = self.guardrails.get_system_health() | |
| overall_status = "healthy" if guardrails_health["status"] == "healthy" else "degraded" | |
| return { | |
| "status": overall_status, | |
| "base_pipeline": base_health, | |
| "guardrails": guardrails_health, | |
| } | |
| def config(self) -> RAGConfig: | |
| """Access base pipeline configuration.""" | |
| return self.base_pipeline.config | |
| def validate_response_only(self, response: str, query: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """ | |
| Validate a response using only guardrails (without generating). | |
| Useful for testing and external validation. | |
| """ | |
| guardrails_result = self.guardrails.validate_response(response=response, query=query, sources=sources) | |
| return { | |
| "approved": guardrails_result.is_approved, | |
| "confidence": guardrails_result.confidence_score, | |
| "safety_result": { | |
| "is_safe": guardrails_result.safety_result.is_safe, | |
| "risk_level": guardrails_result.safety_result.risk_level, | |
| "issues": guardrails_result.safety_result.issues_found, | |
| }, | |
| "quality_score": { | |
| "overall": guardrails_result.quality_score.overall_score, | |
| "relevance": guardrails_result.quality_score.relevance_score, | |
| "completeness": guardrails_result.quality_score.completeness_score, | |
| "coherence": guardrails_result.quality_score.coherence_score, | |
| "source_fidelity": (guardrails_result.quality_score.source_fidelity_score), | |
| }, | |
| "citations": [ | |
| { | |
| "document": citation.document, | |
| "confidence": citation.confidence, | |
| "excerpt": citation.excerpt, | |
| } | |
| for citation in guardrails_result.citations | |
| ], | |
| "recommendations": guardrails_result.recommendations, | |
| "warnings": guardrails_result.warnings, | |
| "processing_time": guardrails_result.processing_time, | |
| } | |