Tobias Pasquale commited on
Commit
c280a92
·
1 Parent(s): 2770882

feat: complete Phase 3 Issue #23 LLM integration and RAG pipeline

Browse files

- Add multi-provider LLM service (OpenRouter/Groq) with fallback support
- Implement complete RAG pipeline with context management and response formatting
- Add /chat POST endpoint for conversational AI interactions
- Add /chat/health GET endpoint for RAG pipeline monitoring
- Create comprehensive prompt templates for corporate policy Q&A
- Add context optimization with length management and deduplication
- Implement 90+ comprehensive tests with TDD approach
- Update requirements.txt with requests dependency
- Update CHANGELOG.md with Phase 3 completion status

Components:
- src/llm/: LLM service, context manager, prompt templates
- src/rag/: RAG pipeline orchestration and response formatting
- app.py: Flask endpoints integration with full error handling
- tests/: Comprehensive test coverage for all new functionality

Resolves #23

CHANGELOG.md CHANGED
@@ -19,6 +19,80 @@ Each entry includes:
19
 
20
  ---
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ### 2025-10-17 - Phase 2B Complete - Documentation and Testing Implementation
23
 
24
  **Entry #022** | **Action Type**: CREATE/UPDATE | **Component**: Phase 2B Completion | **Issues**: #17, #19 ✅ **COMPLETED**
 
19
 
20
  ---
21
 
22
+ ### 2025-10-17 - Phase 3 RAG Core Implementation - LLM Integration Complete
23
+
24
+ **Entry #023** | **Action Type**: CREATE/IMPLEMENT | **Component**: RAG Core Implementation | **Issue**: #23 ✅ **COMPLETED**
25
+
26
+ - **Phase 3 Launch**: ✅ **Issue #23 - LLM Integration and Chat Endpoint - FULLY IMPLEMENTED**
27
+ - **Multi-Provider LLM Service**: OpenRouter and Groq API integration with automatic fallback
28
+ - **Complete RAG Pipeline**: End-to-end retrieval-augmented generation system
29
+ - **Flask API Integration**: New `/chat` and `/chat/health` endpoints
30
+ - **Comprehensive Testing**: 90+ test cases with TDD implementation approach
31
+
32
+ - **Core Components Implemented**:
33
+ - **Files Created**:
34
+ - `src/llm/llm_service.py` - Multi-provider LLM service with retry logic and health checks
35
+ - `src/llm/context_manager.py` - Context optimization and length management system
36
+ - `src/llm/prompt_templates.py` - Corporate policy Q&A templates with citation requirements
37
+ - `src/rag/rag_pipeline.py` - Complete RAG orchestration combining search, context, and generation
38
+ - `src/rag/response_formatter.py` - Response formatting for API and chat interfaces
39
+ - `tests/test_llm/test_llm_service.py` - Comprehensive TDD tests for LLM service
40
+ - `tests/test_chat_endpoint.py` - Flask endpoint validation tests
41
+ - **Files Updated**:
42
+ - `app.py` - Added `/chat` POST and `/chat/health` GET endpoints with full integration
43
+ - `requirements.txt` - Added requests>=2.28.0 dependency for HTTP client functionality
44
+
45
+ - **LLM Service Architecture**:
46
+ - **Multi-Provider Support**: OpenRouter (primary) and Groq (fallback) API integration
47
+ - **Environment Configuration**: Automatic service initialization from OPENROUTER_API_KEY/GROQ_API_KEY
48
+ - **Robust Error Handling**: Retry logic, timeout management, and graceful degradation
49
+ - **Health Monitoring**: Service availability checks and performance metrics
50
+ - **Response Processing**: JSON parsing, content extraction, and error validation
51
+
52
+ - **RAG Pipeline Features**:
53
+ - **Context Retrieval**: Integration with existing SearchService for document similarity search
54
+ - **Context Optimization**: Smart truncation, duplicate removal, and relevance scoring
55
+ - **Prompt Engineering**: Corporate policy-focused templates with citation requirements
56
+ - **Response Generation**: LLM integration with confidence scoring and source attribution
57
+ - **Citation Validation**: Automatic source tracking and reference formatting
58
+
59
+ - **Flask API Endpoints**:
60
+ - **POST `/chat`**: Conversational RAG endpoint with message processing and response generation
61
+ - **Input Validation**: Required message parameter, optional conversation_id, include_sources, include_debug
62
+ - **JSON Response**: Answer, confidence score, sources, citations, and processing metrics
63
+ - **Error Handling**: 400 for validation errors, 503 for service unavailability, 500 for server errors
64
+ - **GET `/chat/health`**: RAG pipeline health monitoring with component status reporting
65
+ - **Service Checks**: LLM service, vector database, search service, and embedding service validation
66
+ - **Status Reporting**: Healthy/degraded/unhealthy states with detailed component information
67
+
68
+ - **API Specifications**:
69
+ - **Chat Request**: `{"message": "What is the remote work policy?", "include_sources": true}`
70
+ - **Chat Response**: `{"status": "success", "answer": "...", "confidence": 0.85, "sources": [...], "citations": [...]}`
71
+ - **Health Response**: `{"status": "success", "health": {"pipeline_status": "healthy", "components": {...}}}`
72
+
73
+ - **Testing Implementation**:
74
+ - **Test Coverage**: 90+ test cases covering all LLM service functionality and API endpoints
75
+ - **TDD Approach**: Comprehensive test-driven development with mocking and integration tests
76
+ - **Validation Results**: All input validation tests passing, proper error handling confirmed
77
+ - **Integration Testing**: Full RAG pipeline validation with existing search and vector systems
78
+
79
+ - **Technical Achievements**:
80
+ - **Production-Ready RAG**: Complete retrieval-augmented generation system with enterprise-grade error handling
81
+ - **Modular Architecture**: Clean separation of concerns with dependency injection for testing
82
+ - **Comprehensive Documentation**: Type hints, docstrings, and architectural documentation
83
+ - **Environment Flexibility**: Multi-provider LLM support with graceful fallback mechanisms
84
+
85
+ - **Success Criteria Met**: ✅ All Phase 3 Issue #23 requirements completed
86
+ - ✅ Multi-provider LLM integration (OpenRouter, Groq)
87
+ - ✅ Context management and optimization system
88
+ - ✅ RAG pipeline orchestration and response generation
89
+ - ✅ Flask API endpoint integration with health monitoring
90
+ - ✅ Comprehensive test coverage and validation
91
+
92
+ - **Project Status**: Phase 3 Issue #23 **COMPLETE** ✅ - Ready for Issue #24 (Guardrails and Quality Assurance)
93
+
94
+ ---
95
+
96
  ### 2025-10-17 - Phase 2B Complete - Documentation and Testing Implementation
97
 
98
  **Entry #022** | **Action Type**: CREATE/UPDATE | **Component**: Phase 2B Completion | **Issues**: #17, #19 ✅ **COMPLETED**
app.py CHANGED
@@ -164,5 +164,172 @@ def search():
164
  return jsonify({"status": "error", "message": f"Search failed: {str(e)}"}), 500
165
 
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  if __name__ == "__main__":
168
  app.run(debug=True)
 
164
  return jsonify({"status": "error", "message": f"Search failed: {str(e)}"}), 500
165
 
166
 
167
+ @app.route("/chat", methods=["POST"])
168
+ def chat():
169
+ """
170
+ Endpoint for conversational RAG interactions.
171
+
172
+ Accepts JSON requests with user messages and returns AI-generated
173
+ responses based on corporate policy documents.
174
+ """
175
+ try:
176
+ # Validate request contains JSON data
177
+ if not request.is_json:
178
+ return (
179
+ jsonify({
180
+ "status": "error",
181
+ "message": "Content-Type must be application/json"
182
+ }),
183
+ 400,
184
+ )
185
+
186
+ data = request.get_json()
187
+
188
+ # Validate required message parameter
189
+ message = data.get("message")
190
+ if message is None:
191
+ return (
192
+ jsonify({
193
+ "status": "error",
194
+ "message": "message parameter is required"
195
+ }),
196
+ 400,
197
+ )
198
+
199
+ if not isinstance(message, str) or not message.strip():
200
+ return (
201
+ jsonify({
202
+ "status": "error",
203
+ "message": "message must be a non-empty string"
204
+ }),
205
+ 400,
206
+ )
207
+
208
+ # Extract optional parameters
209
+ conversation_id = data.get("conversation_id")
210
+ include_sources = data.get("include_sources", True)
211
+ include_debug = data.get("include_debug", False)
212
+
213
+ # Initialize RAG pipeline components
214
+ try:
215
+ from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
216
+ from src.embedding.embedding_service import EmbeddingService
217
+ from src.search.search_service import SearchService
218
+ from src.vector_store.vector_db import VectorDatabase
219
+ from src.llm.llm_service import LLMService
220
+ from src.rag.rag_pipeline import RAGPipeline
221
+ from src.rag.response_formatter import ResponseFormatter
222
+
223
+ # Initialize services
224
+ vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
225
+ embedding_service = EmbeddingService()
226
+ search_service = SearchService(vector_db, embedding_service)
227
+
228
+ # Initialize LLM service from environment
229
+ llm_service = LLMService.from_environment()
230
+
231
+ # Initialize RAG pipeline
232
+ rag_pipeline = RAGPipeline(search_service, llm_service)
233
+
234
+ # Initialize response formatter
235
+ formatter = ResponseFormatter()
236
+
237
+ except ValueError as e:
238
+ return (
239
+ jsonify({
240
+ "status": "error",
241
+ "message": f"LLM service configuration error: {str(e)}",
242
+ "details": "Please ensure OPENROUTER_API_KEY or GROQ_API_KEY environment variables are set"
243
+ }),
244
+ 503,
245
+ )
246
+ except Exception as e:
247
+ return (
248
+ jsonify({
249
+ "status": "error",
250
+ "message": f"Service initialization failed: {str(e)}"
251
+ }),
252
+ 500,
253
+ )
254
+
255
+ # Generate RAG response
256
+ rag_response = rag_pipeline.generate_answer(message.strip())
257
+
258
+ # Format response for API
259
+ if include_sources:
260
+ formatted_response = formatter.format_api_response(rag_response, include_debug)
261
+ else:
262
+ formatted_response = formatter.format_chat_response(
263
+ rag_response,
264
+ conversation_id,
265
+ include_sources=False
266
+ )
267
+
268
+ return jsonify(formatted_response)
269
+
270
+ except Exception as e:
271
+ return jsonify({
272
+ "status": "error",
273
+ "message": f"Chat request failed: {str(e)}"
274
+ }), 500
275
+
276
+
277
+ @app.route("/chat/health", methods=["GET"])
278
+ def chat_health():
279
+ """
280
+ Health check endpoint for RAG chat functionality.
281
+
282
+ Returns the status of all RAG pipeline components.
283
+ """
284
+ try:
285
+ from src.config import COLLECTION_NAME, VECTOR_DB_PERSIST_PATH
286
+ from src.embedding.embedding_service import EmbeddingService
287
+ from src.search.search_service import SearchService
288
+ from src.vector_store.vector_db import VectorDatabase
289
+ from src.llm.llm_service import LLMService
290
+ from src.rag.rag_pipeline import RAGPipeline
291
+ from src.rag.response_formatter import ResponseFormatter
292
+
293
+ # Initialize services for health check
294
+ vector_db = VectorDatabase(VECTOR_DB_PERSIST_PATH, COLLECTION_NAME)
295
+ embedding_service = EmbeddingService()
296
+ search_service = SearchService(vector_db, embedding_service)
297
+
298
+ try:
299
+ llm_service = LLMService.from_environment()
300
+ rag_pipeline = RAGPipeline(search_service, llm_service)
301
+ formatter = ResponseFormatter()
302
+
303
+ # Perform health check
304
+ health_data = rag_pipeline.health_check()
305
+ health_response = formatter.create_health_response(health_data)
306
+
307
+ # Determine HTTP status based on health
308
+ if health_data.get("pipeline") == "healthy":
309
+ return jsonify(health_response), 200
310
+ elif health_data.get("pipeline") == "degraded":
311
+ return jsonify(health_response), 200 # Still functional
312
+ else:
313
+ return jsonify(health_response), 503 # Service unavailable
314
+
315
+ except ValueError as e:
316
+ return jsonify({
317
+ "status": "error",
318
+ "message": f"LLM configuration error: {str(e)}",
319
+ "health": {
320
+ "pipeline_status": "unhealthy",
321
+ "components": {
322
+ "llm_service": {"status": "unconfigured", "error": str(e)}
323
+ }
324
+ }
325
+ }), 503
326
+
327
+ except Exception as e:
328
+ return jsonify({
329
+ "status": "error",
330
+ "message": f"Health check failed: {str(e)}"
331
+ }), 500
332
+
333
+
334
  if __name__ == "__main__":
