Spaces:
Configuration error
Configuration error
| import json | |
| import os | |
| import logging | |
| import torch | |
| from typing import List | |
| from langchain_core.documents import Document | |
| from sentence_transformers import SentenceTransformer | |
| try: | |
| from datasets import load_dataset | |
| except ImportError: | |
| load_dataset = None | |
| logger = logging.getLogger(__name__) | |
| def get_device(): | |
| """ | |
| Determine the appropriate device for PyTorch. | |
| Returns: | |
| str: Device name ('cuda', 'mps', or 'cpu'). | |
| """ | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| elif torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def load_guest_dataset(dataset_path: str = "agents-course/unit3-invitees") -> List[Document]: | |
| """ | |
| Load guest dataset from a local JSON file or Hugging Face dataset. | |
| Args: | |
| dataset_path (str): Path to local JSON file or Hugging Face dataset name. | |
| Returns: | |
| List[Document]: List of Document objects with guest information. | |
| """ | |
| try: | |
| # Try loading from Hugging Face dataset if datasets library is available | |
| if load_dataset and not os.path.exists(dataset_path): | |
| logger.info(f"Attempting to load Hugging Face dataset: {dataset_path}") | |
| guest_dataset = load_dataset(dataset_path, split="train") | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| f"Name: {guest['name']}", | |
| f"Relation: {guest['relation']}", | |
| f"Description: {guest['description']}", | |
| f"Email: {guest['email']}" | |
| ]), | |
| metadata={ | |
| "name": guest["name"], | |
| "relation": guest["relation"], | |
| "description": guest["description"], | |
| "email": guest["email"] | |
| } | |
| ) | |
| for guest in guest_dataset | |
| ] | |
| logger.info(f"Loaded {len(docs)} guests from Hugging Face dataset") | |
| return docs | |
| # Try loading from local JSON file | |
| if os.path.exists(dataset_path): | |
| logger.info(f"Loading guest dataset from local path: {dataset_path}") | |
| with open(dataset_path, 'r') as f: | |
| guests = json.load(f) | |
| docs = [ | |
| Document( | |
| page_content=guest.get('description', ''), | |
| metadata={ | |
| 'name': guest.get('name', ''), | |
| 'relation': guest.get('relation', ''), | |
| 'description': guest.get('description', ''), | |
| 'email': guest.get('email', '') # Optional email field | |
| } | |
| ) | |
| for guest in guests | |
| ] | |
| logger.info(f"Loaded {len(docs)} guests from local JSON") | |
| return docs | |
| # Fallback to mock dataset if both fail | |
| logger.warning(f"Dataset not found at {dataset_path}, using mock dataset") | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| "Name: Dr. Nikola Tesla", | |
| "Relation: old friend from university days", | |
| "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
| "Email: [email protected]" | |
| ]), | |
| metadata={ | |
| "name": "Dr. Nikola Tesla", | |
| "relation": "old friend from university days", | |
| "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
| "email": "[email protected]" | |
| } | |
| ) | |
| ] | |
| logger.info("Loaded mock dataset with 1 guest") | |
| return docs | |
| except Exception as e: | |
| logger.error(f"Failed to load guest dataset: {e}") | |
| # Return mock dataset as final fallback | |
| docs = [ | |
| Document( | |
| page_content="\n".join([ | |
| "Name: Dr. Nikola Tesla", | |
| "Relation: old friend from university days", | |
| "Description: Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
| "Email: [email protected]" | |
| ]), | |
| metadata={ | |
| "name": "Dr. Nikola Tesla", | |
| "relation": "old friend from university days", | |
| "description": "Dr. Nikola Tesla is an old friend from your university days. He's recently patented a new wireless energy transmission system.", | |
| "email": "[email protected]" | |
| } | |
| ) | |
| ] | |
| logger.info("Loaded mock dataset with 1 guest due to error") | |
| return docs | |
| class BM25Retriever: | |
| """ | |
| A retriever class using SentenceTransformer for embedding-based search. | |
| """ | |
| def __init__(self, dataset_path: str): | |
| """ | |
| Initialize the retriever with a SentenceTransformer model. | |
| Args: | |
| dataset_path (str): Path to the dataset for retrieval. | |
| Raises: | |
| Exception: If embedder initialization fails. | |
| """ | |
| try: | |
| self.model = SentenceTransformer("all-MiniLM-L6-v2", device=get_device()) | |
| self.dataset_path = dataset_path | |
| logger.info("Initialized SentenceTransformer") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize embedder: {e}") | |
| raise | |
| def search(self, query: str) -> List[dict]: | |
| """ | |
| Search the dataset for relevant guest information. | |
| Args: | |
| query (str): Search query (e.g., guest name or relation). | |
| Returns: | |
| List[dict]: List of matching guest metadata dictionaries. | |
| """ | |
| try: | |
| # Load dataset | |
| docs = load_guest_dataset(self.dataset_path) | |
| if not docs: | |
| logger.warning("No documents available for search") | |
| return [] | |
| # Convert documents to text for BM25 (using metadata for consistency) | |
| texts = [f"{doc.metadata['name']} {doc.metadata['relation']} {doc.metadata['description']}" for doc in docs] | |
| from langchain_community.retrievers import BM25Retriever | |
| retriever = BM25Retriever.from_texts(texts) | |
| retriever.k = 3 # Limit to top 3 results | |
| # Perform search | |
| results = retriever.invoke(query) | |
| # Map results back to original metadata | |
| matches = [ | |
| docs[i].metadata | |
| for i in range(len(docs)) | |
| if any(f"{docs[i].metadata['name']} {docs[i].metadata['relation']} {docs[i].metadata['description']}" in r.page_content for r in results) | |
| ] | |
| logger.info(f"Found {len(matches)} matches for query: {query}") | |
| return matches[:3] # Return top 3 matches | |
| except Exception as e: | |
| logger.error(f"Search failed for query '{query}': {e}") | |
| return [] |