ctapi / foundation_engine.py
Your Name
Clone api2 for experimentation
d78f02a
raw
history blame
54 kB
"""
Foundation 1.2
Clinical trial query system with 355M foundation model
"""
import gradio as gr
import os
from pathlib import Path
import pickle
import numpy as np
from sentence_transformers import SentenceTransformer
import logging
from rank_bm25 import BM25Okapi
import re
from two_llm_system_FIXED import expand_query_with_355m, generate_clinical_response_with_xupract, rank_trials_with_355m
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize
hf_token = os.getenv("HF_TOKEN")
# Paths for data storage
# Files will be downloaded from HF Dataset on first run
DATASET_FILE = Path(__file__).parent / "complete_dataset_WITH_RESULTS_FULL.txt"
CHUNKS_FILE = Path(__file__).parent / "dataset_chunks_TRIAL_AWARE.pkl"
EMBEDDINGS_FILE = Path(__file__).parent / "dataset_embeddings_TRIAL_AWARE_FIXED.npy" # FIXED version to avoid cache
INVERTED_INDEX_FILE = Path(__file__).parent / "inverted_index_TRIAL_AWARE.pkl" # Pre-built inverted index (638MB)
# HF Dataset containing the large files
DATASET_REPO = "gmkdigitalmedia/foundation1.2-data"
# Global storage
embedder = None
doc_chunks = []
doc_embeddings = None
bm25_index = None # BM25 index for fast keyword search
inverted_index = None # Inverted index for instant drug lookup
# ============================================================================
# RAG FUNCTIONS
# ============================================================================
def load_embedder():
"""Load L6 embedding model (matches generated embeddings)"""
global embedder
if embedder is None:
logger.info("Loading MiniLM-L6 embedding model...")
# Force CPU to avoid CUDA init in main process
embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
logger.info("L6 model loaded on CPU")
def build_inverted_index(chunks):
"""
Build targeted inverted index for clinical search
Maps drugs, diseases, companies, and endpoints to trial indices for O(1) lookup
Indexes ONLY what matters:
1. INTERVENTION - drug/device names
2. CONDITIONS - diseases being treated
3. SPONSOR/COLLABORATOR/MANUFACTURER - company names
4. OUTCOME - trial endpoints (what's being measured)
Does NOT index trial names (unnecessary noise)
"""
import time
t_start = time.time()
inv_index = {}
logger.info("Building targeted index: drugs, diseases, companies, endpoints...")
# Generic words to skip
skip_words = {
'with', 'versus', 'combination', 'treatment', 'therapy', 'study', 'trial',
'phase', 'double', 'blind', 'placebo', 'group', 'control', 'active',
'randomized', 'multicenter', 'open', 'label', 'crossover'
}
for idx, chunk_data in enumerate(chunks):
if idx % 100000 == 0 and idx > 0:
logger.info(f" Indexed {idx:,}/{len(chunks):,} trials...")
text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
text_lower = text.lower()
# 1. DRUGS from INTERVENTION field
intervention_match = re.search(r'intervention[:\s]+([^\n]+)', text_lower)
if intervention_match:
intervention_text = intervention_match.group(1)
drugs = re.split(r'[,;\-\s]+', intervention_text)
for drug in drugs:
drug = drug.strip('.,;:() ')
if len(drug) > 3 and drug not in skip_words:
if drug not in inv_index:
inv_index[drug] = []
if idx not in inv_index[drug]:
inv_index[drug].append(idx)
# 2. DISEASES from CONDITIONS field
conditions_match = re.search(r'conditions?[:\s]+([^\n]+)', text_lower)
if conditions_match:
conditions_text = conditions_match.group(1)
diseases = re.split(r'[,;\|]+', conditions_text)
for disease in diseases:
disease = disease.strip('.,;:() ')
# Split multi-word conditions and index each significant word
disease_words = re.findall(r'\b\w{4,}\b', disease)
for word in disease_words:
if word not in skip_words:
if word not in inv_index:
inv_index[word] = []
if idx not in inv_index[word]:
inv_index[word].append(idx)
# 3. COMPANIES from SPONSOR field
sponsor_match = re.search(r'sponsor[:\s]+([^\n]+)', text_lower)
if sponsor_match:
sponsor_text = sponsor_match.group(1)
sponsors = re.split(r'[,;\|]+', sponsor_text)
for sponsor in sponsors:
sponsor = sponsor.strip('.,;:() ')
if len(sponsor) > 3:
if sponsor not in inv_index:
inv_index[sponsor] = []
if idx not in inv_index[sponsor]:
inv_index[sponsor].append(idx)
# 4. COMPANIES from COLLABORATOR field
collab_match = re.search(r'collaborator[:\s]+([^\n]+)', text_lower)
if collab_match:
collab_text = collab_match.group(1)
collaborators = re.split(r'[,;\|]+', collab_text)
for collab in collaborators:
collab = collab.strip('.,;:() ')
if len(collab) > 3:
if collab not in inv_index:
inv_index[collab] = []
if idx not in inv_index[collab]:
inv_index[collab].append(idx)
# 5. COMPANIES from MANUFACTURER field
manuf_match = re.search(r'manufacturer[:\s]+([^\n]+)', text_lower)
if manuf_match:
manuf_text = manuf_match.group(1)
manufacturers = re.split(r'[,;\|]+', manuf_text)
for manuf in manufacturers:
manuf = manuf.strip('.,;:() ')
if len(manuf) > 3:
if manuf not in inv_index:
inv_index[manuf] = []
if idx not in inv_index[manuf]:
inv_index[manuf].append(idx)
# 6. ENDPOINTS from OUTCOME fields
# Look for outcome measures (what's being measured)
outcome_matches = re.findall(r'outcome[:\s]+([^\n]+)', text_lower)
for outcome_match in outcome_matches[:5]: # First 5 outcomes only
# Extract meaningful endpoint terms
endpoint_words = re.findall(r'\b\w{5,}\b', outcome_match) # 5+ char words
for word in endpoint_words[:3]: # First 3 words per outcome
if word not in skip_words and word not in {'outcome', 'measure', 'primary', 'secondary'}:
if word not in inv_index:
inv_index[word] = []
if idx not in inv_index[word]:
inv_index[word].append(idx)
t_elapsed = time.time() - t_start
logger.info(f"✓ Targeted index built in {t_elapsed:.1f}s with {len(inv_index):,} terms")
# Log sample entries for debugging (drugs, diseases, companies, endpoints)
sample_terms = {
'drugs': ['keytruda', 'opdivo', 'humira'],
'diseases': ['cancer', 'diabetes', 'melanoma'],
'companies': ['novartis', 'pfizer', 'merck'],
'endpoints': ['survival', 'response', 'remission']
}
for category, terms in sample_terms.items():
logger.info(f" {category.upper()} samples:")
for term in terms:
if term in inv_index:
logger.info(f" '{term}' -> {len(inv_index[term])} trials")
return inv_index
def download_from_dataset(filename):
"""Download file from HF Dataset if not present locally"""
from huggingface_hub import hf_hub_download
import tempfile
# Use /tmp for downloads (has write permissions in Docker)
download_dir = Path("/tmp/foundation_data")
download_dir.mkdir(exist_ok=True)
local_file = download_dir / filename
if local_file.exists():
logger.info(f"Found cached {filename}")
return local_file
try:
logger.info(f"Downloading {filename} from {DATASET_REPO}...")
downloaded_file = hf_hub_download(
repo_id=DATASET_REPO,
filename=filename,
repo_type="dataset",
local_dir=download_dir,
local_dir_use_symlinks=False
)
logger.info(f"Downloaded {filename}")
return Path(downloaded_file)
except Exception as e:
logger.error(f"Failed to download {filename}: {e}")
return None
def load_embeddings():
"""Load pre-generated embeddings (download from dataset if needed)"""
global doc_chunks, doc_embeddings, bm25_index
# Try to download if not present - store paths returned by download
chunks_path = CHUNKS_FILE
embeddings_path = EMBEDDINGS_FILE
dataset_path = DATASET_FILE
if not CHUNKS_FILE.exists():
downloaded = download_from_dataset("dataset_chunks_TRIAL_AWARE.pkl")
if downloaded:
chunks_path = downloaded
if not EMBEDDINGS_FILE.exists():
downloaded = download_from_dataset("dataset_embeddings_TRIAL_AWARE_FIXED.npy") # FIXED version
if downloaded:
embeddings_path = downloaded
if not DATASET_FILE.exists():
downloaded = download_from_dataset("complete_dataset_WITH_RESULTS_FULL.txt")
if downloaded:
dataset_path = downloaded
if chunks_path.exists() and embeddings_path.exists():
try:
logger.info("Loading embeddings from disk...")
with open(chunks_path, 'rb') as f:
doc_chunks = pickle.load(f)
# Load embeddings
loaded_embeddings = np.load(embeddings_path, allow_pickle=True)
logger.info(f"Loaded embeddings type: {type(loaded_embeddings)}")
# Check if it's already a proper numpy array
if isinstance(loaded_embeddings, np.ndarray) and loaded_embeddings.ndim == 2:
doc_embeddings = loaded_embeddings
logger.info(f"✓ Embeddings are proper numpy array with shape: {doc_embeddings.shape}")
elif isinstance(loaded_embeddings, list):
logger.info(f"Converting embeddings from list to numpy array (memory efficient)...")
# Convert in chunks to avoid memory spike
chunk_size = 10000
total = len(loaded_embeddings)
# DEBUG: Print first 3 items to see format
logger.info(f"DEBUG: Total embeddings: {total}")
logger.info(f"DEBUG: Type of first item: {type(loaded_embeddings[0])}")
# Check if this is actually the chunks file (wrong file uploaded)
if isinstance(loaded_embeddings[0], tuple) and len(loaded_embeddings[0]) == 2:
if isinstance(loaded_embeddings[0][0], int) and isinstance(loaded_embeddings[0][1], str):
raise ValueError(
f"ERROR: The embeddings file contains (int, string) tuples!\n"
f"This looks like the CHUNKS file was uploaded as the embeddings file.\n\n"
f"First item: {loaded_embeddings[0][:2]}\n\n"
f"Please re-upload the correct file:\n"
f" CORRECT: dataset_embeddings_TRIAL_AWARE.npy (numpy array, 855 MB)\n"
f" WRONG: dataset_chunks_TRIAL_AWARE.pkl (tuples, 2.8 GB)\n\n"
f"The local file at /mnt/c/Users/ibm/Documents/HF/kg_to_model/dataset_embeddings_TRIAL_AWARE.npy is correct."
)
if isinstance(loaded_embeddings[0], tuple):
logger.info(f"DEBUG: Tuple length: {len(loaded_embeddings[0])}")
for i, item in enumerate(loaded_embeddings[0][:5] if len(loaded_embeddings[0]) > 5 else loaded_embeddings[0]):
logger.info(f"DEBUG: Tuple element {i}: type={type(item)}, preview={str(item)[:100]}")
# Get embedding dimension from first item
first_emb = loaded_embeddings[0]
emb_idx = None # Initialize
# Handle different formats
if isinstance(first_emb, tuple):
# Try both positions - could be (id, emb) or (emb, id)
logger.info(f"DEBUG: Trying to find embedding vector in tuple...")
emb_vector = None
for idx, elem in enumerate(first_emb):
if isinstance(elem, (list, np.ndarray)):
emb_vector = elem
emb_idx = idx
logger.info(f"DEBUG: Found embedding at position {idx}")
break
if emb_vector is None:
raise ValueError(f"No embedding vector found in tuple. Tuple contains: {[type(x) for x in first_emb]}")
emb_dim = len(emb_vector)
logger.info(f"DEBUG: Embedding dimension: {emb_dim}")
elif isinstance(first_emb, list):
emb_dim = len(first_emb)
emb_idx = None
elif isinstance(first_emb, np.ndarray):
emb_dim = first_emb.shape[0]
emb_idx = None
else:
raise ValueError(f"Unknown embedding format: {type(first_emb)}")
logger.info(f"Creating array for {total} embeddings of dimension {emb_dim}")
# Pre-allocate array
doc_embeddings = np.zeros((total, emb_dim), dtype=np.float32)
# Fill in chunks
for i in range(0, total, chunk_size):
end = min(i + chunk_size, total)
# Extract embeddings from tuples if needed
if isinstance(first_emb, tuple) and emb_idx is not None:
# Extract just the embedding vector from each tuple at the correct position
batch = [item[emb_idx] for item in loaded_embeddings[i:end]]
doc_embeddings[i:end] = batch
else:
doc_embeddings[i:end] = loaded_embeddings[i:end]
if i % 50000 == 0:
logger.info(f"Converted {i}/{total} embeddings...")
logger.info(f"✓ Converted to array with shape: {doc_embeddings.shape}")
else:
doc_embeddings = loaded_embeddings
logger.info(f"Embeddings already numpy array with shape: {doc_embeddings.shape}")
logger.info(f"Loaded {len(doc_chunks)} chunks with embeddings")
# Skip BM25 (too memory-heavy for Docker), use inverted index only
global inverted_index
# Try to load pre-built inverted index (638MB) - MUCH faster than building (15 minutes)
if INVERTED_INDEX_FILE.exists():
logger.info(f"Loading pre-built inverted index from {INVERTED_INDEX_FILE.name}...")
try:
with open(INVERTED_INDEX_FILE, 'rb') as f:
inverted_index = pickle.load(f)
logger.info(f"✓ Loaded pre-built inverted index with {len(inverted_index):,} terms (instant vs 15min build)")
except Exception as e:
logger.warning(f"Failed to load pre-built index: {e}, building from scratch...")
inverted_index = build_inverted_index(doc_chunks)
else:
logger.info("Pre-built inverted index not found, building from scratch (this takes 15 minutes)...")
inverted_index = build_inverted_index(doc_chunks)
logger.info("Will use inverted index + semantic search (no BM25)")
return True
except Exception as e:
logger.error(f"Failed to load embeddings: {e}")
raise RuntimeError("Embeddings are required but failed to load") from e
raise RuntimeError("Embeddings files not found - system cannot function without embeddings")
def filter_trial_for_clinical_summary(trial_text):
"""
Filter trial data to keep essential clinical information including SOME results.
COMPREHENSIVE FILTERING:
- Keeps all core trial info (title, summary, conditions, interventions)
- Keeps sponsor/collaborator/manufacturer (WHO is running the trial)
- Keeps first 5 outcomes (to show key endpoints)
- Keeps first 5 result values per trial (to show actual data)
- Filters out overwhelming statistical noise (hundreds of baseline/adverse event lines)
This ensures the LLM sees comprehensive context including company information.
"""
if not trial_text:
return trial_text
lines = trial_text.split('\n')
filtered_lines = []
# Counters to limit repetitive data
outcome_count = 0
outcome_desc_count = 0
result_value_count = 0
# Limits
MAX_OUTCOMES = 5
MAX_OUTCOME_DESC = 5
MAX_RESULT_VALUES = 5
for line in lines:
line_stripped = line.strip()
# Skip empty lines
if not line_stripped:
continue
# ALWAYS SKIP: Overwhelming noise
always_skip = [
'BASELINE:', 'SERIOUS_ADVERSE_EVENT:', 'OTHER_ADVERSE_EVENT:',
'OUTCOME_TYPE:', 'OUTCOME_TIME_FRAME:', 'OUTCOME_SAFETY:',
'OUTCOME_OTHER:', 'OUTCOME_NUMBER:'
]
should_skip = False
for marker in always_skip:
if line_stripped.startswith(marker):
should_skip = True
break
if should_skip:
continue
# LIMITED KEEP: Outcomes (first N only)
if line_stripped.startswith('OUTCOME:'):
outcome_count += 1
if outcome_count <= MAX_OUTCOMES:
filtered_lines.append(line)
continue
# LIMITED KEEP: Outcome descriptions (first N only)
if line_stripped.startswith('OUTCOME_DESCRIPTION:'):
outcome_desc_count += 1
if outcome_desc_count <= MAX_OUTCOME_DESC:
filtered_lines.append(line)
continue
# LIMITED KEEP: Result values (first N only)
if line_stripped.startswith('RESULT_VALUE:'):
result_value_count += 1
if result_value_count <= MAX_RESULT_VALUES:
filtered_lines.append(line)
continue
# ALWAYS KEEP: Core trial information + context
always_keep = [
'NCT_ID:', 'TITLE:', 'OFFICIAL_TITLE:',
'SUMMARY:', 'DESCRIPTION:',
'CONDITIONS:', 'INTERVENTION:', # WHAT disease, WHAT drug
'SPONSOR:', 'COLLABORATOR:', 'MANUFACTURER:', # WHO is running/funding
'ELIGIBILITY:'
# Note: OUTCOME/OUTCOME_DESCRIPTION handled in LIMITED KEEP section above
]
for marker in always_keep:
if line_stripped.startswith(marker):
filtered_lines.append(line)
break
return '\n'.join(filtered_lines)
def retrieve_context_with_embeddings(query, top_k=10):
"""
ENTERPRISE HYBRID SEARCH: Always combines keyword + semantic scoring
- Extracts ALL meaningful terms from query (case-insensitive)
- Scores each trial by keyword frequency (TF-IDF style)
- Also gets semantic similarity scores
- Merges both scores with weighted combination
- Works regardless of capitalization, language, or spelling
"""
import time
import re
from collections import Counter
global doc_chunks, doc_embeddings, embedder
if doc_embeddings is None or len(doc_chunks) == 0:
logger.error("Embeddings not loaded!")
return ""
t0 = time.time()
# Extract ALL meaningful words from query (stop words removed)
stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with',
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'what', 'how', 'do', 'you', 'know',
'about', 'that', 'this', 'there', 'it'}
query_lower = query.lower()
# Remove punctuation and split
words = re.findall(r'\b\w+\b', query_lower)
# Filter out stop words and short words
query_terms = [w for w in words if len(w) > 2 and w not in stop_words]
logger.info(f"[HYBRID] Query terms extracted: {query_terms}")
# PARALLEL SEARCH: Run both keyword and semantic simultaneously
# 1. KEYWORD SCORING WITH BM25 (Fast!)
t_kw = time.time()
# Use inverted index for drug lookup (lightweight, no BM25)
global bm25_index, inverted_index
keyword_scores = {}
if inverted_index is not None:
# Check if any query terms are in our drug/intervention inverted index
inv_index_candidates = set()
for term in query_terms:
if term in inverted_index:
inv_index_candidates.update(inverted_index[term])
logger.info(f"[INVERTED INDEX] Found {len(inverted_index[term])} trials for '{term}'")
# FAST PATH: If we have inverted index hits (drug names), score those trials
if inv_index_candidates:
logger.info(f"[FAST PATH] Checking {len(inv_index_candidates)} inverted index candidates")
# CRITICAL: Identify which terms are specific drugs (low frequency)
drug_specific_terms = set()
for term in query_terms:
if term in inverted_index and len(inverted_index[term]) < 100:
# This term appears in <100 trials - likely a specific drug name!
drug_specific_terms.add(term)
logger.info(f"[DRUG SPECIFIC] '{term}' found in {len(inverted_index[term])} trials - treating as drug name")
for idx in inv_index_candidates:
# No BM25, use simple match count as base score
base_score = 1.0
# Check if this trial contains a drug-specific term
chunk_data = doc_chunks[idx]
chunk_text = chunk_data[1] if isinstance(chunk_data, tuple) else chunk_data
chunk_lower = chunk_text.lower()
has_drug_match = False
for drug_term in drug_specific_terms:
if drug_term in chunk_lower:
has_drug_match = True
break
# MASSIVE PRIORITY for drug-specific trials
if has_drug_match:
# Drug-specific trials get GUARANTEED top ranking
score = 1000.0 + base_score
logger.info(f"[DRUG PRIORITY] Trial {idx} contains specific drug - score={score:.1f}")
else:
# Regular inverted index hits (generic terms)
if base_score <= 0:
base_score = 0.1
score = base_score
# Apply field-specific boosting for non-drug terms
max_field_boost = 1.0
for term in query_terms:
if term not in chunk_lower or term in drug_specific_terms:
continue
# INTERVENTION field - medium priority for non-drug terms
if f'intervention: {term}' in chunk_lower or f'intervention:{term}' in chunk_lower:
max_field_boost = max(max_field_boost, 3.0)
# TITLE field - low priority
elif 'title:' in chunk_lower:
title_pos = chunk_lower.find('title:')
term_pos = chunk_lower.find(term)
if title_pos < term_pos < title_pos + 200:
max_field_boost = max(max_field_boost, 2.0)
score *= max_field_boost
keyword_scores[idx] = score
else:
logger.info(f"[FALLBACK] No inverted index hits, using pure semantic search")
logger.info(f"[HYBRID] Inverted index scoring: {len(keyword_scores)} trials matched ({time.time()-t_kw:.2f}s)")
# 2. SEMANTIC SCORING
load_embedder()
t_sem = time.time()
query_embedding = embedder.encode([query])[0]
semantic_similarities = np.dot(doc_embeddings, query_embedding)
logger.info(f"[HYBRID] Semantic scoring complete ({time.time()-t_sem:.2f}s)")
# 3. MERGE SCORES
# Normalize both scores to 0-1 range
if keyword_scores:
max_kw = max(keyword_scores.values())
keyword_scores_norm = {idx: score/max_kw for idx, score in keyword_scores.items()}
else:
keyword_scores_norm = {}
max_sem = semantic_similarities.max()
min_sem = semantic_similarities.min()
semantic_scores_norm = (semantic_similarities - min_sem) / (max_sem - min_sem + 1e-10)
# Combined score: 50% keyword (with IDF/field boost), 50% semantic (context)
# Balanced approach: IDF-weighted keywords + semantic understanding
combined_scores = np.zeros(len(doc_chunks))
for idx in range(len(doc_chunks)):
kw_score = keyword_scores_norm.get(idx, 0.0)
sem_score = semantic_scores_norm[idx]
# If keyword match exists, balance keyword + semantic
if kw_score > 0:
combined_scores[idx] = 0.5 * kw_score + 0.5 * sem_score
else:
# Pure semantic if no keyword match
combined_scores[idx] = sem_score
# Get top K by combined score (get more candidates to sort by recency)
# We'll get 10 candidates, then sort by NCT ID to find the 3 most recent
candidate_k = max(top_k * 3, 10) # Get 3x requested, minimum 10
top_indices = np.argsort(combined_scores)[-candidate_k:][::-1]
logger.info(f"[HYBRID] Top 3 combined scores: {combined_scores[top_indices[:3]]}")
logger.info(f"[HYBRID] Top 3 keyword scores: {[keyword_scores_norm.get(i, 0.0) for i in top_indices[:3]]}")
logger.info(f"[HYBRID] Top 3 semantic scores: {[semantic_scores_norm[i] for i in top_indices[:3]]}")
# Extract text and scores for 355M ranking
# Format as (score, text) tuples for rank_trials_with_355m
candidate_trials_for_ranking = [(combined_scores[i], doc_chunks[i][1] if isinstance(doc_chunks[i], tuple) else doc_chunks[i]) for i in top_indices]
# SORT BY NCT ID (higher = newer) before 355M ranking
def extract_nct_number(trial_tuple):
"""Extract NCT number from trial text for sorting (higher = newer)"""
_, text = trial_tuple
match = re.search(r'NCT_ID:\s*NCT(\d+)', text)
return int(match.group(1)) if match else 0
# Sort candidates by NCT ID (descending = newest first)
candidate_trials_for_ranking.sort(key=extract_nct_number, reverse=True)
# Log top 5 NCT IDs to show recency sorting
top_ncts = []
for score, text in candidate_trials_for_ranking[:5]:
match = re.search(r'NCT_ID:\s*(NCT\d+)', text)
if match:
top_ncts.append(match.group(1))
logger.info(f"[NCT SORT] Top 5 candidates by recency: {top_ncts}")
# SKIP 355M RANKING - It's broken (gives 0.50 to everything) and wastes 10 seconds
# Just use the hybrid-scored + recency-sorted candidates
logger.info(f"[FAST MODE] Using hybrid search + recency sort (skipping broken 355M ranking)")
ranked_trials = candidate_trials_for_ranking
# Take top K from ranked results
top_ranked = ranked_trials[:top_k]
logger.info(f"[FAST MODE] Selected top {len(top_ranked)} trials (hybrid score + recency)")
# Extract just the text
raw_chunks = [trial_text for _, trial_text in top_ranked]
# Apply clinical filter to each trial
context_chunks = [filter_trial_for_clinical_summary(chunk) for chunk in raw_chunks]
if context_chunks:
first_trial_preview = context_chunks[0][:200]
logger.info(f"[HYBRID] First result (filtered): {first_trial_preview}")
# Add ranking information if available from 355M
if hasattr(ranked_trials, 'ranking_info'):
ranking_header = "[TRIAL RANKING BY CLINICAL RELEVANCE GPT]\n"
for info in ranked_trials.ranking_info:
ranking_header += f"Rank {info['rank']}: {info['nct_id']} - Relevance {info['relevance_rating']}\n"
ranking_header += "---\n\n"
# Prepend ranking info to first trial
if context_chunks:
context_chunks[0] = ranking_header + context_chunks[0]
logger.info(f"[355M RANKING] Added ranking metadata to context for final LLM")
context = "\n\n---\n\n".join(context_chunks) # Use --- as separator between trials
logger.info(f"[HYBRID] TOTAL TIME: {time.time()-t0:.2f}s")
logger.info(f"[HYBRID] Filtered context length: {len(context)} chars (was ~{sum(len(c) for c in raw_chunks)} chars)")
return context
def keyword_search_query_text(query, max_results=10, hf_token=None):
"""Search dataset using ALL meaningful words from the full query"""
if not DATASET_FILE.exists():
logger.error("Dataset file not found")
return ""
# Extract all meaningful words from the full query
# Remove common stopwords but keep medical/clinical terms
stopwords = {'the', 'a', 'an', 'is', 'are', 'was', 'were', 'be', 'been', 'being',
'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'should',
'could', 'may', 'might', 'must', 'can', 'of', 'at', 'by', 'for', 'with',
'about', 'as', 'into', 'through', 'during', 'to', 'from', 'in', 'on',
'what', 'you', 'know', 'that', 'relevant'}
# Extract words, filter stopwords and short words
words = query.lower().split()
search_terms = [w.strip('?.,!;:()[]{}') for w in words
if w.lower() not in stopwords and len(w) >= 3]
if not search_terms:
logger.warning("No search terms extracted from query")
return ""
logger.info(f"Search terms from full query: {search_terms}")
# Store trials with match scores
trials_with_scores = []
current_trial = ""
try:
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
# Check if new trial starts
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
# Score previous trial
if current_trial:
trial_lower = current_trial.lower()
# Count matches for all search terms
score = sum(1 for term in search_terms if term in trial_lower)
if score > 0:
trials_with_scores.append((score, current_trial))
current_trial = line
else:
current_trial += line
# Check last trial
if current_trial:
trial_lower = current_trial.lower()
score = sum(1 for term in search_terms if term in trial_lower)
if score > 0:
trials_with_scores.append((score, current_trial))
# Sort by score (highest first) and take top results
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
matching_trials = [(score, trial) for score, trial in trials_with_scores[:max_results]]
if matching_trials:
logger.info(f"Keyword search found {len(matching_trials)} trials")
return matching_trials # Return list of (score, trial) tuples
else:
logger.warning("Keyword search found no matching trials")
return []
except Exception as e:
logger.error(f"Keyword search failed: {e}")
return []
def keyword_search_in_dataset(entities, max_results=10):
"""Legacy: Search dataset file for keyword matches using extracted entities"""
if not DATASET_FILE.exists():
logger.error("Dataset file not found")
return ""
drugs = [d.lower() for d in entities.get('drugs', [])]
conditions = [c.lower() for c in entities.get('conditions', [])]
if not drugs and not conditions:
logger.warning("No search terms for keyword search")
return ""
logger.info(f"Keyword search - Drugs: {drugs}, Conditions: {conditions}")
# Store trials with match scores
trials_with_scores = []
current_trial = ""
try:
with open(DATASET_FILE, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
# Check if new trial starts
if line.startswith("NCT_ID:") or line.startswith("TRIAL NCT"):
# Score previous trial
if current_trial:
trial_lower = current_trial.lower()
# Count matches
drug_matches = sum(1 for d in drugs if d in trial_lower)
condition_matches = sum(1 for c in conditions if c in trial_lower)
# Only include trials that match at least the drug (if drug was specified)
if drugs:
if drug_matches > 0:
score = drug_matches * 10 + condition_matches
trials_with_scores.append((score, current_trial))
elif condition_matches > 0:
# No drug specified, just match conditions
trials_with_scores.append((condition_matches, current_trial))
current_trial = line
else:
current_trial += line
# Check last trial
if current_trial:
trial_lower = current_trial.lower()
drug_matches = sum(1 for d in drugs if d in trial_lower)
condition_matches = sum(1 for c in conditions if c in trial_lower)
if drugs:
if drug_matches > 0:
score = drug_matches * 10 + condition_matches
trials_with_scores.append((score, current_trial))
elif condition_matches > 0:
trials_with_scores.append((condition_matches, current_trial))
# Sort by score (highest first) and take top results
trials_with_scores.sort(reverse=True, key=lambda x: x[0])
matching_trials = [trial for score, trial in trials_with_scores[:max_results]]
if matching_trials:
context = "\n\n---\n\n".join(matching_trials)
if len(context) > 6000:
context = context[:6000] + "..."
logger.info(f"Keyword search found {len(matching_trials)} trials (from {len(trials_with_scores)} candidates)")
return context
else:
logger.warning("Keyword search found no trials matching drug")
return ""
except Exception as e:
logger.error(f"Keyword search failed: {e}")
return ""
# ============================================================================
# ENTITY EXTRACTION
# ============================================================================
def parse_entities_from_query(conversation, hf_token=None):
"""Parse entities from query using both 355M and 8B models + regex fallback"""
entities = {'drugs': [], 'conditions': []}
# Use 355M model for entity extraction
extracted_355m = extract_entities_with_small_model(conversation)
# Also use 8B model for more reliable extraction
extracted_8b = extract_entities_with_8b(conversation, hf_token=hf_token)
# Combine both extractions
extracted = (extracted_355m or "") + "\n" + (extracted_8b or "")
# Parse model output
if extracted:
lines = extracted.split('\n')
for line in lines:
lower_line = line.lower()
if 'drug:' in lower_line or 'medication:' in lower_line:
drug = re.sub(r'(drug:|medication:)', '', line, flags=re.IGNORECASE).strip()
if drug:
entities['drugs'].append(drug)
elif 'condition:' in lower_line or 'disease:' in lower_line:
condition = re.sub(r'(condition:|disease:)', '', line, flags=re.IGNORECASE).strip()
if condition:
entities['conditions'].append(condition)
# Regex fallback for standard drug naming patterns
drug_patterns = [
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: -mab suffix
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: -nib suffix
r'\b([A-Z]\d+[A-Z]+\d+)\b' # Alphanumeric codes like F8IL10
]
for pattern in drug_patterns:
matches = re.findall(pattern, conversation)
for match in matches:
if match.lower() not in [d.lower() for d in entities['drugs']]:
entities['drugs'].append(match)
condition_patterns = [
r'\b(sjogren\'?s?|lupus|myelofibrosis|rheumatoid arthritis)\b'
]
for pattern in condition_patterns:
matches = re.findall(pattern, conversation, re.IGNORECASE)
for match in matches:
if match not in [c.lower() for c in entities['conditions']]:
entities['conditions'].append(match)
logger.info(f"Extracted entities: {entities}")
return entities
# ============================================================================
# MAIN QUERY PROCESSING
# ============================================================================
def extract_entities_simple(query):
"""Simple entity extraction using regex patterns - no model needed"""
entities = {'drugs': [], 'conditions': []}
# Drug patterns
drug_patterns = [
r'\b([A-Z][a-z]+mab)\b', # Monoclonal antibodies: ianalumab, rituximab, etc.
r'\b([A-Z][a-z]+nib)\b', # Kinase inhibitors: imatinib, etc.
r'\b([A-Z]\d+[A-Z]+\d+)\b', # Alphanumeric codes
r'\b(ianalumab|rituximab|tocilizumab|adalimumab|infliximab)\b', # Common drugs
]
# Condition patterns
condition_patterns = [
r'\b(sjogren\'?s?\s+syndrome)\b',
r'\b(rheumatoid arthritis)\b',
r'\b(lupus)\b',
r'\b(myelofibrosis)\b',
r'\b(diabetes)\b',
r'\b(cancer|carcinoma|melanoma)\b',
]
query_lower = query.lower()
# Extract drugs
for pattern in drug_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
if match.lower() not in [d.lower() for d in entities['drugs']]:
entities['drugs'].append(match)
# Extract conditions
for pattern in condition_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
for match in matches:
if match.lower() not in [c.lower() for c in entities['conditions']]:
entities['conditions'].append(match)
logger.info(f"Extracted entities: {entities}")
return entities
def parse_query_with_llm(query, hf_token=None):
"""
Use fast LLM to parse query and extract structured information
Extracts:
- Drug names
- Diseases/conditions
- Companies (sponsors/manufacturers)
- Endpoints (what's being measured)
- Search terms (optimized for RAG)
Returns: Dict with extracted entities and optimized search query
"""
try:
from huggingface_hub import InferenceClient
logger.info("[QUERY PARSER] Analyzing user query with LLM...")
client = InferenceClient(token=hf_token, timeout=30)
parse_prompt = f"""Extract key information from this clinical trial query.
Query: "{query}"
Extract and return in this EXACT format:
DRUGS: [list drug/treatment names, or "none"]
DISEASES: [list diseases/conditions, or "none"]
COMPANIES: [list company/sponsor names, or "none"]
ENDPOINTS: [list trial endpoints/outcomes, or "none"]
SEARCH_TERMS: [optimized search keywords]
Examples:
Query: "What Novartis drugs treat melanoma?"
DRUGS: none
DISEASES: melanoma
COMPANIES: Novartis
ENDPOINTS: none
SEARCH_TERMS: Novartis melanoma treatment drugs
Query: "Tell me about Keytruda for lung cancer"
DRUGS: Keytruda
DISEASES: lung cancer
COMPANIES: none
ENDPOINTS: none
SEARCH_TERMS: Keytruda lung cancer
Now parse the query above:"""
response = client.chat_completion(
model="meta-llama/Llama-3.1-70B-Instruct",
messages=[{"role": "user", "content": parse_prompt}],
max_tokens=256,
temperature=0.1 # Low temp for consistent parsing
)
parsed = response.choices[0].message.content.strip()
logger.info(f"[QUERY PARSER] Extracted entities:\n{parsed}")
# Parse the response into dict
result = {
'raw_parsed': parsed,
'drugs': [],
'diseases': [],
'companies': [],
'endpoints': [],
'search_terms': query # fallback
}
lines = parsed.split('\n')
for line in lines:
line = line.strip()
if line.startswith('DRUGS:'):
drugs = line.replace('DRUGS:', '').strip()
if drugs.lower() != 'none':
result['drugs'] = [d.strip() for d in drugs.split(',')]
elif line.startswith('DISEASES:'):
diseases = line.replace('DISEASES:', '').strip()
if diseases.lower() != 'none':
result['diseases'] = [d.strip() for d in diseases.split(',')]
elif line.startswith('COMPANIES:'):
companies = line.replace('COMPANIES:', '').strip()
if companies.lower() != 'none':
result['companies'] = [c.strip() for c in companies.split(',')]
elif line.startswith('ENDPOINTS:'):
endpoints = line.replace('ENDPOINTS:', '').strip()
if endpoints.lower() != 'none':
result['endpoints'] = [e.strip() for e in endpoints.split(',')]
elif line.startswith('SEARCH_TERMS:'):
result['search_terms'] = line.replace('SEARCH_TERMS:', '').strip()
logger.info(f"[QUERY PARSER] ✓ Drugs: {result['drugs']}, Diseases: {result['diseases']}, Companies: {result['companies']}")
return result
except Exception as e:
logger.warning(f"[QUERY PARSER] Failed: {e}, using original query")
return {
'drugs': [],
'diseases': [],
'companies': [],
'endpoints': [],
'search_terms': query,
'raw_parsed': ''
}
def generate_llama_response(query, rag_context, hf_token=None):
"""
Generate response using FAST Groq API (10x faster than HF)
Speed comparison:
- HuggingFace: ~40 tokens/sec = 15 seconds
- Groq: ~300 tokens/sec = 2 seconds (FREE!)
"""
try:
# Try Groq first (much faster), fallback to HuggingFace
groq_api_key = os.getenv("GROQ_API_KEY")
if groq_api_key:
logger.info("Generating response with Llama-3.1-70B via GROQ (fast)...")
from groq import Groq
client = Groq(api_key=groq_api_key)
# Simplified prompt for faster generation
system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
user_prompt = f"""Clinical trials:
{rag_context[:6000]}
Question: {query}
Provide a concise answer citing specific NCT trial IDs."""
response = client.chat.completions.create(
model="llama-3.1-70b-versatile", # Groq's optimized 70B
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=512, # Shorter for speed
temperature=0.3,
timeout=30
)
return response.choices[0].message.content.strip()
else:
# Fallback to HuggingFace (slower)
logger.info("Generating response with Llama-3.1-70B via HuggingFace (slow)...")
from huggingface_hub import InferenceClient
client = InferenceClient(token=hf_token, timeout=120)
system_prompt = """You are a medical research assistant. Answer based ONLY on the provided clinical trial data. Be concise and cite NCT IDs."""
user_prompt = f"""Clinical trials:
{rag_context[:6000]}
Question: {query}
Provide a concise answer citing specific NCT trial IDs."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
response = client.chat_completion(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
messages=messages,
max_tokens=512, # Reduced from 2048 for speed
temperature=0.3
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"Llama error: {e}")
return f"Llama API error: {str(e)}"
def process_query_simple_test(conversation):
"""TEST JUST THE RAG - no models"""
try:
import time
output = []
output.append(f"QUERY: {conversation}\n")
# Check if embeddings loaded
if doc_embeddings is None or len(doc_chunks) == 0:
return "FAIL: Embeddings not loaded"
output.append(f"✓ Embeddings loaded: {len(doc_chunks)} chunks\n")
output.append(f"✓ Embeddings shape: {doc_embeddings.shape}\n")
# Try to search
start = time.time()
context = retrieve_context_with_embeddings(conversation, top_k=3)
search_time = time.time() - start
if not context:
return "".join(output) + "\nFAIL: RAG returned empty"
output.append(f"✓ RAG search took: {search_time:.2f}s\n")
output.append(f"✓ Retrieved {context.count('NCT')} trials\n\n")
output.append("FIRST 1000 CHARS:\n")
output.append(context[:1000])
return "".join(output)
except Exception as e:
import traceback
return f"ERROR IN RAG TEST:\n{str(e)}\n\nTRACEBACK:\n{traceback.format_exc()}"
def process_query(conversation):
"""
Complete pipeline with LLM query parsing and natural language generation
Flow:
0. LLM Parser - Extract drugs, diseases, companies, endpoints (~2-3s)
1. RAG Search - Hybrid search using optimized query (~2s)
2. Skipped - 355M ranking removed (was broken)
3. LLM Response - Llama 70B generates natural language (~15s)
Total: ~20 seconds
"""
import time
import traceback
import sys
# MASTER try/except - catches EVERYTHING
try:
start_time = time.time()
output_parts = [f"QUERY: {conversation}\n\n"]
# Step 0: Parse query with LLM to extract structured info
try:
step0_start = time.time()
logger.info("Step 0: Parsing query with LLM...")
output_parts.append("✓ Step 0: LLM query parser started...\n")
parsed_query = parse_query_with_llm(conversation, hf_token=hf_token)
# Use optimized search terms from parser
search_query = parsed_query['search_terms']
step0_time = time.time() - step0_start
output_parts.append(f"✓ Step 0 Complete: Extracted entities ({step0_time:.1f}s)\n")
output_parts.append(f" Drugs: {parsed_query['drugs']}\n")
output_parts.append(f" Diseases: {parsed_query['diseases']}\n")
output_parts.append(f" Companies: {parsed_query['companies']}\n")
output_parts.append(f" Optimized search: {search_query}\n")
logger.info(f"Query parsing successful in {step0_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 0 WARNING (LLM Parser): {str(e)}, using original query"
logger.warning(error_msg)
output_parts.append(f"{error_msg}\n")
search_query = conversation # Fallback to original
# Step 1: RAG search (using optimized search query)
try:
step1_start = time.time()
logger.info("Step 1: RAG search...")
output_parts.append("✓ Step 1: RAG search started...\n")
context = retrieve_context_with_embeddings(search_query, top_k=3)
if not context:
return "No matching trials found in RAG search."
# No limit - use complete trials
step1_time = time.time() - step1_start
output_parts.append(f"✓ Step 1 Complete: Found {context.count('NCT')} trials ({step1_time:.1f}s)\n")
logger.info(f"RAG search successful - found trials in {step1_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 1 FAILED (RAG search): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg
# Step 2: Skipped (355M ranking removed - was broken)
output_parts.append("✓ Step 2: Skipped (using hybrid search + recency)\n")
# Step 3: Llama 70B
try:
step3_start = time.time()
logger.info("Step 3: Generating response with Llama-3.1-70B...")
output_parts.append("✓ Step 3: Llama 70B generation started...\n")
llama_response = generate_llama_response(conversation, context, hf_token=hf_token)
step3_time = time.time() - step3_start
output_parts.append(f"✓ Step 3 Complete: Llama 70B response generated ({step3_time:.1f}s)\n")
logger.info(f"Llama 70B generation successful in {step3_time:.1f}s")
except Exception as e:
error_msg = f"✗ Step 3 FAILED (Llama 70B): {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
llama_response = f"[Llama 70B error: {str(e)}]"
output_parts.append(f"✗ Step 3 Failed: {str(e)}\n")
total_time = time.time() - start_time
# Format output - handle missing variables
try:
context_display = context if 'context' in locals() else "[No context retrieved]"
clinical_display = clinical_context_355m if 'clinical_context_355m' in locals() else "[355M not run]"
llama_display = llama_response if 'llama_response' in locals() else "[Llama 70B not run]"
output = f"""{''.join(output_parts)}
CLINICAL SUMMARY (Llama-3.1-70B-Instruct):
{llama_display}
---
RAG RETRIEVED TRIALS (Top 3 Most Relevant):
{context_display}
---
Total Time: {total_time:.1f}s
"""
return output
except Exception as e:
# Absolute fallback
error_info = f"""
CRITICAL ERROR IN OUTPUT FORMATTING:
{str(e)}
TRACEBACK:
{traceback.format_exc()}
OUTPUT PARTS:
{''.join(output_parts)}
Variables defined: {locals().keys()}
"""
logger.error(error_info)
return error_info
# MASTER EXCEPTION HANDLER - catches ANY unhandled error
except Exception as master_error:
master_error_msg = f"""
========================================
MASTER ERROR HANDLER CAUGHT EXCEPTION
========================================
Error Type: {type(master_error).__name__}
Error Message: {str(master_error)}
FULL TRACEBACK:
{traceback.format_exc()}
System Info:
- Python version: {sys.version}
- Error at line: {sys.exc_info()[2].tb_lineno if sys.exc_info()[2] else 'unknown'}
========================================
"""
logger.error(master_error_msg)
return master_error_msg
# ============================================================================
# GRADIO INTERFACE
# ============================================================================
with gr.Blocks(title="Foundation 1.2") as demo:
gr.Markdown("# Foundation 1.2 - Clinical Trial AI")
query_input = gr.Textbox(
label="Ask about clinical trials",
placeholder="Example: What are the results for ianalumab in Sjogren's syndrome?",
lines=3
)
submit_btn = gr.Button("Generate Response", variant="primary")
output = gr.Textbox(
label="AI Response",
lines=30
)
submit_btn.click(
fn=process_query, # Full pipeline: RAG + 355M + Llama
inputs=query_input,
outputs=output
)
gr.Markdown("""
**Production RAG Pipeline - Optimized for Clinical Accuracy**
**Search (3-Stage Hybrid):**
1. Keyword matching (70%) + Semantic search (30%) → 10 candidates
2. 355M Clinical Trial GPT re-ranks by relevance
3. Returns top 3 trials with best clinical relevance scores
**Generation (Qwen2.5-14B-Instruct):**
- 14B parameter model via HuggingFace Inference API
- Structured clinical summaries with clear headings
- Cites specific NCT trial IDs
- Includes actual trial results and efficacy data
- High-quality medical reasoning and analysis
*355M model used for ranking (not generation) + Qwen2.5-14B for responses*
""")
# ============================================================================
# STARTUP
# ============================================================================
# Embeddings will be loaded by FastAPI startup event in app.py
# Do NOT load here - causes Docker permission errors
logger.info("=== Foundation 1.2 Module Loaded ===")
logger.info("Call load_embeddings() to initialize the system")
if DATASET_FILE.exists():
file_size_mb = DATASET_FILE.stat().st_size / (1024 * 1024)
logger.info(f"✓ Dataset file found: {file_size_mb:.0f}MB")
else:
logger.error("✗ Dataset file not found!")
logger.info("=== Startup Complete ===")
if __name__ == "__main__":
demo.launch()