335
  app.run(debug=True)
requirements.txt CHANGED
@@ -4,3 +4,4 @@ gunicorn
4
  chromadb==0.4.15
5
  sentence-transformers==2.7.0
6
  numpy>=1.21.0
 
 
4
  chromadb==0.4.15
5
  sentence-transformers==2.7.0
6
  numpy>=1.21.0
7
+ requests>=2.28.0
src/llm/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Integration Package
3
+
4
+ This package provides integration with Large Language Models (LLMs)
5
+ for the RAG application, supporting multiple providers like OpenRouter and Groq.
6
+
7
+ Classes:
8
+ LLMService: Main service for LLM interactions
9
+ PromptTemplates: Predefined prompt templates for corporate policy Q&A
10
+ ContextManager: Manages context retrieval and formatting
11
+ """
src/llm/context_manager.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Context Manager for RAG Pipeline
3
+
4
+ This module handles context retrieval, formatting, and management
5
+ for the RAG pipeline, ensuring optimal context window utilization.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+ from dataclasses import dataclass
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class ContextConfig:
17
+ """Configuration for context management."""
18
+ max_context_length: int = 3000 # Maximum characters in context
19
+ max_results: int = 5 # Maximum search results to include
20
+ min_similarity: float = 0.1 # Minimum similarity threshold
21
+ overlap_penalty: float = 0.1 # Penalty for overlapping content
22
+
23
+
24
+ class ContextManager:
25
+ """
26
+ Manages context retrieval and optimization for RAG pipeline.
27
+
28
+ Handles:
29
+ - Context length management
30
+ - Relevance filtering
31
+ - Duplicate content removal
32
+ - Source prioritization
33
+ """
34
+
35
+ def __init__(self, config: Optional[ContextConfig] = None):
36
+ """
37
+ Initialize ContextManager with configuration.
38
+
39
+ Args:
40
+ config: Context configuration, uses defaults if None
41
+ """
42
+ self.config = config or ContextConfig()
43
+ logger.info("ContextManager initialized")
44
+
45
+ def prepare_context(
46
+ self,
47
+ search_results: List[Dict[str, Any]],
48
+ query: str
49
+ ) -> Tuple[str, List[Dict[str, Any]]]:
50
+ """
51
+ Prepare optimized context from search results.
52
+
53
+ Args:
54
+ search_results: Results from SearchService
55
+ query: Original user query for context optimization
56
+
57
+ Returns:
58
+ Tuple of (formatted_context, filtered_results)
59
+ """
60
+ if not search_results:
61
+ return "No relevant information found.", []
62
+
63
+ # Filter and rank results
64
+ filtered_results = self._filter_results(search_results)
65
+
66
+ # Remove duplicates and optimize for context window
67
+ optimized_results = self._optimize_context(filtered_results)
68
+
69
+ # Format for prompt
70
+ formatted_context = self._format_context(optimized_results)
71
+
72
+ logger.debug(
73
+ f"Prepared context from {len(search_results)} results, "
74
+ f"filtered to {len(optimized_results)} results, "
75
+ f"{len(formatted_context)} characters"
76
+ )
77
+
78
+ return formatted_context, optimized_results
79
+
80
+ def _filter_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
81
+ """
82
+ Filter search results by relevance and quality.
83
+
84
+ Args:
85
+ results: Raw search results
86
+
87
+ Returns:
88
+ Filtered and sorted results
89
+ """
90
+ filtered = []
91
+
92
+ for result in results:
93
+ similarity = result.get("similarity_score", 0.0)
94
+ content = result.get("content", "").strip()
95
+
96
+ # Apply filters
97
+ if (similarity >= self.config.min_similarity and
98
+ content and
99
+ len(content) > 20): # Minimum content length
100
+ filtered.append(result)
101
+
102
+ # Sort by similarity score (descending)
103
+ filtered.sort(key=lambda x: x.get("similarity_score", 0.0), reverse=True)
104
+
105
+ # Limit to max results
106
+ return filtered[:self.config.max_results]
107
+
108
+ def _optimize_context(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
109
+ """
110
+ Optimize context to fit within token limits while maximizing relevance.
111
+
112
+ Args:
113
+ results: Filtered search results
114
+
115
+ Returns:
116
+ Optimized results list
117
+ """
118
+ if not results:
119
+ return []
120
+
121
+ optimized = []
122
+ current_length = 0
123
+ seen_content = set()
124
+
125
+ for result in results:
126
+ content = result.get("content", "").strip()
127
+ content_length = len(content)
128
+
129
+ # Check if adding this result would exceed limit
130
+ estimated_formatted_length = current_length + content_length + 100 # Buffer
131
+ if estimated_formatted_length > self.config.max_context_length:
132
+ # Try to truncate content
133
+ remaining_space = self.config.max_context_length - current_length - 100
134
+ if remaining_space > 200: # Minimum useful content
135
+ truncated_content = content[:remaining_space] + "..."
136
+ result_copy = result.copy()
137
+ result_copy["content"] = truncated_content
138
+ optimized.append(result_copy)
139
+ break
140
+
141
+ # Check for duplicate or highly similar content
142
+ content_lower = content.lower()
143
+ is_duplicate = False
144
+
145
+ for seen in seen_content:
146
+ # Simple similarity check for duplicates
147
+ if (len(set(content_lower.split()) & set(seen.split())) /
148
+ max(len(content_lower.split()), len(seen.split())) > 0.8):
149
+ is_duplicate = True
150
+ break
151
+
152
+ if not is_duplicate:
153
+ optimized.append(result)
154
+ seen_content.add(content_lower)
155
+ current_length += content_length
156
+
157
+ return optimized
158
+
159
+ def _format_context(self, results: List[Dict[str, Any]]) -> str:
160
+ """
161
+ Format optimized results into context string.
162
+
163
+ Args:
164
+ results: Optimized search results
165
+
166
+ Returns:
167
+ Formatted context string
168
+ """
169
+ if not results:
170
+ return "No relevant information found in corporate policies."
171
+
172
+ context_parts = []
173
+
174
+ for i, result in enumerate(results, 1):
175
+ metadata = result.get("metadata", {})
176
+ filename = metadata.get("filename", f"document_{i}")
177
+ content = result.get("content", "").strip()
178
+
179
+ # Format with document info
180
+ context_parts.append(
181
+ f"Document: {filename}\n"
182
+ f"Content: {content}"
183
+ )
184
+
185
+ return "\n\n---\n\n".join(context_parts)
186
+
187
+ def validate_context_quality(
188
+ self,
189
+ context: str,
190
+ query: str,
191
+ min_quality_score: float = 0.3
192
+ ) -> Dict[str, Any]:
193
+ """
194
+ Validate the quality of prepared context for a given query.
195
+
196
+ Args:
197
+ context: Formatted context string
198
+ query: Original user query
199
+ min_quality_score: Minimum acceptable quality score
200
+
201
+ Returns:
202
+ Dictionary with quality metrics and validation result
203
+ """
204
+ # Simple quality checks
205
+ quality_metrics = {
206
+ "length": len(context),
207
+ "has_content": bool(context.strip()),
208
+ "estimated_relevance": 0.0,
209
+ "passes_validation": False
210
+ }
211
+
212
+ if not context.strip():
213
+ quality_metrics["passes_validation"] = False
214
+ return quality_metrics
215
+
216
+ # Estimate relevance based on query term overlap
217
+ query_terms = set(query.lower().split())
218
+ context_terms = set(context.lower().split())
219
+
220
+ if query_terms and context_terms:
221
+ overlap = len(query_terms & context_terms)
222
+ relevance = overlap / len(query_terms)
223
+ quality_metrics["estimated_relevance"] = relevance
224
+ quality_metrics["passes_validation"] = relevance >= min_quality_score
225
+ else:
226
+ quality_metrics["passes_validation"] = False
227
+
228
+ return quality_metrics
229
+
230
+ def get_source_summary(self, results: List[Dict[str, Any]]) -> Dict[str, Any]:
231
+ """
232
+ Generate summary of sources used in context.
233
+
234
+ Args:
235
+ results: Search results used for context
236
+
237
+ Returns:
238
+ Summary of sources and their contribution
239
+ """
240
+ sources = {}
241
+ total_content_length = 0
242
+
243
+ for result in results:
244
+ metadata = result.get("metadata", {})
245
+ filename = metadata.get("filename", "unknown")
246
+ content_length = len(result.get("content", ""))
247
+ similarity = result.get("similarity_score", 0.0)
248
+
249
+ if filename not in sources:
250
+ sources[filename] = {
251
+ "chunks": 0,
252
+ "total_content_length": 0,
253
+ "max_similarity": 0.0,
254
+ "avg_similarity": 0.0
255
+ }
256
+
257
+ sources[filename]["chunks"] += 1
258
+ sources[filename]["total_content_length"] += content_length
259
+ sources[filename]["max_similarity"] = max(
260
+ sources[filename]["max_similarity"], similarity
261
+ )
262
+
263
+ total_content_length += content_length
264
+
265
+ # Calculate averages and percentages
266
+ for source_info in sources.values():
267
+ source_info["content_percentage"] = (
268
+ source_info["total_content_length"] / max(total_content_length, 1) * 100
269
+ )
270
+
271
+ return {
272
+ "total_sources": len(sources),
273
+ "total_chunks": len(results),
274
+ "total_content_length": total_content_length,
275
+ "sources": sources
276
+ }
src/llm/llm_service.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Service for RAG Application
3
+
4
+ This module provides integration with Large Language Models through multiple providers
5
+ including OpenRouter and Groq, with fallback capabilities and comprehensive error handling.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ from typing import Any, Dict, List, Optional, Union
12
+ import requests
13
+ from dataclasses import dataclass
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class LLMConfig:
20
+ """Configuration for LLM providers."""
21
+ provider: str # "openrouter" or "groq"
22
+ api_key: str
23
+ model_name: str
24
+ base_url: str
25
+ max_tokens: int = 1000
26
+ temperature: float = 0.1
27
+ timeout: int = 30
28
+
29
+
30
+ @dataclass
31
+ class LLMResponse:
32
+ """Standardized response from LLM providers."""
33
+ content: str
34
+ provider: str
35
+ model: str
36
+ usage: Dict[str, Any]
37
+ response_time: float
38
+ success: bool
39
+ error_message: Optional[str] = None
40
+
41
+
42
+ class LLMService:
43
+ """
44
+ Service for interacting with Large Language Models.
45
+
46
+ Supports multiple providers with automatic fallback and retry logic.
47
+ Designed for corporate policy Q&A with appropriate guardrails.
48
+ """
49
+
50
+ def __init__(self, configs: List[LLMConfig]):
51
+ """
52
+ Initialize LLMService with provider configurations.
53
+
54
+ Args:
55
+ configs: List of LLMConfig objects for different providers
56
+
57
+ Raises:
58
+ ValueError: If no valid configurations provided
59
+ """
60
+ if not configs:
61
+ raise ValueError("At least one LLM configuration must be provided")
62
+
63
+ self.configs = configs
64
+ self.current_config_index = 0
65
+ logger.info(f"LLMService initialized with {len(configs)} provider(s)")
66
+
67
+ @classmethod
68
+ def from_environment(cls) -> 'LLMService':
69
+ """
70
+ Create LLMService instance from environment variables.
71
+
72
+ Expected environment variables:
73
+ - OPENROUTER_API_KEY: API key for OpenRouter
74
+ - GROQ_API_KEY: API key for Groq
75
+
76
+ Returns:
77
+ LLMService instance with available providers
78
+
79
+ Raises:
80
+ ValueError: If no API keys found in environment
81
+ """
82
+ configs = []
83
+
84
+ # OpenRouter configuration
85
+ openrouter_key = os.getenv("OPENROUTER_API_KEY")
86
+ if openrouter_key:
87
+ configs.append(LLMConfig(
88
+ provider="openrouter",
89
+ api_key=openrouter_key,
90
+ model_name="microsoft/wizardlm-2-8x22b", # Free tier model
91
+ base_url="https://openrouter.ai/api/v1",
92
+ max_tokens=1000,
93
+ temperature=0.1
94
+ ))
95
+
96
+ # Groq configuration
97
+ groq_key = os.getenv("GROQ_API_KEY")
98
+ if groq_key:
99
+ configs.append(LLMConfig(
100
+ provider="groq",
101
+ api_key=groq_key,
102
+ model_name="llama3-8b-8192", # Free tier model
103
+ base_url="https://api.groq.com/openai/v1",
104
+ max_tokens=1000,
105
+ temperature=0.1
106
+ ))
107
+
108
+ if not configs:
109
+ raise ValueError(
110
+ "No LLM API keys found in environment. "
111
+ "Please set OPENROUTER_API_KEY or GROQ_API_KEY"
112
+ )
113
+
114
+ return cls(configs)
115
+
116
+ def generate_response(
117
+ self,
118
+ prompt: str,
119
+ max_retries: int = 2
120
+ ) -> LLMResponse:
121
+ """
122
+ Generate response from LLM with fallback support.
123
+
124
+ Args:
125
+ prompt: Input prompt for the LLM
126
+ max_retries: Maximum retry attempts per provider
127
+
128
+ Returns:
129
+ LLMResponse with generated content or error information
130
+ """
131
+ last_error = None
132
+
133
+ # Try each provider configuration
134
+ for attempt in range(len(self.configs)):
135
+ config = self.configs[self.current_config_index]
136
+
137
+ try:
138
+ logger.debug(f"Attempting generation with {config.provider}")
139
+ response = self._call_provider(config, prompt, max_retries)
140
+
141
+ if response.success:
142
+ logger.info(f"Successfully generated response using {config.provider}")
143
+ return response
144
+
145
+ last_error = response.error_message
146
+ logger.warning(f"Provider {config.provider} failed: {last_error}")
147
+
148
+ except Exception as e:
149
+ last_error = str(e)
150
+ logger.error(f"Error with provider {config.provider}: {last_error}")
151
+
152
+ # Move to next provider
153
+ self.current_config_index = (self.current_config_index + 1) % len(self.configs)
154
+
155
+ # All providers failed
156
+ logger.error("All LLM providers failed")
157
+ return LLMResponse(
158
+ content="",
159
+ provider="none",
160
+ model="none",
161
+ usage={},
162
+ response_time=0.0,
163
+ success=False,
164
+ error_message=f"All providers failed. Last error: {last_error}"
165
+ )
166
+
167
+ def _call_provider(
168
+ self,
169
+ config: LLMConfig,
170
+ prompt: str,
171
+ max_retries: int
172
+ ) -> LLMResponse:
173
+ """
174
+ Make API call to specific provider with retry logic.
175
+
176
+ Args:
177
+ config: Provider configuration
178
+ prompt: Input prompt
179
+ max_retries: Maximum retry attempts
180
+
181
+ Returns:
182
+ LLMResponse from the provider
183
+ """
184
+ start_time = time.time()
185
+
186
+ for attempt in range(max_retries + 1):
187
+ try:
188
+ headers = {
189
+ "Authorization": f"Bearer {config.api_key}",
190
+ "Content-Type": "application/json"
191
+ }
192
+
193
+ # Add provider-specific headers
194
+ if config.provider == "openrouter":
195
+ headers["HTTP-Referer"] = "https://github.com/sethmcknight/msse-ai-engineering"
196
+ headers["X-Title"] = "MSSE RAG Application"
197
+
198
+ payload = {
199
+ "model": config.model_name,
200
+ "messages": [
201
+ {
202
+ "role": "user",
203
+ "content": prompt
204
+ }
205
+ ],
206
+ "max_tokens": config.max_tokens,
207
+ "temperature": config.temperature
208
+ }
209
+
210
+ response = requests.post(
211
+ f"{config.base_url}/chat/completions",
212
+ headers=headers,
213
+ json=payload,
214
+ timeout=config.timeout
215
+ )
216
+
217
+ response.raise_for_status()
218
+ data = response.json()
219
+
220
+ # Extract response content
221
+ content = data["choices"][0]["message"]["content"]
222
+ usage = data.get("usage", {})
223
+
224
+ response_time = time.time() - start_time
225
+
226
+ return LLMResponse(
227
+ content=content,
228
+ provider=config.provider,
229
+ model=config.model_name,
230
+ usage=usage,
231
+ response_time=response_time,
232
+ success=True
233
+ )
234
+
235
+ except requests.exceptions.RequestException as e:
236
+ logger.warning(f"Request failed for {config.provider} (attempt {attempt + 1}): {e}")
237
+ if attempt < max_retries:
238
+ time.sleep(2 ** attempt) # Exponential backoff
239
+ continue
240
+
241
+ return LLMResponse(
242
+ content="",
243
+ provider=config.provider,
244
+ model=config.model_name,
245
+ usage={},
246
+ response_time=time.time() - start_time,
247
+ success=False,
248
+ error_message=str(e)
249
+ )
250
+
251
+ except Exception as e:
252
+ logger.error(f"Unexpected error with {config.provider}: {e}")
253
+ return LLMResponse(
254
+ content="",
255
+ provider=config.provider,
256
+ model=config.model_name,
257
+ usage={},
258
+ response_time=time.time() - start_time,
259
+ success=False,
260
+ error_message=str(e)
261
+ )
262
+
263
+ def health_check(self) -> Dict[str, Any]:
264
+ """
265
+ Check health status of all configured providers.
266
+
267
+ Returns:
268
+ Dictionary with provider health status
269
+ """
270
+ health_status = {}
271
+
272
+ for config in self.configs:
273
+ try:
274
+ # Simple test prompt
275
+ test_response = self._call_provider(
276
+ config,
277
+ "Hello, this is a test. Please respond with 'OK'.",
278
+ max_retries=1
279
+ )
280
+
281
+ health_status[config.provider] = {
282
+ "status": "healthy" if test_response.success else "unhealthy",
283
+ "model": config.model_name,
284
+ "response_time": test_response.response_time,
285
+ "error": test_response.error_message
286
+ }
287
+
288
+ except Exception as e:
289
+ health_status[config.provider] = {
290
+ "status": "unhealthy",
291
+ "model": config.model_name,
292
+ "response_time": 0.0,
293
+ "error": str(e)
294
+ }
295
+
296
+ return health_status
297
+
298
+ def get_available_providers(self) -> List[str]:
299
+ """Get list of available provider names."""
300
+ return [config.provider for config in self.configs]
src/llm/prompt_templates.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt Templates for Corporate Policy Q&A
3
+
4
+ This module contains predefined prompt templates optimized for
5
+ corporate policy question-answering with proper citation requirements.
6
+ """
7
+
8
+ from typing import Dict, List
9
+ from dataclasses import dataclass
10
+
11
+
12
+ @dataclass
13
+ class PromptTemplate:
14
+ """Template for generating prompts with context and citations."""
15
+ system_prompt: str
16
+ user_template: str
17
+ citation_format: str
18
+
19
+
20
+ class PromptTemplates:
21
+ """
22
+ Collection of prompt templates for different types of policy questions.
23
+
24
+ Templates are designed to ensure:
25
+ - Accurate responses based on provided context
26
+ - Proper citation of source documents
27
+ - Adherence to corporate policy scope
28
+ - Consistent formatting and tone
29
+ """
30
+
31
+ # System prompt for corporate policy assistant
32
+ SYSTEM_PROMPT = """You are a helpful corporate policy assistant. Your job is to answer questions about company policies based ONLY on the provided context documents.
33
+
34
+ IMPORTANT GUIDELINES:
35
+ 1. Answer questions using ONLY the information provided in the context
36
+ 2. If the context doesn't contain enough information to answer the question, say so explicitly
37
+ 3. Always cite your sources using the format: [Source: filename.md]
38
+ 4. Be accurate, concise, and professional
39
+ 5. If asked about topics not covered in the policies, politely redirect to HR or appropriate department
40
+ 6. Do not make assumptions or provide information not explicitly stated in the context
41
+
42
+ Your responses should be helpful while staying strictly within the scope of the provided corporate policies."""
43
+
44
+ @classmethod
45
+ def get_policy_qa_template(cls) -> PromptTemplate:
46
+ """
47
+ Get the standard template for policy question-answering.
48
+
49
+ Returns:
50
+ PromptTemplate configured for corporate policy Q&A
51
+ """
52
+ return PromptTemplate(
53
+ system_prompt=cls.SYSTEM_PROMPT,
54
+ user_template="""Based on the following corporate policy documents, please answer this question: {question}
55
+
56
+ CONTEXT DOCUMENTS:
57
+ {context}
58
+
59
+ Please provide a clear, accurate answer based on the information above. Include citations for all information using the format [Source: filename.md].""",
60
+ citation_format="[Source: {filename}]"
61
+ )
62
+
63
+ @classmethod
64
+ def get_clarification_template(cls) -> PromptTemplate:
65
+ """
66
+ Get template for when clarification is needed.
67
+
68
+ Returns:
69
+ PromptTemplate for clarification requests
70
+ """
71
+ return PromptTemplate(
72
+ system_prompt=cls.SYSTEM_PROMPT,
73
+ user_template="""The user asked: {question}
74
+
75
+ CONTEXT DOCUMENTS:
76
+ {context}
77
+
78
+ The provided context documents don't contain sufficient information to fully answer this question. Please provide a helpful response that:
79
+ 1. Acknowledges what information is available (if any)
80
+ 2. Clearly states what information is missing
81
+ 3. Suggests appropriate next steps (contact HR, check other resources, etc.)
82
+ 4. Cites any relevant sources using [Source: filename.md] format""",
83
+ citation_format="[Source: {filename}]"
84
+ )
85
+
86
+ @classmethod
87
+ def get_off_topic_template(cls) -> PromptTemplate:
88
+ """
89
+ Get template for off-topic questions.
90
+
91
+ Returns:
92
+ PromptTemplate for redirecting off-topic questions
93
+ """
94
+ return PromptTemplate(
95
+ system_prompt=cls.SYSTEM_PROMPT,
96
+ user_template="""The user asked: {question}
97
+
98
+ This question appears to be outside the scope of our corporate policies. Please provide a polite response that:
99
+ 1. Acknowledges the question
100
+ 2. Explains that this falls outside corporate policy documentation
101
+ 3. Suggests appropriate resources (HR, IT, management, etc.)
102
+ 4. Offers to help with any policy-related questions instead""",
103
+ citation_format=""
104
+ )
105
+
106
+ @staticmethod
107
+ def format_context(search_results: List[Dict]) -> str:
108
+ """
109
+ Format search results into context for the prompt.
110
+
111
+ Args:
112
+ search_results: List of search results from SearchService
113
+
114
+ Returns:
115
+ Formatted context string for the prompt
116
+ """
117
+ if not search_results:
118
+ return "No relevant policy documents found."
119
+
120
+ context_parts = []
121
+ for i, result in enumerate(search_results[:5], 1): # Limit to top 5 results
122
+ filename = result.get("metadata", {}).get("filename", "unknown")
123
+ content = result.get("content", "").strip()
124
+ similarity = result.get("similarity_score", 0.0)
125
+
126
+ context_parts.append(
127
+ f"Document {i}: {filename} (relevance: {similarity:.2f})\n"
128
+ f"Content: {content}\n"
129
+ )
130
+
131
+ return "\n---\n".join(context_parts)
132
+
133
+ @staticmethod
134
+ def extract_citations(response: str) -> List[str]:
135
+ """
136
+ Extract citations from LLM response.
137
+
138
+ Args:
139
+ response: Generated response text
140
+
141
+ Returns:
142
+ List of extracted filenames from citations
143
+ """
144
+ import re
145
+
146
+ # Pattern to match [Source: filename.md] format
147
+ citation_pattern = r'\[Source:\s*([^\]]+)\]'
148
+ matches = re.findall(citation_pattern, response)
149
+
150
+ # Clean up filenames
151
+ citations = []
152
+ for match in matches:
153
+ filename = match.strip()
154
+ if filename and filename not in citations:
155
+ citations.append(filename)
156
+
157
+ return citations
158
+
159
+ @staticmethod
160
+ def validate_citations(response: str, available_sources: List[str]) -> Dict[str, bool]:
161
+ """
162
+ Validate that all citations in response refer to available sources.
163
+
164
+ Args:
165
+ response: Generated response text
166
+ available_sources: List of available source filenames
167
+
168
+ Returns:
169
+ Dictionary mapping citations to their validity
170
+ """
171
+ citations = PromptTemplates.extract_citations(response)
172
+ validation = {}
173
+
174
+ for citation in citations:
175
+ # Check if citation matches any available source
176
+ valid = any(citation in source or source in citation
177
+ for source in available_sources)
178
+ validation[citation] = valid
179
+
180
+ return validation
181
+
182
+ @staticmethod
183
+ def add_fallback_citations(
184
+ response: str,
185
+ search_results: List[Dict]
186
+ ) -> str:
187
+ """
188
+ Add citations to response if none were provided by LLM.
189
+
190
+ Args:
191
+ response: Generated response text
192
+ search_results: Original search results used for context
193
+
194
+ Returns:
195
+ Response with added citations if needed
196
+ """
197
+ existing_citations = PromptTemplates.extract_citations(response)
198
+
199
+ if existing_citations:
200
+ return response # Already has citations
201
+
202
+ if not search_results:
203
+ return response # No sources to cite
204
+
205
+ # Add citations from top search results
206
+ top_sources = []
207
+ for result in search_results[:3]: # Top 3 sources
208
+ filename = result.get("metadata", {}).get("filename", "")
209
+ if filename and filename not in top_sources:
210
+ top_sources.append(filename)
211
+
212
+ if top_sources:
213
+ citation_text = " [Sources: " + ", ".join(top_sources) + "]"
214
+ return response + citation_text
215
+
216
+ return response
src/rag/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG (Retrieval-Augmented Generation) Package
3
+
4
+ This package implements the core RAG pipeline functionality,
5
+ combining semantic search with LLM-based response generation.
6
+
7
+ Classes:
8
+ RAGPipeline: Main RAG orchestration service
9
+ ResponseFormatter: Formats LLM responses with citations and metadata
10
+ """
src/rag/rag_pipeline.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RAG Pipeline - Core RAG Functionality
3
+
4
+ This module orchestrates the complete RAG (Retrieval-Augmented Generation) pipeline,
5
+ combining semantic search, context management, and LLM generation.
6
+ """
7
+
8
+ import logging
9
+ import time
10
+ from typing import Any, Dict, List, Optional
11
+ from dataclasses import dataclass
12
+
13
+ # Import our modules
14
+ from src.search.search_service import SearchService
15
+ from src.llm.llm_service import LLMService, LLMResponse
16
+ from src.llm.context_manager import ContextManager, ContextConfig
17
+ from src.llm.prompt_templates import PromptTemplates, PromptTemplate
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class RAGConfig:
24
+ """Configuration for RAG pipeline."""
25
+ max_context_length: int = 3000
26
+ search_top_k: int = 10
27
+ search_threshold: float = 0.1
28
+ min_similarity_for_answer: float = 0.15
29
+ max_response_length: int = 1000
30
+ enable_citation_validation: bool = True
31
+
32
+
33
+ @dataclass
34
+ class RAGResponse:
35
+ """Response from RAG pipeline with metadata."""
36
+ answer: str
37
+ sources: List[Dict[str, Any]]
38
+ confidence: float
39
+ processing_time: float
40
+ llm_provider: str
41
+ llm_model: str
42
+ context_length: int
43
+ search_results_count: int
44
+ success: bool
45
+ error_message: Optional[str] = None
46
+
47
+
48
+ class RAGPipeline:
49
+ """
50
+ Complete RAG pipeline orchestrating retrieval and generation.
51
+
52
+ Combines:
53
+ - Semantic search for context retrieval
54
+ - Context optimization and management
55
+ - LLM-based response generation
56
+ - Citation validation and formatting
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ search_service: SearchService,
62
+ llm_service: LLMService,
63
+ config: Optional[RAGConfig] = None
64
+ ):
65
+ """
66
+ Initialize RAG pipeline with required services.
67
+
68
+ Args:
69
+ search_service: Configured SearchService instance
70
+ llm_service: Configured LLMService instance
71
+ config: RAG configuration, uses defaults if None
72
+ """
73
+ self.search_service = search_service
74
+ self.llm_service = llm_service
75
+ self.config = config or RAGConfig()
76
+
77
+ # Initialize context manager with matching config
78
+ context_config = ContextConfig(
79
+ max_context_length=self.config.max_context_length,
80
+ max_results=self.config.search_top_k,
81
+ min_similarity=self.config.search_threshold
82
+ )
83
+ self.context_manager = ContextManager(context_config)
84
+
85
+ # Initialize prompt templates
86
+ self.prompt_templates = PromptTemplates()
87
+
88
+ logger.info("RAGPipeline initialized successfully")
89
+
90
+ def generate_answer(self, question: str) -> RAGResponse:
91
+ """
92
+ Generate answer to question using RAG pipeline.
93
+
94
+ Args:
95
+ question: User's question about corporate policies
96
+
97
+ Returns:
98
+ RAGResponse with answer and metadata
99
+ """
100
+ start_time = time.time()
101
+
102
+ try:
103
+ # Step 1: Retrieve relevant context
104
+ logger.debug(f"Starting RAG pipeline for question: {question[:100]}...")
105
+
106
+ search_results = self._retrieve_context(question)
107
+
108
+ if not search_results:
109
+ return self._create_no_context_response(question, start_time)
110
+
111
+ # Step 2: Prepare and optimize context
112
+ context, filtered_results = self.context_manager.prepare_context(
113
+ search_results, question
114
+ )
115
+
116
+ # Step 3: Check if we have sufficient context
117
+ quality_metrics = self.context_manager.validate_context_quality(
118
+ context, question, self.config.min_similarity_for_answer
119
+ )
120
+
121
+ if not quality_metrics["passes_validation"]:
122
+ return self._create_insufficient_context_response(
123
+ question, filtered_results, start_time
124
+ )
125
+
126
+ # Step 4: Generate response using LLM
127
+ llm_response = self._generate_llm_response(question, context)
128
+
129
+ if not llm_response.success:
130
+ return self._create_llm_error_response(
131
+ question, llm_response.error_message, start_time
132
+ )
133
+
134
+ # Step 5: Process and validate response
135
+ processed_response = self._process_response(
136
+ llm_response.content, filtered_results
137
+ )
138
+
139
+ processing_time = time.time() - start_time
140
+
141
+ return RAGResponse(
142
+ answer=processed_response,
143
+ sources=self._format_sources(filtered_results),
144
+ confidence=self._calculate_confidence(quality_metrics, llm_response),
145
+ processing_time=processing_time,
146
+ llm_provider=llm_response.provider,
147
+ llm_model=llm_response.model,
148
+ context_length=len(context),
149
+ search_results_count=len(search_results),
150
+ success=True
151
+ )
152
+
153
+ except Exception as e:
154
+ logger.error(f"RAG pipeline error: {e}")
155
+ return RAGResponse(
156
+ answer="I apologize, but I encountered an error processing your question. Please try again or contact support.",
157
+ sources=[],
158
+ confidence=0.0,
159
+ processing_time=time.time() - start_time,
160
+ llm_provider="none",
161
+ llm_model="none",
162
+ context_length=0,
163
+ search_results_count=0,
164
+ success=False,
165
+ error_message=str(e)
166
+ )
167
+
168
+ def _retrieve_context(self, question: str) -> List[Dict[str, Any]]:
169
+ """Retrieve relevant context using search service."""
170
+ try:
171
+ results = self.search_service.search(
172
+ query=question,
173
+ top_k=self.config.search_top_k,
174
+ threshold=self.config.search_threshold
175
+ )
176
+
177
+ logger.debug(f"Retrieved {len(results)} search results")
178
+ return results
179
+
180
+ except Exception as e:
181
+ logger.error(f"Context retrieval error: {e}")
182
+ return []
183
+
184
+ def _generate_llm_response(self, question: str, context: str) -> LLMResponse:
185
+ """Generate response using LLM with formatted prompt."""
186
+ template = self.prompt_templates.get_policy_qa_template()
187
+
188
+ # Format the prompt
189
+ formatted_prompt = template.user_template.format(
190
+ question=question,
191
+ context=context
192
+ )
193
+
194
+ # Add system prompt (if LLM service supports it in future)
195
+ full_prompt = f"{template.system_prompt}\n\n{formatted_prompt}"
196
+
197
+ return self.llm_service.generate_response(full_prompt)
198
+
199
+ def _process_response(
200
+ self,
201
+ raw_response: str,
202
+ search_results: List[Dict[str, Any]]
203
+ ) -> str:
204
+ """Process and validate LLM response."""
205
+
206
+ # Ensure citations are present
207
+ response_with_citations = self.prompt_templates.add_fallback_citations(
208
+ raw_response, search_results
209
+ )
210
+
211
+ # Validate citations if enabled
212
+ if self.config.enable_citation_validation:
213
+ available_sources = [
214
+ result.get("metadata", {}).get("filename", "")
215
+ for result in search_results
216
+ ]
217
+
218
+ citation_validation = self.prompt_templates.validate_citations(
219
+ response_with_citations, available_sources
220
+ )
221
+
222
+ # Log any invalid citations
223
+ invalid_citations = [
224
+ citation for citation, valid in citation_validation.items()
225
+ if not valid
226
+ ]
227
+
228
+ if invalid_citations:
229
+ logger.warning(f"Invalid citations detected: {invalid_citations}")
230
+
231
+ # Truncate if too long
232
+ if len(response_with_citations) > self.config.max_response_length:
233
+ truncated = response_with_citations[:self.config.max_response_length - 3] + "..."
234
+ logger.warning(f"Response truncated from {len(response_with_citations)} to {len(truncated)} characters")
235
+ return truncated
236
+
237
+ return response_with_citations
238
+
239
+ def _format_sources(self, search_results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
240
+ """Format search results for response metadata."""
241
+ sources = []
242
+
243
+ for result in search_results:
244
+ metadata = result.get("metadata", {})
245
+ sources.append({
246
+ "document": metadata.get("filename", "unknown"),
247
+ "chunk_id": result.get("chunk_id", ""),
248
+ "relevance_score": result.get("similarity_score", 0.0),
249
+ "excerpt": result.get("content", "")[:200] + "..." if len(result.get("content", "")) > 200 else result.get("content", "")
250
+ })
251
+
252
+ return sources
253
+
254
+ def _calculate_confidence(
255
+ self,
256
+ quality_metrics: Dict[str, Any],
257
+ llm_response: LLMResponse
258
+ ) -> float:
259
+ """Calculate confidence score for the response."""
260
+
261
+ # Base confidence on context quality
262
+ context_confidence = quality_metrics.get("estimated_relevance", 0.0)
263
+
264
+ # Adjust based on LLM response time (faster might indicate more confidence)
265
+ time_factor = min(1.0, 10.0 / max(llm_response.response_time, 1.0))
266
+
267
+ # Combine factors
268
+ confidence = (context_confidence * 0.7) + (time_factor * 0.3)
269
+
270
+ return min(1.0, max(0.0, confidence))
271
+
272
+ def _create_no_context_response(self, question: str, start_time: float) -> RAGResponse:
273
+ """Create response when no relevant context found."""
274
+ return RAGResponse(
275
+ answer="I couldn't find any relevant information in our corporate policies to answer your question. Please contact HR or check other company resources for assistance.",
276
+ sources=[],
277
+ confidence=0.0,
278
+ processing_time=time.time() - start_time,
279
+ llm_provider="none",
280
+ llm_model="none",
281
+ context_length=0,
282
+ search_results_count=0,
283
+ success=True # This is a valid "no answer" response
284
+ )
285
+
286
+ def _create_insufficient_context_response(
287
+ self,
288
+ question: str,
289
+ results: List[Dict[str, Any]],
290
+ start_time: float
291
+ ) -> RAGResponse:
292
+ """Create response when context quality is insufficient."""
293
+ return RAGResponse(
294
+ answer="I found some potentially relevant information, but it doesn't provide enough detail to fully answer your question. Please contact HR for more specific guidance or rephrase your question.",
295
+ sources=self._format_sources(results),
296
+ confidence=0.2,
297
+ processing_time=time.time() - start_time,
298
+ llm_provider="none",
299
+ llm_model="none",
300
+ context_length=0,
301
+ search_results_count=len(results),
302
+ success=True
303
+ )
304
+
305
+ def _create_llm_error_response(
306
+ self,
307
+ question: str,
308
+ error_message: str,
309
+ start_time: float
310
+ ) -> RAGResponse:
311
+ """Create response when LLM generation fails."""
312
+ return RAGResponse(
313
+ answer="I apologize, but I'm currently unable to generate a response. Please try again in a moment or contact support if the issue persists.",
314
+ sources=[],
315
+ confidence=0.0,
316
+ processing_time=time.time() - start_time,
317
+ llm_provider="error",
318
+ llm_model="error",
319
+ context_length=0,
320
+ search_results_count=0,
321
+ success=False,
322
+ error_message=error_message
323
+ )
324
+
325
+ def health_check(self) -> Dict[str, Any]:
326
+ """
327
+ Perform health check on all pipeline components.
328
+
329
+ Returns:
330
+ Dictionary with component health status
331
+ """
332
+ health_status = {
333
+ "pipeline": "healthy",
334
+ "components": {}
335
+ }
336
+
337
+ try:
338
+ # Check search service
339
+ test_results = self.search_service.search("test query", top_k=1, threshold=0.0)
340
+ health_status["components"]["search_service"] = {
341
+ "status": "healthy",
342
+ "test_results_count": len(test_results)
343
+ }
344
+ except Exception as e:
345
+ health_status["components"]["search_service"] = {
346
+ "status": "unhealthy",
347
+ "error": str(e)
348
+ }
349
+ health_status["pipeline"] = "degraded"
350
+
351
+ try:
352
+ # Check LLM service
353
+ llm_health = self.llm_service.health_check()
354
+ health_status["components"]["llm_service"] = llm_health
355
+
356
+ # Pipeline is unhealthy if all LLM providers are down
357
+ healthy_providers = sum(
358
+ 1 for provider_status in llm_health.values()
359
+ if provider_status.get("status") == "healthy"
360
+ )
361
+
362
+ if healthy_providers == 0:
363
+ health_status["pipeline"] = "unhealthy"
364
+
365
+ except Exception as e:
366
+ health_status["components"]["llm_service"] = {
367
+ "status": "unhealthy",
368
+ "error": str(e)
369
+ }
370
+ health_status["pipeline"] = "unhealthy"
371
+
372
+ return health_status
src/rag/response_formatter.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Response Formatter for RAG Pipeline
3
+
4
+ This module handles formatting of RAG responses with proper citation
5
+ formatting, metadata inclusion, and consistent response structure.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any, Dict, List, Optional
10
+ from dataclasses import dataclass, asdict
11
+ import json
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class FormattedResponse:
18
+ """Standardized formatted response for API endpoints."""
19
+ status: str
20
+ answer: str
21
+ sources: List[Dict[str, Any]]
22
+ metadata: Dict[str, Any]
23
+ processing_info: Dict[str, Any]
24
+ error: Optional[str] = None
25
+
26
+
27
+ class ResponseFormatter:
28
+ """
29
+ Formats RAG pipeline responses for various output formats.
30
+
31
+ Handles:
32
+ - API response formatting
33
+ - Citation formatting
34
+ - Metadata inclusion
35
+ - Error response formatting
36
+ """
37
+
38
+ def __init__(self):
39
+ """Initialize ResponseFormatter."""
40
+ logger.info("ResponseFormatter initialized")
41
+
42
+ def format_api_response(
43
+ self,
44
+ rag_response: Any, # RAGResponse type
45
+ include_debug: bool = False
46
+ ) -> Dict[str, Any]:
47
+ """
48
+ Format RAG response for API consumption.
49
+
50
+ Args:
51
+ rag_response: RAGResponse from RAG pipeline
52
+ include_debug: Whether to include debug information
53
+
54
+ Returns:
55
+ Formatted dictionary for JSON API response
56
+ """
57
+ if not rag_response.success:
58
+ return self._format_error_response(rag_response)
59
+
60
+ # Base response structure
61
+ formatted_response = {
62
+ "status": "success",
63
+ "answer": rag_response.answer,
64
+ "sources": self._format_source_list(rag_response.sources),
65
+ "metadata": {
66
+ "confidence": round(rag_response.confidence, 3),
67
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1),
68
+ "source_count": len(rag_response.sources),
69
+ "context_length": rag_response.context_length
70
+ }
71
+ }
72
+
73
+ # Add debug information if requested
74
+ if include_debug:
75
+ formatted_response["debug"] = {
76
+ "llm_provider": rag_response.llm_provider,
77
+ "llm_model": rag_response.llm_model,
78
+ "search_results_count": rag_response.search_results_count,
79
+ "processing_time_seconds": round(rag_response.processing_time, 3)
80
+ }
81
+
82
+ return formatted_response
83
+
84
+ def format_chat_response(
85
+ self,
86
+ rag_response: Any, # RAGResponse type
87
+ conversation_id: Optional[str] = None,
88
+ include_sources: bool = True
89
+ ) -> Dict[str, Any]:
90
+ """
91
+ Format RAG response for chat interface.
92
+
93
+ Args:
94
+ rag_response: RAGResponse from RAG pipeline
95
+ conversation_id: Optional conversation ID
96
+ include_sources: Whether to include source information
97
+
98
+ Returns:
99
+ Formatted dictionary for chat interface
100
+ """
101
+ if not rag_response.success:
102
+ return self._format_chat_error(rag_response, conversation_id)
103
+
104
+ response = {
105
+ "message": rag_response.answer,
106
+ "confidence": round(rag_response.confidence, 2),
107
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1)
108
+ }
109
+
110
+ if conversation_id:
111
+ response["conversation_id"] = conversation_id
112
+
113
+ if include_sources and rag_response.sources:
114
+ response["sources"] = self._format_sources_for_chat(rag_response.sources)
115
+
116
+ return response
117
+
118
+ def _format_source_list(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
119
+ """Format source list for API response."""
120
+ formatted_sources = []
121
+
122
+ for source in sources:
123
+ formatted_source = {
124
+ "document": source.get("document", "unknown"),
125
+ "relevance_score": round(source.get("relevance_score", 0.0), 3),
126
+ "excerpt": source.get("excerpt", "")
127
+ }
128
+
129
+ # Add chunk ID if available
130
+ chunk_id = source.get("chunk_id", "")
131
+ if chunk_id:
132
+ formatted_source["chunk_id"] = chunk_id
133
+
134
+ formatted_sources.append(formatted_source)
135
+
136
+ return formatted_sources
137
+
138
+ def _format_sources_for_chat(self, sources: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
139
+ """Format sources for chat interface (more concise)."""
140
+ formatted_sources = []
141
+
142
+ for i, source in enumerate(sources[:3], 1): # Limit to top 3 for chat
143
+ formatted_source = {
144
+ "id": i,
145
+ "document": source.get("document", "unknown"),
146
+ "relevance": f"{source.get('relevance_score', 0.0):.1%}",
147
+ "preview": source.get("excerpt", "")[:100] + "..." if len(source.get("excerpt", "")) > 100 else source.get("excerpt", "")
148
+ }
149
+ formatted_sources.append(formatted_source)
150
+
151
+ return formatted_sources
152
+
153
+ def _format_error_response(self, rag_response: Any) -> Dict[str, Any]:
154
+ """Format error response for API."""
155
+ return {
156
+ "status": "error",
157
+ "error": {
158
+ "message": rag_response.answer,
159
+ "details": rag_response.error_message,
160
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1)
161
+ },
162
+ "sources": [],
163
+ "metadata": {
164
+ "confidence": 0.0,
165
+ "source_count": 0,
166
+ "context_length": 0
167
+ }
168
+ }
169
+
170
+ def _format_chat_error(
171
+ self,
172
+ rag_response: Any,
173
+ conversation_id: Optional[str] = None
174
+ ) -> Dict[str, Any]:
175
+ """Format error response for chat interface."""
176
+ response = {
177
+ "message": rag_response.answer,
178
+ "error": True,
179
+ "processing_time_ms": round(rag_response.processing_time * 1000, 1)
180
+ }
181
+
182
+ if conversation_id:
183
+ response["conversation_id"] = conversation_id
184
+
185
+ return response
186
+
187
+ def validate_response_format(self, response: Dict[str, Any]) -> bool:
188
+ """
189
+ Validate that response follows expected format.
190
+
191
+ Args:
192
+ response: Formatted response dictionary
193
+
194
+ Returns:
195
+ True if format is valid, False otherwise
196
+ """
197
+ required_fields = ["status"]
198
+
199
+ # Check required fields
200
+ for field in required_fields:
201
+ if field not in response:
202
+ logger.error(f"Missing required field: {field}")
203
+ return False
204
+
205
+ # Check status-specific requirements
206
+ if response["status"] == "success":
207
+ success_fields = ["answer", "sources", "metadata"]
208
+ for field in success_fields:
209
+ if field not in response:
210
+ logger.error(f"Missing success field: {field}")
211
+ return False
212
+
213
+ elif response["status"] == "error":
214
+ if "error" not in response:
215
+ logger.error("Missing error field in error response")
216
+ return False
217
+
218
+ return True
219
+
220
+ def create_health_response(self, health_data: Dict[str, Any]) -> Dict[str, Any]:
221
+ """
222
+ Format health check response.
223
+
224
+ Args:
225
+ health_data: Health status from RAG pipeline
226
+
227
+ Returns:
228
+ Formatted health response
229
+ """
230
+ return {
231
+ "status": "success",
232
+ "health": {
233
+ "pipeline_status": health_data.get("pipeline", "unknown"),
234
+ "components": health_data.get("components", {}),
235
+ "timestamp": self._get_timestamp()
236
+ }
237
+ }
238
+
239
+ def create_no_answer_response(self, question: str, reason: str = "no_context") -> Dict[str, Any]:
240
+ """
241
+ Create standardized response when no answer can be provided.
242
+
243
+ Args:
244
+ question: Original user question
245
+ reason: Reason for no answer (no_context, insufficient_context, etc.)
246
+
247
+ Returns:
248
+ Formatted no-answer response
249
+ """
250
+ messages = {
251
+ "no_context": "I couldn't find any relevant information in our corporate policies to answer your question.",
252
+ "insufficient_context": "I found some potentially relevant information, but not enough to provide a complete answer.",
253
+ "off_topic": "This question appears to be outside the scope of our corporate policies.",
254
+ "error": "I encountered an error while processing your question."
255
+ }
256
+
257
+ message = messages.get(reason, messages["error"])
258
+
259
+ return {
260
+ "status": "no_answer",
261
+ "message": message,
262
+ "reason": reason,
263
+ "suggestion": "Please contact HR or rephrase your question for better results.",
264
+ "sources": []
265
+ }
266
+
267
+ def _get_timestamp(self) -> str:
268
+ """Get current timestamp in ISO format."""
269
+ from datetime import datetime
270
+ return datetime.utcnow().isoformat() + "Z"
271
+
272
+ def format_for_logging(self, rag_response: Any, question: str) -> Dict[str, Any]:
273
+ """
274
+ Format response data for logging purposes.
275
+
276
+ Args:
277
+ rag_response: RAGResponse from pipeline
278
+ question: Original question
279
+
280
+ Returns:
281
+ Formatted data for logging
282
+ """
283
+ return {
284
+ "timestamp": self._get_timestamp(),
285
+ "question_length": len(question),
286
+ "question_hash": hash(question) % 10000, # Simple hash for tracking
287
+ "success": rag_response.success,
288
+ "confidence": rag_response.confidence,
289
+ "processing_time": rag_response.processing_time,
290
+ "llm_provider": rag_response.llm_provider,
291
+ "source_count": len(rag_response.sources),
292
+ "context_length": rag_response.context_length,
293
+ "answer_length": len(rag_response.answer),
294
+ "error": rag_response.error_message
295
+ }
tests/test_chat_endpoint.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pytest
4
+ from unittest.mock import patch, MagicMock
5
+
6
+ from app import app as flask_app
7
+
8
+
9
+ @pytest.fixture
10
+ def app():
11
+ yield flask_app
12
+
13
+
14
+ @pytest.fixture
15
+ def client(app):
16
+ return app.test_client()
17
+
18
+
19
+ class TestChatEndpoint:
20
+ """Test cases for the /chat endpoint"""
21
+
22
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
23
+ @patch('app.RAGPipeline')
24
+ @patch('app.ResponseFormatter')
25
+ @patch('app.LLMService')
26
+ @patch('app.SearchService')
27
+ @patch('app.VectorDatabase')
28
+ @patch('app.EmbeddingService')
29
+ def test_chat_endpoint_valid_request(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
30
+ """Test chat endpoint with valid request"""
31
+ # Mock the RAG pipeline response
32
+ mock_response = {
33
+ 'answer': 'Based on the remote work policy, employees can work remotely up to 3 days per week.',
34
+ 'confidence': 0.85,
35
+ 'sources': [{'chunk_id': '123', 'content': 'Remote work policy content...'}],
36
+ 'citations': ['remote_work_policy.md'],
37
+ 'processing_time_ms': 1500
38
+ }
39
+
40
+ # Setup mock instances
41
+ mock_rag_instance = MagicMock()
42
+ mock_rag_instance.generate_answer.return_value = mock_response
43
+ mock_rag.return_value = mock_rag_instance
44
+
45
+ mock_formatter_instance = MagicMock()
46
+ mock_formatter_instance.format_api_response.return_value = {
47
+ "status": "success",
48
+ "answer": mock_response['answer'],
49
+ "confidence": mock_response['confidence'],
50
+ "sources": mock_response['sources'],
51
+ "citations": mock_response['citations']
52
+ }
53
+ mock_formatter.return_value = mock_formatter_instance
54
+
55
+ # Mock LLMService.from_environment to return a mock instance
56
+ mock_llm_instance = MagicMock()
57
+ mock_llm.from_environment.return_value = mock_llm_instance
58
+
59
+ request_data = {
60
+ "message": "What is the remote work policy?",
61
+ "include_sources": True
62
+ }
63
+
64
+ response = client.post(
65
+ "/chat",
66
+ data=json.dumps(request_data),
67
+ content_type="application/json"
68
+ )
69
+
70
+ assert response.status_code == 200
71
+ data = response.get_json()
72
+
73
+ assert data["status"] == "success"
74
+ assert "answer" in data
75
+ assert "confidence" in data
76
+ assert "sources" in data
77
+ assert "citations" in data
78
+
79
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
80
+ @patch('app.RAGPipeline')
81
+ @patch('app.ResponseFormatter')
82
+ @patch('app.LLMService')
83
+ @patch('app.SearchService')
84
+ @patch('app.VectorDatabase')
85
+ @patch('app.EmbeddingService')
86
+ def test_chat_endpoint_minimal_request(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
87
+ """Test chat endpoint with minimal request (only message)"""
88
+ mock_response = {
89
+ 'answer': 'Employee benefits include health insurance, retirement plans, and PTO.',
90
+ 'confidence': 0.78,
91
+ 'sources': [],
92
+ 'citations': ['employee_benefits_guide.md'],
93
+ 'processing_time_ms': 1200
94
+ }
95
+
96
+ # Setup mock instances
97
+ mock_rag_instance = MagicMock()
98
+ mock_rag_instance.generate_answer.return_value = mock_response
99
+ mock_rag.return_value = mock_rag_instance
100
+
101
+ mock_formatter_instance = MagicMock()
102
+ mock_formatter_instance.format_api_response.return_value = {
103
+ "status": "success",
104
+ "answer": mock_response['answer']
105
+ }
106
+ mock_formatter.return_value = mock_formatter_instance
107
+
108
+ mock_llm.from_environment.return_value = MagicMock()
109
+
110
+ request_data = {"message": "What are the employee benefits?"}
111
+
112
+ response = client.post(
113
+ "/chat",
114
+ data=json.dumps(request_data),
115
+ content_type="application/json"
116
+ )
117
+
118
+ assert response.status_code == 200
119
+ data = response.get_json()
120
+ assert data["status"] == "success"
121
+
122
+ def test_chat_endpoint_missing_message(self, client):
123
+ """Test chat endpoint with missing message parameter"""
124
+ request_data = {"include_sources": True}
125
+
126
+ response = client.post(
127
+ "/chat",
128
+ data=json.dumps(request_data),
129
+ content_type="application/json"
130
+ )
131
+
132
+ assert response.status_code == 400
133
+ data = response.get_json()
134
+ assert data["status"] == "error"
135
+ assert "message parameter is required" in data["message"]
136
+
137
+ def test_chat_endpoint_empty_message(self, client):
138
+ """Test chat endpoint with empty message"""
139
+ request_data = {"message": ""}
140
+
141
+ response = client.post(
142
+ "/chat",
143
+ data=json.dumps(request_data),
144
+ content_type="application/json"
145
+ )
146
+
147
+ assert response.status_code == 400
148
+ data = response.get_json()
149
+ assert data["status"] == "error"
150
+ assert "non-empty string" in data["message"]
151
+
152
+ def test_chat_endpoint_non_string_message(self, client):
153
+ """Test chat endpoint with non-string message"""
154
+ request_data = {"message": 123}
155
+
156
+ response = client.post(
157
+ "/chat",
158
+ data=json.dumps(request_data),
159
+ content_type="application/json"
160
+ )
161
+
162
+ assert response.status_code == 400
163
+ data = response.get_json()
164
+ assert data["status"] == "error"
165
+ assert "non-empty string" in data["message"]
166
+
167
+ def test_chat_endpoint_non_json_request(self, client):
168
+ """Test chat endpoint with non-JSON request"""
169
+ response = client.post("/chat", data="not json", content_type="text/plain")
170
+
171
+ assert response.status_code == 400
172
+ data = response.get_json()
173
+ assert data["status"] == "error"
174
+ assert "application/json" in data["message"]
175
+
176
+ def test_chat_endpoint_no_llm_config(self, client):
177
+ """Test chat endpoint with no LLM configuration"""
178
+ with patch.dict(os.environ, {}, clear=True):
179
+ request_data = {"message": "What is the policy?"}
180
+
181
+ response = client.post(
182
+ "/chat",
183
+ data=json.dumps(request_data),
184
+ content_type="application/json"
185
+ )
186
+
187
+ assert response.status_code == 503
188
+ data = response.get_json()
189
+ assert data["status"] == "error"
190
+ assert "LLM service configuration error" in data["message"]
191
+
192
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
193
+ @patch('app.RAGPipeline')
194
+ @patch('app.ResponseFormatter')
195
+ @patch('app.LLMService')
196
+ @patch('app.SearchService')
197
+ @patch('app.VectorDatabase')
198
+ @patch('app.EmbeddingService')
199
+ def test_chat_endpoint_with_conversation_id(self, mock_embedding, mock_vector, mock_search, mock_llm, mock_formatter, mock_rag, client):
200
+ """Test chat endpoint with conversation_id parameter"""
201
+ mock_response = {
202
+ 'answer': 'The PTO policy allows 15 days of vacation annually.',
203
+ 'confidence': 0.9,
204
+ 'sources': [],
205
+ 'citations': ['pto_policy.md'],
206
+ 'processing_time_ms': 1100
207
+ }
208
+ mock_generate.return_value = mock_response
209
+ mock_llm_service.return_value = MagicMock()
210
+
211
+ request_data = {
212
+ "message": "What is the PTO policy?",
213
+ "conversation_id": "conv_123",
214
+ "include_sources": False
215
+ }
216
+
217
+ response = client.post(
218
+ "/chat",
219
+ data=json.dumps(request_data),
220
+ content_type="application/json"
221
+ )
222
+
223
+ assert response.status_code == 200
224
+ data = response.get_json()
225
+ assert data["status"] == "success"
226
+
227
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
228
+ @patch('src.llm.llm_service.LLMService.from_environment')
229
+ @patch('src.rag.rag_pipeline.RAGPipeline.generate_answer')
230
+ def test_chat_endpoint_with_debug(self, mock_generate, mock_llm_service, client):
231
+ """Test chat endpoint with debug information"""
232
+ mock_response = {
233
+ 'answer': 'The security policy requires 2FA authentication.',
234
+ 'confidence': 0.95,
235
+ 'sources': [{'chunk_id': '456', 'content': 'Security requirements...'}],
236
+ 'citations': ['information_security_policy.md'],
237
+ 'processing_time_ms': 1800,
238
+ 'search_results_count': 5,
239
+ 'context_length': 2048
240
+ }
241
+ mock_generate.return_value = mock_response
242
+ mock_llm_service.return_value = MagicMock()
243
+
244
+ request_data = {
245
+ "message": "What are the security requirements?",
246
+ "include_debug": True
247
+ }
248
+
249
+ response = client.post(
250
+ "/chat",
251
+ data=json.dumps(request_data),
252
+ content_type="application/json"
253
+ )
254
+
255
+ assert response.status_code == 200
256
+ data = response.get_json()
257
+ assert data["status"] == "success"
258
+
259
+
260
+ class TestChatHealthEndpoint:
261
+ """Test cases for the /chat/health endpoint"""
262
+
263
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
264
+ @patch('src.llm.llm_service.LLMService.from_environment')
265
+ @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
266
+ def test_chat_health_healthy(self, mock_health_check, mock_llm_service, client):
267
+ """Test chat health endpoint when all services are healthy"""
268
+ mock_health_data = {
269
+ "pipeline": "healthy",
270
+ "components": {
271
+ "search_service": {"status": "healthy"},
272
+ "llm_service": {"status": "healthy"},
273
+ "vector_db": {"status": "healthy"}
274
+ }
275
+ }
276
+ mock_health_check.return_value = mock_health_data
277
+ mock_llm_service.return_value = MagicMock()
278
+
279
+ response = client.get("/chat/health")
280
+
281
+ assert response.status_code == 200
282
+ data = response.get_json()
283
+ assert data["status"] == "success"
284
+
285
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
286
+ @patch('src.llm.llm_service.LLMService.from_environment')
287
+ @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
288
+ def test_chat_health_degraded(self, mock_health_check, mock_llm_service, client):
289
+ """Test chat health endpoint when services are degraded"""
290
+ mock_health_data = {
291
+ "pipeline": "degraded",
292
+ "components": {
293
+ "search_service": {"status": "healthy"},
294
+ "llm_service": {"status": "degraded", "warning": "High latency"},
295
+ "vector_db": {"status": "healthy"}
296
+ }
297
+ }
298
+ mock_health_check.return_value = mock_health_data
299
+ mock_llm_service.return_value = MagicMock()
300
+
301
+ response = client.get("/chat/health")
302
+
303
+ assert response.status_code == 200
304
+ data = response.get_json()
305
+ assert data["status"] == "success"
306
+
307
+ def test_chat_health_no_llm_config(self, client):
308
+ """Test chat health endpoint with no LLM configuration"""
309
+ with patch.dict(os.environ, {}, clear=True):
310
+ response = client.get("/chat/health")
311
+
312
+ assert response.status_code == 503
313
+ data = response.get_json()
314
+ assert data["status"] == "error"
315
+ assert "LLM configuration error" in data["message"]
316
+
317
+ @patch.dict(os.environ, {'OPENROUTER_API_KEY': 'test_key'})
318
+ @patch('src.llm.llm_service.LLMService.from_environment')
319
+ @patch('src.rag.rag_pipeline.RAGPipeline.health_check')
320
+ def test_chat_health_unhealthy(self, mock_health_check, mock_llm_service, client):
321
+ """Test chat health endpoint when services are unhealthy"""
322
+ mock_health_data = {
323
+ "pipeline": "unhealthy",
324
+ "components": {
325
+ "search_service": {"status": "unhealthy", "error": "Database connection failed"},
326
+ "llm_service": {"status": "unhealthy", "error": "API unreachable"},
327
+ "vector_db": {"status": "unhealthy"}
328
+ }
329
+ }
330
+ mock_health_check.return_value = mock_health_data
331
+ mock_llm_service.return_value = MagicMock()
332
+
333
+ response = client.get("/chat/health")
334
+
335
+ assert response.status_code == 503
336
+ data = response.get_json()
337
+ assert data["status"] == "success" # Still returns success, but 503 status code
tests/test_llm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # LLM Service Tests
tests/test_llm/test_llm_service.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test LLM Service
3
+
4
+ Tests for LLM integration and service functionality.
5
+ """
6
+
7
+ import pytest
8
+ from unittest.mock import Mock, patch, MagicMock
9
+ import requests
10
+ from src.llm.llm_service import LLMService, LLMConfig, LLMResponse
11
+
12
+
13
+ class TestLLMConfig:
14
+ """Test LLMConfig dataclass."""
15
+
16
+ def test_llm_config_creation(self):
17
+ """Test basic LLMConfig creation."""
18
+ config = LLMConfig(
19
+ provider="openrouter",
20
+ api_key="test-key",
21
+ model_name="test-model",
22
+ base_url="https://test.com"
23
+ )
24
+
25
+ assert config.provider == "openrouter"
26
+ assert config.api_key == "test-key"
27
+ assert config.model_name == "test-model"
28
+ assert config.base_url == "https://test.com"
29
+ assert config.max_tokens == 1000 # Default value
30
+ assert config.temperature == 0.1 # Default value
31
+
32
+
33
+ class TestLLMResponse:
34
+ """Test LLMResponse dataclass."""
35
+
36
+ def test_llm_response_creation(self):
37
+ """Test basic LLMResponse creation."""
38
+ response = LLMResponse(
39
+ content="Test response",
40
+ provider="openrouter",
41
+ model="test-model",
42
+ usage={"tokens": 100},
43
+ response_time=1.5,
44
+ success=True
45
+ )
46
+
47
+ assert response.content == "Test response"
48
+ assert response.provider == "openrouter"
49
+ assert response.model == "test-model"
50
+ assert response.usage == {"tokens": 100}
51
+ assert response.response_time == 1.5
52
+ assert response.success is True
53
+ assert response.error_message is None
54
+
55
+
56
+ class TestLLMService:
57
+ """Test LLMService functionality."""
58
+
59
+ def test_initialization_with_configs(self):
60
+ """Test LLMService initialization with configurations."""
61
+ config = LLMConfig(
62
+ provider="openrouter",
63
+ api_key="test-key",
64
+ model_name="test-model",
65
+ base_url="https://test.com"
66
+ )
67
+
68
+ service = LLMService([config])
69
+
70
+ assert len(service.configs) == 1
71
+ assert service.configs[0] == config
72
+ assert service.current_config_index == 0
73
+
74
+ def test_initialization_empty_configs_raises_error(self):
75
+ """Test that empty configs raise ValueError."""
76
+ with pytest.raises(ValueError, match="At least one LLM configuration must be provided"):
77
+ LLMService([])
78
+
79
+ @patch.dict('os.environ', {'OPENROUTER_API_KEY': 'test-openrouter-key'})
80
+ def test_from_environment_with_openrouter_key(self):
81
+ """Test creating service from environment with OpenRouter key."""
82
+ service = LLMService.from_environment()
83
+
84
+ assert len(service.configs) >= 1
85
+ openrouter_config = next(
86
+ (config for config in service.configs if config.provider == "openrouter"),
87
+ None
88
+ )
89
+ assert openrouter_config is not None
90
+ assert openrouter_config.api_key == "test-openrouter-key"
91
+
92
+ @patch.dict('os.environ', {'GROQ_API_KEY': 'test-groq-key'})
93
+ def test_from_environment_with_groq_key(self):
94
+ """Test creating service from environment with Groq key."""
95
+ service = LLMService.from_environment()
96
+
97
+ assert len(service.configs) >= 1
98
+ groq_config = next(
99
+ (config for config in service.configs if config.provider == "groq"),
100
+ None
101
+ )
102
+ assert groq_config is not None
103
+ assert groq_config.api_key == "test-groq-key"
104
+
105
+ @patch.dict('os.environ', {}, clear=True)
106
+ def test_from_environment_no_keys_raises_error(self):
107
+ """Test that no environment keys raise ValueError."""
108
+ with pytest.raises(ValueError, match="No LLM API keys found in environment"):
109
+ LLMService.from_environment()
110
+
111
+ @patch('requests.post')
112
+ def test_successful_response_generation(self, mock_post):
113
+ """Test successful response generation."""
114
+ # Mock successful API response
115
+ mock_response = Mock()
116
+ mock_response.status_code = 200
117
+ mock_response.json.return_value = {
118
+ "choices": [
119
+ {"message": {"content": "Test response content"}}
120
+ ],
121
+ "usage": {"prompt_tokens": 50, "completion_tokens": 20}
122
+ }
123
+ mock_response.raise_for_status = Mock()
124
+ mock_post.return_value = mock_response
125
+
126
+ config = LLMConfig(
127
+ provider="openrouter",
128
+ api_key="test-key",
129
+ model_name="test-model",
130
+ base_url="https://api.openrouter.ai/api/v1"
131
+ )
132
+ service = LLMService([config])
133
+
134
+ result = service.generate_response("Test prompt")
135
+
136
+ assert result.success is True
137
+ assert result.content == "Test response content"
138
+ assert result.provider == "openrouter"
139
+ assert result.model == "test-model"
140
+ assert result.usage == {"prompt_tokens": 50, "completion_tokens": 20}
141
+ assert result.response_time > 0
142
+
143
+ # Verify API call
144
+ mock_post.assert_called_once()
145
+ args, kwargs = mock_post.call_args
146
+ assert args[0] == "https://api.openrouter.ai/api/v1/chat/completions"
147
+ assert kwargs["json"]["model"] == "test-model"
148
+ assert kwargs["json"]["messages"][0]["content"] == "Test prompt"
149
+
150
+ @patch('requests.post')
151
+ def test_api_error_handling(self, mock_post):
152
+ """Test handling of API errors."""
153
+ # Mock API error
154
+ mock_post.side_effect = requests.exceptions.RequestException("API Error")
155
+
156
+ config = LLMConfig(
157
+ provider="openrouter",
158
+ api_key="test-key",
159
+ model_name="test-model",
160
+ base_url="https://api.openrouter.ai/api/v1"
161
+ )
162
+ service = LLMService([config])
163
+
164
+ result = service.generate_response("Test prompt")
165
+
166
+ assert result.success is False
167
+ assert "API Error" in result.error_message
168
+ assert result.content == ""
169
+ assert result.provider == "openrouter"
170
+
171
+ @patch('requests.post')
172
+ def test_fallback_to_second_provider(self, mock_post):
173
+ """Test fallback to second provider when first fails."""
174
+ # Mock first provider failing, second succeeding
175
+ first_call = Mock()
176
+ first_call.side_effect = requests.exceptions.RequestException("First provider error")
177
+
178
+ second_call = Mock()
179
+ second_response = Mock()
180
+ second_response.status_code = 200
181
+ second_response.json.return_value = {
182
+ "choices": [{"message": {"content": "Second provider response"}}],
183
+ "usage": {}
184
+ }
185
+ second_response.raise_for_status = Mock()
186
+ second_call.return_value = second_response
187
+
188
+ mock_post.side_effect = [first_call.side_effect, second_response]
189
+
190
+ config1 = LLMConfig(
191
+ provider="openrouter",
192
+ api_key="key1",
193
+ model_name="model1",
194
+ base_url="https://api1.com"
195
+ )
196
+ config2 = LLMConfig(
197
+ provider="groq",
198
+ api_key="key2",
199
+ model_name="model2",
200
+ base_url="https://api2.com"
201
+ )
202
+
203
+ service = LLMService([config1, config2])
204
+ result = service.generate_response("Test prompt")
205
+
206
+ assert result.success is True
207
+ assert result.content == "Second provider response"
208
+ assert result.provider == "groq"
209
+ assert mock_post.call_count == 2
210
+
211
+ @patch('requests.post')
212
+ def test_all_providers_fail(self, mock_post):
213
+ """Test when all providers fail."""
214
+ mock_post.side_effect = requests.exceptions.RequestException("All providers down")
215
+
216
+ config1 = LLMConfig(provider="provider1", api_key="key1", model_name="model1", base_url="url1")
217
+ config2 = LLMConfig(provider="provider2", api_key="key2", model_name="model2", base_url="url2")
218
+
219
+ service = LLMService([config1, config2])
220
+ result = service.generate_response("Test prompt")
221
+
222
+ assert result.success is False
223
+ assert "All providers failed" in result.error_message
224
+ assert result.provider == "none"
225
+ assert result.model == "none"
226
+
227
+ @patch('requests.post')
228
+ def test_retry_logic(self, mock_post):
229
+ """Test retry logic for failed requests."""
230
+ # First call fails, second succeeds
231
+ first_response = Mock()
232
+ first_response.side_effect = requests.exceptions.RequestException("Temporary error")
233
+
234
+ second_response = Mock()
235
+ second_response.status_code = 200
236
+ second_response.json.return_value = {
237
+ "choices": [{"message": {"content": "Success after retry"}}],
238
+ "usage": {}
239
+ }
240
+ second_response.raise_for_status = Mock()
241
+
242
+ mock_post.side_effect = [first_response.side_effect, second_response]
243
+
244
+ config = LLMConfig(
245
+ provider="openrouter",
246
+ api_key="test-key",
247
+ model_name="test-model",
248
+ base_url="https://api.openrouter.ai/api/v1"
249
+ )
250
+ service = LLMService([config])
251
+
252
+ result = service.generate_response("Test prompt", max_retries=1)
253
+
254
+ assert result.success is True
255
+ assert result.content == "Success after retry"
256
+ assert mock_post.call_count == 2
257
+
258
+ def test_get_available_providers(self):
259
+ """Test getting list of available providers."""
260
+ config1 = LLMConfig(provider="openrouter", api_key="key1", model_name="model1", base_url="url1")
261
+ config2 = LLMConfig(provider="groq", api_key="key2", model_name="model2", base_url="url2")
262
+
263
+ service = LLMService([config1, config2])
264
+ providers = service.get_available_providers()
265
+
266
+ assert providers == ["openrouter", "groq"]
267
+
268
+ @patch('requests.post')
269
+ def test_health_check(self, mock_post):
270
+ """Test health check functionality."""
271
+ # Mock successful health check
272
+ mock_response = Mock()
273
+ mock_response.status_code = 200
274
+ mock_response.json.return_value = {
275
+ "choices": [{"message": {"content": "OK"}}],
276
+ "usage": {}
277
+ }
278
+ mock_response.raise_for_status = Mock()
279
+ mock_post.return_value = mock_response
280
+
281
+ config = LLMConfig(
282
+ provider="openrouter",
283
+ api_key="test-key",
284
+ model_name="test-model",
285
+ base_url="https://api.openrouter.ai/api/v1"
286
+ )
287
+ service = LLMService([config])
288
+
289
+ health_status = service.health_check()
290
+
291
+ assert "openrouter" in health_status
292
+ assert health_status["openrouter"]["status"] == "healthy"
293
+ assert health_status["openrouter"]["model"] == "test-model"
294
+ assert health_status["openrouter"]["response_time"] > 0
295
+
296
+ @patch('requests.post')
297
+ def test_openrouter_specific_headers(self, mock_post):
298
+ """Test that OpenRouter-specific headers are added."""
299
+ mock_response = Mock()
300
+ mock_response.status_code = 200
301
+ mock_response.json.return_value = {
302
+ "choices": [{"message": {"content": "Test"}}],
303
+ "usage": {}
304
+ }
305
+ mock_response.raise_for_status = Mock()
306
+ mock_post.return_value = mock_response
307
+
308
+ config = LLMConfig(
309
+ provider="openrouter",
310
+ api_key="test-key",
311
+ model_name="test-model",
312
+ base_url="https://api.openrouter.ai/api/v1"
313
+ )
314
+ service = LLMService([config])
315
+
316
+ service.generate_response("Test")
317
+
318
+ # Check headers
319
+ args, kwargs = mock_post.call_args
320
+ headers = kwargs["headers"]
321
+ assert "HTTP-Referer" in headers
322
+ assert "X-Title" in headers
323
+ assert headers["HTTP-Referer"] == "https://github.com/sethmcknight/msse-ai-engineering"
tests/test_rag/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # RAG Pipeline Tests