Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced Flask app with integrated guardrails system. | |
| This module demonstrates how to integrate the guardrails system | |
| with the existing Flask API endpoints. | |
| """ | |
| # ...existing code... | |
| from dotenv import load_dotenv | |
| from flask import Flask, jsonify, render_template, request | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| app = Flask(__name__) | |
| def index(): | |
| """ | |
| Renders the chat interface. | |
| """ | |
| return render_template("chat.html") | |
| def health(): | |
| """ | |
| Health check endpoint. | |
| """ | |
| return jsonify({"status": "ok"}), 200 | |
| def chat(): | |
| """ | |
| Enhanced endpoint for conversational RAG interactions with guardrails. | |
| Accepts JSON requests with user messages and returns AI-generated | |
| responses with comprehensive validation and safety checks. | |
| """ | |
| try: | |
| # Validate request contains JSON data | |
| if not request.is_json: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": "Content-Type must be application/json", | |
| } | |
| ), | |
| 400, | |
| ) | |
| data = request.get_json() | |
| # Validate required message parameter | |
| message = data.get("message") | |
| if message is None: | |
| return ( | |
| jsonify({"status": "error", "message": "message parameter is required"}), | |
| 400, | |
| ) | |
| if not isinstance(message, str) or not message.strip(): | |
| return ( | |
| jsonify({"status": "error", "message": "message must be a non-empty string"}), | |
| 400, | |
| ) | |
| # Extract optional parameters | |
| conversation_id = data.get("conversation_id") | |
| include_sources = data.get("include_sources", True) | |
| include_debug = data.get("include_debug", False) | |
| enable_guardrails = data.get("enable_guardrails", True) | |
| # Initialize enhanced RAG pipeline components | |
| try: | |
| from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH | |
| from src.embedding.embedding_service import EmbeddingService | |
| from src.llm.llm_service import LLMService | |
| from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline | |
| from src.rag.rag_pipeline import RAGPipeline | |
| from src.rag.response_formatter import ResponseFormatter | |
| from src.search.search_service import SearchService | |
| from src.vector_store.vector_db import VectorDatabase | |
| # Initialize services | |
| vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME) | |
| embedding_service = EmbeddingService() | |
| search_service = SearchService(vector_db, embedding_service) | |
| # Initialize LLM service from environment | |
| llm_service = LLMService.from_environment() | |
| # Initialize base RAG pipeline | |
| base_rag_pipeline = RAGPipeline(search_service, llm_service) | |
| # Initialize enhanced pipeline with guardrails if enabled | |
| if enable_guardrails: | |
| # Configure guardrails for production use | |
| guardrails_config = { | |
| "min_confidence_threshold": 0.7, | |
| "strict_mode": False, | |
| "enable_response_enhancement": True, | |
| "log_all_results": True, | |
| } | |
| rag_pipeline = EnhancedRAGPipeline(base_rag_pipeline, guardrails_config) | |
| else: | |
| rag_pipeline = base_rag_pipeline | |
| # Initialize response formatter | |
| formatter = ResponseFormatter() | |
| except ValueError as e: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": f"LLM service configuration error: {str(e)}", | |
| "details": ( | |
| "Please ensure OPENROUTER_API_KEY or GROQ_API_KEY " "environment variables are set" | |
| ), | |
| } | |
| ), | |
| 503, | |
| ) | |
| except Exception as e: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": f"Service initialization failed: {str(e)}", | |
| } | |
| ), | |
| 500, | |
| ) | |
| # Generate RAG response with enhanced validation | |
| rag_response = rag_pipeline.generate_answer(message.strip()) | |
| # Format response for API with guardrails information | |
| if include_sources: | |
| formatted_response = formatter.format_api_response(rag_response, include_debug) | |
| # Add guardrails information if available | |
| if hasattr(rag_response, "guardrails_approved"): | |
| formatted_response["guardrails"] = { | |
| "approved": rag_response.guardrails_approved, | |
| "confidence": rag_response.guardrails_confidence, | |
| "safety_passed": rag_response.safety_passed, | |
| "quality_score": rag_response.quality_score, | |
| "warnings": getattr(rag_response, "guardrails_warnings", []), | |
| "fallbacks": getattr(rag_response, "guardrails_fallbacks", []), | |
| } | |
| else: | |
| formatted_response = formatter.format_chat_response(rag_response, conversation_id, include_sources=False) | |
| return jsonify(formatted_response) | |
| except Exception as e: | |
| return ( | |
| jsonify({"status": "error", "message": f"Chat request failed: {str(e)}"}), | |
| 500, | |
| ) | |
| def chat_health(): | |
| """ | |
| Health check endpoint for enhanced RAG chat functionality. | |
| Returns the status of all RAG pipeline components including guardrails. | |
| """ | |
| try: | |
| from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH | |
| from src.embedding.embedding_service import EmbeddingService | |
| from src.llm.llm_service import LLMService | |
| from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline | |
| from src.rag.rag_pipeline import RAGPipeline | |
| from src.search.search_service import SearchService | |
| from src.vector_store.vector_db import VectorDatabase | |
| # Initialize services | |
| vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME) | |
| embedding_service = EmbeddingService() | |
| search_service = SearchService(vector_db, embedding_service) | |
| llm_service = LLMService.from_environment() | |
| # Initialize enhanced pipeline | |
| base_rag_pipeline = RAGPipeline(search_service, llm_service) | |
| enhanced_pipeline = EnhancedRAGPipeline(base_rag_pipeline) | |
| # Get comprehensive health status | |
| health_status = enhanced_pipeline.get_health_status() | |
| return jsonify( | |
| { | |
| "status": "healthy", | |
| "components": health_status, | |
| "timestamp": health_status.get("timestamp", "unknown"), | |
| } | |
| ) | |
| except ValueError as e: | |
| # Specific handling for LLM configuration errors | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": f"LLM configuration error: {str(e)}", | |
| "health": { | |
| "pipeline_status": "unhealthy", | |
| "components": { | |
| "llm_service": { | |
| "status": "unconfigured", | |
| "error": str(e), | |
| } | |
| }, | |
| }, | |
| } | |
| ), | |
| 503, | |
| ) | |
| except Exception as e: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "components": {"error": "Failed to initialize components"}, | |
| } | |
| ), | |
| 500, | |
| ) | |
| def validate_response(): | |
| """ | |
| Standalone endpoint for validating responses with guardrails. | |
| Allows testing of guardrails validation without full RAG pipeline. | |
| """ | |
| try: | |
| if not request.is_json: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": "Content-Type must be application/json", | |
| } | |
| ), | |
| 400, | |
| ) | |
| data = request.get_json() | |
| # Validate required parameters | |
| response_text = data.get("response") | |
| query_text = data.get("query") | |
| sources = data.get("sources", []) | |
| if not response_text or not query_text: | |
| return ( | |
| jsonify( | |
| { | |
| "status": "error", | |
| "message": "response and query parameters are required", | |
| } | |
| ), | |
| 400, | |
| ) | |
| # Initialize enhanced pipeline for validation | |
| from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH | |
| from src.embedding.embedding_service import EmbeddingService | |
| from src.llm.llm_service import LLMService | |
| from src.rag.enhanced_rag_pipeline import EnhancedRAGPipeline | |
| from src.rag.rag_pipeline import RAGPipeline | |
| from src.search.search_service import SearchService | |
| from src.vector_store.vector_db import VectorDatabase | |
| # Initialize services | |
| vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME) | |
| embedding_service = EmbeddingService() | |
| search_service = SearchService(vector_db, embedding_service) | |
| llm_service = LLMService.from_environment() | |
| # Initialize enhanced pipeline | |
| base_rag_pipeline = RAGPipeline(search_service, llm_service) | |
| enhanced_pipeline = EnhancedRAGPipeline(base_rag_pipeline) | |
| # Perform validation | |
| validation_result = enhanced_pipeline.validate_response_only(response_text, query_text, sources) | |
| return jsonify({"status": "success", "validation": validation_result}) | |
| except Exception as e: | |
| return ( | |
| jsonify({"status": "error", "message": f"Validation failed: {str(e)}"}), | |
| 500, | |
| ) | |
| if __name__ == "__main__": | |
| app.run(debug=True, host="0.0.0.0", port=8080) | |