my-gradio-app / agents /core /coordinator.py
Nguyen Trong Lap
Recreate history without binary blobs
eeb0f9c
"""
Agent Coordinator - Manages agent collaboration and handoffs
Enables multi-agent responses and smooth transitions
"""
from typing import Dict, List, Optional, Any
import asyncio
from concurrent.futures import ThreadPoolExecutor
from utils.memory import ConversationMemory
from utils.session_store import get_session_store
from utils.conversation_summarizer import get_summarizer
from agents.core.router import route_to_agent, get_router
from fine_tuning import get_data_collector
from health_data import HealthContext, HealthDataStore
import hashlib
import json
class AgentCoordinator:
"""
Coordinates multiple agents and manages handoffs
Provides multi-agent collaboration capabilities
"""
def __init__(self, user_id: Optional[str] = None, use_embedding_router=True, enable_cache=True, enable_data_collection=True, enable_session_persistence=True):
"""
Initialize coordinator with shared memory and data store
Args:
user_id: Unique user identifier for session persistence
use_embedding_router: Use embedding-based routing (faster)
enable_cache: Enable response caching
enable_data_collection: Enable conversation logging for fine-tuning
enable_session_persistence: Enable session persistence across restarts
"""
# Session persistence
self.user_id = user_id
self.session_store = get_session_store() if enable_session_persistence else None
# Initialize memory with session persistence
self.memory = ConversationMemory(
user_id=user_id,
session_store=self.session_store
)
self.data_store = HealthDataStore()
self.health_context = None
self.agents = {}
# Enable embedding router (faster than LLM routing)
self.use_embedding_router = use_embedding_router
if use_embedding_router:
self.router = get_router(use_embeddings=True)
else:
self.router = None
# Enable response cache
self.enable_cache = enable_cache
self.response_cache = {} if enable_cache else None
# Enable data collection for fine-tuning
self.enable_data_collection = enable_data_collection
if enable_data_collection:
self.data_collector = get_data_collector()
else:
self.data_collector = None
# Conversation summarizer
self.summarizer = get_summarizer()
self._initialize_agents()
def _initialize_agents(self) -> None:
"""Initialize all agents with shared memory"""
# Import agents (lazy import to avoid circular dependencies)
from agents.specialized.nutrition_agent import NutritionAgent
from agents.specialized.exercise_agent import ExerciseAgent
from agents.specialized.symptom_agent import SymptomAgent
from agents.specialized.mental_health_agent import MentalHealthAgent
from agents.specialized.general_health_agent import GeneralHealthAgent
# Create agents with shared memory
self.agents = {
'nutrition_agent': NutritionAgent(memory=self.memory),
'exercise_agent': ExerciseAgent(memory=self.memory),
'symptom_agent': SymptomAgent(memory=self.memory),
'mental_health_agent': MentalHealthAgent(memory=self.memory),
'general_health_agent': GeneralHealthAgent(memory=self.memory)
}
def handle_query(self, message: str, chat_history: Optional[List] = None, user_id: Optional[str] = None) -> str:
"""
Main entry point - handles user query with coordination
Args:
message: User's message
chat_history: Conversation history
user_id: User ID for data persistence
Returns:
str: Response (possibly from multiple agents)
"""
chat_history = chat_history or []
# Create or update health context for user
if user_id:
self.health_context = HealthContext(user_id, self.data_store)
# Inject health context into all agents
for agent in self.agents.values():
if hasattr(agent, 'set_health_context'):
agent.set_health_context(self.health_context)
# Update memory from chat history
self._update_memory_from_history(chat_history)
# Summarize if conversation is too long
if self.summarizer.should_summarize(chat_history):
chat_history = self._summarize_if_needed(chat_history)
# Check if multi-agent collaboration is needed
if self._needs_multi_agent(message):
return self._handle_multi_agent_query(message, chat_history)
# Single agent routing
return self._handle_single_agent_query(message, chat_history)
def _get_cache_key(self, message: str, chat_history: List) -> str:
"""Generate cache key from message and recent history"""
# Include last 2 exchanges for context
recent_history = chat_history[-4:] if len(chat_history) > 4 else chat_history
cache_data = {
"message": message.lower().strip(),
"history": [(h[0].lower().strip() if h[0] else "", h[1][:50] if len(h) > 1 else "") for h in recent_history]
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode()).hexdigest()
def _handle_single_agent_query(self, message: str, chat_history: List, file_data: Optional[Dict] = None) -> str:
"""Handle query with single agent (with potential handoff)"""
# Check cache first
if self.enable_cache:
cache_key = self._get_cache_key(message, chat_history)
if cache_key in self.response_cache:
# print("[CACHE HIT] Returning cached response")
return self.response_cache[cache_key]
# Route to appropriate agent (use embedding router if available)
if self.router:
routing_result = self.router.route(message, chat_history)
else:
routing_result = route_to_agent(message, chat_history)
agent_name = routing_result['agent']
parameters = routing_result['parameters']
# Update current agent in memory
self.memory.set_current_agent(agent_name)
# Get agent
agent = self.agents.get(agent_name)
if not agent:
return "Xin lỗi, không tìm thấy agent phù hợp."
# Let agent handle the request
response = agent.handle(parameters, chat_history)
# Log conversation for fine-tuning (with cleaned data)
if self.enable_data_collection and self.data_collector:
user_data = self.memory.get_full_profile()
# Clean user data before logging to prevent learning from errors
cleaned_user_data = self._clean_user_data_for_training(user_data)
self.data_collector.log_conversation(
agent_name=agent_name,
user_message=message,
agent_response=response,
user_data=cleaned_user_data,
metadata={'data_cleaned': True} # Flag that data was cleaned
)
# Cache the response
if self.enable_cache:
cache_key = self._get_cache_key(message, chat_history)
self.response_cache[cache_key] = response
# Limit cache size to 100 entries
if len(self.response_cache) > 100:
# Remove oldest entry (simple FIFO)
self.response_cache.pop(next(iter(self.response_cache)))
# Check if handoff is needed
if hasattr(agent, 'should_handoff') and agent.should_handoff(message, chat_history):
next_agent_name = agent.suggest_next_agent(message)
if next_agent_name and next_agent_name in self.agents:
return self._perform_handoff(agent, next_agent_name, response, message, chat_history)
return response
def _handle_multi_agent_query(self, message: str, chat_history: List) -> str:
"""Handle query that needs multiple agents (with parallel execution)"""
# Detect which agents are needed
agents_needed = self._detect_required_agents(message)
if len(agents_needed) <= 1:
# Fallback to single agent
return self._handle_single_agent_query(message, chat_history)
# Use async for parallel execution (faster!)
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
responses = loop.run_until_complete(
self._handle_multi_agent_async(message, chat_history, agents_needed)
)
loop.close()
except Exception as e:
print(f"Async multi-agent failed, falling back to sequential: {e}")
# Fallback to sequential if async fails
responses = {}
for agent_name in agents_needed:
agent = self.agents.get(agent_name)
if agent:
parameters = {'user_query': message}
responses[agent_name] = agent.handle(parameters, chat_history)
# Combine responses
return self._combine_responses(responses, agents_needed)
async def _handle_multi_agent_async(self, message: str, chat_history: List, agents_needed: List[str]) -> Dict[str, str]:
"""Execute multiple agents in parallel using asyncio"""
async def call_agent(agent_name: str):
"""Async wrapper for agent.handle()"""
agent = self.agents.get(agent_name)
if not agent:
return None
# Run in thread pool (since agent.handle is sync)
loop = asyncio.get_event_loop()
with ThreadPoolExecutor() as pool:
parameters = {'user_query': message}
response = await loop.run_in_executor(
pool,
agent.handle,
parameters,
chat_history
)
return response
# Create tasks for all agents
tasks = {agent_name: call_agent(agent_name) for agent_name in agents_needed}
# Execute in parallel
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
# Map results back to agent names
responses = {}
for agent_name, result in zip(tasks.keys(), results):
if isinstance(result, Exception):
print(f"Agent {agent_name} failed: {result}")
responses[agent_name] = f"Xin lỗi, {agent_name} gặp lỗi."
elif result:
responses[agent_name] = result
return responses
def _perform_handoff(
self,
from_agent: Any,
to_agent_name: str,
current_response: str,
message: str,
chat_history: List
) -> str:
"""
Perform smooth handoff between agents
Args:
from_agent: Current agent
to_agent_name: Name of agent to hand off to
current_response: Current agent's response
message: User's message
chat_history: Conversation history
Returns:
str: Combined response with handoff
"""
# Create handoff message
handoff_msg = from_agent.create_handoff_message(to_agent_name, current_response)
# Update memory
self.memory.set_current_agent(to_agent_name)
return handoff_msg
def _needs_multi_agent(self, message: str) -> bool:
"""
Determine if query needs multiple agents
Args:
message: User's message
Returns:
bool: True if multiple agents needed
"""
agents_needed = self._detect_required_agents(message)
return len(agents_needed) > 1
def _detect_required_agents(self, message: str) -> List[str]:
"""
Detect which agents are needed for this query
Args:
message: User's message
Returns:
List[str]: List of agent names needed
"""
agents_needed = []
message_lower = message.lower()
# PRIORITY 1: Symptom keywords (highest priority - health emergencies)
symptom_keywords = ['đau', 'sốt', 'ho', 'buồn nôn', 'chóng mặt', 'triệu chứng', 'khó tiêu', 'đầy bụng', 'ợ hơi']
has_symptoms = any(kw in message_lower for kw in symptom_keywords)
# PRIORITY 2: Nutrition keywords (but NOT if it's a symptom context)
nutrition_keywords = ['thực đơn', 'calo', 'giảm cân', 'tăng cân', 'dinh dưỡng', 'rau củ', 'thực phẩm']
# Special handling: 'ăn' only counts as nutrition if NOT in symptom context
has_nutrition = any(kw in message_lower for kw in nutrition_keywords)
if not has_symptoms and 'ăn' in message_lower:
has_nutrition = True
# PRIORITY 3: Exercise keywords
exercise_keywords = ['tập', 'gym', 'cardio', 'yoga', 'chạy bộ', 'exercise', 'workout']
has_exercise = any(kw in message_lower for kw in exercise_keywords)
# PRIORITY 4: Mental health keywords
mental_keywords = ['stress', 'lo âu', 'trầm cảm', 'mất ngủ', 'burnout', 'mental']
has_mental = any(kw in message_lower for kw in mental_keywords)
# IMPORTANT: Only trigger multi-agent if CLEARLY needs multiple domains
# Example: "Tôi bị đau bụng, nên ăn gì?" -> symptom + nutrition
# But: "WHO khuyến nghị ăn bao nhiêu rau củ?" -> ONLY nutrition
# Count how many domains are triggered
domain_count = sum([has_symptoms, has_nutrition, has_exercise, has_mental])
# If only 1 domain -> single agent (no multi-agent)
if domain_count <= 1:
if has_symptoms:
agents_needed.append('symptom_agent')
elif has_nutrition:
agents_needed.append('nutrition_agent')
elif has_exercise:
agents_needed.append('exercise_agent')
elif has_mental:
agents_needed.append('mental_health_agent')
else:
# Multiple domains detected
# Check if it's a REAL multi-domain question or false positive
# False positive patterns (should be single agent)
false_positives = [
'who khuyến nghị', # WHO recommendations -> single domain
'bao nhiêu', # Quantitative questions -> single domain
'khó tiêu', # Digestive issues -> symptom only
'đầy bụng', # Bloating -> symptom only
'đau bụng', # Stomach pain -> symptom only
'ợ hơi', # Burping -> symptom only
]
is_false_positive = any(pattern in message_lower for pattern in false_positives)
if is_false_positive:
# Use primary domain only
if has_nutrition:
agents_needed.append('nutrition_agent')
elif has_exercise:
agents_needed.append('exercise_agent')
elif has_symptoms:
agents_needed.append('symptom_agent')
elif has_mental:
agents_needed.append('mental_health_agent')
else:
# Real multi-domain question
if has_symptoms:
agents_needed.append('symptom_agent')
if has_nutrition:
agents_needed.append('nutrition_agent')
if has_exercise:
agents_needed.append('exercise_agent')
if has_mental:
agents_needed.append('mental_health_agent')
return agents_needed
def _combine_responses(self, responses: Dict[str, str], agents_order: List[str]) -> str:
"""
Combine responses from multiple agents
Args:
responses: Dict of agent_name -> response
agents_order: Order of agents
Returns:
str: Combined response
"""
# For natural flow, just combine responses without headers
# Make it feel like ONE person giving comprehensive advice
responses_list = [responses[agent] for agent in agents_order if agent in responses]
if len(responses_list) == 1:
# Single agent - return as is
return responses_list[0]
# Multiple agents - combine naturally
combined = ""
# First response (usually symptom assessment)
combined += responses_list[0]
# Add other responses with smooth transitions
for i in range(1, len(responses_list)):
# Natural transition phrases
transitions = [
"\n\nNgoài ra, ",
"\n\nBên cạnh đó, ",
"\n\nĐồng thời, ",
"\n\nVề mặt khác, "
]
transition = transitions[min(i-1, len(transitions)-1)]
combined += transition + responses_list[i]
# Natural closing (not too formal)
combined += "\n\nBạn thử làm theo xem có đỡ không nhé. Có gì thắc mắc cứ hỏi mình!"
return combined
def _update_memory_from_history(self, chat_history: List) -> None:
"""Extract and update SHARED memory from chat history to prevent duplicate questions"""
if not chat_history:
return
# Extract user info from ALL conversations (not just current agent)
user_info = self._extract_user_info_from_all_history(chat_history)
# Update SHARED memory that ALL agents can access
if user_info:
for key, value in user_info.items():
self.memory.update_profile(key, value)
def _extract_user_info_from_all_history(self, chat_history: List) -> Dict:
"""Extract user information from entire conversation history"""
user_info = {}
# Common patterns to extract
patterns = {
'age': [r'(\d+)\s*tuổi', r'tôi\s*(\d+)', r'(\d+)\s*years?\s*old'],
'gender': [r'tôi là (nam|nữ)', r'giới tính[:\s]*(nam|nữ)', r'(male|female|nam|nữ)'],
'weight': [r'(\d+)\s*kg', r'nặng\s*(\d+)', r'cân nặng[:\s]*(\d+)'],
'height': [r'(\d+)\s*cm', r'cao\s*(\d+)', r'chiều cao[:\s]*(\d+)'],
'goal': [r'muốn\s*(giảm cân|tăng cân|tăng cơ|khỏe mạnh)', r'mục tiêu[:\s]*(.+)']
}
# Search through all user messages
import re
for user_msg, _ in chat_history:
if not user_msg:
continue
for field, field_patterns in patterns.items():
if field not in user_info: # Only extract if not already found
for pattern in field_patterns:
match = re.search(pattern, user_msg.lower())
if match:
user_info[field] = match.group(1)
break
return user_info
# Extract gender
if not self.memory.get_profile('gender'):
if re.search(r'\bnam\b|male', all_messages.lower()):
self.memory.update_profile('gender', 'male')
elif re.search(r'\bnữ\b|female', all_messages.lower()):
self.memory.update_profile('gender', 'female')
# Extract weight
if not self.memory.get_profile('weight'):
weight_match = re.search(r'(\d+)\s*kg|nặng\s*(\d+)', all_messages.lower())
if weight_match:
weight = float([g for g in weight_match.groups() if g][0])
self.memory.update_profile('weight', weight)
# Extract height
if not self.memory.get_profile('height'):
height_match = re.search(r'(\d+)\s*cm|cao\s*(\d+)', all_messages.lower())
if height_match:
height = float([g for g in height_match.groups() if g][0])
self.memory.update_profile('height', height)
def _summarize_if_needed(self, chat_history: List) -> List:
"""
Summarize conversation if it's too long
Args:
chat_history: Full conversation history
Returns:
Compressed history with summary
"""
user_profile = self.memory.get_full_profile()
compressed = self.summarizer.compress_history(
chat_history,
target_turns=10 # Keep last 10 turns + summary
)
# print(f"📝 Summarized {len(chat_history)} turns → {len(compressed)} turns")
return compressed
def get_conversation_stats(self, chat_history: List) -> Dict[str, Any]:
"""Get statistics about current conversation"""
return self.summarizer.get_summary_stats(chat_history)
def get_memory_summary(self) -> str:
"""Get summary of current memory state"""
return self.memory.get_context_summary()
def _clean_user_data_for_training(self, user_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Clean user data before logging for training
Ensures only valid, corrected data is used for fine-tuning
This prevents the model from learning bad patterns like:
- "cao 200m" (should be 200cm)
- "nặng 75g" (should be 75kg)
- Invalid BMI values
"""
cleaned = user_data.copy()
# Validate and clean height (should be 50-300 cm)
if 'height' in cleaned and cleaned['height'] is not None:
height = float(cleaned['height'])
if not (50 <= height <= 300):
# Invalid height - don't log it
cleaned['height'] = None
# Validate and clean weight (should be 20-300 kg)
if 'weight' in cleaned and cleaned['weight'] is not None:
weight = float(cleaned['weight'])
if not (20 <= weight <= 300):
# Invalid weight - don't log it
cleaned['weight'] = None
# Validate and clean age (should be 1-120)
if 'age' in cleaned and cleaned['age'] is not None:
age = int(cleaned['age'])
if not (1 <= age <= 120):
# Invalid age - don't log it
cleaned['age'] = None
# Validate and clean body fat (should be 3-60%)
if 'body_fat_percentage' in cleaned and cleaned['body_fat_percentage'] is not None:
bf = float(cleaned['body_fat_percentage'])
if not (3 <= bf <= 60):
# Invalid body fat - don't log it
cleaned['body_fat_percentage'] = None
# Remove any None values to keep training data clean
cleaned = {k: v for k, v in cleaned.items() if v is not None}
return cleaned
def clear_memory(self) -> None:
"""Clear all memory (start fresh)"""
self.memory.clear()
def __repr__(self) -> str:
return f"<AgentCoordinator: {self.get_memory_summary()}>"