Spaces:
Running
Running
| """ | |
| Foundation 1.2 | |
| Clinical trial query system with 355M foundation model | |
| """ | |
| import gradio as gr | |
| import os | |
| from pathlib import Path | |
| import pickle | |
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| import logging | |
| from rank_bm25 import BM25Okapi | |
| import re | |
| from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize | |
| hf_token = os.getenv("HF_TOKEN") | |
| # Paths for data storage | |
| # Files will be downloaded from HF Dataset on first run | |
| DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt" | |
| CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl" | |
| EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache | |
| INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_TRIAL_AWARE.pkl" # Pre-built inverted index (638MB) | |
| # HF Dataset containing the large files | |
| DATASET_REPO = "gmkdigitalmedia/foundation1.2-data" | |
| # Global storage | |
| embedder = None | |
| doc_chunks = [] | |
| doc_embeddings = None | |
| bm25_index = None # BM25 index for fast keyword search | |
| inverted_index = None # Inverted index for instant drug lookup | |
| # ============================================================================ | |
| # RAG FUNCTIONS | |
| # ============================================================================ | |
| def load_embedder(): | |
| """Load L6 embedding model (matches generated embeddings)""" | |
| global embedder | |
| if embedder is None: | |
| logger.info("Loading MiniLM-L6 embedding model...") | |
| # Force CPU to avoid CUDA init in main process | |
| embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
| logger.info("L6 model loaded on CPU") | |
| def build_inverted_index(chunks): | |
| """ | |
| Build targeted inverted index for clinical search | |
| Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup | |
| Indexes ONLY what matters: | |
| 1. INTERVENTION - drug/device names | |
| 2. CONDITIONS - diseases being treated | |
| 3. SPONSOR/COLLABORATOR/MANUFACTURER - company names | |
| 4. OUTCOME - trial endpoints (what's being measured) | |
| Does NOT index trial names (unnecessary noise) | |
| """ | |
| import time | |
| t_start = time.time() | |
| inv_index = {} | |
| logger.info("Building targeted index: drugs, diseases, companies, endpoints...") | |
| # Generic words to skip | |
| skip_words = { | |
| 'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial', | |
| 'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active', | |
| 'randomized', 'multicenter', 'open', 'label', 'crossover' | |
| } | |
| for idx, chunk_data in enumerate(chunks): | |
| if idx % 100000 == 0 and idx > 0: | |
| logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...") | |
| text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data | |
| text_lower = text.lower() | |
| # 1. DRUGS from INTERVENTION field | |
| intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower) | |
| if intervention_match: | |
| intervention_text = intervention_match.group(1) | |
| drugs = re.split(r'[,;\-\s]+', intervention_text) | |
| for drug in drugs: | |
| drug = drug.strip('.,;:() ') | |
| if len(drug) > 3 and drug not in skip_words: | |
| if drug not in inv_index: | |
| inv_index[drug] = [] | |
| if idx not in inv_index[drug]: | |
| inv_index[drug].append(idx) | |
| # 2. DISEASES from CONDITIONS field | |
| conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower) | |
| if conditions_match: | |
| conditions_text = conditions_match.group(1) | |
| diseases = re.split(r'[,;\|]+', conditions_text) | |
| for disease in diseases: | |
| disease = disease.strip('.,;:() ') | |
| # Split multi-word conditions and index each significant word | |
| disease_words = re.findall(r'\b\w{4,}\b', disease) | |
| for word in disease_words: | |
| if word not in skip_words: | |
| if word not in inv_index: | |
| inv_index[word] = [] | |
| if idx not in inv_index[word]: | |
| inv_index[word].append(idx) | |
| # 3. COMPANIES from SPONSOR field | |
| sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower) | |
| if sponsor_match: | |
| sponsor_text = sponsor_match.group(1) | |
| sponsors = re.split(r'[,;\|]+', sponsor_text) | |
| for sponsor in sponsors: | |
| sponsor = sponsor.strip('.,;:() ') | |
| if len(sponsor) > 3: | |
| if sponsor not in inv_index: | |
| inv_index[sponsor] = [] | |
| if idx not in inv_index[sponsor]: | |
| inv_index[sponsor].append(idx) | |
| # 4. COMPANIES from COLLABORATOR field | |
| collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower) | |
| if collab_match: | |
| collab_text = collab_match.group(1) | |
| collaborators = re.split(r'[,;\|]+', collab_text) | |
| for collab in collaborators: | |
| collab = collab.strip('.,;:() ') | |
| if len(collab) > 3: | |
| if collab not in inv_index: | |
| inv_index[collab] = [] | |
| if idx not in inv_index[collab]: | |
| inv_index[collab].append(idx) | |
| # 5. COMPANIES from MANUFACTURER field | |
| manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower) | |
| if manuf_match: | |
| manuf_text = manuf_match.group(1) | |
| manufacturers = re.split(r'[,;\|]+', manuf_text) | |
| for manuf in manufacturers: | |
| manuf = manuf.strip('.,;:() ') | |
| if len(manuf) > 3: | |
| if manuf not in inv_index: | |
| inv_index[manuf] = [] | |
| if idx not in inv_index[manuf]: | |
| inv_index[manuf].append(idx) | |
| # 6. ENDPOINTS from OUTCOME fields | |
| # Look for outcome measures (what's being measured) | |
| outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower) | |
| for outcome_match in outcome_matches[:5]: # First 5 outcomes only | |
| # Extract meaningful endpoint terms | |
| endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words | |
| for word in endpoint_words[:3]: # First 3 words per outcome | |
| if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}: | |
| if word not in inv_index: | |
| inv_index[word] = [] | |
| if idx not in inv_index[word]: | |
| inv_index[word].append(idx) | |
| t_elapsed = time.time() - t_start | |
| logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms") | |
| # Log sample entries for debugging (drugs, diseases, companies, endpoints) | |
| sample_terms = { | |
| 'drugs': ['keytruda', 'opdivo', 'humira'], | |
| 'diseases': ['cancer', 'diabetes', 'melanoma'], | |
| 'companies': ['novartis', 'pfizer', 'merck'], | |
| 'endpoints': ['survival', 'response', 'remission'] | |
| } | |
| for category, terms in sample_terms.items(): | |
| logger.info(f" {category.upper()} samples:") | |
| for term in terms: | |
| if term in inv_index: | |
| logger.info(f" '{term}' -> {len(inv_index[term])} trials") | |
| return inv_index | |
| def download_from_dataset(filename): | |
| """Download file from HF Dataset if not present locally""" | |
| from huggingface_hub import hf_hub_download | |
| import tempfile | |
| # Use /tmp for downloads (has write permissions in Docker) | |
| download_dir = Path("/tmp/foundation_data") | |
| download_dir.mkdir(exist_ok=True) | |
| local_file = download_dir / filename | |
| if local_file.exists(): | |
| logger.info(f"Found cached {filename}") | |
| return local_file | |
| try: | |
| logger.info(f"Downloading {filename} from {DATASET_REPO}...") | |
| downloaded_file = hf_hub_download( | |
| repo_id=DATASET_REPO, | |
| filename=filename, | |
| repo_type="dataset", | |
| local_dir=download_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| logger.info(f"Downloaded {filename}") | |
| return Path(downloaded_file) | |
| except Exception as e: | |
| logger.error(f"Failed to download {filename}: {e}") | |
| return None | |
| def load_embeddings(): | |
| """Load pre-generated embeddings (download from dataset if needed)""" | |
| global doc_chunks, doc_embeddings, bm25_index | |
| # Try to download if not present - store paths returned by download | |
| chunks_path = CHUNKS_FILE | |
| embeddings_path = EMBEDDINGS_FILE | |
| dataset_path = DATASET_FILE | |
| if not CHUNKS_FILE.exists(): | |
| downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl") | |
| if downloaded: | |
| chunks_path = downloaded | |
| if not EMBEDDINGS_FILE.exists(): | |
| downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version | |
| if downloaded: | |
| embeddings_path = downloaded | |
| if not DATASET_FILE.exists(): | |
| downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt") | |
| if downloaded: | |
| dataset_path = downloaded | |
| if chunks_path.exists() and embeddings_path.exists(): | |
| try: | |
| logger.info("Loading embeddings from disk...") | |
| with open(chunks_path, 'rb') as f: | |
| doc_chunks = pickle.load(f) | |
| # Load embeddings | |
| loaded_embeddings = np.load(embeddings_path, allow_pickle=True) | |
| logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}") | |
| # Check if it's already a proper numpy array | |
| if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2: | |
| doc_embeddings = loaded_embeddings | |
| logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}") | |
| elif isinstance(loaded_embeddings, list): | |
| logger.info(f"Converting embeddings from list to numpy array (memory efficient)...") | |
| # Convert in chunks to avoid memory spike | |
| chunk_size = 10000 | |
| total = len(loaded_embeddings) | |
| # DEBUG: Print first 3 items to see format | |
| logger.info(f"DEBUG: Total embeddings: {total}") | |
| logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}") | |
| # Check if this is actually the chunks file (wrong file uploaded) | |
| if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2: | |
| if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str): | |
| raise ValueError( | |
| f"ERROR: The embeddings file contains (int, string) tuples!\n" | |
| f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n" | |
| f"First item: {loaded_embeddings[0][:2]}\n\n" | |
| f"Please re-upload the correct file:\n" | |
| f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n" | |
| f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n" | |
| f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct." | |
| ) | |
| if isinstance(loaded_embeddings[0], tuple): | |
| logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}") | |
| for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]): | |
| logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}") | |
| # Get embedding dimension from first item | |
| first_emb = loaded_embeddings[0] | |
| emb_idx = None # Initialize | |
| # Handle different formats | |
| if isinstance(first_emb, tuple): | |
| # Try both positions - could be (id, emb) or (emb, id) | |
| logger.info(f"DEBUG: Trying to find embedding vector in tuple...") | |
| emb_vector = None | |
| for idx, elem in enumerate(first_emb): | |
| if isinstance(elem, (list, np.ndarray)): | |
| emb_vector = elem | |
| emb_idx = idx | |
| logger.info(f"DEBUG: Found embedding at position {idx}") | |
| break | |
| if emb_vector is None: | |
| raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}") | |
| emb_dim = len(emb_vector) | |
| logger.info(f"DEBUG: Embedding dimension: {emb_dim}") | |
| elif isinstance(first_emb, list): | |
| emb_dim = len(first_emb) | |
| emb_idx = None | |
| elif isinstance(first_emb, np.ndarray): | |
| emb_dim = first_emb.shape[0] | |
| emb_idx = None | |
| else: | |
| raise ValueError(f"Unknown embedding format: {type(first_emb)}") | |
| logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}") | |
| # Pre-allocate array | |
| doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32) | |
| # Fill in chunks | |
| for i in range(0, total, chunk_size): | |
| end = min(i + chunk_size, total) | |
| # Extract embeddings from tuples if needed | |
| if isinstance(first_emb, tuple) and emb_idx is not None: | |
| # Extract just the embedding vector from each tuple at the correct position | |
| batch = [item[emb_idx] for item in loaded_embeddings[i:end]] | |
| doc_embeddings[i:end] = batch | |
| else: | |
| doc_embeddings[i:end] = loaded_embeddings[i:end] | |
| if i % 50000 == 0: | |
| logger.info(f"Converted {i}/{total} embeddings...") | |
| logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}") | |
| else: | |
| doc_embeddings = loaded_embeddings | |
| logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}") | |
| logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings") | |
| # Skip BM25 (too memory-heavy for Docker), use inverted index only | |
| global inverted_index | |
| # Try to load pre-built inverted index (638MB) - MUCH faster than building (15 minutes) | |
| if INVERTED_INDEX_FILE.exists(): | |
| logger.info(f"Loading pre-built inverted index from {INVERTED_INDEX_FILE.name}...") | |
| try: | |
| with open(INVERTED_INDEX_FILE, 'rb') as f: | |
| inverted_index = pickle.load(f) | |
| logger.info(f"✓ Loaded pre-built inverted index with {len(inverted_index):,} terms (instant vs 15min build)") | |
| except Exception as e: | |
| logger.warning(f"Failed to load pre-built index: {e}, building from scratch...") | |
| inverted_index = build_inverted_index(doc_chunks) | |
| else: | |
| logger.info("Pre-built inverted index not found, building from scratch (this takes 15 minutes)...") | |
| inverted_index = build_inverted_index(doc_chunks) | |
| logger.info("Will use inverted index + semantic search (no BM25)") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load embeddings: {e}") | |
| raise RuntimeError("Embeddings are required but failed to load") from e | |
| raise RuntimeError("Embeddings files not found - system cannot function without embeddings") | |
| def filter_trial_for_clinical_summary(trial_text): | |
| """ | |
| Filter trial data to keep essential clinical information including SOME results. | |
| COMPREHENSIVE FILTERING: | |
| - Keeps all core trial info (title, summary, conditions, interventions) | |
| - Keeps sponsor/collaborator/manufacturer (WHO is running the trial) | |
| - Keeps first 5 outcomes (to show key endpoints) | |
| - Keeps first 5 result values per trial (to show actual data) | |
| - Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines) | |
| This ensures the LLM sees comprehensive context including company information. | |
| """ | |
| if not trial_text: | |
| return trial_text | |
| lines = trial_text.split('\n') | |
| filtered_lines = [] | |
| # Counters to limit repetitive data | |
| outcome_count = 0 | |
| outcome_desc_count = 0 | |
| result_value_count = 0 | |
| # Limits | |
| MAX_OUTCOMES = 5 | |
| MAX_OUTCOME_DESC = 5 | |
| MAX_RESULT_VALUES = 5 | |
| for line in lines: | |
| line_stripped = line.strip() | |
| # Skip empty lines | |
| if not line_stripped: | |
| continue | |
| # ALWAYS SKIP: Overwhelming noise | |
| always_skip = [ | |
| 'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:', | |
| 'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:', | |
| 'OUTCOME_OTHER:', 'OUTCOME_NUMBER:' | |
| ] | |
| should_skip = False | |
| for marker in always_skip: | |
| if line_stripped.startswith(marker): | |
| should_skip = True | |
| break | |
| if should_skip: | |
| continue | |
| # LIMITED KEEP: Outcomes (first N only) | |
| if line_stripped.startswith('OUTCOME:'): | |
| outcome_count += 1 | |
| if outcome_count <= MAX_OUTCOMES: | |
| filtered_lines.append(line) | |
| continue | |
| # LIMITED KEEP: Outcome descriptions (first N only) | |
| if line_stripped.startswith('OUTCOME_DESCRIPTION:'): | |
| outcome_desc_count += 1 | |
| if outcome_desc_count <= MAX_OUTCOME_DESC: | |
| filtered_lines.append(line) | |
| continue | |
| # LIMITED KEEP: Result values (first N only) | |
| if line_stripped.startswith('RESULT_VALUE:'): | |
| result_value_count += 1 | |
| if result_value_count <= MAX_RESULT_VALUES: | |
| filtered_lines.append(line) | |
| continue | |
| # ALWAYS KEEP: Core trial information + context | |
| always_keep = [ | |
| 'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:', | |
| 'SUMMARY:', 'DESCRIPTION:', | |
| 'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug | |
| 'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding | |
| 'ELIGIBILITY:' | |
| # Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above | |
| ] | |
| for marker in always_keep: | |
| if line_stripped.startswith(marker): | |
| filtered_lines.append(line) | |
| break | |
| return '\n'.join(filtered_lines) | |
| def retrieve_context_with_embeddings(query, top_k=10): | |
| """ | |
| ENTERPRISE HYBRID SEARCH: Always combines keyword + semantic scoring | |
| - Extracts ALL meaningful terms from query (case-insensitive) | |
| - Scores each trial by keyword frequency (TF-IDF style) | |
| - Also gets semantic similarity scores | |
| - Merges both scores with weighted combination | |
| - Works regardless of capitalization, language, or spelling | |
| """ | |
| import time | |
| import re | |
| from collections import Counter | |
| global doc_chunks, doc_embeddings, embedder | |
| if doc_embeddings is None or len(doc_chunks) == 0: | |
| logger.error("Embeddings not loaded!") | |
| return "" | |
| t0 = time.time() | |
| # Extract ALL meaningful words from query (stop words removed) | |
| stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', | |
| 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know', | |
| 'about', 'that', 'this', 'there', 'it'} | |
| query_lower = query.lower() | |
| # Remove punctuation and split | |
| words = re.findall(r'\b\w+\b', query_lower) | |
| # Filter out stop words and short words | |
| query_terms = [w for w in words if len(w) > 2 and w not in stop_words] | |
| logger.info(f"[HYBRID] Query terms extracted: {query_terms}") | |
| # PARALLEL SEARCH: Run both keyword and semantic simultaneously | |
| # 1. KEYWORD SCORING WITH BM25 (Fast!) | |
| t_kw = time.time() | |
| # Use inverted index for drug lookup (lightweight, no BM25) | |
| global bm25_index, inverted_index | |
| keyword_scores = {} | |
| if inverted_index is not None: | |
| # Check if any query terms are in our drug/intervention inverted index | |
| inv_index_candidates = set() | |
| for term in query_terms: | |
| if term in inverted_index: | |
| inv_index_candidates.update(inverted_index[term]) | |
| logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'") | |
| # FAST PATH: If we have inverted index hits (drug names), score those trials | |
| if inv_index_candidates: | |
| logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates") | |
| # CRITICAL: Identify which terms are specific drugs (low frequency) | |
| drug_specific_terms = set() | |
| for term in query_terms: | |
| if term in inverted_index and len(inverted_index[term]) < 100: | |
| # This term appears in <100 trials - likely a specific drug name! | |
| drug_specific_terms.add(term) | |
| logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name") | |
| for idx in inv_index_candidates: | |
| # No BM25, use simple match count as base score | |
| base_score = 1.0 | |
| # Check if this trial contains a drug-specific term | |
| chunk_data = doc_chunks[idx] | |
| chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data | |
| chunk_lower = chunk_text.lower() | |
| has_drug_match = False | |
| for drug_term in drug_specific_terms: | |
| if drug_term in chunk_lower: | |
| has_drug_match = True | |
| break | |
| # MASSIVE PRIORITY for drug-specific trials | |
| if has_drug_match: | |
| # Drug-specific trials get GUARANTEED top ranking | |
| score = 1000.0 + base_score | |
| logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}") | |
| else: | |
| # Regular inverted index hits (generic terms) | |
| if base_score <= 0: | |
| base_score = 0.1 | |
| score = base_score | |
| # Apply field-specific boosting for non-drug terms | |
| max_field_boost = 1.0 | |
| for term in query_terms: | |
| if term not in chunk_lower or term in drug_specific_terms: | |
| continue | |
| # INTERVENTION field - medium priority for non-drug terms | |
| if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower: | |
| max_field_boost = max(max_field_boost, 3.0) | |
| # TITLE field - low priority | |
| elif 'title:' in chunk_lower: | |
| title_pos = chunk_lower.find('title:') | |
| term_pos = chunk_lower.find(term) | |
| if title_pos < term_pos < title_pos + 200: | |
| max_field_boost = max(max_field_boost, 2.0) | |
| score *= max_field_boost | |
| keyword_scores[idx] = score | |
| else: | |
| logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search") | |
| logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)") | |
| # 2. SEMANTIC SCORING | |
| load_embedder() | |
| t_sem = time.time() | |
| query_embedding = embedder.encode([query])[0] | |
| semantic_similarities = np.dot(doc_embeddings, query_embedding) | |
| logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)") | |
| # 3. MERGE SCORES | |
| # Normalize both scores to 0-1 range | |
| if keyword_scores: | |
| max_kw = max(keyword_scores.values()) | |
| keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()} | |
| else: | |
| keyword_scores_norm = {} | |
| max_sem = semantic_similarities.max() | |
| min_sem = semantic_similarities.min() | |
| semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10) | |
| # Combined score: 50% keyword (with IDF/field boost), 50% semantic (context) | |
| # Balanced approach: IDF-weighted keywords + semantic understanding | |
| combined_scores = np.zeros(len(doc_chunks)) | |
| for idx in range(len(doc_chunks)): | |
| kw_score = keyword_scores_norm.get(idx, 0.0) | |
| sem_score = semantic_scores_norm[idx] | |
| # If keyword match exists, balance keyword + semantic | |
| if kw_score > 0: | |
| combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score | |
| else: | |
| # Pure semantic if no keyword match | |
| combined_scores[idx] = sem_score | |
| # Get top K by combined score (get more candidates to sort by recency) | |
| # We'll get 10 candidates, then sort by NCT ID to find the 3 most recent | |
| candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10 | |
| top_indices = np.argsort(combined_scores)[-candidate_k:][::-1] | |
| logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}") | |
| logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}") | |
| logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}") | |
| # Extract text and scores for 355M ranking | |
| # Format as (score, text) tuples for rank_trials_with_355m | |
| candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices] | |
| # SORT BY NCT ID (higher = newer) before 355M ranking | |
| def extract_nct_number(trial_tuple): | |
| """Extract NCT number from trial text for sorting (higher = newer)""" | |
| _, text = trial_tuple | |
| match = re.search(r'NCT_ID:\s*NCT(\d+)', text) | |
| return int(match.group(1)) if match else 0 | |
| # Sort candidates by NCT ID (descending = newest first) | |
| candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True) | |
| # Log top 5 NCT IDs to show recency sorting | |
| top_ncts = [] | |
| for score, text in candidate_trials_for_ranking[:5]: | |
| match = re.search(r'NCT_ID:\s*(NCT\d+)', text) | |
| if match: | |
| top_ncts.append(match.group(1)) | |
| logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}") | |
| # SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds | |
| # Just use the hybrid-scored + recency-sorted candidates | |
| logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)") | |
| ranked_trials = candidate_trials_for_ranking | |
| # Take top K from ranked results | |
| top_ranked = ranked_trials[:top_k] | |
| logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)") | |
| # Extract just the text | |
| raw_chunks = [trial_text for _, trial_text in top_ranked] | |
| # Apply clinical filter to each trial | |
| context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks] | |
| if context_chunks: | |
| first_trial_preview = context_chunks[0][:200] | |
| logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}") | |
| # Add ranking information if available from 355M | |
| if hasattr(ranked_trials, 'ranking_info'): | |
| ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n" | |
| for info in ranked_trials.ranking_info: | |
| ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n" | |
| ranking_header += "---\n\n" | |
| # Prepend ranking info to first trial | |
| if context_chunks: | |
| context_chunks[0] = ranking_header + context_chunks[0] | |
| logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM") | |
| context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials | |
| logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s") | |
| logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)") | |
| return context | |
| def keyword_search_query_text(query, max_results=10, hf_token=None): | |
| """Search dataset using ALL meaningful words from the full query""" | |
| if not DATASET_FILE.exists(): | |
| logger.error("Dataset file not found") | |
| return "" | |
| # Extract all meaningful words from the full query | |
| # Remove common stopwords but keep medical/clinical terms | |
| stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being', | |
| 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should', | |
| 'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with', | |
| 'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on', | |
| 'what', 'you', 'know', 'that', 'relevant'} | |
| # Extract words, filter stopwords and short words | |
| words = query.lower().split() | |
| search_terms = [w.strip('?.,!;:()[]{}') for w in words | |
| if w.lower() not in stopwords and len(w) >= 3] | |
| if not search_terms: | |
| logger.warning("No search terms extracted from query") | |
| return "" | |
| logger.info(f"Search terms from full query: {search_terms}") | |
| # Store trials with match scores | |
| trials_with_scores = [] | |
| current_trial = "" | |
| try: | |
| with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: | |
| for line in f: | |
| # Check if new trial starts | |
| if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): | |
| # Score previous trial | |
| if current_trial: | |
| trial_lower = current_trial.lower() | |
| # Count matches for all search terms | |
| score = sum(1 for term in search_terms if term in trial_lower) | |
| if score > 0: | |
| trials_with_scores.append((score, current_trial)) | |
| current_trial = line | |
| else: | |
| current_trial += line | |
| # Check last trial | |
| if current_trial: | |
| trial_lower = current_trial.lower() | |
| score = sum(1 for term in search_terms if term in trial_lower) | |
| if score > 0: | |
| trials_with_scores.append((score, current_trial)) | |
| # Sort by score (highest first) and take top results | |
| trials_with_scores.sort(reverse=True, key=lambda x: x[0]) | |
| matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]] | |
| if matching_trials: | |
| logger.info(f"Keyword search found {len(matching_trials)} trials") | |
| return matching_trials # Return list of (score, trial) tuples | |
| else: | |
| logger.warning("Keyword search found no matching trials") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Keyword search failed: {e}") | |
| return [] | |
| def keyword_search_in_dataset(entities, max_results=10): | |
| """Legacy: Search dataset file for keyword matches using extracted entities""" | |
| if not DATASET_FILE.exists(): | |
| logger.error("Dataset file not found") | |
| return "" | |
| drugs = [d.lower() for d in entities.get('drugs', [])] | |
| conditions = [c.lower() for c in entities.get('conditions', [])] | |
| if not drugs and not conditions: | |
| logger.warning("No search terms for keyword search") | |
| return "" | |
| logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}") | |
| # Store trials with match scores | |
| trials_with_scores = [] | |
| current_trial = "" | |
| try: | |
| with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f: | |
| for line in f: | |
| # Check if new trial starts | |
| if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"): | |
| # Score previous trial | |
| if current_trial: | |
| trial_lower = current_trial.lower() | |
| # Count matches | |
| drug_matches = sum(1 for d in drugs if d in trial_lower) | |
| condition_matches = sum(1 for c in conditions if c in trial_lower) | |
| # Only include trials that match at least the drug (if drug was specified) | |
| if drugs: | |
| if drug_matches > 0: | |
| score = drug_matches * 10 + condition_matches | |
| trials_with_scores.append((score, current_trial)) | |
| elif condition_matches > 0: | |
| # No drug specified, just match conditions | |
| trials_with_scores.append((condition_matches, current_trial)) | |
| current_trial = line | |
| else: | |
| current_trial += line | |
| # Check last trial | |
| if current_trial: | |
| trial_lower = current_trial.lower() | |
| drug_matches = sum(1 for d in drugs if d in trial_lower) | |
| condition_matches = sum(1 for c in conditions if c in trial_lower) | |
| if drugs: | |
| if drug_matches > 0: | |
| score = drug_matches * 10 + condition_matches | |
| trials_with_scores.append((score, current_trial)) | |
| elif condition_matches > 0: | |
| trials_with_scores.append((condition_matches, current_trial)) | |
| # Sort by score (highest first) and take top results | |
| trials_with_scores.sort(reverse=True, key=lambda x: x[0]) | |
| matching_trials = [trial for score, trial in trials_with_scores[:max_results]] | |
| if matching_trials: | |
| context = "\n\n---\n\n".join(matching_trials) | |
| if len(context) > 6000: | |
| context = context[:6000] + "..." | |
| logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)") | |
| return context | |
| else: | |
| logger.warning("Keyword search found no trials matching drug") | |
| return "" | |
| except Exception as e: | |
| logger.error(f"Keyword search failed: {e}") | |
| return "" | |
| # ============================================================================ | |
| # ENTITY EXTRACTION | |
| # ============================================================================ | |
| def parse_entities_from_query(conversation, hf_token=None): | |
| """Parse entities from query using both 355M and 8B models + regex fallback""" | |
| entities = {'drugs': [], 'conditions': []} | |
| # Use 355M model for entity extraction | |
| extracted_355m = extract_entities_with_small_model(conversation) | |
| # Also use 8B model for more reliable extraction | |
| extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token) | |
| # Combine both extractions | |
| extracted = (extracted_355m or "") + "\n" + (extracted_8b or "") | |
| # Parse model output | |
| if extracted: | |
| lines = extracted.split('\n') | |
| for line in lines: | |
| lower_line = line.lower() | |
| if 'drug:' in lower_line or 'medication:' in lower_line: | |
| drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip() | |
| if drug: | |
| entities['drugs'].append(drug) | |
| elif 'condition:' in lower_line or 'disease:' in lower_line: | |
| condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip() | |
| if condition: | |
| entities['conditions'].append(condition) | |
| # Regex fallback for standard drug naming patterns | |
| drug_patterns = [ | |
| r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix | |
| r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix | |
| r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10 | |
| ] | |
| for pattern in drug_patterns: | |
| matches = re.findall(pattern, conversation) | |
| for match in matches: | |
| if match.lower() not in [d.lower() for d in entities['drugs']]: | |
| entities['drugs'].append(match) | |
| condition_patterns = [ | |
| r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b' | |
| ] | |
| for pattern in condition_patterns: | |
| matches = re.findall(pattern, conversation, re.IGNORECASE) | |
| for match in matches: | |
| if match not in [c.lower() for c in entities['conditions']]: | |
| entities['conditions'].append(match) | |
| logger.info(f"Extracted entities: {entities}") | |
| return entities | |
| # ============================================================================ | |
| # MAIN QUERY PROCESSING | |
| # ============================================================================ | |
| def extract_entities_simple(query): | |
| """Simple entity extraction using regex patterns - no model needed""" | |
| entities = {'drugs': [], 'conditions': []} | |
| # Drug patterns | |
| drug_patterns = [ | |
| r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc. | |
| r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc. | |
| r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes | |
| r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs | |
| ] | |
| # Condition patterns | |
| condition_patterns = [ | |
| r'\b(sjogren\'?s?\s+syndrome)\b', | |
| r'\b(rheumatoid arthritis)\b', | |
| r'\b(lupus)\b', | |
| r'\b(myelofibrosis)\b', | |
| r'\b(diabetes)\b', | |
| r'\b(cancer|carcinoma|melanoma)\b', | |
| ] | |
| query_lower = query.lower() | |
| # Extract drugs | |
| for pattern in drug_patterns: | |
| matches = re.findall(pattern, query, re.IGNORECASE) | |
| for match in matches: | |
| if match.lower() not in [d.lower() for d in entities['drugs']]: | |
| entities['drugs'].append(match) | |
| # Extract conditions | |
| for pattern in condition_patterns: | |
| matches = re.findall(pattern, query, re.IGNORECASE) | |
| for match in matches: | |
| if match.lower() not in [c.lower() for c in entities['conditions']]: | |
| entities['conditions'].append(match) | |
| logger.info(f"Extracted entities: {entities}") | |
| return entities | |
| def parse_query_with_llm(query, hf_token=None): | |
| """ | |
| Use fast LLM to parse query and extract structured information | |
| Extracts: | |
| - Drug names | |
| - Diseases/conditions | |
| - Companies (sponsors/manufacturers) | |
| - Endpoints (what's being measured) | |
| - Search terms (optimized for RAG) | |
| Returns: Dict with extracted entities and optimized search query | |
| """ | |
| try: | |
| from huggingface_hub import InferenceClient | |
| logger.info("[QUERY PARSER] Analyzing user query with LLM...") | |
| client = InferenceClient(token=hf_token, timeout=30) | |
| parse_prompt = f"""Extract key information from this clinical trial query. | |
| Query: "{query}" | |
| Extract and return in this EXACT format: | |
| DRUGS: [list drug/treatment names, or "none"] | |
| DISEASES: [list diseases/conditions, or "none"] | |
| COMPANIES: [list company/sponsor names, or "none"] | |
| ENDPOINTS: [list trial endpoints/outcomes, or "none"] | |
| SEARCH_TERMS: [optimized search keywords] | |
| Examples: | |
| Query: "What Novartis drugs treat melanoma?" | |
| DRUGS: none | |
| DISEASES: melanoma | |
| COMPANIES: Novartis | |
| ENDPOINTS: none | |
| SEARCH_TERMS: Novartis melanoma treatment drugs | |
| Query: "Tell me about Keytruda for lung cancer" | |
| DRUGS: Keytruda | |
| DISEASES: lung cancer | |
| COMPANIES: none | |
| ENDPOINTS: none | |
| SEARCH_TERMS: Keytruda lung cancer | |
| Now parse the query above:""" | |
| response = client.chat_completion( | |
| model="meta-llama/Llama-3.1-70B-Instruct", | |
| messages=[{"role": "user", "content": parse_prompt}], | |
| max_tokens=256, | |
| temperature=0.1 # Low temp for consistent parsing | |
| ) | |
| parsed = response.choices[0].message.content.strip() | |
| logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}") | |
| # Parse the response into dict | |
| result = { | |
| 'raw_parsed': parsed, | |
| 'drugs': [], | |
| 'diseases': [], | |
| 'companies': [], | |
| 'endpoints': [], | |
| 'search_terms': query # fallback | |
| } | |
| lines = parsed.split('\n') | |
| for line in lines: | |
| line = line.strip() | |
| if line.startswith('DRUGS:'): | |
| drugs = line.replace('DRUGS:', '').strip() | |
| if drugs.lower() != 'none': | |
| result['drugs'] = [d.strip() for d in drugs.split(',')] | |
| elif line.startswith('DISEASES:'): | |
| diseases = line.replace('DISEASES:', '').strip() | |
| if diseases.lower() != 'none': | |
| result['diseases'] = [d.strip() for d in diseases.split(',')] | |
| elif line.startswith('COMPANIES:'): | |
| companies = line.replace('COMPANIES:', '').strip() | |
| if companies.lower() != 'none': | |
| result['companies'] = [c.strip() for c in companies.split(',')] | |
| elif line.startswith('ENDPOINTS:'): | |
| endpoints = line.replace('ENDPOINTS:', '').strip() | |
| if endpoints.lower() != 'none': | |
| result['endpoints'] = [e.strip() for e in endpoints.split(',')] | |
| elif line.startswith('SEARCH_TERMS:'): | |
| result['search_terms'] = line.replace('SEARCH_TERMS:', '').strip() | |
| logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"[QUERY PARSER] Failed: {e}, using original query") | |
| return { | |
| 'drugs': [], | |
| 'diseases': [], | |
| 'companies': [], | |
| 'endpoints': [], | |
| 'search_terms': query, | |
| 'raw_parsed': '' | |
| } | |
| def generate_llama_response(query, rag_context, hf_token=None): | |
| """ | |
| Generate response using FAST Groq API (10x faster than HF) | |
| Speed comparison: | |
| - HuggingFace: ~40 tokens/sec = 15 seconds | |
| - Groq: ~300 tokens/sec = 2 seconds (FREE!) | |
| """ | |
| try: | |
| # Try Groq first (much faster), fallback to HuggingFace | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| if groq_api_key: | |
| logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...") | |
| from groq import Groq | |
| client = Groq(api_key=groq_api_key) | |
| # Simplified prompt for faster generation | |
| system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs.""" | |
| user_prompt = f"""Clinical trials: | |
| {rag_context[:6000]} | |
| Question: {query} | |
| Provide a concise answer citing specific NCT trial IDs.""" | |
| response = client.chat.completions.create( | |
| model="llama-3.1-70b-versatile", # Groq's optimized 70B | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| max_tokens=512, # Shorter for speed | |
| temperature=0.3, | |
| timeout=30 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| else: | |
| # Fallback to HuggingFace (slower) | |
| logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...") | |
| from huggingface_hub import InferenceClient | |
| client = InferenceClient(token=hf_token, timeout=120) | |
| system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs.""" | |
| user_prompt = f"""Clinical trials: | |
| {rag_context[:6000]} | |
| Question: {query} | |
| Provide a concise answer citing specific NCT trial IDs.""" | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ] | |
| response = client.chat_completion( | |
| model="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
| messages=messages, | |
| max_tokens=512, # Reduced from 2048 for speed | |
| temperature=0.3 | |
| ) | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| logger.error(f"Llama error: {e}") | |
| return f"Llama API error: {str(e)}" | |
| def process_query_simple_test(conversation): | |
| """TEST JUST THE RAG - no models""" | |
| try: | |
| import time | |
| output = [] | |
| output.append(f"QUERY: {conversation}\n") | |
| # Check if embeddings loaded | |
| if doc_embeddings is None or len(doc_chunks) == 0: | |
| return "FAIL: Embeddings not loaded" | |
| output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n") | |
| output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n") | |
| # Try to search | |
| start = time.time() | |
| context = retrieve_context_with_embeddings(conversation, top_k=3) | |
| search_time = time.time() - start | |
| if not context: | |
| return "".join(output) + "\nFAIL: RAG returned empty" | |
| output.append(f"✓ RAG search took: {search_time:.2f}s\n") | |
| output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n") | |
| output.append("FIRST 1000 CHARS:\n") | |
| output.append(context[:1000]) | |
| return "".join(output) | |
| except Exception as e: | |
| import traceback | |
| return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}" | |
| def process_query(conversation): | |
| """ | |
| Complete pipeline with LLM query parsing and natural language generation | |
| Flow: | |
| 0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s) | |
| 1. RAG Search - Hybrid search using optimized query (~2s) | |
| 2. Skipped - 355M ranking removed (was broken) | |
| 3. LLM Response - Llama 70B generates natural language (~15s) | |
| Total: ~20 seconds | |
| """ | |
| import time | |
| import traceback | |
| import sys | |
| # MASTER try/except - catches EVERYTHING | |
| try: | |
| start_time = time.time() | |
| output_parts = [f"QUERY: {conversation}\n\n"] | |
| # Step 0: Parse query with LLM to extract structured info | |
| try: | |
| step0_start = time.time() | |
| logger.info("Step 0: Parsing query with LLM...") | |
| output_parts.append("✓ Step 0: LLM query parser started...\n") | |
| parsed_query = parse_query_with_llm(conversation, hf_token=hf_token) | |
| # Use optimized search terms from parser | |
| search_query = parsed_query['search_terms'] | |
| step0_time = time.time() - step0_start | |
| output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n") | |
| output_parts.append(f" Drugs: {parsed_query['drugs']}\n") | |
| output_parts.append(f" Diseases: {parsed_query['diseases']}\n") | |
| output_parts.append(f" Companies: {parsed_query['companies']}\n") | |
| output_parts.append(f" Optimized search: {search_query}\n") | |
| logger.info(f"Query parsing successful in {step0_time:.1f}s") | |
| except Exception as e: | |
| error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query" | |
| logger.warning(error_msg) | |
| output_parts.append(f"{error_msg}\n") | |
| search_query = conversation # Fallback to original | |
| # Step 1: RAG search (using optimized search query) | |
| try: | |
| step1_start = time.time() | |
| logger.info("Step 1: RAG search...") | |
| output_parts.append("✓ Step 1: RAG search started...\n") | |
| context = retrieve_context_with_embeddings(search_query, top_k=3) | |
| if not context: | |
| return "No matching trials found in RAG search." | |
| # No limit - use complete trials | |
| step1_time = time.time() - step1_start | |
| output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n") | |
| logger.info(f"RAG search successful - found trials in {step1_time:.1f}s") | |
| except Exception as e: | |
| error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg | |
| # Step 2: Skipped (355M ranking removed - was broken) | |
| output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n") | |
| # Step 3: Llama 70B | |
| try: | |
| step3_start = time.time() | |
| logger.info("Step 3: Generating response with Llama-3.1-70B...") | |
| output_parts.append("✓ Step 3: Llama 70B generation started...\n") | |
| llama_response = generate_llama_response(conversation, context, hf_token=hf_token) | |
| step3_time = time.time() - step3_start | |
| output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n") | |
| logger.info(f"Llama 70B generation successful in {step3_time:.1f}s") | |
| except Exception as e: | |
| error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| llama_response = f"[Llama 70B error: {str(e)}]" | |
| output_parts.append(f"✗ Step 3 Failed: {str(e)}\n") | |
| total_time = time.time() - start_time | |
| # Format output - handle missing variables | |
| try: | |
| context_display = context if 'context' in locals() else "[No context retrieved]" | |
| clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]" | |
| llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]" | |
| output = f"""{''.join(output_parts)} | |
| CLINICAL SUMMARY (Llama-3.1-70B-Instruct): | |
| {llama_display} | |
| --- | |
| RAG RETRIEVED TRIALS (Top 3 Most Relevant): | |
| {context_display} | |
| --- | |
| Total Time: {total_time:.1f}s | |
| """ | |
| return output | |
| except Exception as e: | |
| # Absolute fallback | |
| error_info = f""" | |
| CRITICAL ERROR IN OUTPUT FORMATTING: | |
| {str(e)} | |
| TRACEBACK: | |
| {traceback.format_exc()} | |
| OUTPUT PARTS: | |
| {''.join(output_parts)} | |
| Variables defined: {locals().keys()} | |
| """ | |
| logger.error(error_info) | |
| return error_info | |
| # MASTER EXCEPTION HANDLER - catches ANY unhandled error | |
| except Exception as master_error: | |
| master_error_msg = f""" | |
| ======================================== | |
| MASTER ERROR HANDLER CAUGHT EXCEPTION | |
| ======================================== | |
| Error Type: {type(master_error).__name__} | |
| Error Message: {str(master_error)} | |
| FULL TRACEBACK: | |
| {traceback.format_exc()} | |
| System Info: | |
| - Python version: {sys.version} | |
| - Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'} | |
| ======================================== | |
| """ | |
| logger.error(master_error_msg) | |
| return master_error_msg | |
| # ============================================================================ | |
| # GRADIO INTERFACE | |
| # ============================================================================ | |
| with gr.Blocks(title="Foundation 1.2") as demo: | |
| gr.Markdown("# Foundation 1.2 - Clinical Trial AI") | |
| query_input = gr.Textbox( | |
| label="Ask about clinical trials", | |
| placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?", | |
| lines=3 | |
| ) | |
| submit_btn = gr.Button("Generate Response", variant="primary") | |
| output = gr.Textbox( | |
| label="AI Response", | |
| lines=30 | |
| ) | |
| submit_btn.click( | |
| fn=process_query, # Full pipeline: RAG + 355M + Llama | |
| inputs=query_input, | |
| outputs=output | |
| ) | |
| gr.Markdown(""" | |
| **Production RAG Pipeline - Optimized for Clinical Accuracy** | |
| **Search (3-Stage Hybrid):** | |
| 1. Keyword matching (70%) + Semantic search (30%) → 10 candidates | |
| 2. 355M Clinical Trial GPT re-ranks by relevance | |
| 3. Returns top 3 trials with best clinical relevance scores | |
| **Generation (Qwen2.5-14B-Instruct):** | |
| - 14B parameter model via HuggingFace Inference API | |
| - Structured clinical summaries with clear headings | |
| - Cites specific NCT trial IDs | |
| - Includes actual trial results and efficacy data | |
| - High-quality medical reasoning and analysis | |
| *355M model used for ranking (not generation) + Qwen2.5-14B for responses* | |
| """) | |
| # ============================================================================ | |
| # STARTUP | |
| # ============================================================================ | |
| # Embeddings will be loaded by FastAPI startup event in app.py | |
| # Do NOT load here - causes Docker permission errors | |
| logger.info("=== Foundation 1.2 Module Loaded ===") | |
| logger.info("Call load_embeddings() to initialize the system") | |
| if DATASET_FILE.exists(): | |
| file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024) | |
| logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB") | |
| else: | |
| logger.error("✗ Dataset file not found!") | |
| logger.info("=== Startup Complete ===") | |
| if __name__ == "__main__": | |
| demo.launch() | |