Spaces:
Sleeping
Sleeping
| import os | |
| import base64 | |
| import logging | |
| import json | |
| import re | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, status, Depends, Header, HTTPException, Query | |
| from fastapi.concurrency import run_in_threadpool # This line is corrected (no syntax error) | |
| from pydantic import BaseModel | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from elevenlabs.client import ElevenLabs | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_postgres.vectorstores import PGVector | |
| from sqlalchemy import create_engine | |
| import asyncio | |
| import io | |
| from typing import Optional | |
| # --- SETUP --- | |
| os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
| logging.getLogger('tensorflow').setLevel(logging.ERROR) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Load environment variables | |
| load_dotenv() | |
| NEON_DATABASE_URL = os.getenv("NEON_DATABASE_URL") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ELEVENLABS_API_KEY = os.getenv("ELEVENLABS_API_KEY") | |
| SHARED_SECRET = os.getenv("SHARED_SECRET") | |
| # --- CONFIGURATION --- | |
| COLLECTION_NAME = "real_estate_embeddings" | |
| EMBEDDING_MODEL = "hkunlp/instructor-large" | |
| ELEVENLABS_VOICE_NAME = "Leo" | |
| PLANNER_MODEL = "gpt-4o-mini" | |
| ANSWERER_MODEL = "gpt-4o" | |
| TABLE_DESCRIPTIONS = """ | |
| - "ongoing_projects_source": Details about projects currently under construction. | |
| - "upcoming_projects_source": Information on future planned projects. | |
| - "completed_projects_source": Facts about projects that are already finished. | |
| - "historical_sales_source": Specific sales records, including price, date, and property ID. | |
| - "past_customers_source": Information about previous customers. | |
| - "feedback_source": Customer feedback and ratings for projects. | |
| """ | |
| # --- GLOBAL VARIABLES & CLIENTS --- | |
| embeddings = None | |
| vector_store = None | |
| client_openai = OpenAI(api_key=OPENAI_API_KEY) | |
| client_elevenlabs = ElevenLabs(api_key=ELEVENLABS_API_KEY) | |
| # --- FASTAPI LIFESPAN MANAGEMENT --- | |
| async def lifespan(app: FastAPI): | |
| """Manages application startup and shutdown logic.""" | |
| global embeddings, vector_store | |
| logging.info(f"Initializing embedding model: '{EMBEDDING_MODEL}'...") | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| logging.info("Embedding model loaded successfully.") | |
| logging.info(f"Connecting to vector store '{COLLECTION_NAME}'...") | |
| engine = create_engine(NEON_DATABASE_URL, pool_pre_ping=True) | |
| vector_store = PGVector( | |
| connection=engine, | |
| collection_name=COLLECTION_NAME, | |
| embeddings=embeddings, | |
| ) | |
| logging.info("Successfully connected to the vector store.") | |
| yield | |
| logging.info("Application shutting down.") | |
| # --- INITIALIZE FastAPI APP --- | |
| app = FastAPI(lifespan=lifespan) | |
| # --- PROMPTS --- | |
| QUERY_FORMULATION_PROMPT = f""" | |
| You are a query analysis agent. Your task is to transform a user's query into a precise search query for a vector database and determine the correct table to filter by. | |
| **Available Tables:** | |
| {TABLE_DESCRIPTIONS} | |
| **User's Query:** "{{user_query}}" | |
| **Your Task:** | |
| 1. Rephrase the user's query into a clear, keyword-focused English question suitable for a database search. | |
| 2. Analyze the user's query for keywords indicating project status (e.g., "ongoing", "under construction", "completed", "finished", "upcoming", "new launch"). | |
| 3. If such status keywords are present, identify the single most relevant table from the list above to filter by. | |
| 4. If no specific status keywords are mentioned (e.g., the user asks generally about projects in a location), set the filter table to null. | |
| 5. Respond ONLY with a JSON object containing "search_query" and "filter_table" (which should be the table name string or null). | |
| """ | |
| ANSWER_SYSTEM_PROMPT = """ | |
| You are an expert AI assistant for a premier real estate developer. | |
| ## YOUR PERSONA | |
| - You are professional, helpful, and highly knowledgeable. Your tone should be polite and articulate. | |
| ## CORE BUSINESS KNOWLEDGE | |
| - **Operational Cities:** We are currently operational in Pune, Mumbai, Bengaluru, Delhi, Chennai, Hyderabad, Goa, Gurgaon, Kolkata. | |
| - **Property Types:** We offer luxury apartments, villas, and commercial properties. | |
| - **Budget Range:** Our residential properties typically range from 45 lakhs to 5 crores. | |
| ## CORE RULES | |
| 1. **Language Adaptation:** If the user's original query was in Hinglish, respond in Hinglish. If in English, respond in English. | |
| 2. **Fact-Based Answers:** Use the provided CONTEXT to answer the user's question. If the context is empty, use your Core Business Knowledge. | |
| 3. **Stay on Topic:** Only answer questions related to real estate. | |
| """ | |
| # --- HELPER FUNCTIONS (to be run in threadpool) --- | |
| def transcribe_audio(audio_bytes: bytes) -> str: | |
| """ | |
| Transcribes any audio format (WAV, MP3, WebM, Opus) from raw bytes. | |
| Whisper will auto-detect the format. | |
| """ | |
| for attempt in range(3): | |
| try: | |
| audio_file = io.BytesIO(audio_bytes) | |
| # Give it a "name" hint for the API, but format is auto-detected | |
| audio_file.name = "input.audio" | |
| transcript = client_openai.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file | |
| ) | |
| text = transcript.text | |
| # Check for Hindi script and transliterate | |
| if re.search(r'[\u0900-\u097F]', text): | |
| translit_prompt = f"Transliterate this Hindi text to Roman script (Hinglish style): {text}" | |
| response = client_openai.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "user", "content": translit_prompt}], | |
| temperature=0.0 | |
| ) | |
| text = response.choices[0].message.content | |
| return text.strip() | |
| except Exception as e: | |
| logging.error(f"Error during transcription (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return "" | |
| def generate_elevenlabs_sync(text: str, voice: str) -> bytes: | |
| """Synchronous ElevenLabs generation wrapper for run_in_threadpool.""" | |
| for attempt in range(3): | |
| try: | |
| return client_elevenlabs.generate( | |
| text=text, | |
| voice=voice, | |
| model="eleven_multilingual_v2", | |
| output_format="mp3_44100_128" | |
| ) | |
| except Exception as e: | |
| logging.error(f"Error in ElevenLabs generate (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return b'' | |
| # --- RAG/LLM FUNCTIONS (async) --- | |
| async def formulate_search_plan(user_query: str) -> dict: | |
| logging.info("Formulating search plan with Planner LLM...") | |
| for attempt in range(3): | |
| try: | |
| response = await run_in_threadpool( | |
| client_openai.chat.completions.create, | |
| model=PLANNER_MODEL, | |
| messages=[{"role": "user", "content": QUERY_FORMULATION_PROMPT.format(user_query=user_query)}], | |
| response_format={"type": "json_object"}, | |
| temperature=0.0 | |
| ) | |
| plan = json.loads(response.choices[0].message.content) | |
| logging.info(f"Search plan received: {plan}") | |
| return plan | |
| except Exception as e: | |
| logging.error(f"Error in Planner LLM call (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return {"search_query": user_query, "filter_table": None} | |
| async def get_agent_response(user_text: str) -> str: | |
| """Runs RAG and generation logic for a given text query with retries.""" | |
| for attempt in range(3): | |
| try: | |
| search_plan = await formulate_search_plan(user_text) | |
| search_query = search_plan.get("search_query", user_text) | |
| filter_table = search_plan.get("filter_table") | |
| search_filter = {"source_table": filter_table} if filter_table else {} | |
| if search_filter: | |
| logging.info(f"Applying initial filter: {search_filter}") | |
| # Run blocking DB call in threadpool | |
| retrieved_docs = await run_in_threadpool( | |
| vector_store.similarity_search, | |
| search_query, k=3, filter=search_filter | |
| ) | |
| if not retrieved_docs: | |
| logging.info("Initial search returned no results. Performing a broader fallback search.") | |
| retrieved_docs = await run_in_threadpool( | |
| vector_store.similarity_search, | |
| search_query, k=3 | |
| ) | |
| context_text = "\n\n".join([doc.page_content for doc in retrieved_docs]) | |
| logging.info(f"Retrieved Context (preview): {context_text[:500]}...") | |
| final_prompt_messages = [ | |
| {"role": "system", "content": ANSWER_SYSTEM_PROMPT}, | |
| {"role": "system", "content": f"Use the following CONTEXT to answer:\n{context_text}"}, | |
| {"role": "user", "content": f"My original question was: '{user_text}'"} | |
| ] | |
| # Run blocking OpenAI call in threadpool | |
| final_response = await run_in_threadpool( | |
| client_openai.chat.completions.create, | |
| model=ANSWERER_MODEL, | |
| messages=final_prompt_messages | |
| ) | |
| # --- TYPO FIX WAS HERE --- | |
| return final_response.choices[0].message.content | |
| except Exception as e: | |
| logging.error(f"Error in get_agent_response (attempt {attempt+1}): {e}", exc_info=True) | |
| if attempt == 2: | |
| return "Sorry, I couldn't generate a response. Please try again." | |
| # --- AUTH / TEST ENDPOINT HELPERS --- | |
| class TextQuery(BaseModel): | |
| query: str | |
| async def verify_token(x_auth_token: str = Header(...)): | |
| """Dependency to verify the shared secret token.""" | |
| if not SHARED_SECRET or x_auth_token != SHARED_SECRET: | |
| logging.warning("Authentication failed for /test-text-query.") | |
| raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing authentication token") | |
| logging.info("Authentication successful for /test-text-query.") | |
| # --- API ENDPOINTS --- | |
| async def test_text_query_endpoint(query: TextQuery): | |
| """Endpoint for text-based testing via Swagger UI.""" | |
| logging.info(f"Received text query: {query.query}") | |
| response_text = await get_agent_response(query.query) | |
| logging.info(f"Generated text response: {response_text}") | |
| return {"response": response_text} | |
| async def browser_websocket_endpoint( | |
| websocket: WebSocket, | |
| token: Optional[str] = Query(None) # Get token from query param | |
| ): | |
| """ | |
| Main WebSocket endpoint for browser-based audio. | |
| Authenticates using a query parameter. | |
| """ | |
| # Authentication block | |
| if not token or token != SHARED_SECRET: | |
| logging.warning(f"Browser auth failed: Invalid token '{token}'") | |
| await websocket.accept() # Accept briefly to send error | |
| await websocket.close(code=status.WS_1008_POLICY_VIOLATION) | |
| return | |
| await websocket.accept() | |
| logging.info("Browser client connected and authenticated.") | |
| try: | |
| while True: | |
| # 1. Receive JSON message from browser | |
| message = await websocket.receive_json() | |
| audio_base64 = message.get("audio") | |
| if not audio_base64: | |
| continue | |
| logging.info("Received audio blob from browser.") | |
| audio_bytes = base64.b64decode(audio_base64) | |
| # 2. Transcribe (Shared logic) | |
| user_text = await run_in_threadpool(transcribe_audio, audio_bytes) | |
| if not user_text: | |
| logging.info("Transcription empty; skipping.") | |
| continue | |
| logging.info(f"User said: {user_text}") | |
| # 3. Get AI response (Shared logic) | |
| agent_response_text = await get_agent_response(user_text) | |
| if not agent_response_text: | |
| logging.warning("Agent generated empty response.") | |
| continue | |
| logging.info(f"AI Responded (preview): {agent_response_text[:100]}...") | |
| # 4. Generate AI speech (Shared logic) | |
| ai_audio_bytes = await run_in_threadpool( | |
| generate_elevenlabs_sync, | |
| agent_response_text, | |
| ELEVENLABS_VOICE_NAME | |
| ) | |
| if not ai_audio_bytes: | |
| continue | |
| # 5. Send audio and text back to browser | |
| response_audio_base64 = base64.b64encode(ai_audio_bytes).decode('utf-8') | |
| await websocket.send_json({ | |
| "text": agent_response_text, | |
| "audio": response_audio_base64 | |
| }) | |
| logging.info("Sent AI audio response back to browser.") | |
| except WebSocketDisconnect: | |
| logging.info("Browser client disconnected.") | |
| except Exception as e: | |
| logging.error(f"An error occurred in browser websocket: {e}", exc_info=True) | |
| finally: | |
| try: | |
| await websocket.close() | |
| except Exception: | |
| pass |