Spaces:
Sleeping
Sleeping
| import logging | |
| import numpy as np | |
| import uuid | |
| import os | |
| from datetime import datetime | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, Any | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from urllib.parse import urljoin, urlparse | |
| import re | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| from sentence_transformers import SentenceTransformer | |
| from langchain_community.llms import Ollama | |
| import json | |
| import asyncio | |
| from src.config import Config | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Configuration | |
| class Config: | |
| """ | |
| Application configuration settings. | |
| Contains constants for storage, models, and Qdrant connection. | |
| """ | |
| STORAGE_DIR = "data/qdrant_storage" | |
| EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| url="https://6fe012ee-5a7c-4304-a77c-293a1888a9cf.us-west-2-0.aws.cloud.qdrant.io" | |
| QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.NUKB9m360LPEBTnpdo2TJpJmEIttumHLz-9ZbAUBKIM" | |
| QDRANT_COLLECTION_NAME = "Chat-Bot" | |
| def create_storage_dir(): | |
| """Ensure storage directory exists""" | |
| os.makedirs(Config.STORAGE_DIR, exist_ok=True) | |
| # Data classes | |
| class Document: | |
| """Represents a document with text content and metadata""" | |
| def __init__(self, text: str, metadata: Dict[str, Any]): | |
| self.text = text | |
| self.metadata = metadata | |
| # Session Manager | |
| class SessionManager: | |
| """ | |
| Manages user sessions and Qdrant collections. | |
| Handles session state, document storage, and conversation history. | |
| """ | |
| def __init__(self): | |
| """Initialize with in-memory sessions and Qdrant connection""" | |
| from src.prompts.templates import rag_prompt_template | |
| from src.llm import GeminiProvider | |
| self.llm = GeminiProvider() | |
| self.sessions = {} # In-memory session store | |
| self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL) | |
| self.qdrant_client = QdrantClient( | |
| url=Config.url, | |
| api_key=Config.QDRANT_API_KEY, | |
| timeout=30 | |
| ) | |
| def get_collection_name(self, session_id: str) -> str: | |
| """Generate standardized Qdrant collection name for a session""" | |
| return f"collection_{session_id}" | |
| def get_session(self, session_id: str) -> Dict: | |
| """ | |
| Get or create session with the given ID. | |
| Maintains original interface while adding robustness. | |
| """ | |
| if session_id not in self.sessions: | |
| self._initialize_new_session(session_id) | |
| print(f"[SessionManager] Created new session: {session_id}") | |
| return self.sessions[session_id] | |
| def _initialize_new_session(self, session_id: str): | |
| """Internal method to handle new session creation""" | |
| self.sessions[session_id] = { | |
| 'documents': [], | |
| 'history': [] | |
| } | |
| self._ensure_qdrant_collection(session_id) | |
| print(f"[SessionManager] Initialized session {session_id} with Qdrant collection.") | |
| def _ensure_qdrant_collection(self, session_id: str): | |
| """Ensure Qdrant collection exists for the session""" | |
| collection_name = self.get_collection_name(session_id) | |
| try: | |
| # First try to get the collection (might already exist) | |
| self.qdrant_client.get_collection(collection_name) | |
| logger.debug(f"Using existing Qdrant collection: {collection_name}") | |
| except Exception: | |
| # Collection doesn't exist, create it | |
| try: | |
| self.qdrant_client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams( | |
| size=self.embedding_model.get_sentence_embedding_dimension(), | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| logger.info(f"Created new Qdrant collection: {collection_name}") | |
| except UnexpectedResponse as e: | |
| if "already exists" in str(e): | |
| logger.debug(f"Collection already exists: {collection_name}") | |
| else: | |
| logger.error(f"Error creating collection: {e}") | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error ensuring collection: {e}") | |
| raise | |
| def add_to_history(self, session_id: str, question: str, answer: str): | |
| """Add conversation to session history""" | |
| if session_id not in self.sessions: | |
| logger.warning(f"Session {session_id} not found when adding history") | |
| return | |
| self.sessions[session_id]['history'].append({ | |
| 'question': question, | |
| 'answer': answer, | |
| 'timestamp': datetime.now().isoformat() | |
| }) | |
| def get_history(self, session_id: str, limit: Optional[int] = None) -> List[Dict]: | |
| """Get conversation history with optional limit""" | |
| if session_id not in self.sessions: | |
| logger.warning(f"Session {session_id} not found when getting history") | |
| return [] | |
| history = self.sessions[session_id]['history'] | |
| return history[-limit:] if limit else history | |
| def session_exists(self, session_id: str) -> bool: | |
| """Check if session exists""" | |
| if session_id in self.sessions: | |
| return True | |
| collection_name = self.get_collection_name(session_id) | |
| try: | |
| self.qdrant_client.get_collection(collection_name) | |
| # Add to sessions if collection exists | |
| self.sessions[session_id] = { | |
| 'documents': [], | |
| 'history': [] | |
| } | |
| return True | |
| except Exception: | |
| return False | |
| def cleanup_inactive_sessions(self, inactive_minutes: int = 60): | |
| """Clean up sessions inactive for specified minutes""" | |
| current_time = datetime.now() | |
| for session_id in list(self.sessions.keys()): | |
| history = self.sessions[session_id]['history'] | |
| if history: | |
| last_activity = datetime.fromisoformat(history[-1]['timestamp']) | |
| if (current_time - last_activity).total_seconds() > inactive_minutes * 60: | |
| del self.sessions[session_id] | |
| logger.info(f"Cleaned up inactive session: {session_id}") | |
| def save_session(self, session_id: str): | |
| """Qdrant persists data automatically""" | |
| pass | |
| def add_conversation(self, session_id: str, query: str, response: str): | |
| """Add conversation to session history""" | |
| self.sessions[session_id]['history'].append({"query": query, "response": response}) | |
| def get_conversation_history(self, session_id: str): | |
| """Get full conversation history""" | |
| return self.sessions[session_id]['history'] | |
| def add_documents_to_qdrant(self, session_id: str, documents: List[Document]): | |
| """Add documents to Qdrant collection with validation""" | |
| texts = [doc.text for doc in documents] | |
| try: | |
| embeddings = self.embedding_model.encode(texts, batch_size=32, show_progress_bar=True) | |
| if isinstance(embeddings, np.ndarray): | |
| embeddings = embeddings.tolist() | |
| points = [ | |
| PointStruct( | |
| id=idx, | |
| vector=embedding, | |
| payload={ | |
| "text": doc.text, | |
| "metadata": doc.metadata | |
| } | |
| ) | |
| for idx, (embedding, doc) in enumerate(zip(embeddings, documents)) | |
| ] | |
| collection_name = self.get_collection_name(session_id) | |
| operation_info = self.qdrant_client.upsert( | |
| collection_name=collection_name, | |
| points=points, | |
| wait=True # Wait for operation confirmation | |
| ) | |
| logger.info(f"Upsert operation status: {operation_info.status}") | |
| self.sessions[session_id]['documents'].extend(documents) | |
| except Exception as e: | |
| logger.error(f"Document insertion failed: {e}") | |
| raise | |
| def search_qdrant(self, session_id: str, query_embedding: np.ndarray, k: int = 3): | |
| """Search Qdrant collection with error handling""" | |
| try: | |
| if isinstance(query_embedding, np.ndarray): | |
| query_embedding = query_embedding.tolist() | |
| collection_name = self.get_collection_name(session_id) | |
| return self.qdrant_client.search( | |
| collection_name=collection_name, | |
| query_vector=query_embedding, | |
| limit=k, | |
| with_payload=True, | |
| with_vectors=False | |
| ) | |
| except Exception as e: | |
| logger.error(f"Search failed: {e}") | |
| raise | |
| # Web Crawler | |
| class WebCrawler: | |
| """Handles web crawling with depth control and duplicate prevention""" | |
| def __init__(self, max_depth=2, delay=1): | |
| self.max_depth = max_depth | |
| self.delay = delay | |
| self.visited = set() | |
| def crawl_recursive(self, url, depth=0): | |
| """Recursively crawl URLs up to max_depth""" | |
| print(f"[WebCrawler] Crawling {url} at depth {depth}") | |
| if not hasattr(self, "collected_links"): | |
| self.collected_links = set() | |
| if depth > self.max_depth or url in self.visited or len(self.collected_links) >= 50: | |
| return [] | |
| self.visited.add(url) | |
| self.collected_links.add(url) | |
| links = [url] | |
| try: | |
| response = requests.get(url, timeout=10, headers={"User-Agent": "Mozilla/5.0"}) | |
| soup = BeautifulSoup(response.content, "html.parser") | |
| for tag in soup.find_all("a", href=True): | |
| if len(self.collected_links) >= 10: | |
| break # Stop if 50 links collected | |
| href = urljoin(url, tag["href"]) | |
| if urlparse(href).netloc == urlparse(url).netloc: | |
| links.extend(self.crawl_recursive(href, depth + 1)) | |
| except Exception as e: | |
| logger.warning(f"Failed to crawl {url}: {e}") | |
| return list(set(links)) | |
| # Connection Manager | |
| class ConnectionManager: | |
| """Manages active WebSocket connections""" | |
| def __init__(self): | |
| self.active_connections: Dict[str, WebSocket] = {} | |
| async def connect(self, websocket: WebSocket, session_id: str): | |
| """Register new WebSocket connection""" | |
| await websocket.accept() | |
| self.active_connections[session_id] = websocket | |
| async def disconnect(self, session_id: str): | |
| """Remove WebSocket connection""" | |
| if session_id in self.active_connections: | |
| del self.active_connections[session_id] | |
| async def send_message(self, message: str, session_id: str): | |
| """Send message to specific WebSocket connection""" | |
| if session_id in self.active_connections: | |
| await self.active_connections[session_id].send_text(message) | |
| # RAG System with Qdrant | |
| class RAGSystem: | |
| """Main RAG system orchestrating crawling, indexing and querying""" | |
| def __init__(self): | |
| self.session_manager = SessionManager() | |
| self.crawler = WebCrawler() | |
| self.llm = Ollama(base_url="http://localhost:11434", model="mistral") | |
| def crawl_and_index(self, session_id: str, start_url: str) -> Dict[str, Any]: | |
| """Crawl website and index content in Qdrant""" | |
| print(f"[RAGSystem] Starting crawl and index for session {session_id} with URL: {start_url}") | |
| try: | |
| session = self.session_manager.get_session(session_id) | |
| all_urls = self.crawler.crawl_recursive(start_url) | |
| documents, successful_urls = [], [] | |
| print(f"[RAGSystem] Crawled {len(all_urls)} URLs for session {session_id}") | |
| for url in all_urls[:20]: # Limit to 20 URLs | |
| try: | |
| print(f"[RAGSystem] Processing URL: {url}") | |
| response = requests.get(url, timeout=10, headers={"User-Agent": "Mozilla/5.0"}) | |
| soup = BeautifulSoup(response.content, "html.parser") | |
| for tag in soup(["script", "style"]): | |
| tag.decompose() | |
| text = " ".join(chunk.strip() for chunk in soup.get_text().splitlines() if chunk.strip()) | |
| if len(text) > 100: | |
| documents.append(Document(text, {"source_url": url, "session_id": session_id})) | |
| successful_urls.append(url) | |
| except Exception as e: | |
| logger.warning(f"Error processing {url}: {e}") | |
| if documents: | |
| self.session_manager.add_documents_to_qdrant(session_id, documents) | |
| return { | |
| "status": "success", | |
| "urls_processed": successful_urls, | |
| "total_documents": len(documents) | |
| } | |
| return {"status": "error", "message": "No documents indexed"} | |
| except Exception as e: | |
| logger.error(f"crawl_and_index error: {e}") | |
| return { | |
| "status": "error", | |
| "message": f"Error during crawling and indexing: {str(e)}" | |
| } | |
| async def chat( | |
| self, | |
| session_id: str, | |
| question: str, | |
| model: str = "mistral", | |
| ollama_url: str = None, | |
| gemini_api_key: str = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Handle chat requests with model selection. | |
| Supports both Mistral (via Ollama) and Gemini models. | |
| """ | |
| try: | |
| # Get session data | |
| session = self.session_manager.get_session(session_id) | |
| if not session.get('documents'): | |
| return { | |
| "status": "error", | |
| "message": "No documents indexed for this session" | |
| } | |
| # Select appropriate LLM | |
| if model == "mistral" and ollama_url: | |
| self.llm = Ollama(base_url=ollama_url, model="mistral") | |
| elif model == "gemini" and gemini_api_key: | |
| from src.llm import GeminiProvider | |
| self.llm = GeminiProvider() | |
| # Process the query | |
| result = self.process_query(session_id, question) | |
| # Add to conversation history if successful | |
| if result["status"] == "success": | |
| self.session_manager.add_conversation( | |
| session_id, | |
| question, | |
| result["response"] | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Chat error: {str(e)}") | |
| return { | |
| "status": "error", | |
| "message": f"Chat error: {str(e)}" | |
| } | |
| def process_query(self, session_id: str, query: str) -> Dict[str, Any]: | |
| """Process user query through RAG pipeline""" | |
| try: | |
| # Import rag_prompt_template here to ensure it's defined | |
| from src.prompts.templates import rag_prompt_template | |
| # Validate and encode query | |
| query_embedding = self.session_manager.embedding_model.encode(query) | |
| if isinstance(query_embedding, np.ndarray): | |
| query_embedding = query_embedding.astype("float32") | |
| # Search with proper parameters | |
| search_result = self.session_manager.search_qdrant( | |
| session_id=session_id, | |
| query_embedding=query_embedding | |
| ) | |
| # Generate response using retrieved context | |
| context = "\n\n".join(hit.payload["text"] for hit in search_result) | |
| prompt = rag_prompt_template(context, query) | |
| response = self.llm.generate([prompt]) | |
| if hasattr(response, "generations"): | |
| response_text = response.generations[0][0].text | |
| else: | |
| response_text = response | |
| return { | |
| "status": "success", | |
| "response": response_text, | |
| "sources": [hit.payload["metadata"] for hit in search_result] | |
| } | |
| except Exception as e: | |
| logger.error(f"Query processing failed: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # FastAPI App | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Initialize RAG system | |
| rag = RAGSystem() | |
| # Request models | |
| class URLRequest(BaseModel): | |
| """Request model for URL crawling""" | |
| url: str | |
| session_id: Optional[str] = None | |
| class ChatRequest(BaseModel): | |
| """Request model for chat queries""" | |
| session_id: str | |
| question: str | |
| class SearchRequest(BaseModel): | |
| """Request model for direct searches""" | |
| session_id: str | |
| query: str | |
| limit: Optional[int] = 5 | |
| # API Endpoints | |
| async def root(): | |
| """Health check endpoint""" | |
| return {"message": "RAG with Ollama Mistral and Qdrant is running"} | |
| async def create_session(): | |
| """Create a new session ID""" | |
| session_id = str(uuid.uuid4()) | |
| return {"session_id": session_id, "status": "success"} | |
| async def crawl_and_index(request: URLRequest): | |
| """Crawl and index a website""" | |
| session_id = request.session_id or str(uuid.uuid4()) | |
| result = rag.crawl_and_index(session_id, request.url) | |
| return result | |
| async def chat(request: ChatRequest): | |
| """Handle chat request""" | |
| return await rag.chat(request.session_id, request.question) | |
| async def search(request: SearchRequest): | |
| """Handle direct search request""" | |
| try: | |
| session = rag.session_manager.get_session(request.session_id) | |
| query_embedding = rag.session_manager.embedding_model.encode(request.query) | |
| if isinstance(query_embedding, np.ndarray): | |
| query_embedding = query_embedding.tolist() | |
| collection_name = rag.session_manager.get_collection_name(request.session_id) | |
| search_results = rag.session_manager.qdrant_client.search( | |
| collection_name=collection_name, | |
| query_vector=query_embedding, | |
| limit=request.limit | |
| ) | |
| return { | |
| "status": "success", | |
| "results": [ | |
| { | |
| "text": hit.payload["text"], | |
| "score": hit.score, | |
| "metadata": hit.payload.get("metadata", {}) | |
| } | |
| for hit in search_results | |
| ] | |
| } | |
| except Exception as e: | |
| logger.error(f"API search failed: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| """WebSocket endpoint for real-time chat""" | |
| await websocket.accept() | |
| try: | |
| while True: | |
| data = await websocket.receive_json() | |
| uid = data.get("uid") | |
| question = data.get("question") | |
| if not uid or not question: | |
| await websocket.send_json({"error": "Missing 'uid' or 'question'"}) | |
| continue | |
| # Get response from RAG system | |
| response = await rag.chat(uid, question) | |
| # Handle both success and error cases | |
| if response["status"] == "success": | |
| await websocket.send_json({ | |
| "uid": uid, | |
| "question": question, | |
| "answer": response["response"], | |
| "sources": response.get("sources", []) | |
| }) | |
| else: | |
| await websocket.send_json({ | |
| "uid": uid, | |
| "error": response["message"] | |
| }) | |
| except WebSocketDisconnect: | |
| logger.info("WebSocket disconnected") | |
| except Exception as e: | |
| await websocket.send_json({"error": str(e)}) | |
| # Main entry point | |
| if __name__ == "__main__": | |
| Config.create_dirs() | |
| from src import launch_interface | |
| launch_interface() |