ashish-ninehertz commited on
Commit
e272f4f
·
1 Parent(s): f05a666
Alpha DELETED
@@ -1 +0,0 @@
1
- Subproject commit 959e8417af1dd5dba45ed870104ba1b5abacbe6b
 
 
app DELETED
@@ -1 +0,0 @@
1
- Subproject commit 9eff68fb1134d1f51c07aac9e0a88a77e60379c3
 
 
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import uuid
3
+ import logging
4
+ from typing import List, Tuple
5
+ from src.main import RAGSystem
6
+ import asyncio
7
+
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Initialize the RAG system
13
+ rag = RAGSystem()
14
+
15
+ def create_session() -> str:
16
+ """Create a new session ID"""
17
+ return str(uuid.uuid4())
18
+
19
+ def index_website(url: str, session_id: str) -> Tuple[bool, str]:
20
+ """Index a website for a given session"""
21
+ try:
22
+ result = rag.crawl_and_index(session_id, url)
23
+ if result["status"] == "success":
24
+ return True, f"Successfully indexed {len(result.get('urls_processed', []))} pages"
25
+ return False, result.get("message", "Unknown error during indexing")
26
+ except Exception as e:
27
+ logger.error(f"Indexing error: {str(e)}")
28
+ return False, f"Error during indexing: {str(e)}"
29
+
30
+ def chat_response(
31
+ session_id: str,
32
+ message: str,
33
+ model_choice: str,
34
+ ollama_url: str,
35
+ gemini_api_key: str,
36
+ chat_history: List[Tuple[str, str]]
37
+ ) -> Tuple[List[Tuple[str, str]], str]:
38
+ """Generate a chat response with proper error handling"""
39
+ if not session_id:
40
+ chat_history.append(("🧑 " + message, "🤖 Please index a website first or enter a valid session ID"))
41
+ return chat_history, ""
42
+
43
+ try:
44
+ response = asyncio.run(rag.chat(
45
+ session_id=session_id,
46
+ question=message,
47
+ model=model_choice.lower(),
48
+ ollama_url=ollama_url if model_choice == "mistral" else None,
49
+ gemini_api_key=gemini_api_key if model_choice == "gemini" else None
50
+ ))
51
+
52
+ if response["status"] == "success":
53
+ answer = response["response"]
54
+ sources = "\n\nSources:\n" + "\n".join(
55
+ f"- {src['source_url']}" for src in response.get("sources", [])
56
+ ) if response.get("sources") else ""
57
+ full_response = answer + sources
58
+ else:
59
+ full_response = f"Error: {response.get('message', 'Unknown error')}"
60
+
61
+ chat_history.append(("🧑 " + message, "🤖 " + full_response))
62
+ return chat_history, ""
63
+ except Exception as e:
64
+ logger.error(f"Chat error: {str(e)}")
65
+ chat_history.append(("🧑 " + message, f"🤖 System error: {str(e)}"))
66
+ return chat_history, ""
67
+
68
+ def toggle_model_inputs(model_choice: str) -> List[gr.update]:
69
+ """Show/hide model-specific inputs"""
70
+ if model_choice == "mistral":
71
+ return [gr.update(visible=True), gr.update(visible=False)]
72
+ return [gr.update(visible=False), gr.update(visible=True)]
73
+
74
+ def load_session(existing_session_id: str) -> Tuple[str, str]:
75
+ """Load an existing session"""
76
+ if existing_session_id:
77
+ # Here you might want to add validation if the session exists
78
+ return existing_session_id, f"Loaded existing session: {existing_session_id}"
79
+ return "", "Please enter a valid session ID"
80
+
81
+ def get_session(self, session_id: str):
82
+ # If session exists in memory, return it
83
+ if session_id in self.sessions:
84
+ return self.sessions[session_id]
85
+ # If not, check if Qdrant collection exists and has documents
86
+ collection_name = self.get_collection_name(session_id)
87
+ try:
88
+ results = self.qdrant_client.scroll(collection_name=collection_name, limit=1)
89
+ if results and results[0]:
90
+ # Rehydrate session in memory
91
+ self.sessions[session_id] = {
92
+ "documents": [], # Optionally, you can fetch all docs if needed
93
+ "history": []
94
+ }
95
+ return self.sessions[session_id]
96
+ except Exception as e:
97
+ logger.warning(f"Session {session_id} not found in Qdrant: {e}")
98
+ # If not found, return None or raise
99
+ raise ValueError("No documents indexed for this session")
100
+
101
+ # Custom CSS for better styling
102
+ custom_css = """
103
+ .gradio-container {
104
+ max-width: 1200px !important;
105
+ margin: 0 auto !important;
106
+ }
107
+ .dark .gradio-container {
108
+ background: #1e1e2e !important;
109
+ }
110
+ #chatbot {
111
+ min-height: 500px;
112
+ border-radius: 12px !important;
113
+ }
114
+ .message.user {
115
+ border-left: 4px solid #4f46e5 !important;
116
+ }
117
+ .message.assistant {
118
+ border-left: 4px solid #10b981 !important;
119
+ }
120
+ .btn-primary {
121
+ background: linear-gradient(to right, #4f46e5, #7c3aed) !important;
122
+ border: none !important;
123
+ }
124
+ .btn-primary:hover {
125
+ background: linear-gradient(to right, #4338ca, #6d28d9) !important;
126
+ }
127
+ .prose {
128
+ max-width: 100% !important;
129
+ }
130
+ """
131
+
132
+ with gr.Blocks(title="RAG Chat with Mistral/Gemini", css=custom_css, theme="soft") as demo:
133
+ # Header section
134
+ with gr.Row():
135
+ gr.Markdown("""
136
+ # 🌐 RAG Chat Assistant
137
+ ### Chat with any website using Mistral or Gemini
138
+ """)
139
+
140
+ # Session state
141
+ session_id = gr.State("")
142
+
143
+ with gr.Tabs():
144
+ with gr.TabItem("📚 Index Website"):
145
+ with gr.Row():
146
+ with gr.Column():
147
+ gr.Markdown("### Step 1: Configure and Index")
148
+ with gr.Group():
149
+ url_input = gr.Textbox(
150
+ label="Website URL to index",
151
+ placeholder="https://example.com",
152
+ interactive=True,
153
+ lines=1
154
+ )
155
+
156
+ with gr.Row():
157
+ model_choice = gr.Radio(
158
+ choices=["mistral", "gemini"],
159
+ label="Select Model",
160
+ value="mistral",
161
+ interactive=True
162
+ )
163
+
164
+ index_btn = gr.Button(
165
+ "🚀 Index Website",
166
+ variant="primary",
167
+ scale=0
168
+ )
169
+
170
+ with gr.Accordion("🔐 Model Settings", open=False):
171
+ ollama_url = gr.Textbox(
172
+ label="Ollama URL (required for Mistral)",
173
+ placeholder="http://localhost:11434",
174
+ visible=True
175
+ )
176
+
177
+ gemini_api_key = gr.Textbox(
178
+ label="Gemini API Key (required for Gemini)",
179
+ placeholder="your-api-key-here",
180
+ visible=False,
181
+ type="password"
182
+ )
183
+
184
+ status_output = gr.Textbox(
185
+ label="Status",
186
+ interactive=False,
187
+ elem_classes="prose"
188
+ )
189
+
190
+ gr.Markdown("""
191
+ **Instructions:**
192
+ 1. Enter a website URL
193
+ 2. Select your preferred model
194
+ 3. Configure model settings if needed
195
+ 4. Click 'Index Website'
196
+ """)
197
+
198
+ with gr.TabItem("💬 Chat"):
199
+ with gr.Row():
200
+ with gr.Column(scale=2):
201
+ # New session ID input for resuming sessions
202
+ with gr.Accordion("🔍 Resume Previous Session", open=False):
203
+ existing_session_input = gr.Textbox(
204
+ label="Enter existing Session ID",
205
+ placeholder="Paste your session ID here...",
206
+ interactive=True
207
+ )
208
+ load_session_btn = gr.Button(
209
+ "🔁 Load Session",
210
+ variant="secondary"
211
+ )
212
+ session_status = gr.Textbox(
213
+ label="Session Status",
214
+ interactive=False
215
+ )
216
+
217
+ chatbot = gr.Chatbot(
218
+ label="Chat History",
219
+ height=500,
220
+ avatar_images=(None, None),
221
+ show_copy_button=True,
222
+ type="messages" # Use OpenAI-style messages
223
+ )
224
+
225
+ with gr.Row():
226
+ message_input = gr.Textbox(
227
+ label="Type your message",
228
+ placeholder="Ask about the website content...",
229
+ interactive=True,
230
+ container=False,
231
+ scale=7,
232
+ autofocus=True
233
+ )
234
+
235
+ send_btn = gr.Button(
236
+ "Send",
237
+ variant="primary",
238
+ scale=1,
239
+ min_width=100
240
+ )
241
+
242
+ # Event handlers
243
+ model_choice.change(
244
+ fn=toggle_model_inputs,
245
+ inputs=model_choice,
246
+ outputs=[ollama_url, gemini_api_key]
247
+ )
248
+
249
+ index_btn.click(
250
+ fn=create_session,
251
+ outputs=session_id
252
+ ).success(
253
+ fn=index_website,
254
+ inputs=[url_input, session_id],
255
+ outputs=[status_output]
256
+ )
257
+
258
+ # New handler for loading existing sessions
259
+ load_session_btn.click(
260
+ fn=load_session,
261
+ inputs=[existing_session_input],
262
+ outputs=[session_id, session_status]
263
+ )
264
+
265
+ send_btn.click(
266
+ fn=chat_response,
267
+ inputs=[session_id, message_input, model_choice, ollama_url, gemini_api_key, chatbot],
268
+ outputs=[chatbot, message_input]
269
+ )
270
+
271
+ # Allow submitting with Enter key
272
+ message_input.submit(
273
+ fn=chat_response,
274
+ inputs=[session_id, message_input, model_choice, ollama_url, gemini_api_key, chatbot],
275
+ outputs=[chatbot, message_input]
276
+ )
277
+
278
+ if __name__ == "__main__":
279
+ demo.launch(
280
+ server_name="0.0.0.0",
281
+ server_port=7860,
282
+ favicon_path="assets/favicon.ico" # Local path, not URL
283
+ )
src/__init__.py ADDED
File without changes
src/config.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ load_dotenv()
7
+
8
+ # Logging configuration
9
+ logging.basicConfig(
10
+ level=logging.INFO,
11
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class Config:
16
+ # Crawling and content configuration
17
+ MAX_PAGES_TO_CRAWL = int(os.getenv('MAX_PAGES_TO_CRAWL', 20))
18
+ MAX_LINKS_PER_PAGE = int(os.getenv('MAX_LINKS_PER_PAGE', 10))
19
+ MAX_CONTEXT_LENGTH = int(os.getenv('MAX_CONTEXT_LENGTH', 4000))
20
+ MAX_HISTORY_MESSAGES = int(os.getenv('MAX_HISTORY_MESSAGES', 20))
21
+
22
+ # Model configuration
23
+ EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL', 'all-MiniLM-L6-v2')
24
+ OLLAMA_BASE_URL = os.getenv('OLLAMA_BASE_URL', 'http://localhost:11434')
25
+ OLLAMA_MODEL = os.getenv('OLLAMA_MODEL', 'mistral')
26
+
27
+ # Path configuration
28
+ BASE_DIR = Path(__file__).parent.parent
29
+ SESSIONS_DIR = os.path.join(BASE_DIR, "sessions")
30
+
31
+ # Qdrant configuration (new)
32
+ QDRANT_HOST = os.getenv('QDRANT_HOST', 'localhost')
33
+ QDRANT_PORT = int(os.getenv('QDRANT_PORT', 6333))
34
+ QDRANT_COLLECTION_PREFIX = os.getenv('QDRANT_COLLECTION_PREFIX', 'Chat-Bot')
35
+ QDRANT_URL = os.getenv('QDRANT_URL', 'https://6fe012ee-5a7c-4304-a77c-293a1888a9cf.us-west-2-0.aws.cloud.qdrant.io')
36
+ QDRANT_API_KEY = os.getenv('QDRANT_API_KEY', None) # For cloud version
37
+
38
+ # MongoDB configuration
39
+ MONGO_URI = os.getenv('MONGO_URI', "mongodb+srv://mehulxy21:[email protected]/")
40
+ DATABASE_NAME = os.getenv('DATABASE_NAME', "rag_chat_history")
41
+ HISTORY_COLLECTION = os.getenv('HISTORY_COLLECTION', "conversations")
42
+
43
+ @staticmethod
44
+ def create_storage_dirs():
45
+ """Create necessary directories for storage"""
46
+ os.makedirs(Config.SESSIONS_DIR, exist_ok=True)
47
+ logger.info(f"Sessions directory created at: {Config.SESSIONS_DIR}")
48
+
49
+ # Create data directory if it doesn't exist
50
+ data_dir = os.path.join(Config.BASE_DIR, "data")
51
+ os.makedirs(data_dir, exist_ok=True)
52
+ logger.info(f"Data directory created at: {data_dir}")
53
+
54
+ # Session Management
55
+ SESSION_INACTIVITY_TIMEOUT = 3600 # 1 hour in seconds
56
+ SESSION_CLEANUP_INTERVAL = 600
src/crawler.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from urllib.parse import urljoin, urlparse
4
+ from typing import List, Set, Optional, Dict
5
+ import logging
6
+ import re
7
+ from app.config import Config
8
+ import aiohttp
9
+
10
+ class URLCrawler:
11
+ """
12
+ A web crawler that extracts and processes content from websites.
13
+ Handles both synchronous and asynchronous crawling operations.
14
+
15
+ Features:
16
+ - URL validation and sanitization
17
+ - Content extraction with noise removal
18
+ - Breadth-first crawling with configurable depth
19
+ - Respects robots.txt and avoids non-html content
20
+ """
21
+
22
+ def __init__(self):
23
+ """Initialize the crawler with default settings."""
24
+ self.visited_urls: Set[str] = set() # Tracks crawled URLs to avoid duplicates
25
+ self.logger = logging.getLogger(__name__)
26
+ # Configure headers to mimic a real browser
27
+ self.headers = {
28
+ 'User-Agent': 'Mozilla/5.0 (compatible; RAGBot/1.0)',
29
+ 'Accept-Language': 'en-US,en;q=0.9'
30
+ }
31
+
32
+ def is_valid_url(self, url: str, base_domain: str) -> bool:
33
+ """
34
+ Validate if a URL should be crawled.
35
+
36
+ Args:
37
+ url: URL to validate
38
+ base_domain: The target domain to stay within
39
+
40
+ Returns:
41
+ bool: True if URL is crawlable
42
+ """
43
+ parsed = urlparse(url)
44
+ return (parsed.scheme in ('http', 'https') and # Only HTTP/HTTPS
45
+ parsed.netloc == base_domain and # Stay within target domain
46
+ not any(ext in url.lower() # Skip binary files
47
+ for ext in ['.pdf', '.jpg', '.png', '.zip']) and
48
+ url not in self.visited_urls) # Avoid duplicates
49
+
50
+ def sanitize_url(self, url: str) -> str:
51
+ """
52
+ Normalize URL by removing fragments and query parameters.
53
+
54
+ Args:
55
+ url: URL to sanitize
56
+
57
+ Returns:
58
+ str: Normalized URL
59
+ """
60
+ parsed = urlparse(url)
61
+ return f"{parsed.scheme}://{parsed.netloc}{parsed.path.rstrip('/')}"
62
+
63
+ def clean_text(self, text: str) -> str:
64
+ """
65
+ Clean and normalize extracted text content.
66
+
67
+ Args:
68
+ text: Raw extracted text
69
+
70
+ Returns:
71
+ str: Cleaned text content
72
+ """
73
+ # Remove excessive whitespace
74
+ text = re.sub(r'\s+', ' ', text)
75
+ # Remove common boilerplate
76
+ text = re.sub(r'(\b(privacy policy|terms of service|cookie policy)\b|\b\d+\s*(comments|shares|likes)\b)', '', text, flags=re.I)
77
+ # Remove short lines (likely not meaningful content)
78
+ return '\n'.join(line for line in text.split('\n')
79
+ if len(line.strip()) > 30)
80
+
81
+ def extract_main_content(self, soup: BeautifulSoup) -> str:
82
+ """
83
+ Extract primary content from HTML using semantic heuristics.
84
+
85
+ Args:
86
+ soup: BeautifulSoup parsed HTML document
87
+
88
+ Returns:
89
+ str: Extracted main content
90
+ """
91
+ # Remove unwanted elements that typically don't contain main content
92
+ for element in soup(['script', 'style', 'nav', 'footer',
93
+ 'header', 'iframe', 'aside', 'form']):
94
+ element.decompose()
95
+
96
+ # Prioritize semantic HTML containers that likely contain main content
97
+ for tag in ['article', 'main', 'section[role="main"]', '.content']:
98
+ content = soup.select_one(tag)
99
+ if content:
100
+ return self.clean_text(content.get_text(separator='\n'))
101
+
102
+ # Fallback to body if no semantic containers found
103
+ return self.clean_text(soup.body.get_text(separator='\n'))
104
+
105
+ def get_page_content(self, url: str) -> Optional[Dict]:
106
+ """
107
+ Fetch and process a single web page.
108
+
109
+ Args:
110
+ url: URL to fetch
111
+
112
+ Returns:
113
+ Optional[Dict]: Structured page data or None if invalid
114
+ """
115
+ try:
116
+ response = requests.get(url, headers=self.headers, timeout=15)
117
+ response.raise_for_status()
118
+
119
+ # Skip non-HTML content
120
+ if 'text/html' not in response.headers.get('Content-Type', ''):
121
+ return None
122
+
123
+ soup = BeautifulSoup(response.text, 'lxml')
124
+ title = soup.title.string if soup.title else urlparse(url).path
125
+ content = self.extract_main_content(soup)
126
+
127
+ # Skip pages with insufficient content
128
+ if len(content.split()) < 100: # Minimum 100 words
129
+ return None
130
+
131
+ return {
132
+ 'url': url,
133
+ 'title': title,
134
+ 'content': content,
135
+ 'last_modified': response.headers.get('Last-Modified', '')
136
+ }
137
+
138
+ except Exception as e:
139
+ self.logger.warning(f"Error processing {url}: {str(e)}")
140
+ return None
141
+
142
+ def extract_links(self, url: str, soup: BeautifulSoup) -> List[str]:
143
+ """
144
+ Extract all crawlable links from a page.
145
+
146
+ Args:
147
+ url: Base URL for relative link resolution
148
+ soup: Parsed HTML document
149
+
150
+ Returns:
151
+ List[str]: List of absolute URLs to crawl
152
+ """
153
+ base_domain = urlparse(url).netloc
154
+ links = set()
155
+
156
+ for link in soup.find_all('a', href=True):
157
+ href = link['href'].split('#')[0] # Remove fragments
158
+ if not href or href.startswith('javascript:'):
159
+ continue
160
+
161
+ absolute_url = urljoin(url, href)
162
+ sanitized_url = self.sanitize_url(absolute_url)
163
+
164
+ if self.is_valid_url(sanitized_url, base_domain):
165
+ links.add(sanitized_url)
166
+
167
+ return sorted(links)[:Config.MAX_LINKS_PER_PAGE] # Apply limit
168
+
169
+ async def crawl(self, url: str) -> str:
170
+ """
171
+ Asynchronously crawl a single URL and return its text content.
172
+
173
+ Args:
174
+ url: URL to crawl
175
+
176
+ Returns:
177
+ str: Extracted text content
178
+
179
+ Raises:
180
+ Exception: If crawling fails
181
+ """
182
+ try:
183
+ async with aiohttp.ClientSession() as session:
184
+ async with session.get(url) as response:
185
+ html = await response.text()
186
+ soup = BeautifulSoup(html, 'html.parser')
187
+ # Remove script and style elements
188
+ for script in soup(["script", "style"]):
189
+ script.decompose()
190
+ return soup.get_text()
191
+ except Exception as e:
192
+ self.logger.error(f"Crawling error: {str(e)}")
193
+ raise
194
+
195
+ def crawl_sync(self, start_url: str, max_pages: int = Config.MAX_PAGES_TO_CRAWL) -> List[Dict]:
196
+ """
197
+ Synchronously crawl a website using breadth-first search.
198
+
199
+ Args:
200
+ start_url: Initial URL to begin crawling
201
+ max_pages: Maximum number of pages to crawl
202
+
203
+ Returns:
204
+ List[Dict]: Structured documents from crawled pages
205
+ """
206
+ base_domain = urlparse(start_url).netloc
207
+ queue = [start_url] # URLs to crawl
208
+ documents = [] # Collected documents
209
+
210
+ while queue and len(documents) < max_pages:
211
+ current_url = queue.pop(0)
212
+ sanitized_url = self.sanitize_url(current_url)
213
+
214
+ if sanitized_url in self.visited_urls:
215
+ continue
216
+
217
+ self.visited_urls.add(sanitized_url)
218
+ self.logger.info(f"Crawling: {sanitized_url}")
219
+
220
+ page_data = self.get_page_content(sanitized_url)
221
+ if not page_data:
222
+ continue
223
+
224
+ documents.append(page_data)
225
+
226
+ # Get links for further crawling
227
+ try:
228
+ response = requests.get(sanitized_url, headers=self.headers, timeout=10)
229
+ soup = BeautifulSoup(response.text, 'lxml')
230
+ new_links = self.extract_links(sanitized_url, soup)
231
+ queue.extend(link for link in new_links
232
+ if link not in self.visited_urls)
233
+ except Exception as e:
234
+ self.logger.warning(f"Error getting links from {sanitized_url}: {str(e)}")
235
+
236
+ return documents
src/databases/__init__.py ADDED
File without changes
src/databases/models.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Optional
3
+ from pydantic import BaseModel
4
+
5
+ class Conversation(BaseModel):
6
+ session_id: str
7
+ user_query: str
8
+ bot_response: str
9
+ timestamp: Optional[datetime] = datetime.utcnow()
10
+ metadata: Optional[dict] = None
src/databases/mongo_handler.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+ from pymongo.errors import ServerSelectionTimeoutError, AutoReconnect
3
+ import logging
4
+ from datetime import datetime
5
+ from app.config import Config
6
+ from typing import List, Dict, Optional
7
+
8
+ class MongoDBHandler:
9
+ def __init__(self):
10
+ self.logger = logging.getLogger(__name__)
11
+ self._ensure_mongodb_running()
12
+ try:
13
+ self.client = MongoClient(
14
+ Config.MONGO_URI,
15
+ serverSelectionTimeoutMS=5000
16
+ )
17
+ # Test connection immediately
18
+ self.client.server_info()
19
+ self.db = self.client[Config.DATABASE_NAME]
20
+ self.collection = self.db[Config.HISTORY_COLLECTION]
21
+ self._create_indexes()
22
+ self.logger.info("MongoDB connection established successfully")
23
+ except (ServerSelectionTimeoutError, AutoReconnect) as e:
24
+ self.logger.error(f"MongoDB connection failed: {e}")
25
+ self._diagnose_connection_issue()
26
+ raise
27
+
28
+ def _ensure_mongodb_running(self):
29
+ """Ensure MongoDB service is running"""
30
+ try:
31
+ import subprocess
32
+ result = subprocess.run(
33
+ ['net', 'start', 'MongoDB'],
34
+ capture_output=True,
35
+ text=True
36
+ )
37
+ if "already running" not in result.stderr.lower():
38
+ self.logger.info("Started MongoDB service")
39
+ except Exception as e:
40
+ self.logger.error(f"Failed to start MongoDB: {e}")
41
+
42
+ def _diagnose_connection_issue(self):
43
+ """Diagnose common MongoDB connection issues"""
44
+ import os
45
+ issues = []
46
+
47
+ # Check data directory
48
+ if not os.path.exists("C:\\data\\db"):
49
+ issues.append("Data directory missing")
50
+
51
+ # Check log directory
52
+ if not os.path.exists("C:\\data\\log"):
53
+ issues.append("Log directory missing")
54
+
55
+ # Check service status
56
+ try:
57
+ import subprocess
58
+ result = subprocess.run(
59
+ ['sc', 'query', 'MongoDB'],
60
+ capture_output=True,
61
+ text=True
62
+ )
63
+ if "RUNNING" not in result.stdout:
64
+ issues.append("MongoDB service not running")
65
+ except Exception:
66
+ issues.append("Could not check service status")
67
+
68
+ if issues:
69
+ self.logger.error("MongoDB Issues Found:")
70
+ for issue in issues:
71
+ self.logger.error(f" - {issue}")
72
+
73
+ def _create_indexes(self):
74
+ """Create indexes for better query performance"""
75
+ self.collection.create_index([("session_id", 1)])
76
+ self.collection.create_index([("timestamp", -1)])
77
+ self.collection.create_index([("session_id", 1), ("timestamp", -1)])
78
+
79
+ def save_conversation(self, session_id: str, query: str, response: str, metadata: dict = None) -> str:
80
+ """Save a conversation with automatic timestamp"""
81
+ conversation = {
82
+ "session_id": session_id,
83
+ "user_query": query,
84
+ "bot_response": response,
85
+ "timestamp": datetime.utcnow(),
86
+ "metadata": metadata or {}
87
+ }
88
+ result = self.collection.insert_one(conversation)
89
+ return str(result.inserted_id)
90
+
91
+ def verify_storage(self) -> bool:
92
+ """Verify storage is working by inserting and retrieving a test document"""
93
+ try:
94
+ # Insert test document
95
+ test_id = self.save_conversation(
96
+ session_id="test",
97
+ query="test_query",
98
+ response="test_response",
99
+ metadata={"test": True}
100
+ )
101
+
102
+ # Verify retrieval
103
+ test_doc = self.collection.find_one({"_id": test_id})
104
+
105
+ # Cleanup test document
106
+ self.collection.delete_one({"_id": test_id})
107
+
108
+ return test_doc is not None
109
+ except Exception as e:
110
+ self.logger.error(f"Storage verification failed: {e}")
111
+ return False
112
+
113
+ def get_conversation_history(self, session_id: str, limit: int = 10) -> List[Dict]:
114
+ """Retrieve conversation history for a session"""
115
+ try:
116
+ cursor = self.collection.find(
117
+ {"session_id": session_id},
118
+ {"_id": 0} # Exclude _id field
119
+ ).sort("timestamp", -1).limit(limit)
120
+
121
+ return list(cursor)
122
+ except Exception as e:
123
+ self.logger.error(f"Error retrieving conversation history: {e}")
124
+ return []
125
+
126
+ def clear_session_history(self, session_id: str) -> int:
127
+ """Clear all conversations for a session"""
128
+ try:
129
+ result = self.collection.delete_many({"session_id": session_id})
130
+ return result.deleted_count
131
+ except Exception as e:
132
+ self.logger.error(f"Error clearing session history: {e}")
133
+ return 0
src/databases/storage.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.databases.mongo_handler import MongoDBHandler
2
+ from app.main import SessionManager
3
+ import logging
4
+
5
+ class StorageService:
6
+ def __init__(self):
7
+ self.logger = logging.getLogger(__name__)
8
+ self.session_manager = SessionManager()
9
+ self.mongo_handler = MongoDBHandler()
10
+
11
+ async def add_conversation(self, session_id: str, query: str, response: str, sources: list = None):
12
+ """Save conversation to both session manager and MongoDB"""
13
+ try:
14
+ # Save to session manager (memory)
15
+ self.session_manager.add_conversation(session_id, query, response)
16
+
17
+ # Save to MongoDB (persistent)
18
+ mongo_id = self.mongo_handler.save_conversation(
19
+ session_id=session_id,
20
+ query=query,
21
+ response=response,
22
+ metadata={"sources": sources} if sources else {}
23
+ )
24
+
25
+ self.logger.debug(f"Conversation saved to MongoDB with ID: {mongo_id}")
26
+ return True
27
+ except Exception as e:
28
+ self.logger.error(f"Failed to save conversation: {e}")
29
+ return False
src/databases/test_mongo.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from pymongo import MongoClient
2
+
3
+ client = MongoClient("mongodb+srv://mehulxy21:[email protected]/")
4
+ db = client["rag_chat_history"]
5
+ print("Databases:", client.list_database_names())
src/embeddings.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sentence_transformers import SentenceTransformer
3
+ import numpy as np
4
+ import logging
5
+ from typing import List, Dict, Optional
6
+ from app.config import Config
7
+ from qdrant_client import QdrantClient
8
+ from qdrant_client.models import Distance, VectorParams, PointStruct
9
+ from qdrant_client.http.exceptions import UnexpectedResponse
10
+
11
+ class EmbeddingHandler:
12
+ """
13
+ Handles all embedding-related operations including:
14
+ - Text embedding generation using SentenceTransformers
15
+ - Vector storage and retrieval with Qdrant
16
+ - Collection management for vector storage
17
+
18
+ This serves as the central component for vector operations in the RAG system.
19
+ """
20
+
21
+ def __init__(self):
22
+ """Initialize the embedding handler with model and vector store client."""
23
+ self.logger = logging.getLogger(__name__)
24
+ try:
25
+ # Initialize embedding model with configuration from Config
26
+ self.model = SentenceTransformer(Config.EMBEDDING_MODEL)
27
+ # Get embedding dimension from the model
28
+ self.embedding_dim = self.model.get_sentence_embedding_dimension()
29
+
30
+ # Initialize Qdrant client with configuration from Config
31
+ self.qdrant_client = QdrantClient(
32
+ url=Config.QDRANT_URL,
33
+ api_key=Config.QDRANT_API_KEY,
34
+ prefer_grpc=False, # HTTP preferred over gRPC for compatibility
35
+ timeout=30 # Connection timeout in seconds
36
+ )
37
+
38
+ # Connection test can be uncommented for local development
39
+ # self._verify_connection()
40
+
41
+ except Exception as e:
42
+ self.logger.error(f"Error initializing embedding handler: {str(e)}", exc_info=True)
43
+ raise RuntimeError("Failed to initialize embedding handler") from e
44
+
45
+ def generate_embeddings(self, texts: List[str]) -> np.ndarray:
46
+ """
47
+ Generate embeddings for a list of text strings.
48
+
49
+ Args:
50
+ texts: List of text strings to embed
51
+
52
+ Returns:
53
+ np.ndarray: Array of embeddings (2D numpy array)
54
+
55
+ Raises:
56
+ Exception: If embedding generation fails
57
+ """
58
+ try:
59
+ return self.model.encode(
60
+ texts,
61
+ show_progress_bar=True, # Visual progress indicator
62
+ batch_size=32, # Optimal batch size for most GPUs
63
+ convert_to_numpy=True # Return as numpy array for efficiency
64
+ )
65
+ except Exception as e:
66
+ self.logger.error(f"Error generating embeddings: {str(e)}", exc_info=True)
67
+ raise
68
+
69
+ def create_collection(self, collection_name: str) -> bool:
70
+ """
71
+ Create a new Qdrant collection for storing vectors.
72
+
73
+ Args:
74
+ collection_name: Name of the collection to create
75
+
76
+ Returns:
77
+ bool: True if collection was created or already exists
78
+
79
+ Raises:
80
+ Exception: If collection creation fails (except for already exists case)
81
+ """
82
+ try:
83
+ self.qdrant_client.create_collection(
84
+ collection_name=collection_name,
85
+ vectors_config=VectorParams(
86
+ size=self.embedding_dim, # Must match model's embedding dimension
87
+ distance=Distance.COSINE # Using cosine similarity
88
+ )
89
+ )
90
+ self.logger.info(f"Created collection {collection_name}")
91
+ return True
92
+
93
+ except UnexpectedResponse as e:
94
+ # Handle case where collection already exists
95
+ if "already exists" in str(e):
96
+ self.logger.info(f"Collection {collection_name} already exists")
97
+ return True
98
+ else:
99
+ self.logger.error(f"Error creating collection: {e}")
100
+ raise
101
+ except Exception as e:
102
+ self.logger.error(f"Error creating collection: {str(e)}", exc_info=True)
103
+ raise
104
+
105
+ def add_to_collection(self, collection_name: str, embeddings: np.ndarray, payloads: List[dict]) -> bool:
106
+ """
107
+ Add embeddings and associated metadata to a Qdrant collection.
108
+
109
+ Args:
110
+ collection_name: Target collection name
111
+ embeddings: Numpy array of embeddings to add
112
+ payloads: List of metadata dictionaries corresponding to each embedding
113
+
114
+ Returns:
115
+ bool: True if operation succeeded
116
+
117
+ Raises:
118
+ Exception: If operation fails
119
+ """
120
+ try:
121
+ # Convert numpy arrays to lists for Qdrant compatibility
122
+ if isinstance(embeddings, np.ndarray):
123
+ embeddings = embeddings.tolist()
124
+
125
+ # Prepare points in batches for efficient processing
126
+ batch_size = 100 # Optimal batch size for Qdrant Cloud
127
+ points = [
128
+ PointStruct(
129
+ id=idx, # Sequential ID
130
+ vector=embedding,
131
+ payload=payload # Associated metadata
132
+ )
133
+ for idx, (embedding, payload) in enumerate(zip(embeddings, payloads))
134
+ ]
135
+
136
+ # Process in batches to avoid overwhelming the server
137
+ for i in range(0, len(points), batch_size):
138
+ batch = points[i:i + batch_size]
139
+ self.qdrant_client.upsert(
140
+ collection_name=collection_name,
141
+ points=batch,
142
+ wait=True # Ensure immediate persistence
143
+ )
144
+
145
+ self.logger.info(f"Added {len(points)} vectors to collection {collection_name}")
146
+ return True
147
+
148
+ except Exception as e:
149
+ self.logger.error(f"Error adding to collection: {str(e)}", exc_info=True)
150
+ raise
151
+
152
+ async def search_collection(self, collection_name: str, query: str, k: int = 5) -> Dict:
153
+ """
154
+ Search a Qdrant collection for similar vectors to the query.
155
+
156
+ Args:
157
+ collection_name: Name of collection to search
158
+ query: Text query to search for
159
+ k: Number of similar results to return (default: 5)
160
+
161
+ Returns:
162
+ Dict: {
163
+ "status": "success"|"error",
164
+ "results": List[Dict] (if success),
165
+ "message": str (if error)
166
+ }
167
+ """
168
+ try:
169
+ # Generate embedding for the query text
170
+ query_embedding = self.model.encode(query).tolist()
171
+
172
+ # Perform similarity search in Qdrant
173
+ results = self.qdrant_client.search(
174
+ collection_name=collection_name,
175
+ query_vector=query_embedding,
176
+ limit=k, # Number of results to return
177
+ with_payload=True, # Include metadata
178
+ with_vectors=False # Exclude raw vectors to save bandwidth
179
+ )
180
+
181
+ # Format results for consistent API response
182
+ formatted_results = []
183
+ for hit in results:
184
+ formatted_results.append({
185
+ "id": hit.id,
186
+ "score": float(hit.score), # Similarity score
187
+ "payload": hit.payload or {}, # Associated metadata
188
+ "text": hit.payload.get("text", "") if hit.payload else "" # Extracted text
189
+ })
190
+
191
+ return {
192
+ "status": "success",
193
+ "results": formatted_results
194
+ }
195
+
196
+ except Exception as e:
197
+ self.logger.error(f"Search error: {str(e)}", exc_info=True)
198
+ return {
199
+ "status": "error",
200
+ "message": str(e),
201
+ "results": []
202
+ }
203
+
204
+ # Deprecated FAISS methods (maintained for backward compatibility)
205
+ def create_faiss_index(self, *args, **kwargs):
206
+ """Deprecated method - FAISS support has been replaced by Qdrant."""
207
+ self.logger.warning("FAISS operations are deprecated")
208
+ raise NotImplementedError("Use Qdrant collections instead of FAISS")
209
+
210
+ def save_index(self, *args, **kwargs):
211
+ """Deprecated method - Qdrant persists data automatically."""
212
+ self.logger.warning("FAISS operations are deprecated")
213
+ raise NotImplementedError("Qdrant persists data automatically")
214
+
215
+ def load_index(self, *args, **kwargs):
216
+ """Deprecated method - Access Qdrant collections directly."""
217
+ self.logger.warning("FAISS operations are deprecated")
218
+ raise NotImplementedError("Access Qdrant collections directly")
src/handlers.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from fastapi import WebSocket, WebSocketDisconnect
4
+ from pydantic import BaseModel
5
+ from typing import Optional, Dict
6
+ from app.rag import RAGSystem
7
+
8
+ # Configure logging
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ class ConnectionManager:
13
+ """
14
+ Manages active WebSocket connections.
15
+ Tracks connected clients and handles connection/disconnection events.
16
+ """
17
+
18
+ def __init__(self):
19
+ # Dictionary to store active WebSocket connections by session_id
20
+ self.active_connections: Dict[str, WebSocket] = {}
21
+
22
+ async def connect(self, websocket: WebSocket, session_id: str):
23
+ """
24
+ Accept a new WebSocket connection and track it.
25
+
26
+ Args:
27
+ websocket: The WebSocket connection object
28
+ session_id: Unique identifier for the session
29
+ """
30
+ await websocket.accept()
31
+ self.active_connections[session_id] = websocket
32
+ logger.info(f"WebSocket connected: {session_id}")
33
+
34
+ def disconnect(self, session_id: str):
35
+ """
36
+ Remove a WebSocket connection from active connections.
37
+
38
+ Args:
39
+ session_id: The session ID to disconnect
40
+ """
41
+ self.active_connections.pop(session_id, None)
42
+ logger.info(f"WebSocket disconnected: {session_id}")
43
+
44
+ # Initialize system components
45
+ manager = ConnectionManager() # Manages WebSocket connections
46
+ rag = RAGSystem() # The RAG processing system
47
+
48
+ class ResponseFormatter:
49
+ """
50
+ Formats responses before sending to clients.
51
+ Can be extended to standardize response structure.
52
+ """
53
+
54
+ def __init__(self):
55
+ pass
56
+
57
+ def format_response(self, response: dict) -> dict:
58
+ """
59
+ Format a response dictionary.
60
+
61
+ Args:
62
+ response: Raw response dictionary
63
+
64
+ Returns:
65
+ dict: Formatted response
66
+ """
67
+ return response # Currently passes through unchanged
68
+
69
+ formatter = ResponseFormatter() # Create formatter instance
70
+
71
+ class ChatMessage(BaseModel):
72
+ """
73
+ Pydantic model for validating incoming chat messages.
74
+ Ensures proper message structure.
75
+ """
76
+ text: str # The message text/content
77
+ url: Optional[str] = None # Optional URL for context
78
+
79
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
80
+ """
81
+ WebSocket endpoint for handling real-time chat interactions.
82
+
83
+ Args:
84
+ websocket: The WebSocket connection
85
+ session_id: Unique identifier for the chat session
86
+
87
+ Handles:
88
+ - Connection management
89
+ - Message processing
90
+ - Error handling
91
+ """
92
+ # Register the new connection
93
+ await manager.connect(websocket, session_id)
94
+
95
+ try:
96
+ while True:
97
+ # Wait for and receive incoming message
98
+ data = await websocket.receive_json()
99
+ logger.info(f"Received message: {data}")
100
+
101
+ # Parse and validate message using Pydantic model
102
+ message = ChatMessage(**data)
103
+
104
+ # Process the message through RAG system
105
+ response = await rag.process_query(message.text, session_id)
106
+
107
+ # Log and send the response
108
+ logger.info(f"Sending response: {response}")
109
+ await websocket.send_json(response)
110
+
111
+ except WebSocketDisconnect:
112
+ # Handle graceful disconnection
113
+ manager.disconnect(session_id)
114
+ logger.info(f"Client disconnected: {session_id}")
115
+ except Exception as e:
116
+ # Handle other errors and notify client
117
+ logger.error(f"Error in websocket: {str(e)}")
118
+ await websocket.send_json({
119
+ "status": "error",
120
+ "message": str(e)
121
+ })
src/llm.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import ollama
3
+ from typing import List, Dict, Optional
4
+ from app.config import Config
5
+ import os
6
+
7
+
8
+ class OllamaMistral:
9
+ """
10
+ A class to interact with the Ollama API for Mistral model.
11
+ Handles both chat completions and embeddings generation.
12
+ """
13
+
14
+ def __init__(self):
15
+ """Initialize the Ollama Mistral client with default settings."""
16
+ self.logger = logging.getLogger(__name__)
17
+ # Initialize Ollama client with default host
18
+ self.client = ollama.Client(host='http://localhost:11434')
19
+ self.model = 'mistral' # Default model name
20
+
21
+ async def generate_response(self, prompt: str) -> str:
22
+ """
23
+ Asynchronously generate a text response from Mistral model.
24
+
25
+ Args:
26
+ prompt: The input text prompt for the model
27
+
28
+ Returns:
29
+ Generated response text or error message if failed
30
+ """
31
+ try:
32
+ print(f"[Ollama] Sending prompt:\n{prompt}\n")
33
+ # Send chat request to Ollama API
34
+ response = self.client.chat(
35
+ model=self.model,
36
+ messages=[{
37
+ 'role': 'user',
38
+ 'content': prompt
39
+ }]
40
+ )
41
+ print(f"[Ollama] Received response:\n{response}\n")
42
+
43
+ # Handle different response formats from Ollama
44
+ if isinstance(response, dict):
45
+ if 'message' in response and 'content' in response['message']:
46
+ return response['message']['content']
47
+ elif hasattr(response, 'message') and hasattr(response.message, 'content'):
48
+ return response.message.content
49
+ # Fallback: try to convert to string
50
+ return str(response)
51
+
52
+ except Exception as e:
53
+ self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True)
54
+ return f"Error generating response: {str(e)}"
55
+
56
+ def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]:
57
+ """
58
+ Generate embeddings for the input text using specified model.
59
+
60
+ Args:
61
+ text: Input text to generate embeddings for
62
+ model: Model name to use for embeddings (default from Config)
63
+
64
+ Returns:
65
+ List of embeddings or None if failed
66
+ """
67
+ try:
68
+ print(f"[Ollama] Generating embedding for: {text[:60]}...")
69
+ # Request embeddings from Ollama API
70
+ response = self.client.embeddings(
71
+ model=model,
72
+ prompts=[text] # prompts must be a list of strings
73
+ )
74
+ print(f"[Ollama] Embedding response: {response}")
75
+
76
+ # Handle different response formats
77
+ if isinstance(response, dict) and 'embeddings' in response:
78
+ return response['embeddings'][0]
79
+ elif isinstance(response, dict) and 'embedding' in response:
80
+ return response['embedding']
81
+ else:
82
+ self.logger.warning(f"Unexpected embedding response format: {response}")
83
+ return None
84
+
85
+ except Exception as e:
86
+ self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True)
87
+ return None
88
+
89
+ def generate(self, prompt: str) -> str:
90
+ """
91
+ Synchronous wrapper for generate_response.
92
+
93
+ Args:
94
+ prompt: Input text prompt
95
+
96
+ Returns:
97
+ Generated response text
98
+ """
99
+ import asyncio
100
+ try:
101
+ return asyncio.run(self.generate_response(prompt))
102
+ except Exception as e:
103
+ self.logger.error(f"Error in synchronous generate: {e}")
104
+ return f"Error generating response: {str(e)}"
105
+
106
+
107
+ class GeminiProvider:
108
+ """
109
+ A class to interact with Google's Gemini API.
110
+ Requires GEMINI_API_KEY environment variable.
111
+ """
112
+
113
+ def __init__(self):
114
+ """Initialize Gemini provider with API key."""
115
+ self.logger = logging.getLogger(__name__)
116
+ self.api_key = os.getenv('GEMINI_API_KEY')
117
+ if not self.api_key:
118
+ raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider")
119
+
120
+ try:
121
+ import google.generativeai as genai
122
+ # Configure Gemini API
123
+ genai.configure(api_key=self.api_key)
124
+ self.model = genai.GenerativeModel('gemini-1.5-flash')
125
+ except ImportError:
126
+ raise ImportError("google-generativeai package is required for Gemini provider")
127
+
128
+ def generate(self, prompt: str) -> str:
129
+ """
130
+ Generate text response using Gemini model.
131
+
132
+ Args:
133
+ prompt: Input text prompt
134
+
135
+ Returns:
136
+ Generated response text or error message
137
+ """
138
+ try:
139
+ response = self.model.generate_content(prompt)
140
+ return response.text
141
+ except Exception as e:
142
+ self.logger.error(f"[Gemini] Error generating response: {str(e)}")
143
+ return f"Error generating response: {str(e)}"
144
+
145
+
146
+ class OpenChatProvider:
147
+ """
148
+ A class to use OpenChat models locally via transformers.
149
+ Requires transformers package to be installed.
150
+ """
151
+
152
+ def __init__(self):
153
+ """Initialize OpenChat model and tokenizer."""
154
+ self.logger = logging.getLogger(__name__)
155
+ try:
156
+ from transformers import AutoTokenizer, AutoModelForCausalLM
157
+ # Load pretrained OpenChat model
158
+ self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
159
+ self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106")
160
+ except ImportError:
161
+ raise ImportError("transformers package is required for OpenChat provider")
162
+
163
+ def generate(self, prompt: str) -> str:
164
+ """
165
+ Generate text response using OpenChat model.
166
+
167
+ Args:
168
+ prompt: Input text prompt
169
+
170
+ Returns:
171
+ Generated response text
172
+ """
173
+ try:
174
+ # Tokenize input and generate response
175
+ inputs = self.tokenizer(prompt, return_tensors="pt")
176
+ outputs = self.model.generate(**inputs, max_length=512, temperature=0.7)
177
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
178
+ return response
179
+ except Exception as e:
180
+ self.logger.error(f"[OpenChat] Error generating response: {str(e)}")
181
+ return f"Error generating response: {str(e)}"
182
+
183
+
184
+ class LLMFactory:
185
+ """
186
+ Factory class to create and manage different LLM providers.
187
+ Implements the Factory design pattern for LLM provider instantiation.
188
+ """
189
+
190
+ @staticmethod
191
+ def get_provider(model_name: Optional[str] = None) -> any:
192
+ """
193
+ Get appropriate LLM provider based on model name.
194
+
195
+ Args:
196
+ model_name: Name of the model ('mistral', 'gemini', 'openchat')
197
+ Defaults to 'mistral' if None or unknown
198
+
199
+ Returns:
200
+ Instance of the requested LLM provider
201
+
202
+ Raises:
203
+ ValueError: If required dependencies are missing for the provider
204
+ """
205
+ if model_name is None:
206
+ model_name = "mistral" # Default to mistral
207
+
208
+ model_name = model_name.lower()
209
+
210
+ # Return appropriate provider based on model name
211
+ if model_name == "mistral":
212
+ return OllamaMistral()
213
+ elif model_name == "gemini":
214
+ return GeminiProvider()
215
+ elif model_name == "openchat":
216
+ return OpenChatProvider()
217
+ else:
218
+ # Default to mistral if unknown model is specified
219
+ logging.warning(f"Unknown model '{model_name}', defaulting to mistral")
220
+ return OllamaMistral()
src/main.py ADDED
@@ -0,0 +1,535 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import numpy as np
3
+ import uuid
4
+ import os
5
+ from datetime import datetime
6
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from pydantic import BaseModel
9
+ from typing import Optional, List, Dict, Any
10
+ from bs4 import BeautifulSoup
11
+ import requests
12
+ from urllib.parse import urljoin, urlparse
13
+ import re
14
+ from qdrant_client import QdrantClient
15
+ from qdrant_client.models import Distance, VectorParams, PointStruct
16
+ from app.prompts.templates import rag_prompt_template
17
+ from sentence_transformers import SentenceTransformer
18
+ from langchain_ollama import OllamaLLM
19
+ import json
20
+ import asyncio
21
+ from app.llm import GeminiProvider
22
+ from app.config import Config
23
+ from qdrant_client.http.exceptions import UnexpectedResponse
24
+
25
+ # Configure logging
26
+ logger = logging.getLogger(__name__)
27
+
28
+ # Configuration
29
+ class Config:
30
+ """
31
+ Application configuration settings.
32
+ Contains constants for storage, models, and Qdrant connection.
33
+ """
34
+ STORAGE_DIR = "data/qdrant_storage"
35
+ EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
36
+ url="https://6fe012ee-5a7c-4304-a77c-293a1888a9cf.us-west-2-0.aws.cloud.qdrant.io"
37
+ QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.NUKB9m360LPEBTnpdo2TJpJmEIttumHLz-9ZbAUBKIM"
38
+ QDRANT_COLLECTION_NAME = "Chat-Bot"
39
+
40
+ @staticmethod
41
+ def create_storage_dir():
42
+ """Ensure storage directory exists"""
43
+ os.makedirs(Config.STORAGE_DIR, exist_ok=True)
44
+
45
+ # Data classes
46
+ class Document:
47
+ """Represents a document with text content and metadata"""
48
+ def __init__(self, text: str, metadata: Dict[str, Any]):
49
+ self.text = text
50
+ self.metadata = metadata
51
+
52
+ # Session Manager
53
+ class SessionManager:
54
+ """
55
+ Manages user sessions and Qdrant collections.
56
+ Handles session state, document storage, and conversation history.
57
+ """
58
+
59
+ def __init__(self):
60
+ """Initialize with in-memory sessions and Qdrant connection"""
61
+ self.sessions = {} # In-memory session store
62
+ self.embedding_model = SentenceTransformer(Config.EMBEDDING_MODEL)
63
+ self.qdrant_client = QdrantClient(
64
+ url=Config.url,
65
+ api_key=Config.QDRANT_API_KEY,
66
+ timeout=30
67
+ )
68
+
69
+ def get_collection_name(self, session_id: str) -> str:
70
+ """Generate standardized Qdrant collection name for a session"""
71
+ return f"collection_{session_id}"
72
+
73
+ def get_session(self, session_id: str) -> Dict:
74
+ """
75
+ Get or create session with the given ID.
76
+ Maintains original interface while adding robustness.
77
+ """
78
+ if session_id not in self.sessions:
79
+ self._initialize_new_session(session_id)
80
+ print(f"[SessionManager] Created new session: {session_id}")
81
+ return self.sessions[session_id]
82
+
83
+
84
+ def _initialize_new_session(self, session_id: str):
85
+ """Internal method to handle new session creation"""
86
+ self.sessions[session_id] = {
87
+ 'documents': [],
88
+ 'history': []
89
+ }
90
+ self._ensure_qdrant_collection(session_id)
91
+ print(f"[SessionManager] Initialized session {session_id} with Qdrant collection.")
92
+
93
+ def _ensure_qdrant_collection(self, session_id: str):
94
+ """Ensure Qdrant collection exists for the session"""
95
+ collection_name = self.get_collection_name(session_id)
96
+ try:
97
+ # First try to get the collection (might already exist)
98
+ self.qdrant_client.get_collection(collection_name)
99
+ logger.debug(f"Using existing Qdrant collection: {collection_name}")
100
+ except Exception:
101
+ # Collection doesn't exist, create it
102
+ try:
103
+ self.qdrant_client.create_collection(
104
+ collection_name=collection_name,
105
+ vectors_config=VectorParams(
106
+ size=self.embedding_model.get_sentence_embedding_dimension(),
107
+ distance=Distance.COSINE
108
+ )
109
+ )
110
+ logger.info(f"Created new Qdrant collection: {collection_name}")
111
+ except UnexpectedResponse as e:
112
+ if "already exists" in str(e):
113
+ logger.debug(f"Collection already exists: {collection_name}")
114
+ else:
115
+ logger.error(f"Error creating collection: {e}")
116
+ raise
117
+ except Exception as e:
118
+ logger.error(f"Unexpected error ensuring collection: {e}")
119
+ raise
120
+
121
+ def add_to_history(self, session_id: str, question: str, answer: str):
122
+ """Add conversation to session history"""
123
+ if session_id not in self.sessions:
124
+ logger.warning(f"Session {session_id} not found when adding history")
125
+ return
126
+
127
+ self.sessions[session_id]['history'].append({
128
+ 'question': question,
129
+ 'answer': answer,
130
+ 'timestamp': datetime.now().isoformat()
131
+ })
132
+
133
+ def get_history(self, session_id: str, limit: Optional[int] = None) -> List[Dict]:
134
+ """Get conversation history with optional limit"""
135
+ if session_id not in self.sessions:
136
+ logger.warning(f"Session {session_id} not found when getting history")
137
+ return []
138
+
139
+ history = self.sessions[session_id]['history']
140
+ return history[-limit:] if limit else history
141
+
142
+ def session_exists(self, session_id: str) -> bool:
143
+ """Check if session exists"""
144
+ if session_id in self.sessions:
145
+ return True
146
+ collection_name = self.get_collection_name(session_id)
147
+ try:
148
+ self.qdrant_client.get_collection(collection_name)
149
+ # Add to sessions if collection exists
150
+ self.sessions[session_id] = {
151
+ 'documents': [],
152
+ 'history': []
153
+ }
154
+ return True
155
+ except Exception:
156
+ return False
157
+
158
+ def cleanup_inactive_sessions(self, inactive_minutes: int = 60):
159
+ """Clean up sessions inactive for specified minutes"""
160
+ current_time = datetime.now()
161
+ for session_id in list(self.sessions.keys()):
162
+ history = self.sessions[session_id]['history']
163
+ if history:
164
+ last_activity = datetime.fromisoformat(history[-1]['timestamp'])
165
+ if (current_time - last_activity).total_seconds() > inactive_minutes * 60:
166
+ del self.sessions[session_id]
167
+ logger.info(f"Cleaned up inactive session: {session_id}")
168
+
169
+ def save_session(self, session_id: str):
170
+ """Qdrant persists data automatically"""
171
+ pass
172
+
173
+ def add_conversation(self, session_id: str, query: str, response: str):
174
+ """Add conversation to session history"""
175
+ self.sessions[session_id]['history'].append({"query": query, "response": response})
176
+
177
+ def get_conversation_history(self, session_id: str):
178
+ """Get full conversation history"""
179
+ return self.sessions[session_id]['history']
180
+
181
+ def add_documents_to_qdrant(self, session_id: str, documents: List[Document]):
182
+ """Add documents to Qdrant collection with validation"""
183
+ texts = [doc.text for doc in documents]
184
+ try:
185
+ embeddings = self.embedding_model.encode(texts, batch_size=32, show_progress_bar=True)
186
+ if isinstance(embeddings, np.ndarray):
187
+ embeddings = embeddings.tolist()
188
+ points = [
189
+ PointStruct(
190
+ id=idx,
191
+ vector=embedding,
192
+ payload={
193
+ "text": doc.text,
194
+ "metadata": doc.metadata
195
+ }
196
+ )
197
+ for idx, (embedding, doc) in enumerate(zip(embeddings, documents))
198
+ ]
199
+ collection_name = self.get_collection_name(session_id)
200
+ operation_info = self.qdrant_client.upsert(
201
+ collection_name=collection_name,
202
+ points=points,
203
+ wait=True # Wait for operation confirmation
204
+ )
205
+ logger.info(f"Upsert operation status: {operation_info.status}")
206
+ self.sessions[session_id]['documents'].extend(documents)
207
+ except Exception as e:
208
+ logger.error(f"Document insertion failed: {e}")
209
+ raise
210
+
211
+ def search_qdrant(self, session_id: str, query_embedding: np.ndarray, k: int = 3):
212
+ """Search Qdrant collection with error handling"""
213
+ try:
214
+ if isinstance(query_embedding, np.ndarray):
215
+ query_embedding = query_embedding.tolist()
216
+ collection_name = self.get_collection_name(session_id)
217
+ return self.qdrant_client.search(
218
+ collection_name=collection_name,
219
+ query_vector=query_embedding,
220
+ limit=k,
221
+ with_payload=True,
222
+ with_vectors=False
223
+ )
224
+ except Exception as e:
225
+ logger.error(f"Search failed: {e}")
226
+ raise
227
+
228
+ # Web Crawler
229
+ class WebCrawler:
230
+ """Handles web crawling with depth control and duplicate prevention"""
231
+
232
+ def __init__(self, max_depth=2, delay=1):
233
+ self.max_depth = max_depth
234
+ self.delay = delay
235
+ self.visited = set()
236
+
237
+ def crawl_recursive(self, url, depth=0):
238
+ """Recursively crawl URLs up to max_depth"""
239
+ print(f"[WebCrawler] Crawling {url} at depth {depth}")
240
+
241
+ if not hasattr(self, "collected_links"):
242
+ self.collected_links = set()
243
+
244
+ if depth > self.max_depth or url in self.visited or len(self.collected_links) >= 50:
245
+ return []
246
+
247
+ self.visited.add(url)
248
+ self.collected_links.add(url)
249
+ links = [url]
250
+
251
+ try:
252
+ response = requests.get(url, timeout=10, headers={"User-Agent": "Mozilla/5.0"})
253
+ soup = BeautifulSoup(response.content, "html.parser")
254
+
255
+ for tag in soup.find_all("a", href=True):
256
+ if len(self.collected_links) >= 10:
257
+ break # Stop if 50 links collected
258
+
259
+ href = urljoin(url, tag["href"])
260
+ if urlparse(href).netloc == urlparse(url).netloc:
261
+ links.extend(self.crawl_recursive(href, depth + 1))
262
+ except Exception as e:
263
+ logger.warning(f"Failed to crawl {url}: {e}")
264
+
265
+ return list(set(links))
266
+
267
+ # Connection Manager
268
+ class ConnectionManager:
269
+ """Manages active WebSocket connections"""
270
+
271
+ def __init__(self):
272
+ self.active_connections: Dict[str, WebSocket] = {}
273
+
274
+ async def connect(self, websocket: WebSocket, session_id: str):
275
+ """Register new WebSocket connection"""
276
+ await websocket.accept()
277
+ self.active_connections[session_id] = websocket
278
+
279
+ async def disconnect(self, session_id: str):
280
+ """Remove WebSocket connection"""
281
+ if session_id in self.active_connections:
282
+ del self.active_connections[session_id]
283
+
284
+ async def send_message(self, message: str, session_id: str):
285
+ """Send message to specific WebSocket connection"""
286
+ if session_id in self.active_connections:
287
+ await self.active_connections[session_id].send_text(message)
288
+
289
+ # RAG System with Qdrant
290
+ class RAGSystem:
291
+ """Main RAG system orchestrating crawling, indexing and querying"""
292
+
293
+ def __init__(self):
294
+ self.session_manager = SessionManager()
295
+ self.crawler = WebCrawler()
296
+ self.llm = OllamaLLM(base_url="http://localhost:11434", model="mistral")
297
+
298
+ def crawl_and_index(self, session_id: str, start_url: str) -> Dict[str, Any]:
299
+ """Crawl website and index content in Qdrant"""
300
+ print(f"[RAGSystem] Starting crawl and index for session {session_id} with URL: {start_url}")
301
+ try:
302
+ session = self.session_manager.get_session(session_id)
303
+ all_urls = self.crawler.crawl_recursive(start_url)
304
+ documents, successful_urls = [], []
305
+
306
+ print(f"[RAGSystem] Crawled {len(all_urls)} URLs for session {session_id}")
307
+
308
+ for url in all_urls[:20]: # Limit to 20 URLs
309
+ try:
310
+ print(f"[RAGSystem] Processing URL: {url}")
311
+ response = requests.get(url, timeout=10, headers={"User-Agent": "Mozilla/5.0"})
312
+ soup = BeautifulSoup(response.content, "html.parser")
313
+ for tag in soup(["script", "style"]):
314
+ tag.decompose()
315
+ text = " ".join(chunk.strip() for chunk in soup.get_text().splitlines() if chunk.strip())
316
+ if len(text) > 100:
317
+ documents.append(Document(text, {"source_url": url, "session_id": session_id}))
318
+ successful_urls.append(url)
319
+ except Exception as e:
320
+ logger.warning(f"Error processing {url}: {e}")
321
+
322
+ if documents:
323
+ self.session_manager.add_documents_to_qdrant(session_id, documents)
324
+ return {
325
+ "status": "success",
326
+ "urls_processed": successful_urls,
327
+ "total_documents": len(documents)
328
+ }
329
+ return {"status": "error", "message": "No documents indexed"}
330
+ except Exception as e:
331
+ logger.error(f"crawl_and_index error: {e}")
332
+ return {
333
+ "status": "error",
334
+ "message": f"Error during crawling and indexing: {str(e)}"
335
+ }
336
+
337
+ async def chat(
338
+ self,
339
+ session_id: str,
340
+ question: str,
341
+ model: str = "mistral",
342
+ ollama_url: str = None,
343
+ gemini_api_key: str = None
344
+ ) -> Dict[str, Any]:
345
+ """
346
+ Handle chat requests with model selection.
347
+ Supports both Mistral (via Ollama) and Gemini models.
348
+ """
349
+ try:
350
+ # Get session data
351
+ session = self.session_manager.get_session(session_id)
352
+ if not session.get('documents'):
353
+ return {
354
+ "status": "error",
355
+ "message": "No documents indexed for this session"
356
+ }
357
+
358
+ # Select appropriate LLM
359
+ if model == "mistral" and ollama_url:
360
+ self.llm = OllamaLLM(base_url=ollama_url, model="mistral")
361
+ elif model == "gemini" and gemini_api_key:
362
+ self.llm = GeminiProvider()
363
+
364
+ # Process the query
365
+ result = self.process_query(session_id, question)
366
+
367
+ # Add to conversation history if successful
368
+ if result["status"] == "success":
369
+ self.session_manager.add_conversation(
370
+ session_id,
371
+ question,
372
+ result["response"]
373
+ )
374
+
375
+ return result
376
+
377
+ except Exception as e:
378
+ logger.error(f"Chat error: {str(e)}")
379
+ return {
380
+ "status": "error",
381
+ "message": f"Chat error: {str(e)}"
382
+ }
383
+
384
+ def process_query(self, session_id: str, query: str) -> Dict[str, Any]:
385
+ """Process user query through RAG pipeline"""
386
+ try:
387
+ # Validate and encode query
388
+ query_embedding = self.session_manager.embedding_model.encode(query)
389
+ if isinstance(query_embedding, np.ndarray):
390
+ query_embedding = query_embedding.astype("float32")
391
+
392
+ # Search with proper parameters
393
+ search_result = self.session_manager.search_qdrant(
394
+ session_id=session_id,
395
+ query_embedding=query_embedding
396
+ )
397
+
398
+ # Generate response using retrieved context
399
+ context = "\n\n".join(hit.payload["text"] for hit in search_result)
400
+ prompt = rag_prompt_template(context, query)
401
+ response = self.llm.generate([prompt]) # Remove .generations[0][0].text
402
+
403
+ return {
404
+ "status": "success",
405
+ "response": response,
406
+ "sources": [hit.payload["metadata"] for hit in search_result]
407
+ }
408
+ except Exception as e:
409
+ logger.error(f"Query processing failed: {e}")
410
+ return {"status": "error", "message": str(e)}
411
+
412
+ # FastAPI App
413
+ app = FastAPI()
414
+ app.add_middleware(
415
+ CORSMiddleware,
416
+ allow_origins=["*"],
417
+ allow_credentials=True,
418
+ allow_methods=["*"],
419
+ allow_headers=["*"],
420
+ )
421
+
422
+ # Initialize RAG system
423
+ rag = RAGSystem()
424
+
425
+ # Request models
426
+ class URLRequest(BaseModel):
427
+ """Request model for URL crawling"""
428
+ url: str
429
+ session_id: Optional[str] = None
430
+
431
+ class ChatRequest(BaseModel):
432
+ """Request model for chat queries"""
433
+ session_id: str
434
+ question: str
435
+
436
+ class SearchRequest(BaseModel):
437
+ """Request model for direct searches"""
438
+ session_id: str
439
+ query: str
440
+ limit: Optional[int] = 5
441
+
442
+ # API Endpoints
443
+ @app.get("/")
444
+ async def root():
445
+ """Health check endpoint"""
446
+ return {"message": "RAG with Ollama Mistral and Qdrant is running"}
447
+
448
+ @app.post("/create_session")
449
+ async def create_session():
450
+ """Create a new session ID"""
451
+ session_id = str(uuid.uuid4())
452
+ return {"session_id": session_id, "status": "success"}
453
+
454
+ @app.post("/crawl_and_index")
455
+ async def crawl_and_index(request: URLRequest):
456
+ """Crawl and index a website"""
457
+ session_id = request.session_id or str(uuid.uuid4())
458
+ result = rag.crawl_and_index(session_id, request.url)
459
+ return result
460
+
461
+ @app.post("/chat")
462
+ async def chat(request: ChatRequest):
463
+ """Handle chat request"""
464
+ return await rag.chat(request.session_id, request.question)
465
+
466
+ @app.post("/search")
467
+ async def search(request: SearchRequest):
468
+ """Handle direct search request"""
469
+ try:
470
+ session = rag.session_manager.get_session(request.session_id)
471
+ query_embedding = rag.session_manager.embedding_model.encode(request.query)
472
+ if isinstance(query_embedding, np.ndarray):
473
+ query_embedding = query_embedding.tolist()
474
+ collection_name = rag.session_manager.get_collection_name(request.session_id)
475
+ search_results = rag.session_manager.qdrant_client.search(
476
+ collection_name=collection_name,
477
+ query_vector=query_embedding,
478
+ limit=request.limit
479
+ )
480
+ return {
481
+ "status": "success",
482
+ "results": [
483
+ {
484
+ "text": hit.payload["text"],
485
+ "score": hit.score,
486
+ "metadata": hit.payload.get("metadata", {})
487
+ }
488
+ for hit in search_results
489
+ ]
490
+ }
491
+ except Exception as e:
492
+ logger.error(f"API search failed: {e}")
493
+ raise HTTPException(status_code=500, detail=str(e))
494
+
495
+ @app.websocket("/ws/chat")
496
+ async def websocket_endpoint(websocket: WebSocket):
497
+ """WebSocket endpoint for real-time chat"""
498
+ await websocket.accept()
499
+ try:
500
+ while True:
501
+ data = await websocket.receive_json()
502
+ uid = data.get("uid")
503
+ question = data.get("question")
504
+
505
+ if not uid or not question:
506
+ await websocket.send_json({"error": "Missing 'uid' or 'question'"})
507
+ continue
508
+
509
+ # Get response from RAG system
510
+ response = await rag.chat(uid, question)
511
+
512
+ # Handle both success and error cases
513
+ if response["status"] == "success":
514
+ await websocket.send_json({
515
+ "uid": uid,
516
+ "question": question,
517
+ "answer": response["response"],
518
+ "sources": response.get("sources", [])
519
+ })
520
+ else:
521
+ await websocket.send_json({
522
+ "uid": uid,
523
+ "error": response["message"]
524
+ })
525
+
526
+ except WebSocketDisconnect:
527
+ logger.info("WebSocket disconnected")
528
+ except Exception as e:
529
+ await websocket.send_json({"error": str(e)})
530
+
531
+ # Main entry point
532
+ if __name__ == "__main__":
533
+ Config.create_dirs()
534
+ from app import launch_interface
535
+ launch_interface()
src/models.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional, Dict, Any
3
+ from datetime import datetime
4
+
5
+ class SearchRequest(BaseModel):
6
+ query: str
7
+ limit: Optional[int] = 5
8
+
9
+ class SearchResponse(BaseModel):
10
+ status: str
11
+ results: Optional[List[Dict[str, Any]]] = None
12
+ message: Optional[str] = None
13
+
14
+ class ChatSession(BaseModel):
15
+ session_id: str
16
+ created_at: datetime
17
+ history: List[Dict[str, Any]] = []
18
+ metadata: Dict[str, Any] = {}
src/prompts/templates.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ def rag_prompt_template(context: str, question: str) -> str:
2
+ return f"""You are an expert assistant. Use ONLY the information from the context below to answer the question.
3
+ If the context does not contain the answer, say "I don't know based on the provided content."
4
+
5
+ Context:
6
+ {context}
7
+
8
+ Question: {question}
9
+ Answer:"""
src/rag.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional
2
+ from datetime import datetime
3
+ import logging
4
+ from app.models import SearchResponse
5
+ from app.main import SessionManager
6
+ from app.vectorstore import VectorStore
7
+ from app.llm import OllamaMistral
8
+ from app.embeddings import EmbeddingHandler
9
+
10
+ class RAGSystem:
11
+ """
12
+ Main RAG (Retrieval-Augmented Generation) system class that orchestrates:
13
+ - Session management
14
+ - Vector store operations
15
+ - LLM interactions
16
+ - Chat history management
17
+ """
18
+
19
+ def __init__(self):
20
+ """Initialize the RAG system components."""
21
+ self.logger = logging.getLogger(__name__)
22
+ # Initialize session manager for tracking conversations
23
+ self.session_manager = SessionManager()
24
+ # Initialize vector store for document storage and retrieval
25
+ self.vectorstore = VectorStore() # This will create its own EmbeddingHandler
26
+ # Initialize LLM for generating responses
27
+ self.llm = OllamaMistral()
28
+
29
+ async def chat(self, session_id: str, query: str) -> SearchResponse:
30
+ """
31
+ Handle a chat query using RAG pipeline:
32
+ 1. Retrieve relevant documents
33
+ 2. Generate response using LLM
34
+ 3. Store conversation history
35
+
36
+ Args:
37
+ session_id: Unique identifier for the conversation session
38
+ query: User's question/input
39
+
40
+ Returns:
41
+ SearchResponse: Contains answer, context, and sources
42
+ """
43
+ print(f"[RAGSystem] Starting chat for session: {session_id} with query: {query}")
44
+ try:
45
+ # Search for relevant documents in vector store
46
+ print(f"[RAGSystem] Searching for query: {query} in session: {session_id}")
47
+ search_results = await self.vectorstore.search_similar(
48
+ session_id=session_id,
49
+ query=query,
50
+ k=3 # Number of similar documents to retrieve
51
+ )
52
+ print(f"[RAGSystem] Search results: {search_results}")
53
+
54
+ if search_results["status"] == "error":
55
+ return SearchResponse(**search_results)
56
+
57
+ # Prepare context from search results
58
+ context = "\n".join([
59
+ result["text"]
60
+ for result in search_results["results"]
61
+ if "text" in result
62
+ ]) or "No relevant context found"
63
+
64
+ # Generate response using LLM with context
65
+ prompt = (
66
+ f"You are assisting with a website analysis. Here's relevant context from the website:\n"
67
+ f"{context}\n\n"
68
+ f"Question: {query}\n"
69
+ f"Please provide a detailed answer based on the website content:"
70
+ )
71
+ response = await self.llm.generate_response(prompt)
72
+
73
+ # Save conversation to history
74
+ self.session_manager.add_to_history(
75
+ session_id=session_id,
76
+ question=query,
77
+ answer=response
78
+ )
79
+
80
+ return SearchResponse(
81
+ status="success",
82
+ results=[{
83
+ "answer": response,
84
+ "context": context,
85
+ "sources": search_results["results"]
86
+ }]
87
+ )
88
+
89
+ except Exception as e:
90
+ self.logger.error(f"Error in chat: {str(e)}", exc_info=True)
91
+ return SearchResponse(
92
+ status="error",
93
+ message=f"Chat error: {str(e)}"
94
+ )
95
+
96
+ def _validate_session(self, session_id: str) -> bool:
97
+ """
98
+ Validate and potentially initialize a session.
99
+
100
+ Args:
101
+ session_id: Session identifier to validate
102
+
103
+ Returns:
104
+ bool: True if session is valid/exists, False otherwise
105
+ """
106
+ try:
107
+ # Initialize session if it doesn't exist
108
+ if not self.session_manager.session_exists(session_id):
109
+ self.logger.info(f"Session {session_id} not found, initializing.")
110
+ self.session_manager.get_session(session_id)
111
+ # Verify vectorstore collection exists
112
+ if not self.vectorstore.collection_exists(session_id):
113
+ self.logger.warning(f"Vectorstore collection missing for session: {session_id}")
114
+ return False
115
+ return True
116
+ except Exception as e:
117
+ self.logger.error(f"Session validation failed: {str(e)}")
118
+ return False
119
+
120
+ async def get_chat_history(self, session_id: str, limit: int = 100) -> List[Dict]:
121
+ """
122
+ Retrieve chat history for a session.
123
+
124
+ Args:
125
+ session_id: Session identifier
126
+ limit: Maximum number of history items to return
127
+
128
+ Returns:
129
+ List[Dict]: Chat history entries or empty list if error occurs
130
+ """
131
+ try:
132
+ if not self._validate_session(session_id):
133
+ return []
134
+
135
+ return self.session_manager.get_history(session_id, limit)
136
+ except Exception as e:
137
+ self.logger.error(f"Error getting chat history: {str(e)}")
138
+ return []
src/services/__init__.py ADDED
File without changes
src/services/history_services.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from ..database.mongo_handler import MongoDBHandler
3
+
4
+ class HistoryService:
5
+ def __init__(self):
6
+ self.db_handler = MongoDBHandler()
7
+
8
+ async def save_conversation(self, session_id: str, user_query: str, bot_response: str, metadata: dict = None):
9
+ """Save conversation to database"""
10
+ conversation = {
11
+ "session_id": session_id,
12
+ "user_query": user_query,
13
+ "bot_response": bot_response,
14
+ "timestamp": datetime.utcnow(),
15
+ "metadata": metadata or {}
16
+ }
17
+ return self.db_handler.insert_conversation(conversation)
18
+
19
+ async def get_session_history(self, session_id: str, limit: int = 100) -> list:
20
+ """Retrieve conversation history for a session"""
21
+ return self.db_handler.get_conversations_by_session(session_id, limit)
22
+
23
+ async def update_metadata(self, session_id: str, query: str, metadata: dict) -> bool:
24
+ """Update metadata for a specific query"""
25
+ return self.db_handler.update_conversation_metadata(session_id, query, metadata)
26
+
27
+ async def clear_session_history(self, session_id: str) -> int:
28
+ """Clear all history for a session"""
29
+ return self.db_handler.delete_session_history(session_id)
src/services/qdrant_handler.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from qdrant_client import QdrantClient
2
+ from config import QDRANT_HOST, QDRANT_PORT, QDRANT_COLLECTION_NAME
3
+
4
+ class QdrantHandler:
5
+ def __init__(self):
6
+ self.client = QdrantClient(
7
+ host=QDRANT_HOST,
8
+ port=QDRANT_PORT
9
+ )
10
+ self.collection_name = QDRANT_COLLECTION_NAME
src/storage.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ import logging
4
+ from typing import Dict, List, Any
5
+ from app.config import Config
6
+ from qdrant_client import QdrantClient
7
+ from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
8
+
9
+ class SessionStorage:
10
+ """
11
+ Manages session persistence using a hybrid storage approach:
12
+ - Stores session metadata in local pickle files
13
+ - Stores vector data in Qdrant collections
14
+ - Maintains connection between the two
15
+ """
16
+
17
+ def __init__(self):
18
+ """
19
+ Initialize the session storage system.
20
+ Sets up Qdrant client connection and ensures storage directory exists.
21
+ """
22
+ try:
23
+ Config.create_storage_dir()
24
+ self.logger = logging.getLogger(__name__)
25
+ # Initialize Qdrant client with configuration from Config
26
+ self.qdrant_client = QdrantClient(
27
+ host=Config.QDRANT_HOST,
28
+ port=Config.QDRANT_PORT,
29
+ prefer_grpc=True # Use gRPC for better performance
30
+ )
31
+ self.logger.info("Qdrant client initialized")
32
+ except Exception as e:
33
+ self.logger.error(f"Storage initialization error: {str(e)}")
34
+ raise RuntimeError("Storage initialization failed") from e
35
+
36
+ def get_session_path(self, session_id: str) -> str:
37
+ """
38
+ Get the filesystem path for a session's pickle file.
39
+
40
+ Args:
41
+ session_id: Unique session identifier
42
+
43
+ Returns:
44
+ str: Full path to session file
45
+ """
46
+ return os.path.join(Config.STORAGE_DIR, f"{session_id}.pkl")
47
+
48
+ def save_session(self, session_id: str, data: Dict):
49
+ """
50
+ Persist session data to disk (excluding Qdrant references).
51
+
52
+ Args:
53
+ session_id: Session identifier
54
+ data: Session data dictionary
55
+ """
56
+ session_path = self.get_session_path(session_id)
57
+
58
+ # Remove Qdrant collection reference before saving to avoid serialization issues
59
+ data = data.copy()
60
+ if 'qdrant_collection' in data:
61
+ del data['qdrant_collection']
62
+
63
+ with open(session_path, 'wb') as f:
64
+ pickle.dump(data, f)
65
+
66
+ def load_session(self, session_id: str) -> Dict:
67
+ """
68
+ Load session data from disk and reconnect to Qdrant collection.
69
+
70
+ Args:
71
+ session_id: Session identifier
72
+
73
+ Returns:
74
+ Dict: Session data with restored Qdrant collection reference
75
+ """
76
+ session_path = self.get_session_path(session_id)
77
+
78
+ if not os.path.exists(session_path):
79
+ return None
80
+
81
+ with open(session_path, 'rb') as f:
82
+ data = pickle.load(f)
83
+
84
+ # Restore Qdrant collection reference
85
+ collection_name = f"session_{session_id}"
86
+ data['qdrant_collection'] = collection_name
87
+
88
+ # Ensure collection exists in Qdrant (create if missing)
89
+ if not self.qdrant_client.collection_exists(collection_name):
90
+ self.logger.warning(f"Qdrant collection {collection_name} missing, creating new")
91
+ self.qdrant_client.create_collection(
92
+ collection_name=collection_name,
93
+ vectors_config=VectorParams(
94
+ size=Config.EMBEDDING_SIZE,
95
+ distance=Distance.COSINE
96
+ )
97
+ )
98
+
99
+ return data
100
+
101
+ def delete_session(self, session_id: str):
102
+ """
103
+ Completely remove a session (both disk and Qdrant storage).
104
+
105
+ Args:
106
+ session_id: Session identifier to delete
107
+ """
108
+ session_path = self.get_session_path(session_id)
109
+
110
+ # Delete Qdrant collection first
111
+ collection_name = f"session_{session_id}"
112
+ try:
113
+ self.qdrant_client.delete_collection(collection_name)
114
+ self.logger.info(f"Deleted Qdrant collection: {collection_name}")
115
+ except Exception as e:
116
+ self.logger.error(f"Error deleting Qdrant collection: {str(e)}")
117
+
118
+ # Delete session file
119
+ if os.path.exists(session_path):
120
+ os.remove(session_path)
121
+
122
+ class QdrantStorage:
123
+ """
124
+ Manages vector storage operations using Qdrant.
125
+ Handles collection management and vector operations.
126
+ """
127
+
128
+ def __init__(self, collection_name: str, vector_size: int,
129
+ host: str = Config.QDRANT_HOST, port: int = Config.QDRANT_PORT):
130
+ """
131
+ Initialize Qdrant storage for a specific collection.
132
+
133
+ Args:
134
+ collection_name: Name of the Qdrant collection
135
+ vector_size: Dimensionality of vectors to store
136
+ host: Qdrant server host (default from Config)
137
+ port: Qdrant server port (default from Config)
138
+ """
139
+ self.logger = logging.getLogger(__name__)
140
+ self.collection_name = collection_name
141
+ self.vector_size = vector_size
142
+ # Initialize Qdrant client with gRPC preference
143
+ self.qdrant = QdrantClient(host=host, port=port, prefer_grpc=True)
144
+ self._ensure_collection()
145
+
146
+ def _ensure_collection(self):
147
+ """
148
+ Ensure the collection exists in Qdrant.
149
+ Creates it if missing, otherwise verifies configuration.
150
+ """
151
+ try:
152
+ collection_info = self.qdrant.get_collection(self.collection_name)
153
+ if collection_info.vectors_count > 0:
154
+ self.logger.info(f"Using existing Qdrant collection: {self.collection_name}")
155
+ except Exception:
156
+ self.logger.info(f"Creating Qdrant collection: {self.collection_name}")
157
+ self.qdrant.recreate_collection(
158
+ collection_name=self.collection_name,
159
+ vectors_config=VectorParams(
160
+ size=self.vector_size,
161
+ distance=Distance.COSINE # Using cosine similarity
162
+ )
163
+ )
164
+
165
+ def add_vectors(self, vectors: List[List[float]], payloads: List[Dict[str, Any]], offset: int = 0):
166
+ """
167
+ Add vectors and associated metadata to the collection.
168
+
169
+ Args:
170
+ vectors: List of vector embeddings
171
+ payloads: List of metadata dictionaries
172
+ offset: Starting ID for new points (default 0)
173
+ """
174
+ points = [
175
+ PointStruct(
176
+ id=offset + idx, # Sequential IDs with optional offset
177
+ vector=vector,
178
+ payload=payload
179
+ )
180
+ for idx, (vector, payload) in enumerate(zip(vectors, payloads))
181
+ ]
182
+ self.qdrant.upsert(
183
+ collection_name=self.collection_name,
184
+ points=points,
185
+ wait=True # Ensure immediate persistence
186
+ )
187
+ self.logger.info(f"Added {len(points)} vectors to Qdrant collection '{self.collection_name}'")
188
+
189
+ def search(self, query_vector: List[float], session_id: str, limit: int = 5):
190
+ """
191
+ Search the collection for similar vectors, filtered by session.
192
+
193
+ Args:
194
+ query_vector: The vector to compare against
195
+ session_id: Session identifier to filter results
196
+ limit: Maximum number of results to return
197
+
198
+ Returns:
199
+ List[Dict]: Search results with scores and metadata
200
+ """
201
+ # Add session filter to ensure only current session results
202
+ results = self.qdrant.search(
203
+ collection_name=self.collection_name,
204
+ query_vector=query_vector,
205
+ query_filter=Filter(
206
+ must=[
207
+ FieldCondition(
208
+ key="session_id",
209
+ match=MatchValue(value=session_id)
210
+ )
211
+ ]
212
+ ),
213
+ limit=limit
214
+ )
215
+ return [
216
+ {
217
+ "id": hit.id,
218
+ "score": hit.score,
219
+ "payload": hit.payload
220
+ }
221
+ for hit in results
222
+ ]
src/tests/test_connection.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ from qdrant_client import QdrantClient
3
+
4
+ def check_port(host, port):
5
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
6
+ return s.connect_ex((host, port)) == 0
7
+
8
+ host = "localhost"
9
+ port = 6333
10
+
11
+ if check_port(host, port):
12
+ print(f"Port {port} is open. Testing Qdrant API...")
13
+ try:
14
+ client = QdrantClient(host=host, port=port)
15
+ print("Success! Collections:", client.get_collections())
16
+ except Exception as e:
17
+ print(f"API Error: {e}")
18
+ else:
19
+ print(f"ERROR: Port {port} is closed. Check if Qdrant is running.")
src/tests/test_qdrant_integration.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client.models import VectorParams, Distance
4
+
5
+ @pytest.fixture
6
+ def qdrant_client():
7
+ return QdrantClient(host="localhost", port=6333)
8
+
9
+ def test_collection_creation(qdrant_client):
10
+ test_collection = "test_collection"
11
+ qdrant_client.recreate_collection(test_collection, vectors_config=VectorParams(size=384, distance=Distance.COSINE))
12
+ assert qdrant_client.collection_exists(test_collection)
src/tests/test_storage.py ADDED
File without changes
src/tests/test_ws.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import websockets
3
+ import json
4
+
5
+ async def test_ws():
6
+ uri = "ws://localhost:8000/ws/test-session"
7
+ async with websockets.connect(uri) as ws:
8
+ await ws.send(json.dumps({"query": "What is AI?"}))
9
+ response = await ws.recv()
10
+ print("Response:", response)
11
+
12
+ asyncio.run(test_ws())
src/utils/__init__.py ADDED
File without changes
src/utils/response_formatter.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from typing import Dict, List
3
+ import re
4
+
5
+ class ResponseFormatter:
6
+ @staticmethod
7
+ def clean_text(text: str) -> str:
8
+ """Clean up raw text by removing excessive whitespace and common boilerplate"""
9
+ # Remove multiple newlines and spaces
10
+ text = re.sub(r'\n+', '\n', text)
11
+ text = re.sub(r'\s+', ' ', text)
12
+
13
+ # Remove common website elements
14
+ text = re.sub(r'SUBSCRIBE|RECENT|POPULAR|TRENDY', '', text, flags=re.I)
15
+ text = re.sub(r'Copyright © \d{4}.*', '', text)
16
+ text = re.sub(r'Privacy Policy|Terms of Service', '', text)
17
+
18
+ return text.strip()
19
+
20
+ @staticmethod
21
+ def format_sources(sources: List[Dict]) -> str:
22
+ """Format source URLs into readable references"""
23
+ if not sources:
24
+ return ""
25
+
26
+ formatted_sources = "\n\nSources:\n"
27
+ for i, source in enumerate(sources, 1):
28
+ formatted_sources += f"{i}. {source['url']}\n"
29
+ return formatted_sources
30
+
31
+ @staticmethod
32
+ def format_response(api_response: Dict) -> str:
33
+ """Convert API response to natural language"""
34
+ if "error" in api_response:
35
+ return f"Sorry, I encountered an error: {api_response['error']}"
36
+
37
+ if "response" not in api_response:
38
+ return "I couldn't find any relevant information."
39
+
40
+ # Clean and format the main response
41
+ clean_response = ResponseFormatter.clean_text(api_response["response"])
42
+
43
+ # Add sources if available
44
+ if "sources" in api_response:
45
+ clean_response += ResponseFormatter.format_sources(api_response["sources"])
46
+
47
+ return clean_response
src/vectorstore.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import List, Dict
4
+ from qdrant_client import QdrantClient
5
+ from qdrant_client.models import Distance, VectorParams, PointStruct, Filter
6
+ import uuid
7
+ from langchain.embeddings import HuggingFaceEmbeddings
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+ from langchain.schema import Document
10
+
11
+ from app.config import Config
12
+ from app.crawler import URLCrawler
13
+ from app.models import SearchResponse
14
+ from app.embeddings import EmbeddingHandler
15
+
16
+ class VectorStore:
17
+ """
18
+ A class to handle vector storage operations using Qdrant.
19
+ Manages document storage, retrieval, and similarity search in vector space.
20
+ """
21
+
22
+ def __init__(self):
23
+ """Initialize the VectorStore with Qdrant client and embedding handlers."""
24
+ self.logger = logging.getLogger(__name__)
25
+ # Initialize Qdrant client with configuration from Config
26
+ self.client = QdrantClient(
27
+ url=Config.QDRANT_URL,
28
+ api_key=Config.QDRANT_API_KEY,
29
+ prefer_grpc=False,
30
+ timeout=30
31
+ )
32
+ # Initialize embedding handler and text splitter
33
+ self.embedding_handler = EmbeddingHandler()
34
+ self.embeddings = HuggingFaceEmbeddings(model_name=Config.EMBEDDING_MODEL)
35
+ self.text_splitter = RecursiveCharacterTextSplitter(
36
+ chunk_size=1000, # Size of each text chunk
37
+ chunk_overlap=200 # Overlap between chunks for context preservation
38
+ )
39
+
40
+ def collection_exists(self, session_id: str) -> bool:
41
+ """
42
+ Check if a collection exists for the given session ID.
43
+ Attempts to create the collection if it doesn't exist.
44
+
45
+ Args:
46
+ session_id: Unique identifier for the session
47
+
48
+ Returns:
49
+ bool: True if collection exists or was created successfully, False otherwise
50
+ """
51
+ collection_name = self._get_collection_name(session_id)
52
+ try:
53
+ self.client.get_collection(collection_name=collection_name)
54
+ return True
55
+ except Exception:
56
+ # Try to create the collection if it doesn't exist
57
+ try:
58
+ self.client.recreate_collection(
59
+ collection_name=collection_name,
60
+ vectors_config=VectorParams(
61
+ size=self.embedding_handler.embedding_dim,
62
+ distance=Distance.COSINE # Using cosine similarity
63
+ )
64
+ )
65
+ self.logger.info(f"Created collection {collection_name} automatically.")
66
+ return True
67
+ except Exception as e:
68
+ self.logger.error(f"Failed to create collection {collection_name}: {e}")
69
+ return False
70
+
71
+ def _get_collection_name(self, session_id: str) -> str:
72
+ """
73
+ Generate a standardized collection name from session ID.
74
+
75
+ Args:
76
+ session_id: Unique session identifier
77
+
78
+ Returns:
79
+ str: Formatted collection name
80
+ """
81
+ return f"collection_{session_id}"
82
+
83
+ async def search_similar(self, session_id: str, query: str, k: int = 5) -> Dict:
84
+ """
85
+ Search for similar documents in the vector store.
86
+
87
+ Args:
88
+ session_id: Session identifier for the collection
89
+ query: Search query text
90
+ k: Number of similar documents to return (default: 5)
91
+
92
+ Returns:
93
+ Dict: Search results or error message
94
+ """
95
+ try:
96
+ if not self.collection_exists(session_id):
97
+ return {"status": "error", "message": "Collection not found"}
98
+ return await self.embedding_handler.search_collection(
99
+ collection_name=self._get_collection_name(session_id),
100
+ query=query,
101
+ k=k
102
+ )
103
+
104
+ except Exception as e:
105
+ self.logger.error(f"Search failed: {str(e)}")
106
+ return {"status": "error", "message": str(e)}
107
+
108
+ def create_from_url(self, url: str, session_id: str) -> None:
109
+ """
110
+ Crawl a website and create a vector store from its content.
111
+
112
+ Args:
113
+ url: Website URL to crawl
114
+ session_id: Unique session identifier for storage
115
+
116
+ Raises:
117
+ Exception: If vector store creation fails
118
+ """
119
+ try:
120
+ # Initialize crawler and fetch pages
121
+ crawler = URLCrawler()
122
+ raw_pages = crawler.crawl_sync(url, Config.MAX_PAGES_TO_CRAWL)
123
+
124
+ # Convert crawled pages to LangChain Document format
125
+ documents: List[Document] = [
126
+ Document(
127
+ page_content=page["content"],
128
+ metadata={
129
+ "source": page["url"],
130
+ "title": page["title"],
131
+ "last_modified": page.get("last_modified", "")
132
+ }
133
+ ) for page in raw_pages
134
+ ]
135
+
136
+ # Split documents into chunks
137
+ texts = self.text_splitter.split_documents(documents)
138
+ collection_name = self._get_collection_name(session_id)
139
+
140
+ # Create or recreate collection with proper vector configuration
141
+ self.client.recreate_collection(
142
+ collection_name=collection_name,
143
+ vectors_config=VectorParams(
144
+ size=self.embedding_handler.embedding_dim,
145
+ distance=Distance.COSINE
146
+ )
147
+ )
148
+
149
+ # Prepare points for batch insertion
150
+ points = [
151
+ PointStruct(
152
+ id=str(uuid.uuid4()), # Generate unique ID for each point
153
+ vector=self.embeddings.embed_query(doc.page_content),
154
+ payload={
155
+ "page_content": doc.page_content,
156
+ "metadata": doc.metadata
157
+ }
158
+ ) for doc in texts
159
+ ]
160
+
161
+ # Upsert all points into the collection
162
+ self.client.upsert(
163
+ collection_name=collection_name,
164
+ points=points
165
+ )
166
+ self.logger.info(f"Created vector store for session {session_id}")
167
+ except Exception as e:
168
+ self.logger.error(f"Vector store creation failed: {str(e)}")
169
+ raise
170
+
171
+ def save_vectorstore(self, vectorstore: None, session_id: str):
172
+ """
173
+ Placeholder method since Qdrant persists data automatically.
174
+
175
+ Args:
176
+ vectorstore: Not used (Qdrant handles persistence)
177
+ session_id: Session identifier for logging
178
+ """
179
+ self.logger.debug(f"Data automatically persisted for session {session_id}")
180
+
181
+ def load_vectorstore(self, session_id: str) -> None:
182
+ """
183
+ Verify that a collection exists for the given session ID.
184
+
185
+ Args:
186
+ session_id: Session identifier to check
187
+
188
+ Raises:
189
+ ValueError: If collection doesn't exist
190
+ """
191
+ if not self.collection_exists(session_id):
192
+ raise ValueError(f"Collection for session {session_id} not found")
src/web/__init__.py ADDED
File without changes