|
|
""" |
|
|
Test script for Advanced RAG features |
|
|
Demonstrates new capabilities: multiple texts/images indexing and advanced RAG chat |
|
|
""" |
|
|
|
|
|
import requests |
|
|
import json |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
class AdvancedRAGTester: |
|
|
"""Test client for Advanced RAG API""" |
|
|
|
|
|
def __init__(self, base_url: str = "http://localhost:8000"): |
|
|
self.base_url = base_url |
|
|
|
|
|
def test_multiple_index(self, doc_id: str, texts: List[str], image_paths: Optional[List[str]] = None): |
|
|
""" |
|
|
Test indexing with multiple texts and images |
|
|
|
|
|
Args: |
|
|
doc_id: Document ID |
|
|
texts: List of texts (max 10) |
|
|
image_paths: List of image file paths (max 10) |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"TEST: Indexing document '{doc_id}' with multiple texts/images") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
data = {'id': doc_id} |
|
|
|
|
|
|
|
|
if texts: |
|
|
if len(texts) > 10: |
|
|
print("WARNING: Maximum 10 texts allowed. Taking first 10.") |
|
|
texts = texts[:10] |
|
|
data['texts'] = texts |
|
|
print(f"✓ Texts: {len(texts)} items") |
|
|
|
|
|
|
|
|
files = [] |
|
|
if image_paths: |
|
|
if len(image_paths) > 10: |
|
|
print("WARNING: Maximum 10 images allowed. Taking first 10.") |
|
|
image_paths = image_paths[:10] |
|
|
|
|
|
for img_path in image_paths: |
|
|
try: |
|
|
files.append(('images', open(img_path, 'rb'))) |
|
|
except FileNotFoundError: |
|
|
print(f"WARNING: Image not found: {img_path}") |
|
|
|
|
|
print(f"✓ Images: {len(files)} files") |
|
|
|
|
|
|
|
|
try: |
|
|
response = requests.post(f"{self.base_url}/index", data=data, files=files) |
|
|
response.raise_for_status() |
|
|
|
|
|
result = response.json() |
|
|
print(f"\n✓ SUCCESS") |
|
|
print(f" - Document ID: {result['id']}") |
|
|
print(f" - Message: {result['message']}") |
|
|
return result |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"\n✗ ERROR: {e}") |
|
|
if hasattr(e.response, 'text'): |
|
|
print(f" Response: {e.response.text}") |
|
|
return None |
|
|
|
|
|
finally: |
|
|
|
|
|
for _, file_obj in files: |
|
|
file_obj.close() |
|
|
|
|
|
def test_advanced_rag_chat( |
|
|
self, |
|
|
message: str, |
|
|
hf_token: Optional[str] = None, |
|
|
use_advanced_rag: bool = True, |
|
|
use_reranking: bool = True, |
|
|
use_compression: bool = True, |
|
|
top_k: int = 3, |
|
|
score_threshold: float = 0.5 |
|
|
): |
|
|
""" |
|
|
Test advanced RAG chat |
|
|
|
|
|
Args: |
|
|
message: User question |
|
|
hf_token: Hugging Face token (optional) |
|
|
use_advanced_rag: Use advanced RAG pipeline |
|
|
use_reranking: Enable reranking |
|
|
use_compression: Enable context compression |
|
|
top_k: Number of documents to retrieve |
|
|
score_threshold: Minimum relevance score |
|
|
""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"TEST: Advanced RAG Chat") |
|
|
print(f"{'='*60}") |
|
|
print(f"Question: {message}") |
|
|
print(f"Advanced RAG: {use_advanced_rag}") |
|
|
print(f"Reranking: {use_reranking}") |
|
|
print(f"Compression: {use_compression}") |
|
|
|
|
|
payload = { |
|
|
'message': message, |
|
|
'use_rag': True, |
|
|
'use_advanced_rag': use_advanced_rag, |
|
|
'use_reranking': use_reranking, |
|
|
'use_compression': use_compression, |
|
|
'top_k': top_k, |
|
|
'score_threshold': score_threshold, |
|
|
} |
|
|
|
|
|
if hf_token: |
|
|
payload['hf_token'] = hf_token |
|
|
|
|
|
try: |
|
|
response = requests.post(f"{self.base_url}/chat", json=payload) |
|
|
response.raise_for_status() |
|
|
|
|
|
result = response.json() |
|
|
|
|
|
print(f"\n✓ SUCCESS") |
|
|
print(f"\n--- Answer ---") |
|
|
print(result['response']) |
|
|
|
|
|
print(f"\n--- Retrieved Context ({len(result['context_used'])} documents) ---") |
|
|
for i, ctx in enumerate(result['context_used'], 1): |
|
|
print(f"{i}. [{ctx['id']}] Confidence: {ctx['confidence']:.2%}") |
|
|
text_preview = ctx['metadata'].get('text', '')[:100] |
|
|
print(f" Text: {text_preview}...") |
|
|
|
|
|
if result.get('rag_stats'): |
|
|
print(f"\n--- RAG Pipeline Statistics ---") |
|
|
stats = result['rag_stats'] |
|
|
print(f" Original query: {stats.get('original_query')}") |
|
|
print(f" Expanded queries: {stats.get('expanded_queries')}") |
|
|
print(f" Initial results: {stats.get('initial_results')}") |
|
|
print(f" After reranking: {stats.get('after_rerank')}") |
|
|
print(f" After compression: {stats.get('after_compression')}") |
|
|
|
|
|
return result |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"\n✗ ERROR: {e}") |
|
|
if hasattr(e.response, 'text'): |
|
|
print(f" Response: {e.response.text}") |
|
|
return None |
|
|
|
|
|
def compare_basic_vs_advanced_rag(self, message: str, hf_token: Optional[str] = None): |
|
|
"""Compare basic RAG vs advanced RAG side by side""" |
|
|
print(f"\n{'='*60}") |
|
|
print(f"COMPARISON: Basic RAG vs Advanced RAG") |
|
|
print(f"{'='*60}") |
|
|
print(f"Question: {message}\n") |
|
|
|
|
|
|
|
|
print("\n--- BASIC RAG ---") |
|
|
basic_result = self.test_advanced_rag_chat( |
|
|
message=message, |
|
|
hf_token=hf_token, |
|
|
use_advanced_rag=False |
|
|
) |
|
|
|
|
|
|
|
|
print("\n--- ADVANCED RAG ---") |
|
|
advanced_result = self.test_advanced_rag_chat( |
|
|
message=message, |
|
|
hf_token=hf_token, |
|
|
use_advanced_rag=True |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("COMPARISON SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
if basic_result and advanced_result: |
|
|
print(f"Basic RAG:") |
|
|
print(f" - Retrieved docs: {len(basic_result['context_used'])}") |
|
|
|
|
|
print(f"\nAdvanced RAG:") |
|
|
print(f" - Retrieved docs: {len(advanced_result['context_used'])}") |
|
|
if advanced_result.get('rag_stats'): |
|
|
stats = advanced_result['rag_stats'] |
|
|
print(f" - Query expansion: {len(stats.get('expanded_queries', []))} variants") |
|
|
print(f" - Initial retrieval: {stats.get('initial_results', 0)} docs") |
|
|
print(f" - After reranking: {stats.get('after_rerank', 0)} docs") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run tests""" |
|
|
tester = AdvancedRAGTester() |
|
|
|
|
|
print("="*60) |
|
|
print("ADVANCED RAG FEATURE TESTS") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
print("\n\n### TEST 1: Index Multiple Texts ###") |
|
|
tester.test_multiple_index( |
|
|
doc_id="event_music_festival_2025", |
|
|
texts=[ |
|
|
"Festival âm nhạc quốc tế Hà Nội 2025", |
|
|
"Thời gian: 15-17 tháng 11 năm 2025", |
|
|
"Địa điểm: Công viên Thống Nhất, Hà Nội", |
|
|
"Line-up: Sơn Tùng MTP, Đen Vâu, Hoàng Thùy Linh, Mỹ Tâm", |
|
|
"Giá vé: Early bird 500.000đ, VIP 2.000.000đ", |
|
|
"Dự kiến 50.000 khán giả tham dự", |
|
|
"3 sân khấu chính, 5 food court, khu vực cắm trại" |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n### TEST 2: Index Another Document ###") |
|
|
tester.test_multiple_index( |
|
|
doc_id="safety_guidelines", |
|
|
texts=[ |
|
|
"Vũ khí và đồ vật nguy hiểm bị cấm mang vào sự kiện", |
|
|
"Dao, kiếm, súng và các loại vũ khí nguy hiểm nghiêm cấm", |
|
|
"An ninh sẽ kiểm tra tất cả túi xách và đồ mang theo", |
|
|
"Vi phạm sẽ bị tịch thu và có thể bị trục xuất khỏi sự kiện" |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n### TEST 3: Basic RAG Chat (No LLM) ###") |
|
|
tester.test_advanced_rag_chat( |
|
|
message="Festival Hà Nội diễn ra khi nào?", |
|
|
use_advanced_rag=False |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n### TEST 4: Advanced RAG Chat (No LLM) ###") |
|
|
tester.test_advanced_rag_chat( |
|
|
message="Festival Hà Nội diễn ra khi nào và có những nghệ sĩ nào?", |
|
|
use_advanced_rag=True, |
|
|
use_reranking=True, |
|
|
use_compression=True |
|
|
) |
|
|
|
|
|
|
|
|
print("\n\n### TEST 5: Comparison Test ###") |
|
|
tester.compare_basic_vs_advanced_rag( |
|
|
message="Dao có được mang vào sự kiện không?" |
|
|
) |
|
|
|
|
|
print("\n\n" + "="*60) |
|
|
print("ALL TESTS COMPLETED") |
|
|
print("="*60) |
|
|
print("\nNOTE: To test with actual LLM responses, add your Hugging Face token:") |
|
|
print(" tester.test_advanced_rag_chat(message='...', hf_token='hf_xxxxx')") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|