mtDNALocation / model.py
VyLala's picture
Update model.py
d660fcf verified
import re
import pycountry
from docx import Document
import json
import os
import numpy as np
import faiss
from collections import defaultdict
import ast # For literal_eval
import math # For ceiling function
import data_preprocess
import mtdna_classifier
import smart_fallback
import pipeline
import asyncio
# --- IMPORTANT: UNCOMMENT AND CONFIGURE YOUR REAL API KEY ---
import google.generativeai as genai
#genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
import nltk
from nltk.corpus import stopwords
try:
nltk.data.find('corpora/stopwords')
except LookupError:
nltk.download('stopwords')
nltk.download('punkt_tab')
# # --- Define Pricing Constants (for Gemini 1.5 Flash & text-embedding-004) ---
# # Prices are per 1,000 tokens
# PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
# PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
# PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
# Gemini 2.5 Flash-Lite pricing per 1,000 tokens
PRICE_PER_1K_INPUT_LLM = 0.00010 # $0.10 per 1M input tokens
PRICE_PER_1K_OUTPUT_LLM = 0.00040 # $0.40 per 1M output tokens
# Embedding-001 pricing per 1,000 input tokens
PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 # $0.15 per 1M input tokens
# --- API Functions (REAL API FUNCTIONS) ---
# def get_embedding(text, task_type="RETRIEVAL_DOCUMENT"):
# """Generates an embedding for the given text using a Google embedding model."""
# try:
# result = genai.embed_content(
# model="models/text-embedding-004", # Specify the embedding model
# content=text,
# task_type=task_type
# )
# return np.array(result['embedding']).astype('float32')
# except Exception as e:
# print(f"Error getting embedding: {e}")
# return np.zeros(768, dtype='float32')
def get_embedding(text, task_type="RETRIEVAL_DOCUMENT"):
"""Safe Gemini 1.5 embedding call with fallback."""
import numpy as np
try:
if not text or len(text.strip()) == 0:
raise ValueError("Empty text cannot be embedded.")
result = genai.embed_content(
model="models/text-embedding-004",
content=text,
task_type=task_type
)
return np.array(result['embedding'], dtype='float32')
except Exception as e:
print(f"❌ Embedding error: {e}")
return np.zeros(768, dtype='float32')
def call_llm_api(prompt, model_name="gemini-2.5-flash-lite"):#'gemini-1.5-flash-latest'):
"""Calls a Google Gemini LLM with the given prompt."""
try:
model = genai.GenerativeModel(model_name)
response = model.generate_content(prompt)
return response.text, model # Return model instance for token counting
except Exception as e:
print(f"Error calling LLM: {e}")
return "Error: Could not get response from LLM API.", None
# --- Core Document Processing Functions (All previously provided and fixed) ---
def read_docx_text(path):
"""
Reads text and extracts potential table-like strings from a .docx document.
Separates plain text from structured [ [ ] ] list-like tables.
Also attempts to extract a document title.
"""
doc = Document(path)
plain_text_paragraphs = []
table_strings = []
document_title = "Unknown Document Title" # Default
# Attempt to extract the document title from the first few paragraphs
title_paragraphs = [p.text.strip() for p in doc.paragraphs[:5] if p.text.strip()]
if title_paragraphs:
# A heuristic to find a title: often the first or second non-empty paragraph
# or a very long first paragraph if it's the title
if len(title_paragraphs[0]) > 50 and "Human Genetics" not in title_paragraphs[0]:
document_title = title_paragraphs[0]
elif len(title_paragraphs) > 1 and len(title_paragraphs[1]) > 50 and "Human Genetics" not in title_paragraphs[1]:
document_title = title_paragraphs[1]
elif any("Complete mitochondrial genomes" in p for p in title_paragraphs):
# Fallback to a known title phrase if present
document_title = "Complete mitochondrial genomes of Thai and Lao populations indicate an ancient origin of Austroasiatic groups and demic diffusion in the spread of Tai–Kadai languages"
current_table_lines = []
in_table_parsing_mode = False
for p in doc.paragraphs:
text = p.text.strip()
if not text:
continue
# Condition to start or continue table parsing
if text.startswith("## Table "): # Start of a new table section
if in_table_parsing_mode and current_table_lines:
table_strings.append("\n".join(current_table_lines))
current_table_lines = [text] # Include the "## Table X" line
in_table_parsing_mode = True
elif in_table_parsing_mode and (text.startswith("[") or text.startswith('"')):
# Continue collecting lines if we're in table mode and it looks like table data
# Table data often starts with '[' for lists, or '"' for quoted strings within lists.
current_table_lines.append(text)
else:
# If not in table mode, or if a line doesn't look like table data,
# then close the current table (if any) and add the line to plain text.
if in_table_parsing_mode and current_table_lines:
table_strings.append("\n".join(current_table_lines))
current_table_lines = []
in_table_parsing_mode = False
plain_text_paragraphs.append(text)
# After the loop, add any remaining table lines
if current_table_lines:
table_strings.append("\n".join(current_table_lines))
return "\n".join(plain_text_paragraphs), table_strings, document_title
# --- Structured Data Extraction and RAG Functions ---
def parse_literal_python_list(table_str):
list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str)
#print("Debug: list_match object (before if check):", list_match)
if not list_match:
if "table" in table_str.lower(): # then the table doest have the "]]" at the end
table_str += "]]"
list_match = re.search(r'(\[\s*\[\s*(?:.|\n)*?\s*\]\s*\])', table_str)
if list_match:
try:
matched_string = list_match.group(1)
#print("Debug: Matched string for literal_eval:", matched_string)
return ast.literal_eval(matched_string)
except (ValueError, SyntaxError) as e:
print(f"Error evaluating literal: {e}")
return []
return []
_individual_code_parser = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
def _parse_individual_code_parts(code_str):
match = _individual_code_parser.search(code_str)
if match:
return match.group(1), match.group(2)
return None, None
def parse_sample_id_to_population_code(plain_text_content):
sample_id_map = {}
contiguous_ranges_data = defaultdict(list)
#section_start_marker = "The sample identification of each population is as follows:"
section_start_marker = ["The sample identification of each population is as follows:","## table"]
for s in section_start_marker:
relevant_text_search = re.search(
re.escape(s.lower()) + r"\s*(.*?)(?=\n##|\Z)",
plain_text_content.lower(),
re.DOTALL
)
if relevant_text_search:
break
if not relevant_text_search:
print("Warning: 'Sample ID Population Code' section start marker not found or block empty.")
return sample_id_map, contiguous_ranges_data
relevant_text_block = relevant_text_search.group(1).strip()
# print(f"\nDEBUG_PARSING: --- Start of relevant_text_block (first 500 chars) ---")
# print(relevant_text_block[:500])
# print(f"DEBUG_PARSING: --- End of relevant_text_block (last 500 chars) ---")
# print(relevant_text_block[-500:])
# print(f"DEBUG_PARSING: Relevant text block length: {len(relevant_text_block)}")
mapping_pattern = re.compile(
r'\b([A-Z0-9]+\d+)(?:-([A-Z0-9]+\d+))?\s+([A-Z0-9]+)\b', # Changed the last group
re.IGNORECASE)
range_expansion_count = 0
direct_id_count = 0
total_matches_found = 0
for match in mapping_pattern.finditer(relevant_text_block):
total_matches_found += 1
id1_full_str, id2_full_str_opt, pop_code = match.groups()
#print(f" DEBUG_PARSING: Matched: '{match.group(0)}'")
pop_code_upper = pop_code.upper()
id1_prefix, id1_num_str = _parse_individual_code_parts(id1_full_str)
if id1_prefix is None:
#print(f" DEBUG_PARSING: Failed to parse ID1: {id1_full_str}. Skipping this mapping.")
continue
if id2_full_str_opt:
id2_prefix_opt, id2_num_str_opt = _parse_individual_code_parts(id2_full_str_opt)
if id2_prefix_opt is None:
#print(f" DEBUG_PARSING: Failed to parse ID2: {id2_full_str_opt}. Treating {id1_full_str} as single ID1.")
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
direct_id_count += 1
continue
#print(f" DEBUG_PARSING: Comparing prefixes: '{id1_prefix.lower()}' vs '{id2_prefix_opt.lower()}'")
if id1_prefix.lower() == id2_prefix_opt.lower():
#print(f" DEBUG_PARSING: ---> Prefixes MATCH for range expansion! Range: {id1_prefix}{id1_num_str}-{id2_prefix_opt}{id2_num_str_opt}")
try:
start_num = int(id1_num_str)
end_num = int(id2_num_str_opt)
for num in range(start_num, end_num + 1):
sample_id = f"{id1_prefix.upper()}{num}"
sample_id_map[sample_id] = pop_code_upper
range_expansion_count += 1
contiguous_ranges_data[id1_prefix.upper()].append(
(start_num, end_num, pop_code_upper)
)
except ValueError:
print(f" DEBUG_PARSING: ValueError in range conversion for {id1_num_str}-{id2_num_str_opt}. Adding endpoints only.")
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper
direct_id_count += 2
else:
#print(f" DEBUG_PARSING: Prefixes MISMATCH for range: '{id1_prefix}' vs '{id2_prefix_opt}'. Adding endpoints only.")
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
sample_id_map[f"{id2_prefix_opt.upper()}{id2_num_str_opt}"] = pop_code_upper
direct_id_count += 2
else:
sample_id_map[f"{id1_prefix.upper()}{id1_num_str}"] = pop_code_upper
direct_id_count += 1
# print(f"DEBUG_PARSING: Total matches found by regex: {total_matches_found}.")
# print(f"DEBUG_PARSING: Parsed sample IDs: {len(sample_id_map)} total entries.")
# print(f"DEBUG_PARSING: (including {range_expansion_count} from range expansion and {direct_id_count} direct ID/endpoint entries).")
return sample_id_map, contiguous_ranges_data
country_keywords_regional_overrides = {
"north thailand": "Thailand", "central thailand": "Thailand",
"northeast thailand": "Thailand", "east myanmar": "Myanmar", "west thailand": "Thailand",
"central india": "India", "east india": "India", "northeast india": "India",
"south sibera": "Russia", "siberia": "Russia", "yunnan": "China", #"tibet": "China",
"sumatra": "Indonesia", "borneo": "Indonesia",
"northern mindanao": "Philippines", "west malaysia": "Malaysia",
"mongolia": "China",
"beijing": "China",
"north laos": "Laos", "central laos": "Laos",
"east myanmar": "Myanmar", "west myanmar": "Myanmar"}
# Updated get_country_from_text function
def get_country_from_text(text):
text_lower = text.lower()
# 1. Use pycountry for official country names and common aliases
for country in pycountry.countries:
# Check full name match first
if text_lower == country.name.lower():
return country.name
# Safely check for common_name
if hasattr(country, 'common_name') and text_lower == country.common_name.lower():
return country.common_name
# Safely check for official_name
if hasattr(country, 'official_name') and text_lower == country.official_name.lower():
return country.official_name
# Check if country name is part of the text (e.g., 'Thailand' in 'Thailand border')
if country.name.lower() in text_lower:
return country.name
# Safely check if common_name is part of the text
if hasattr(country, 'common_name') and country.common_name.lower() in text_lower:
return country.common_name
# 2. Prioritize specific regional overrides
for keyword, country in country_keywords_regional_overrides.items():
if keyword in text_lower:
return country
# 3. Check for broader regions that you want to map to "unknown" or a specific country
if "north asia" in text_lower or "southeast asia" in text_lower or "east asia" in text_lower:
return "unknown"
return "unknown"
# Get the list of English stop words from NLTK
non_meaningful_pop_names = set(stopwords.words('english'))
def parse_population_code_to_country(plain_text_content, table_strings):
pop_code_country_map = {}
pop_code_ethnicity_map = {} # NEW: To store ethnicity for structured lookup
pop_code_specific_loc_map = {} # NEW: To store specific location for structured lookup
# Regex for parsing population info in structured lists and general text
# This pattern captures: (Pop Name/Ethnicity) (Pop Code) (Region/Specific Location) (Country) (Linguistic Family)
# The 'Pop Name/Ethnicity' (Group 1) is often the ethnicity
pop_info_pattern = re.compile(
r'([A-Za-z\s]+?)\s+([A-Z]+\d*)\s+' # Pop Name (Group 1), Pop Code (Group 2) - Changed \d+ to \d* for codes like 'SH'
r'([A-Za-z\s\(\)\-,\/]+?)\s+' # Region/Specific Location (Group 3)
r'(North+|South+|West+|East+|Thailand|Laos|Cambodia|Myanmar|Philippines|Indonesia|Malaysia|China|India|Taiwan|Vietnam|Russia|Nepal|Japan|South Korea)\b' # Country (Group 4)
r'(?:.*?([A-Za-z\s\-]+))?\s*' # Optional Linguistic Family (Group 5), made optional with ?, followed by optional space
r'(\d+(?:\s+\d+\.?\d*)*)?', # Match all the numbers (Group 6) - made optional
re.IGNORECASE
)
for table_str in table_strings:
table_data = parse_literal_python_list(table_str)
if table_data:
is_list_of_lists = bool(table_data) and isinstance(table_data[0], list)
if is_list_of_lists:
for row_idx, row in enumerate(table_data):
row_text = " ".join(map(str, row))
match = pop_info_pattern.search(row_text)
if match:
pop_name = match.group(1).strip()
pop_code = match.group(2).upper()
specific_loc_text = match.group(3).strip()
country_text = match.group(4).strip()
linguistic_family = match.group(5).strip() if match.group(5) else 'unknown'
final_country = get_country_from_text(country_text)
if final_country == 'unknown': # Try specific loc text for country if direct country is not found
final_country = get_country_from_text(specific_loc_text)
if pop_code:
pop_code_country_map[pop_code] = final_country
# Populate ethnicity map (often Pop Name is ethnicity)
pop_code_ethnicity_map[pop_code] = pop_name
# Populate specific location map
pop_code_specific_loc_map[pop_code] = specific_loc_text # Store as is from text
else:
row_text = " ".join(map(str, table_data))
match = pop_info_pattern.search(row_text)
if match:
pop_name = match.group(1).strip()
pop_code = match.group(2).upper()
specific_loc_text = match.group(3).strip()
country_text = match.group(4).strip()
linguistic_family = match.group(5).strip() if match.group(5) else 'unknown'
final_country = get_country_from_text(country_text)
if final_country == 'unknown': # Try specific loc text for country if direct country is not found
final_country = get_country_from_text(specific_loc_text)
if pop_code:
pop_code_country_map[pop_code] = final_country
# Populate ethnicity map (often Pop Name is ethnicity)
pop_code_ethnicity_map[pop_code] = pop_name
# Populate specific location map
pop_code_specific_loc_map[pop_code] = specific_loc_text # Store as is from text
# # Special case refinements for ethnicity/location if more specific rules are known from document:
# if pop_name.lower() == "khon mueang": # and specific conditions if needed
# pop_code_ethnicity_map[pop_code] = "Khon Mueang"
# # If Khon Mueang has a specific city/district, add here
# # e.g., if 'Chiang Mai' is directly linked to KM1 in a specific table
# # pop_code_specific_loc_map[pop_code] = "Chiang Mai"
# elif pop_name.lower() == "lawa":
# pop_code_ethnicity_map[pop_code] = "Lawa"
# # Add similar specific rules for other populations (e.g., Mon for MO1, MO2, MO3)
# elif pop_name.lower() == "mon":
# pop_code_ethnicity_map[pop_code] = "Mon"
# # For MO2: "West Thailand (Thailand Myanmar border)" -> no city
# # For MO3: "East Myanmar (Thailand Myanmar border)" -> no city
# # If the doc gives "Bangkok" for MO4, add it here for MO4's actual specific_location.
# # etc.
# Fallback to parsing general plain text content (sentences)
sentences = data_preprocess.extract_sentences(plain_text_content)
for s in sentences: # Still focusing on just this one sentence
# Use re.finditer to get all matches
matches = pop_info_pattern.finditer(s)
pop_name, pop_code, specific_loc_text, country_text = "unknown", "unknown", "unknown", "unknown"
for match in matches:
if match.group(1):
pop_name = match.group(1).strip()
if match.group(2):
pop_code = match.group(2).upper()
if match.group(3):
specific_loc_text = match.group(3).strip()
if match.group(4):
country_text = match.group(4).strip()
# linguistic_family = match.group(5).strip() if match.group(5) else 'unknown' # Already captured by pop_info_pattern
final_country = get_country_from_text(country_text)
if final_country == 'unknown':
final_country = get_country_from_text(specific_loc_text)
if pop_code.lower() not in non_meaningful_pop_names:
if final_country.lower() not in non_meaningful_pop_names:
pop_code_country_map[pop_code] = final_country
if pop_name.lower() not in non_meaningful_pop_names:
pop_code_ethnicity_map[pop_code] = pop_name # Default ethnicity from Pop Name
if specific_loc_text.lower() not in non_meaningful_pop_names:
pop_code_specific_loc_map[pop_code] = specific_loc_text
# Specific rules for ethnicity/location in plain text:
if pop_name.lower() == "khon mueang":
pop_code_ethnicity_map[pop_code] = "Khon Mueang"
elif pop_name.lower() == "lawa":
pop_code_ethnicity_map[pop_code] = "Lawa"
elif pop_name.lower() == "mon":
pop_code_ethnicity_map[pop_code] = "Mon"
elif pop_name.lower() == "seak": # Added specific rule for Seak
pop_code_ethnicity_map[pop_code] = "Seak"
elif pop_name.lower() == "nyaw": # Added specific rule for Nyaw
pop_code_ethnicity_map[pop_code] = "Nyaw"
elif pop_name.lower() == "nyahkur": # Added specific rule for Nyahkur
pop_code_ethnicity_map[pop_code] = "Nyahkur"
elif pop_name.lower() == "suay": # Added specific rule for Suay
pop_code_ethnicity_map[pop_code] = "Suay"
elif pop_name.lower() == "soa": # Added specific rule for Soa
pop_code_ethnicity_map[pop_code] = "Soa"
elif pop_name.lower() == "bru": # Added specific rule for Bru
pop_code_ethnicity_map[pop_code] = "Bru"
elif pop_name.lower() == "khamu": # Added specific rule for Khamu
pop_code_ethnicity_map[pop_code] = "Khamu"
return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map
def general_parse_population_code_to_country(plain_text_content, table_strings):
pop_code_country_map = {}
pop_code_ethnicity_map = {}
pop_code_specific_loc_map = {}
sample_id_to_pop_code = {}
for table_str in table_strings:
table_data = parse_literal_python_list(table_str)
if not table_data or not isinstance(table_data[0], list):
continue
header_row = [col.lower() for col in table_data[0]]
header_map = {col: idx for idx, col in enumerate(header_row)}
# MJ17: Direct PopCode → Country
if 'id' in header_map and 'country' in header_map:
for row in table_strings[1:]:
row = parse_literal_python_list(row)[0]
if len(row) < len(header_row):
continue
pop_code = str(row[header_map['id']]).strip()
country = str(row[header_map['country']]).strip()
province = row[header_map['province']].strip() if 'province' in header_map else 'unknown'
pop_group = row[header_map['population group / region']].strip() if 'population group / region' in header_map else 'unknown'
pop_code_country_map[pop_code] = country
pop_code_specific_loc_map[pop_code] = province
pop_code_ethnicity_map[pop_code] = pop_group
# A1YU101 or EBK/KSK: SampleID → PopCode
elif 'sample id' in header_map and 'population code' in header_map:
for row in table_strings[1:]:
row = parse_literal_python_list(row)[0]
if len(row) < 2:
continue
sample_id = row[header_map['sample id']].strip().upper()
pop_code = row[header_map['population code']].strip().upper()
sample_id_to_pop_code[sample_id] = pop_code
# PopCode → Country (A1YU101/EBK mapping)
elif 'population code' in header_map and 'country' in header_map:
for row in table_strings[1:]:
row = parse_literal_python_list(row)[0]
if len(row) < 2:
continue
pop_code = row[header_map['population code']].strip().upper()
country = row[header_map['country']].strip()
pop_code_country_map[pop_code] = country
return pop_code_country_map, pop_code_ethnicity_map, pop_code_specific_loc_map, sample_id_to_pop_code
def chunk_text(text, chunk_size=500, overlap=50):
"""Splits text into chunks (by words) with overlap."""
chunks = []
words = text.split()
num_words = len(words)
start = 0
while start < num_words:
end = min(start + chunk_size, num_words)
chunk = " ".join(words[start:end])
chunks.append(chunk)
if end == num_words:
break
start += chunk_size - overlap # Move start by (chunk_size - overlap)
return chunks
def build_vector_index_and_data(doc_path, index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"):
"""
Reads document, builds structured lookup, chunks remaining text, embeds chunks,
and builds/saves a FAISS index.
"""
print("Step 1: Reading document and extracting structured data...")
# plain_text_content, table_strings, document_title = read_docx_text(doc_path) # Get document_title here
# sample_id_map, contiguous_ranges_data = parse_sample_id_to_population_code(plain_text_content)
# pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc = parse_population_code_to_country(plain_text_content, table_strings)
# master_structured_lookup = {}
# master_structured_lookup['document_title'] = document_title # Store document title
# master_structured_lookup['sample_id_map'] = sample_id_map
# master_structured_lookup['contiguous_ranges'] = dict(contiguous_ranges_data)
# master_structured_lookup['pop_code_to_country'] = pop_code_to_country
# master_structured_lookup['pop_code_to_ethnicity'] = pop_code_to_ethnicity # NEW: Store pop_code to ethnicity map
# master_structured_lookup['pop_code_to_specific_loc'] = pop_code_to_specific_loc # NEW: Store pop_code to specific_loc map
# # Final consolidation: Use sample_id_map to derive full info for queries
# final_structured_entries = {}
# for sample_id, pop_code in master_structured_lookup['sample_id_map'].items():
# country = master_structured_lookup['pop_code_to_country'].get(pop_code, 'unknown')
# ethnicity = master_structured_lookup['pop_code_to_ethnicity'].get(pop_code, 'unknown') # Retrieve ethnicity
# specific_location = master_structured_lookup['pop_code_to_specific_loc'].get(pop_code, 'unknown') # Retrieve specific location
# final_structured_entries[sample_id] = {
# 'population_code': pop_code,
# 'country': country,
# 'type': 'modern',
# 'ethnicity': ethnicity, # Store ethnicity
# 'specific_location': specific_location # Store specific location
# }
# master_structured_lookup['final_structured_entries'] = final_structured_entries
plain_text_content, table_strings, document_title = read_docx_text(doc_path)
pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc, sample_id_map = general_parse_population_code_to_country(plain_text_content, table_strings)
final_structured_entries = {}
if sample_id_map:
for sample_id, pop_code in sample_id_map.items():
country = pop_code_to_country.get(pop_code, 'unknown')
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
final_structured_entries[sample_id] = {
'population_code': pop_code,
'country': country,
'type': 'modern',
'ethnicity': ethnicity,
'specific_location': specific_loc
}
else:
for pop_code in pop_code_to_country.keys():
country = pop_code_to_country.get(pop_code, 'unknown')
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
final_structured_entries[pop_code] = {
'population_code': pop_code,
'country': country,
'type': 'modern',
'ethnicity': ethnicity,
'specific_location': specific_loc
}
if not final_structured_entries:
# traditional way of A1YU101
sample_id_map, contiguous_ranges_data = parse_sample_id_to_population_code(plain_text_content)
pop_code_to_country, pop_code_to_ethnicity, pop_code_to_specific_loc = parse_population_code_to_country(plain_text_content, table_strings)
if sample_id_map:
for sample_id, pop_code in sample_id_map.items():
country = pop_code_to_country.get(pop_code, 'unknown')
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
final_structured_entries[sample_id] = {
'population_code': pop_code,
'country': country,
'type': 'modern',
'ethnicity': ethnicity,
'specific_location': specific_loc
}
else:
for pop_code in pop_code_to_country.keys():
country = pop_code_to_country.get(pop_code, 'unknown')
ethnicity = pop_code_to_ethnicity.get(pop_code, 'unknown')
specific_loc = pop_code_to_specific_loc.get(pop_code, 'unknown')
final_structured_entries[pop_code] = {
'population_code': pop_code,
'country': country,
'type': 'modern',
'ethnicity': ethnicity,
'specific_location': specific_loc
}
master_lookup = {
'document_title': document_title,
'pop_code_to_country': pop_code_to_country,
'pop_code_to_ethnicity': pop_code_to_ethnicity,
'pop_code_to_specific_loc': pop_code_to_specific_loc,
'sample_id_map': sample_id_map,
'final_structured_entries': final_structured_entries
}
print(f"Structured lookup built with {len(final_structured_entries)} entries in 'final_structured_entries'.")
with open(structured_path, 'w') as f:
json.dump(master_lookup, f, indent=4)
print(f"Structured lookup saved to {structured_path}.")
print("Step 2: Chunking document for RAG vector index...")
# replace the chunk here with the all_output from process_inputToken and fallback to this traditional chunk
clean_text, clean_table = "", ""
if plain_text_content:
clean_text = data_preprocess.normalize_for_overlap(plain_text_content)
if table_strings:
clean_table = data_preprocess.normalize_for_overlap(". ".join(table_strings))
all_clean_chunk = clean_text + clean_table
document_chunks = chunk_text(all_clean_chunk)
print(f"Document chunked into {len(document_chunks)} chunks.")
print("Step 3: Generating embeddings for chunks (this might take time and cost API calls)...")
embedding_model_for_chunks = genai.GenerativeModel('models/text-embedding-004')
chunk_embeddings = []
for i, chunk in enumerate(document_chunks):
embedding = get_embedding(chunk, task_type="RETRIEVAL_DOCUMENT")
if embedding is not None and embedding.shape[0] > 0:
chunk_embeddings.append(embedding)
else:
print(f"Warning: Failed to get valid embedding for chunk {i}. Skipping.")
chunk_embeddings.append(np.zeros(768, dtype='float32'))
if not chunk_embeddings:
raise ValueError("No valid embeddings generated. Check get_embedding function and API.")
embedding_dimension = chunk_embeddings[0].shape[0]
index = faiss.IndexFlatL2(embedding_dimension)
index.add(np.array(chunk_embeddings))
faiss.write_index(index, index_path)
with open(chunks_path, "w") as f:
json.dump(document_chunks, f)
print(f"FAISS index built and saved to {index_path}.")
print(f"Document chunks saved to {chunks_path}.")
return master_lookup, index, document_chunks, all_clean_chunk
def load_rag_assets(index_path="faiss_index.bin", chunks_path="document_chunks.json", structured_path="structured_lookup.json"):
"""Loads pre-built RAG assets (FAISS index, chunks, structured lookup)."""
print("Loading RAG assets...")
master_structured_lookup = {}
if os.path.exists(structured_path):
with open(structured_path, 'r') as f:
master_structured_lookup = json.load(f)
print("Structured lookup loaded.")
else:
print("Structured lookup file not found. Rebuilding is likely needed.")
index = None
chunks = []
if os.path.exists(index_path) and os.path.exists(chunks_path):
try:
index = faiss.read_index(index_path)
with open(chunks_path, "r") as f:
chunks = json.load(f)
print("FAISS index and chunks loaded.")
except Exception as e:
print(f"Error loading FAISS index or chunks: {e}. Will rebuild.")
index = None
chunks = []
else:
print("FAISS index or chunks files not found.")
return master_structured_lookup, index, chunks
# Helper function for query_document_info
def exactInContext(text, keyword):
# try keyword_prfix
# code_pattern = re.compile(r'([A-Z0-9]+?)(\d+)$', re.IGNORECASE)
# # Attempt to parse the keyword into its prefix and numerical part using re.search
# keyword_match = code_pattern.search(keyword)
# keyword_prefix = None
# keyword_num = None
# if keyword_match:
# keyword_prefix = keyword_match.group(1).lower()
# keyword_num = int(keyword_match.group(2))
text = text.lower()
idx = text.find(keyword.lower())
if idx == -1:
# if keyword_prefix:
# idx = text.find(keyword_prefix)
# if idx == -1:
# return False
return False
return True
def chooseContextLLM(contexts, kw):
# if kw in context
for con in contexts:
context = contexts[con]
if context:
if exactInContext(context, kw):
return con, context
#if cannot find anything related to kw in context, return all output
if contexts["all_output"]:
return "all_output", contexts["all_output"]
else:
# if all_output not exist
# look of chunk and still not exist return document chunk
if contexts["chunk"]: return "chunk", contexts["chunk"]
elif contexts["document_chunk"]: return "document_chunk", contexts["document_chunk"]
else: return None, None
def clean_llm_output(llm_response_text, output_format_str):
results = []
lines = llm_response_text.strip().split('\n')
output_country, output_type, output_ethnicity, output_specific_location = [],[],[],[]
for line in lines:
extracted_country, extracted_type, extracted_ethnicity, extracted_specific_location = "unknown", "unknown", "unknown", "unknown"
line = line.strip()
if output_format_str == "ethnicity, specific_location/unknown": # Targeted RAG output
parsed_output = re.search(r'^\s*([^,]+?),\s*(.+?)\s*$', llm_response_text)
if parsed_output:
extracted_ethnicity = parsed_output.group(1).strip()
extracted_specific_location = parsed_output.group(2).strip()
else:
print(" DEBUG: LLM did not follow expected 2-field format for targeted RAG. Defaulting to unknown for ethnicity/specific_location.")
extracted_ethnicity = 'unknown'
extracted_specific_location = 'unknown'
elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown":
parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', llm_response_text)
if parsed_output:
extracted_type = parsed_output.group(1).strip()
extracted_ethnicity = parsed_output.group(2).strip()
extracted_specific_location = parsed_output.group(3).strip()
else:
# Fallback: check if only 2 fields
parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', llm_response_text)
if parsed_output_2_fields:
extracted_type = parsed_output_2_fields.group(1).strip()
extracted_ethnicity = parsed_output_2_fields.group(2).strip()
extracted_specific_location = 'unknown'
else:
# even simpler fallback: 1 field only
parsed_output_1_field = re.search(r'^\s*([^,]+?)\s*$', llm_response_text)
if parsed_output_1_field:
extracted_type = parsed_output_1_field.group(1).strip()
extracted_ethnicity = 'unknown'
extracted_specific_location = 'unknown'
else:
print(" DEBUG: LLM did not follow any expected simplified format. Attempting verbose parsing fallback.")
type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', llm_response_text)
extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown'
extracted_ethnicity = 'unknown'
extracted_specific_location = 'unknown'
else:
parsed_output = re.search(r'^\s*([^,]+?),\s*([^,]+?),\s*([^,]+?),\s*(.+?)\s*$', line)
if parsed_output:
extracted_country = parsed_output.group(1).strip()
extracted_type = parsed_output.group(2).strip()
extracted_ethnicity = parsed_output.group(3).strip()
extracted_specific_location = parsed_output.group(4).strip()
else:
print(f" DEBUG: Line did not follow expected 4-field format: {line}")
parsed_output_2_fields = re.search(r'^\s*([^,]+?),\s*([^,]+?)\s*$', line)
if parsed_output_2_fields:
extracted_country = parsed_output_2_fields.group(1).strip()
extracted_type = parsed_output_2_fields.group(2).strip()
extracted_ethnicity = 'unknown'
extracted_specific_location = 'unknown'
else:
print(f" DEBUG: Fallback to verbose-style parsing: {line}")
country_match_fallback = re.search(r'Country:\s*([A-Za-z\s-]+)', line)
type_match_fallback = re.search(r'Type:\s*([A-Za-z\s-]+)', line)
extracted_country = country_match_fallback.group(1).strip() if country_match_fallback else 'unknown'
extracted_type = type_match_fallback.group(1).strip() if type_match_fallback else 'unknown'
extracted_ethnicity = 'unknown'
extracted_specific_location = 'unknown'
results.append({
"country": extracted_country,
"type": extracted_type,
"ethnicity": extracted_ethnicity,
"specific_location": extracted_specific_location
#"country_explain":extracted_country_explain,
#"type_explain": extracted_type_explain
})
# if more than 2 results
if output_format_str == "ethnicity, specific_location/unknown":
for result in results:
if result["ethnicity"] not in output_ethnicity:
output_ethnicity.append(result["ethnicity"])
if result["specific_location"] not in output_specific_location:
output_specific_location.append(result["specific_location"])
return " or ".join(output_ethnicity), " or ".join(output_specific_location)
elif output_format_str == "modern/ancient/unknown, ethnicity, specific_location/unknown":
for result in results:
if result["type"] not in output_type:
output_type.append(result["type"])
if result["ethnicity"] not in output_ethnicity:
output_ethnicity.append(result["ethnicity"])
if result["specific_location"] not in output_specific_location:
output_specific_location.append(result["specific_location"])
return " or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location)
else:
for result in results:
if result["country"] not in output_country:
output_country.append(result["country"])
if result["type"] not in output_type:
output_type.append(result["type"])
if result["ethnicity"] not in output_ethnicity:
output_ethnicity.append(result["ethnicity"])
if result["specific_location"] not in output_specific_location:
output_specific_location.append(result["specific_location"])
return " or ".join(output_country)," or ".join(output_type)," or ".join(output_ethnicity), " or ".join(output_specific_location)
# def parse_multi_sample_llm_output(raw_response: str, output_format_str):
# """
# Parse LLM output with possibly multiple metadata lines + shared explanations.
# """
# lines = [line.strip() for line in raw_response.strip().splitlines() if line.strip()]
# metadata_list = []
# explanation_lines = []
# if output_format_str == "country_name, modern/ancient/unknown":
# parts = [x.strip() for x in lines[0].split(",")]
# if len(parts)==2:
# metadata_list.append({
# "country": parts[0],
# "sample_type": parts[1]#,
# #"ethnicity": parts[2],
# #"location": parts[3]
# })
# if 1<len(lines):
# line = lines[1]
# if "\n" in line: line = line.split("\n")
# if ". " in line: line = line.split(". ")
# if isinstance(line,str): line = [line]
# explanation_lines += line
# elif output_format_str == "modern/ancient/unknown":
# metadata_list.append({
# "country": "unknown",
# "sample_type": lines[0]#,
# #"ethnicity": parts[2],
# #"location": parts[3]
# })
# explanation_lines.append(lines[1])
# # Assign explanations (optional) to each sample — same explanation reused
# for md in metadata_list:
# md["country_explanation"] = None
# md["sample_type_explanation"] = None
# if md["country"].lower() != "unknown" and len(explanation_lines) >= 1:
# md["country_explanation"] = explanation_lines[0]
# if md["sample_type"].lower() != "unknown":
# if len(explanation_lines) >= 2:
# md["sample_type_explanation"] = explanation_lines[1]
# elif len(explanation_lines) == 1 and md["country"].lower() == "unknown":
# md["sample_type_explanation"] = explanation_lines[0]
# elif len(explanation_lines) == 1:
# md["sample_type_explanation"] = explanation_lines[0]
# return metadata_list
def parse_multi_sample_llm_output(raw_response: str, output_format_str):
"""
Parse LLM output with possibly multiple metadata lines + shared explanations.
"""
metadata_list = {}
explanation_lines = []
output_answers = re.split(r",\s*", raw_response.split("\n")[0].strip()) #raw_response.split("\n")[0].split(", ")
explanation_lines = [x for x in raw_response.split("\n")[1:] if x.strip()]
print("raw explanation line which split by new line: ", explanation_lines)
if len(explanation_lines) == 1:
if len(explanation_lines[0].split(". ")) > len(explanation_lines):
explanation_lines = [x for x in explanation_lines[0].split(". ") if x.strip()]
print("explain line split by dot: ", explanation_lines)
output_formats = output_format_str.split(", ")
explain = ""
# assign output format to its output answer and explanation
if output_format_str:
outputs = output_format_str.split(", ")
for o in range(len(outputs)):
output = outputs[o]
metadata_list[output] = {"answer":"",
output+"_explanation":""}
# assign output answers
if o < len(output_answers):
# check if output_format unexpectedly in the answer such as:
#country_name: Europe, modern/ancient: modern
try:
if ": " in output_answers[o]:
output_answers[o] = output_answers[o].split(": ")[1]
except:
pass
# Europe, modern
metadata_list[output]["answer"] = output_answers[o]
if "unknown" in metadata_list[output]["answer"].lower():
metadata_list[output]["answer"] = "unknown"
else:
metadata_list[output]["answer"] = "unknown"
# assign explanations
if metadata_list[output]["answer"] != "unknown":
# if explanation_lines:
# explain = explanation_lines.pop(0)
# else:
# explain = ". ".join(explanation_lines)
explain = ". ".join(explanation_lines)
metadata_list[output][output+"_explanation"] = explain
else:
metadata_list[output][output+"_explanation"] = "unknown"
return metadata_list
def merge_metadata_outputs(metadata_list):
"""
Merge a list of metadata dicts into one, combining differing values with 'or'.
Assumes all dicts have the same keys.
"""
if not metadata_list:
return {}
merged = {}
keys = metadata_list[0].keys()
for key in keys:
values = [md[key] for md in metadata_list if key in md]
unique_values = list(dict.fromkeys(values)) # preserve order, remove dupes
if "unknown" in unique_values:
unique_values.pop(unique_values.index("unknown"))
if len(unique_values) == 1:
merged[key] = unique_values[0]
else:
merged[key] = " or ".join(unique_values)
return merged
import time
import random
def safe_call_llm(prompt, model="gemini-2.5-flash-lite", max_retries=5):
retry_delay = 20
for attempt in range(max_retries):
try:
resp_text, resp_model = call_llm_api(prompt, model)
return resp_text, resp_model
except Exception as e:
error_msg = str(e)
if "429" in error_msg or "quota" in error_msg.lower():
print(f"\n⚠️ Rate limit hit (attempt {attempt+1}/{max_retries}).")
retry_after = None
for word in error_msg.split():
if "retry" in word.lower() and "s" in word:
try:
retry_after = float(word.replace("s","").replace(".",""))
except:
pass
wait_time = retry_after if retry_after else retry_delay
print(f"⏳ Waiting {wait_time:.1f} seconds before retrying...")
time.sleep(wait_time)
retry_delay *= 2
else:
raise e
raise RuntimeError("❌ Failed after max retries because of repeated rate limits.")
async def query_document_info(niche_cases, query_word, alternative_query_word, saveLinkFolder, metadata, master_structured_lookup, faiss_index, document_chunks, llm_api_function, chunk=None, all_output=None, model_ai=None):
"""
Queries the document using a hybrid approach:
1. Local structured lookup (fast, cheap, accurate for known patterns).
2. RAG with semantic search and LLM (general, flexible, cost-optimized).
"""
print("inside the model.query_doc_info")
outputs, links = {}, []
if model_ai:
if model_ai == "gemini-1.5-flash-latest":
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
PRICE_PER_1K_INPUT_LLM = 0.000075 # $0.075 per 1M tokens
PRICE_PER_1K_OUTPUT_LLM = 0.0003 # $0.30 per 1M tokens
PRICE_PER_1K_EMBEDDING_INPUT = 0.000025 # $0.025 per 1M tokens
global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-1.5-flash-latest")#('gemini-1.5-flash-latest')
else:
genai.configure(api_key=os.getenv("GOOGLE_API_KEY_BACKUP"))
# Gemini 2.5 Flash-Lite pricing per 1,000 tokens
PRICE_PER_1K_INPUT_LLM = 0.00010 # $0.10 per 1M input tokens
PRICE_PER_1K_OUTPUT_LLM = 0.00040 # $0.40 per 1M output tokens
# Embedding-001 pricing per 1,000 input tokens
PRICE_PER_1K_EMBEDDING_INPUT = 0.00015 # $0.15 per 1M input tokens
global_llm_model_for_counting_tokens = genai.GenerativeModel("gemini-2.5-flash-lite")#('gemini-1.5-flash-latest')
if metadata:
extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = metadata["country"], metadata["specific_location"], metadata["ethnicity"], metadata["sample_type"]
extracted_col_date, extracted_iso, extracted_title, extracted_features = metadata["collection_date"], metadata["isolate"], metadata["title"], metadata["all_features"]
else:
extracted_country, extracted_specific_location, extracted_ethnicity, extracted_type = "unknown", "unknown", "unknown", "unknown"
extracted_col_date, extracted_iso, extracted_title = "unknown", "unknown", "unknown"
# --- NEW: Pre-process alternative_query_word to remove '.X' suffix if present ---
if alternative_query_word:
alternative_query_word_cleaned = alternative_query_word.split('.')[0]
else:
alternative_query_word_cleaned = alternative_query_word
country_explanation, sample_type_explanation = None, None
# Use the consolidated final_structured_entries for direct lookup
# final_structured_entries = master_structured_lookup.get('final_structured_entries', {})
# document_title = master_structured_lookup.get('document_title', 'Unknown Document Title') # Retrieve document title
# Default values for all extracted fields. These will be updated.
method_used = 'unknown' # Will be updated based on the method that yields a result
population_code_from_sl = 'unknown' # To pass to RAG prompt if available
total_query_cost = 0
# Attempt 1: Try primary query_word (e.g., isolate name) with structured lookup
# try:
# print("try attempt 1 in model query")
# structured_info = final_structured_entries.get(query_word.upper())
# if structured_info:
# if extracted_country == 'unknown':
# extracted_country = structured_info['country']
# if extracted_type == 'unknown':
# extracted_type = structured_info['type']
# # if extracted_ethnicity == 'unknown':
# # extracted_ethnicity = structured_info.get('ethnicity', 'unknown') # Get ethnicity from structured lookup
# # if extracted_specific_location == 'unknown':
# # extracted_specific_location = structured_info.get('specific_location', 'unknown') # Get specific_location from structured lookup
# population_code_from_sl = structured_info['population_code']
# method_used = "structured_lookup_direct"
# print(f"'{query_word}' found in structured lookup (direct match).")
# except:
# print("pass attempt 1 in model query")
# pass
# # Attempt 2: Try primary query_word with heuristic range lookup if direct fails (only if not already resolved)
# try:
# print("try attempt 2 in model query")
# if method_used == 'unknown':
# query_prefix, query_num_str = _parse_individual_code_parts(query_word)
# if query_prefix is not None and query_num_str is not None:
# try: query_num = int(query_num_str)
# except ValueError: query_num = None
# if query_num is not None:
# query_prefix_upper = query_prefix.upper()
# contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
# pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
# pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
# pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
# if query_prefix_upper in contiguous_ranges:
# for start_num, end_num, pop_code_for_range in contiguous_ranges[query_prefix_upper]:
# if start_num <= query_num <= end_num:
# country_from_heuristic = pop_code_to_country.get(pop_code_for_range, 'unknown')
# if country_from_heuristic != 'unknown':
# if extracted_country == 'unknown':
# extracted_country = country_from_heuristic
# if extracted_type == 'unknown':
# extracted_type = 'modern'
# # if extracted_ethnicity == 'unknown':
# # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
# # if extracted_specific_location == 'unknown':
# # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
# population_code_from_sl = pop_code_for_range
# method_used = "structured_lookup_heuristic_range_match"
# print(f"'{query_word}' not direct. Heuristic: Falls within range {query_prefix_upper}{start_num}-{query_prefix_upper}{end_num}.")
# break
# else:
# print(f"'{query_word}' heuristic match found, but country unknown. Will fall to RAG below.")
# except:
# print("pass attempt 2 in model query")
# pass
# # Attempt 3: If primary query_word failed all structured lookups, try alternative_query_word (cleaned)
# try:
# print("try attempt 3 in model query")
# if method_used == 'unknown' and alternative_query_word_cleaned and alternative_query_word_cleaned != query_word:
# print(f"'{query_word}' not found in structured (or heuristic). Trying alternative '{alternative_query_word_cleaned}'.")
# # Try direct lookup for alternative word
# structured_info_alt = final_structured_entries.get(alternative_query_word_cleaned.upper())
# if structured_info_alt:
# if extracted_country == 'unknown':
# extracted_country = structured_info_alt['country']
# if extracted_type == 'unknown':
# extracted_type = structured_info_alt['type']
# # if extracted_ethnicity == 'unknown':
# # extracted_ethnicity = structured_info_alt.get('ethnicity', 'unknown')
# # if extracted_specific_location == 'unknown':
# # extracted_specific_location = structured_info_alt.get('specific_location', 'unknown')
# population_code_from_sl = structured_info_alt['population_code']
# method_used = "structured_lookup_alt_direct"
# print(f"Alternative '{alternative_query_word_cleaned}' found in structured lookup (direct match).")
# else:
# # Try heuristic lookup for alternative word
# alt_prefix, alt_num_str = _parse_individual_code_parts(alternative_query_word_cleaned)
# if alt_prefix is not None and alt_num_str is not None:
# try: alt_num = int(alt_num_str)
# except ValueError: alt_num = None
# if alt_num is not None:
# alt_prefix_upper = alt_prefix.upper()
# contiguous_ranges = master_structured_lookup.get('contiguous_ranges', defaultdict(list))
# pop_code_to_country = master_structured_lookup.get('pop_code_to_country', {})
# pop_code_to_ethnicity = master_structured_lookup.get('pop_code_to_ethnicity', {})
# pop_code_to_specific_loc = master_structured_lookup.get('pop_code_to_specific_loc', {})
# if alt_prefix_upper in contiguous_ranges:
# for start_num, end_num, pop_code_for_range in contiguous_ranges[alt_prefix_upper]:
# if start_num <= alt_num <= end_num:
# country_from_heuristic_alt = pop_code_to_country.get(pop_code_for_range, 'unknown')
# if country_from_heuristic_alt != 'unknown':
# if extracted_country == 'unknown':
# extracted_country = country_from_heuristic_alt
# if extracted_type == 'unknown':
# extracted_type = 'modern'
# # if extracted_ethnicity == 'unknown':
# # extracted_ethnicity = pop_code_to_ethnicity.get(pop_code_for_range, 'unknown')
# # if extracted_specific_location == 'unknown':
# # extracted_specific_location = pop_code_to_specific_loc.get(pop_code_for_range, 'unknown')
# population_code_from_sl = pop_code_for_range
# method_used = "structured_lookup_alt_heuristic_range_match"
# break
# else:
# print(f"Alternative '{alternative_query_word_cleaned}' heuristic match found, but country unknown. Will fall to RAG below.")
# except:
# print("pass attempt 3 in model query")
# pass
# use the context_for_llm to detect present_ancient before using llm model
# retrieved_chunks_text = []
# if document_chunks:
# for idx in range(len(document_chunks)):
# retrieved_chunks_text.append(document_chunks[idx])
# context_for_llm = ""
# all_context = "\n".join(retrieved_chunks_text) #
# listOfcontexts = {"chunk": chunk,
# "all_output": all_output,
# "document_chunk": all_context}
# label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
# if not context_for_llm:
# label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
# if not context_for_llm:
# context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
# if context_for_llm:
# extracted_type, explain = mtdna_classifier.detect_ancient_flag(context_for_llm)
# extracted_type = extracted_type.lower()
# sample_type_explanation = explain
# 5. Execute RAG if needed (either full RAG or targeted RAG for missing fields)
# Determine if a RAG call is necessary
# run_rag = (extracted_country == 'unknown' or extracted_type == 'unknown')# or \
# #extracted_ethnicity == 'unknown' or extracted_specific_location == 'unknown')
run_rag = True
if run_rag:
print("try run rag")
# Determine the phrase for LLM query
rag_query_phrase = ""
if query_word.lower() != "unknown":
rag_query_phrase += f"the mtDNA isolate name '{query_word}'"
# Accession number (alternative_query_word)
if (
alternative_query_word_cleaned
and alternative_query_word_cleaned != query_word
and alternative_query_word_cleaned.lower() != "unknown"
):
if rag_query_phrase:
rag_query_phrase += f" or its accession number '{alternative_query_word_cleaned}'"
else:
rag_query_phrase += f"the accession number '{alternative_query_word_cleaned}'"
# Construct a more specific semantic query phrase for embedding if structured info is available
semantic_query_for_embedding = rag_query_phrase # Default
prompt_instruction_prefix = ""
output_format_str = ""
# Determine if it's a full RAG or targeted RAG scenario based on what's already extracted
is_full_rag_scenario = True#(extracted_country == 'unknown')
if is_full_rag_scenario: # Full RAG scenario
output_format_str = "country_name, modern/ancient/unknown"#, ethnicity, specific_location/unknown"
explain_list = "country or sample type (modern/ancient)"
if niche_cases:
output_format_str += ", "+ ", ".join(niche_cases)# "ethnicity, specific_location/unknown"
explain_list += " or "+ " or ".join(niche_cases)
method_used = "rag_llm"
print(f"Proceeding to FULL RAG for {rag_query_phrase}.")
current_embedding_cost = 0
print("direct to llm")
listOfcontexts = {"chunk": chunk,
"all_output": all_output,
"document_chunk": chunk}
label, context_for_llm = chooseContextLLM(listOfcontexts, query_word)
if not context_for_llm:
label, context_for_llm = chooseContextLLM(listOfcontexts, alternative_query_word_cleaned)
if not context_for_llm:
context_for_llm = "Collection_date: " + col_date +". Isolate: " + iso + ". Title: " + title + ". Features: " + extracted_features
if len(context_for_llm) > 1000*1000:
context_for_llm = context_for_llm[:900000]
# fix the prompt better:
# firstly clarify more by saying which type of organism, prioritize homo sapiens
features = metadata["all_features"]
organism = "general"
if features != "unknown":
if "organism" in features:
try:
organism = features.split("organism: ")[1].split("\n")[0]
except:
organism = features.replace("\n","; ")
niche_prompt = ""
if niche_cases:
fields_list = ", ".join(niche_cases)
niche_prompt = (
f"Also, extract {fields_list}. "
f"If not explicitly stated, infer the most specific related or contextually relevant value. "
f"If no information is found, write 'unknown'. "
)
prompt_for_llm = (
f"{prompt_instruction_prefix}"
f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} "
f"or the mitochondrial DNA sample in {organism} if these identifiers are not explicitly found. "
f"Identify its **primary associated geographic location**, preferring the most specific available: "
f"first try to determine the exact country; if no country is explicitly mentioned, then provide "
f"the next most specific region, continent, island, or other clear geographic area mentioned. "
f"If no geographic clues at all are present, state 'unknown' for location. "
f"Also, determine if the genetic sample is from a 'modern' (present-day living individual) "
f"or 'ancient' (prehistoric/archaeological) source. "
f"If the text does not specify ancient or archaeological context, assume 'modern'. "
f"{niche_prompt}"
f"Provide only {output_format_str}. "
f"If any information is not explicitly present, use the fallback rules above before defaulting to 'unknown'. "
f"For each non-'unknown' field in {explain_list}, write one sentence explaining how it was inferred from the text "
f"(one sentence for each). "
f"Format your answer so that:\n"
f"1. The **first line** contains only the {output_format_str} values separated by commas.\n"
f"2. The **second line onward** contains the explanations based on the order of the non-unknown {output_format_str} answer.\n"
f"\nText Snippets:\n{context_for_llm}")
print("this is prompt: ", prompt_for_llm)
if model_ai:
print("back up to ", model_ai)
#llm_response_text, model_instance = call_llm_api(prompt_for_llm, model=model_ai)
llm_response_text, model_instance = safe_call_llm(prompt_for_llm, model=model_ai)
else:
print("still 2.5 flash gemini")
llm_response_text, model_instance = safe_call_llm(prompt_for_llm)
#llm_response_text, model_instance = call_llm_api(prompt_for_llm)
print("\n--- DEBUG INFO FOR RAG ---")
print("Retrieved Context Sent to LLM (first 500 chars):")
print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
print("\nRaw LLM Response:")
print(llm_response_text)
print("--- END DEBUG INFO ---")
llm_cost = 0
if model_instance:
try:
input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
(output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
except Exception as e:
print(f" DEBUG: Error counting LLM tokens: {e}")
llm_cost = 0
total_query_cost += current_embedding_cost + llm_cost
print(f" DEBUG: Total estimated cost for this RAG query: ${total_query_cost:.6f}")
metadata_list = parse_multi_sample_llm_output(llm_response_text, output_format_str)
print(metadata_list)
again_output_format, general_knowledge_prompt = "", ""
# if at least 1 answer is unknown, then do smart queries to get more sources besides doi
unknown_count = sum(1 for v in metadata_list.values() if v.get("answer").lower() == "unknown")
if unknown_count >= 1:
print("at least 1 unknown outputs")
out_links = {}
iso, acc = query_word, alternative_query_word
meta_expand = smart_fallback.fetch_ncbi(acc)
tem_links = smart_fallback.smart_google_search(acc, meta_expand)
tem_links = pipeline.unique_preserve_order(tem_links)
print("this is tem links with acc: ", tem_links)
# filter the quality link
print("start the smart filter link")
#success_process, output_process = run_with_timeout(smart_fallback.filter_links_by_metadata,args=(tem_links,saveLinkFolder),kwargs={"accession":acc},timeout=90)
output_process = await smart_fallback.async_filter_links_by_metadata(
tem_links, saveLinkFolder, accession=acc
)
if output_process:
out_links.update(output_process)
print("yeah we have out_link and len: ", len(out_links))
print("yes succeed for smart filter link")
links += list(out_links.keys())
print("link keys: ", links)
if links:
tasks = [
pipeline.process_link_chunk_allOutput(link, iso, acc, saveLinkFolder, out_links, all_output, chunk)
for link in links
]
results = await asyncio.gather(*tasks)
# combine results
for context, new_all_output, new_chunk in results:
context_for_llm += new_all_output
context_for_llm += new_chunk
print("len of context after merge all: ", len(context_for_llm))
if len(context_for_llm) > 750000:
context_for_llm = data_preprocess.normalize_for_overlap(context_for_llm)
if len(context_for_llm) > 750000:
# use build context for llm function to reduce token
texts_reduce = []
out_links_reduce = {}
reduce_context_for_llm = ""
if links:
for link in links:
all_output_reduce, chunk_reduce, context_reduce = "", "",""
context_reduce, all_output_reduce, chunk_reduce = await pipeline.process_link_chunk_allOutput(link,
iso, acc, saveLinkFolder, out_links_reduce,
all_output_reduce, chunk_reduce)
texts_reduce.append(all_output_reduce)
out_links_reduce[link] = {"all_output": all_output_reduce}
input_prompt = ["country_name", "modern/ancient/unknown"]
if niche_cases: input_prompt += niche_cases
reduce_context_for_llm = data_preprocess.build_context_for_llm(texts_reduce, acc, input_prompt)
if reduce_context_for_llm:
print("reduce context for llm")
context_for_llm = reduce_context_for_llm
else:
print("no reduce context for llm despite>1M")
context_for_llm = context_for_llm[:250000]
for key in metadata_list:
answer = metadata_list[key]["answer"]
if answer.lower() in " ".join(["unknown", "unspecified","could not get response from llm api.", "undefined"]):
print("have to do again")
again_output_format = key
print("output format:", again_output_format)
general_knowledge_prompt = (
f"{prompt_instruction_prefix}"
f"Given the following text snippets, analyze the entity/concept {rag_query_phrase} "
f"or the mitochondrial DNA sample in {organism} if these identifiers are not explicitly found. "
f"Identify and extract {again_output_format}"
f"If not explicitly stated, infer the most specific related or contextually relevant value. "
f"If no information is found, write 'unknown'. "
f"Provide only {again_output_format}. "
f"For non-'unknown' field in {again_output_format}, write one sentence explaining how it was inferred from the text "
f"Format your answer so that:\n"
f"1. The **first line** contains only the {again_output_format} answer.\n"
f"2. The **second line onward** contains the explanations based on the non-unknown {again_output_format} answer.\n"
f"\nText Snippets:\n{context_for_llm}")
print("len of prompt:", len(general_knowledge_prompt))
if general_knowledge_prompt:
if model_ai:
print("back up to ", model_ai)
llm_response_text, model_instance = safe_call_llm(general_knowledge_prompt, model=model_ai)
#llm_response_text, model_instance = call_llm_api(general_knowledge_prompt, model=model_ai)
else:
print("still 2.5 flash gemini")
llm_response_text, model_instance = safe_call_llm(general_knowledge_prompt)
#llm_response_text, model_instance = call_llm_api(general_knowledge_prompt)
print("\n--- DEBUG INFO FOR RAG ---")
print("Retrieved Context Sent to LLM (first 500 chars):")
print(context_for_llm[:500] + "..." if len(context_for_llm) > 500 else context_for_llm)
print("\nRaw LLM Response:")
print(llm_response_text)
print("--- END DEBUG INFO ---")
llm_cost = 0
if model_instance:
try:
input_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(prompt_for_llm).total_tokens
output_llm_tokens = global_llm_model_for_counting_tokens.count_tokens(llm_response_text).total_tokens
print(f" DEBUG: LLM Input tokens: {input_llm_tokens}")
print(f" DEBUG: LLM Output tokens: {output_llm_tokens}")
llm_cost = (input_llm_tokens / 1000) * PRICE_PER_1K_INPUT_LLM + \
(output_llm_tokens / 1000) * PRICE_PER_1K_OUTPUT_LLM
print(f" DEBUG: Estimated LLM cost: ${llm_cost:.6f}")
except Exception as e:
print(f" DEBUG: Error counting LLM tokens: {e}")
llm_cost = 0
total_query_cost += current_embedding_cost + llm_cost
print("total query cost in again: ", total_query_cost)
metadata_list_one_case = parse_multi_sample_llm_output(llm_response_text, again_output_format)
print("metadata list after running again unknown output: ", metadata_list)
for key in metadata_list_one_case:
print("keys of outputs: ", outputs.keys())
if key not in list(outputs.keys()):
print("this is key and about to be added into outputs: ", key)
outputs[key] = metadata_list_one_case[key]
else:
outputs[key] = metadata_list[key]
print("all done and method used: ", outputs, method_used)
print("total cost: ", total_query_cost)
return outputs, method_used, total_query_cost, links