Spaces:
Sleeping
Sleeping
| """ | |
| Response Validator - Core response quality and safety validation | |
| This module provides comprehensive validation of RAG responses including | |
| quality metrics, safety checks, and content validation. | |
| """ | |
| import logging | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, List, Optional, Pattern | |
| logger = logging.getLogger(__name__) | |
| class ValidationResult: | |
| """Result of response validation with detailed metrics.""" | |
| is_valid: bool | |
| confidence_score: float | |
| safety_passed: bool | |
| quality_score: float | |
| issues: List[str] | |
| suggestions: List[str] | |
| # Detailed quality metrics | |
| relevance_score: float = 0.0 | |
| completeness_score: float = 0.0 | |
| coherence_score: float = 0.0 | |
| source_fidelity_score: float = 0.0 | |
| # Safety metrics | |
| contains_pii: bool = False | |
| inappropriate_content: bool = False | |
| potential_bias: bool = False | |
| prompt_injection_detected: bool = False | |
| class ResponseValidator: | |
| """ | |
| Validates response quality and safety for RAG system. | |
| Provides comprehensive validation including: | |
| - Content safety and appropriateness | |
| - Response quality metrics | |
| - Source alignment validation | |
| - Professional tone assessment | |
| """ | |
| def __init__(self, config: Optional[Dict[str, Any]] = None): | |
| """ | |
| Initialize ResponseValidator with configuration. | |
| Args: | |
| config: Configuration dictionary with validation thresholds | |
| """ | |
| self.config = config or self._get_default_config() | |
| # Compile regex patterns for efficiency | |
| self._pii_patterns = self._compile_pii_patterns() | |
| self._inappropriate_patterns = self._compile_inappropriate_patterns() | |
| self._bias_patterns = self._compile_bias_patterns() | |
| logger.info("ResponseValidator initialized") | |
| def _get_default_config(self) -> Dict[str, Any]: | |
| """Get default validation configuration.""" | |
| return { | |
| "min_relevance_score": 0.7, | |
| "min_completeness_score": 0.6, | |
| "min_coherence_score": 0.7, | |
| "min_source_fidelity_score": 0.8, | |
| "min_overall_quality": 0.7, | |
| "max_response_length": 1000, | |
| "min_response_length": 20, | |
| "require_citations": True, | |
| "strict_safety_mode": True, | |
| } | |
| def validate_response(self, response: str, sources: List[Dict[str, Any]], query: str) -> ValidationResult: | |
| """ | |
| Validate response quality and safety. | |
| Args: | |
| response: Generated response text | |
| sources: Source documents used for generation | |
| query: Original user query | |
| Returns: | |
| ValidationResult with detailed validation metrics | |
| """ | |
| try: | |
| # Perform safety checks | |
| safety_result = self.check_safety(response) | |
| # Calculate quality metrics | |
| quality_scores = self._calculate_quality_scores(response, sources, query) | |
| # Check response format and citations | |
| format_issues = self._validate_format(response, sources) | |
| # Calculate overall confidence | |
| confidence = self.calculate_confidence(response, sources, quality_scores) | |
| # Determine if response passes validation | |
| is_valid = ( | |
| safety_result["passed"] | |
| and quality_scores["overall"] >= self.config["min_overall_quality"] | |
| and len(format_issues) == 0 | |
| ) | |
| # Compile suggestions | |
| suggestions = [] | |
| if not is_valid: | |
| suggestions.extend(self._generate_improvement_suggestions(safety_result, quality_scores, format_issues)) | |
| return ValidationResult( | |
| is_valid=is_valid, | |
| confidence_score=confidence, | |
| safety_passed=safety_result["passed"], | |
| quality_score=quality_scores["overall"], | |
| issues=safety_result["issues"] + format_issues, | |
| suggestions=suggestions, | |
| relevance_score=quality_scores["relevance"], | |
| completeness_score=quality_scores["completeness"], | |
| coherence_score=quality_scores["coherence"], | |
| source_fidelity_score=quality_scores["source_fidelity"], | |
| contains_pii=safety_result["contains_pii"], | |
| inappropriate_content=safety_result["inappropriate_content"], | |
| potential_bias=safety_result["potential_bias"], | |
| prompt_injection_detected=safety_result["prompt_injection"], | |
| ) | |
| except Exception as e: | |
| logger.error(f"Validation error: {e}") | |
| return ValidationResult( | |
| is_valid=False, | |
| confidence_score=0.0, | |
| safety_passed=False, | |
| quality_score=0.0, | |
| issues=[f"Validation error: {str(e)}"], | |
| suggestions=["Please retry the request"], | |
| ) | |
| def calculate_confidence( | |
| self, | |
| response: str, | |
| sources: List[Dict[str, Any]], | |
| quality_scores: Optional[Dict[str, float]] = None, | |
| ) -> float: | |
| """ | |
| Calculate overall confidence score for response. | |
| Args: | |
| response: Generated response text | |
| sources: Source documents used | |
| quality_scores: Pre-calculated quality scores | |
| Returns: | |
| Confidence score between 0.0 and 1.0 | |
| """ | |
| if quality_scores is None: | |
| quality_scores = self._calculate_quality_scores(response, sources, "") | |
| # Weight different factors | |
| weights = { | |
| "source_count": 0.2, | |
| "avg_source_relevance": 0.3, | |
| "response_quality": 0.4, | |
| "citation_presence": 0.1, | |
| } | |
| # Source-based confidence | |
| source_count_score = min(len(sources) / 3.0, 1.0) # Max at 3 sources | |
| avg_relevance = sum(source.get("relevance_score", 0.0) for source in sources) / len(sources) if sources else 0.0 | |
| # Citation presence | |
| has_citations = self._has_proper_citations(response, sources) | |
| citation_score = 1.0 if has_citations else 0.3 | |
| # Combine scores | |
| confidence = ( | |
| weights["source_count"] * source_count_score | |
| + weights["avg_source_relevance"] * avg_relevance | |
| + weights["response_quality"] * quality_scores["overall"] | |
| + weights["citation_presence"] * citation_score | |
| ) | |
| return min(max(confidence, 0.0), 1.0) | |
| def check_safety(self, content: str) -> Dict[str, Any]: | |
| """ | |
| Perform comprehensive safety checks on content. | |
| Args: | |
| content: Text content to check | |
| Returns: | |
| Dictionary with safety check results | |
| """ | |
| issues = [] | |
| # Check for PII | |
| contains_pii = self._detect_pii(content) | |
| if contains_pii: | |
| issues.append("Content may contain personally identifiable information") | |
| # Check for inappropriate content | |
| inappropriate_content = self._detect_inappropriate_content(content) | |
| if inappropriate_content: | |
| issues.append("Content contains inappropriate material") | |
| # Check for potential bias | |
| potential_bias = self._detect_bias(content) | |
| if potential_bias: | |
| issues.append("Content may contain biased language") | |
| # Check for prompt injection | |
| prompt_injection = self._detect_prompt_injection(content) | |
| if prompt_injection: | |
| issues.append("Potential prompt injection detected") | |
| # Overall safety assessment | |
| passed = ( | |
| not contains_pii | |
| and not inappropriate_content | |
| and (not potential_bias or not self.config["strict_safety_mode"]) | |
| ) | |
| return { | |
| "passed": passed, | |
| "issues": issues, | |
| "contains_pii": contains_pii, | |
| "inappropriate_content": inappropriate_content, | |
| "potential_bias": potential_bias, | |
| "prompt_injection": prompt_injection, | |
| } | |
| def _calculate_quality_scores(self, response: str, sources: List[Dict[str, Any]], query: str) -> Dict[str, float]: | |
| """Calculate detailed quality metrics.""" | |
| # Relevance: How well does response address the query | |
| relevance = self._calculate_relevance(response, query) | |
| # Completeness: Does response adequately address the question | |
| completeness = self._calculate_completeness(response, query) | |
| # Coherence: Is the response logically structured and coherent | |
| coherence = self._calculate_coherence(response) | |
| # Source fidelity: How well does response align with sources | |
| source_fidelity = self._calculate_source_fidelity(response, sources) | |
| # Overall quality (weighted average) | |
| overall = 0.3 * relevance + 0.25 * completeness + 0.2 * coherence + 0.25 * source_fidelity | |
| return { | |
| "relevance": relevance, | |
| "completeness": completeness, | |
| "coherence": coherence, | |
| "source_fidelity": source_fidelity, | |
| "overall": overall, | |
| } | |
| def _calculate_relevance(self, response: str, query: str) -> float: | |
| """Calculate relevance score between response and query.""" | |
| if not query.strip(): | |
| return 1.0 # No query to compare against | |
| # Simple keyword overlap for now (can be enhanced with embeddings) | |
| query_words = set(query.lower().split()) | |
| response_words = set(response.lower().split()) | |
| if not query_words: | |
| return 1.0 | |
| overlap = len(query_words.intersection(response_words)) | |
| return min(overlap / len(query_words), 1.0) | |
| def _calculate_completeness(self, response: str, query: str) -> float: | |
| """Calculate completeness score based on response length and structure.""" | |
| target_length = 200 # Ideal response length | |
| # Length-based score | |
| length_score = min(len(response) / target_length, 1.0) | |
| # Structure score (presence of clear statements) | |
| has_conclusion = any( | |
| phrase in response.lower() for phrase in ["according to", "based on", "in summary", "therefore"] | |
| ) | |
| structure_score = 1.0 if has_conclusion else 0.7 | |
| return (length_score + structure_score) / 2.0 | |
| def _calculate_coherence(self, response: str) -> float: | |
| """Calculate coherence score based on response structure.""" | |
| sentences = response.split(".") | |
| if len(sentences) < 2: | |
| return 0.8 # Short responses are typically coherent | |
| # Check for repetition | |
| unique_sentences = len(set(s.strip().lower() for s in sentences if s.strip())) | |
| repetition_score = unique_sentences / len([s for s in sentences if s.strip()]) | |
| # Check for logical flow indicators | |
| flow_indicators = [ | |
| "however", | |
| "therefore", | |
| "additionally", | |
| "furthermore", | |
| "consequently", | |
| ] | |
| has_flow = any(indicator in response.lower() for indicator in flow_indicators) | |
| flow_score = 1.0 if has_flow else 0.8 | |
| return (repetition_score + flow_score) / 2.0 | |
| def _calculate_source_fidelity(self, response: str, sources: List[Dict[str, Any]]) -> float: | |
| """Calculate how well response aligns with source documents.""" | |
| if not sources: | |
| return 0.5 # Neutral score if no sources | |
| # Check for citation presence | |
| has_citations = self._has_proper_citations(response, sources) | |
| citation_score = 1.0 if has_citations else 0.3 | |
| # Check for content alignment (simplified) | |
| source_content = " ".join(source.get("excerpt", "") for source in sources).lower() | |
| response_lower = response.lower() | |
| # Look for key terms from sources in response | |
| source_words = set(source_content.split()) | |
| response_words = set(response_lower.split()) | |
| if source_words: | |
| alignment = len(source_words.intersection(response_words)) / len(source_words) | |
| else: | |
| alignment = 0.5 | |
| return (citation_score + min(alignment * 2, 1.0)) / 2.0 | |
| def _has_proper_citations(self, response: str, sources: List[Dict[str, Any]]) -> bool: | |
| """Check if response contains proper citations.""" | |
| if not self.config["require_citations"]: | |
| return True | |
| # Look for citation patterns | |
| citation_patterns = [ | |
| r"\[.*?\]", # [source] | |
| r"\(.*?\)", # (source) | |
| r"according to.*?", # according to X | |
| r"based on.*?", # based on X | |
| ] | |
| has_citation_format = any(re.search(pattern, response, re.IGNORECASE) for pattern in citation_patterns) | |
| # Check if source documents are mentioned | |
| source_names = [source.get("document", "").lower() for source in sources] | |
| response_lower = response.lower() | |
| mentions_sources = any(name in response_lower for name in source_names if name) | |
| return has_citation_format or mentions_sources | |
| def _validate_format(self, response: str, sources: List[Dict[str, Any]]) -> List[str]: | |
| """Validate response format and structure.""" | |
| issues = [] | |
| # Length validation | |
| if len(response) < self.config["min_response_length"]: | |
| min_length = self.config["min_response_length"] | |
| issues.append(f"Response too short (minimum {min_length} characters)") | |
| if len(response) > self.config["max_response_length"]: | |
| max_length = self.config["max_response_length"] | |
| issues.append(f"Response too long (maximum {max_length} characters)") | |
| # Professional tone check (basic) | |
| informal_patterns = [ | |
| r"\byo\b", | |
| r"\bwassup\b", | |
| r"\bgonna\b", | |
| r"\bwanna\b", | |
| r"\bunrealz\b", | |
| r"\bwtf\b", | |
| r"\bomg\b", | |
| ] | |
| if any(re.search(pattern, response, re.IGNORECASE) for pattern in informal_patterns): | |
| issues.append("Response contains informal language") | |
| return issues | |
| def _generate_improvement_suggestions( | |
| self, | |
| safety_result: Dict[str, Any], | |
| quality_scores: Dict[str, float], | |
| format_issues: List[str], | |
| ) -> List[str]: | |
| """Generate suggestions for improving response quality.""" | |
| suggestions = [] | |
| if not safety_result["passed"]: | |
| suggestions.append("Review content for safety and appropriateness") | |
| if quality_scores["relevance"] < self.config["min_relevance_score"]: | |
| suggestions.append("Ensure response directly addresses the user's question") | |
| if quality_scores["completeness"] < self.config["min_completeness_score"]: | |
| suggestions.append("Provide more comprehensive information") | |
| if quality_scores["source_fidelity"] < self.config["min_source_fidelity_score"]: | |
| suggestions.append("Include proper citations and source references") | |
| if format_issues: | |
| suggestions.append("Review response format and professional tone") | |
| return suggestions | |
| def _compile_pii_patterns(self) -> List[Pattern[str]]: | |
| """Compile regex patterns for PII detection.""" | |
| patterns = [ | |
| r"\b\d{3}-\d{2}-\d{4}\b", # SSN | |
| r"\b\d{4}[- ]?\d{4}[- ]?\d{4}[- ]?\d{4}\b", # Credit card | |
| r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email | |
| r"\b\d{3}[-.]\d{3}[-.]\d{4}\b", # Phone number | |
| ] | |
| return [re.compile(pattern) for pattern in patterns] | |
| def _compile_inappropriate_patterns(self) -> List[Pattern[str]]: | |
| """Compile regex patterns for inappropriate content detection.""" | |
| # Basic patterns (expand as needed) | |
| patterns = [ | |
| r"\b(?:hate|discriminat|harass)\w*\b", | |
| r"\b(?:offensive|inappropriate|unprofessional)\b", | |
| ] | |
| return [re.compile(pattern, re.IGNORECASE) for pattern in patterns] | |
| def _compile_bias_patterns(self) -> List[Pattern[str]]: | |
| """Compile regex patterns for bias detection.""" | |
| patterns = [ | |
| r"\b(?:always|never|all|none)\s+(?:men|women|people)\b", | |
| r"\b(?:typical|usual)\s+(?:man|woman|person)\b", | |
| ] | |
| return [re.compile(pattern, re.IGNORECASE) for pattern in patterns] | |
| def _detect_pii(self, content: str) -> bool: | |
| """Detect personally identifiable information.""" | |
| return any(pattern.search(content) for pattern in self._pii_patterns) | |
| def _detect_inappropriate_content(self, content: str) -> bool: | |
| """Detect inappropriate content.""" | |
| return any(pattern.search(content) for pattern in self._inappropriate_patterns) | |
| def _detect_bias(self, content: str) -> bool: | |
| """Detect potential bias in content.""" | |
| return any(pattern.search(content) for pattern in self._bias_patterns) | |
| def _detect_prompt_injection(self, content: str) -> bool: | |
| """Detect potential prompt injection attempts.""" | |
| injection_patterns = [ | |
| r"ignore\s+(?:previous|all)\s+instructions", | |
| r"system\s*:", | |
| r"assistant\s*:", | |
| r"user\s*:", | |
| r"prompt\s*:", | |
| ] | |
| return any(re.search(pattern, content, re.IGNORECASE) for pattern in injection_patterns) | |