""" Agent Coordinator - Manages agent collaboration and handoffs Enables multi-agent responses and smooth transitions """ from typing import Dict, List, Optional, Any import asyncio from concurrent.futures import ThreadPoolExecutor from utils.memory import ConversationMemory from utils.session_store import get_session_store from utils.conversation_summarizer import get_summarizer from agents.core.router import route_to_agent, get_router from fine_tuning import get_data_collector from health_data import HealthContext, HealthDataStore import hashlib import json class AgentCoordinator: """ Coordinates multiple agents and manages handoffs Provides multi-agent collaboration capabilities """ def __init__(self, user_id: Optional[str] = None, use_embedding_router=True, enable_cache=True, enable_data_collection=True, enable_session_persistence=True): """ Initialize coordinator with shared memory and data store Args: user_id: Unique user identifier for session persistence use_embedding_router: Use embedding-based routing (faster) enable_cache: Enable response caching enable_data_collection: Enable conversation logging for fine-tuning enable_session_persistence: Enable session persistence across restarts """ # Session persistence self.user_id = user_id self.session_store = get_session_store() if enable_session_persistence else None # Initialize memory with session persistence self.memory = ConversationMemory( user_id=user_id, session_store=self.session_store ) self.data_store = HealthDataStore() self.health_context = None self.agents = {} # Enable embedding router (faster than LLM routing) self.use_embedding_router = use_embedding_router if use_embedding_router: self.router = get_router(use_embeddings=True) else: self.router = None # Enable response cache self.enable_cache = enable_cache self.response_cache = {} if enable_cache else None # Enable data collection for fine-tuning self.enable_data_collection = enable_data_collection if enable_data_collection: self.data_collector = get_data_collector() else: self.data_collector = None # Conversation summarizer self.summarizer = get_summarizer() self._initialize_agents() def _initialize_agents(self) -> None: """Initialize all agents with shared memory""" # Import agents (lazy import to avoid circular dependencies) from agents.specialized.nutrition_agent import NutritionAgent from agents.specialized.exercise_agent import ExerciseAgent from agents.specialized.symptom_agent import SymptomAgent from agents.specialized.mental_health_agent import MentalHealthAgent from agents.specialized.general_health_agent import GeneralHealthAgent # Create agents with shared memory self.agents = { 'nutrition_agent': NutritionAgent(memory=self.memory), 'exercise_agent': ExerciseAgent(memory=self.memory), 'symptom_agent': SymptomAgent(memory=self.memory), 'mental_health_agent': MentalHealthAgent(memory=self.memory), 'general_health_agent': GeneralHealthAgent(memory=self.memory) } def handle_query(self, message: str, chat_history: Optional[List] = None, user_id: Optional[str] = None) -> str: """ Main entry point - handles user query with coordination Args: message: User's message chat_history: Conversation history user_id: User ID for data persistence Returns: str: Response (possibly from multiple agents) """ chat_history = chat_history or [] # Create or update health context for user if user_id: self.health_context = HealthContext(user_id, self.data_store) # Inject health context into all agents for agent in self.agents.values(): if hasattr(agent, 'set_health_context'): agent.set_health_context(self.health_context) # Update memory from chat history self._update_memory_from_history(chat_history) # Summarize if conversation is too long if self.summarizer.should_summarize(chat_history): chat_history = self._summarize_if_needed(chat_history) # Check if multi-agent collaboration is needed if self._needs_multi_agent(message): return self._handle_multi_agent_query(message, chat_history) # Single agent routing return self._handle_single_agent_query(message, chat_history) def _get_cache_key(self, message: str, chat_history: List) -> str: """Generate cache key from message and recent history""" # Include last 2 exchanges for context recent_history = chat_history[-4:] if len(chat_history) > 4 else chat_history cache_data = { "message": message.lower().strip(), "history": [(h[0].lower().strip() if h[0] else "", h[1][:50] if len(h) > 1 else "") for h in recent_history] } cache_str = json.dumps(cache_data, sort_keys=True) return hashlib.md5(cache_str.encode()).hexdigest() def _handle_single_agent_query(self, message: str, chat_history: List, file_data: Optional[Dict] = None) -> str: """Handle query with single agent (with potential handoff)""" # Check cache first if self.enable_cache: cache_key = self._get_cache_key(message, chat_history) if cache_key in self.response_cache: # print("[CACHE HIT] Returning cached response") return self.response_cache[cache_key] # Route to appropriate agent (use embedding router if available) if self.router: routing_result = self.router.route(message, chat_history) else: routing_result = route_to_agent(message, chat_history) agent_name = routing_result['agent'] parameters = routing_result['parameters'] # Update current agent in memory self.memory.set_current_agent(agent_name) # Get agent agent = self.agents.get(agent_name) if not agent: return "Xin lỗi, không tìm thấy agent phù hợp." # Let agent handle the request response = agent.handle(parameters, chat_history) # Log conversation for fine-tuning (with cleaned data) if self.enable_data_collection and self.data_collector: user_data = self.memory.get_full_profile() # Clean user data before logging to prevent learning from errors cleaned_user_data = self._clean_user_data_for_training(user_data) self.data_collector.log_conversation( agent_name=agent_name, user_message=message, agent_response=response, user_data=cleaned_user_data, metadata={'data_cleaned': True} # Flag that data was cleaned ) # Cache the response if self.enable_cache: cache_key = self._get_cache_key(message, chat_history) self.response_cache[cache_key] = response # Limit cache size to 100 entries if len(self.response_cache) > 100: # Remove oldest entry (simple FIFO) self.response_cache.pop(next(iter(self.response_cache))) # Check if handoff is needed if hasattr(agent, 'should_handoff') and agent.should_handoff(message, chat_history): next_agent_name = agent.suggest_next_agent(message) if next_agent_name and next_agent_name in self.agents: return self._perform_handoff(agent, next_agent_name, response, message, chat_history) return response def _handle_multi_agent_query(self, message: str, chat_history: List) -> str: """Handle query that needs multiple agents (with parallel execution)""" # Detect which agents are needed agents_needed = self._detect_required_agents(message) if len(agents_needed) <= 1: # Fallback to single agent return self._handle_single_agent_query(message, chat_history) # Use async for parallel execution (faster!) try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) responses = loop.run_until_complete( self._handle_multi_agent_async(message, chat_history, agents_needed) ) loop.close() except Exception as e: print(f"Async multi-agent failed, falling back to sequential: {e}") # Fallback to sequential if async fails responses = {} for agent_name in agents_needed: agent = self.agents.get(agent_name) if agent: parameters = {'user_query': message} responses[agent_name] = agent.handle(parameters, chat_history) # Combine responses return self._combine_responses(responses, agents_needed) async def _handle_multi_agent_async(self, message: str, chat_history: List, agents_needed: List[str]) -> Dict[str, str]: """Execute multiple agents in parallel using asyncio""" async def call_agent(agent_name: str): """Async wrapper for agent.handle()""" agent = self.agents.get(agent_name) if not agent: return None # Run in thread pool (since agent.handle is sync) loop = asyncio.get_event_loop() with ThreadPoolExecutor() as pool: parameters = {'user_query': message} response = await loop.run_in_executor( pool, agent.handle, parameters, chat_history ) return response # Create tasks for all agents tasks = {agent_name: call_agent(agent_name) for agent_name in agents_needed} # Execute in parallel results = await asyncio.gather(*tasks.values(), return_exceptions=True) # Map results back to agent names responses = {} for agent_name, result in zip(tasks.keys(), results): if isinstance(result, Exception): print(f"Agent {agent_name} failed: {result}") responses[agent_name] = f"Xin lỗi, {agent_name} gặp lỗi." elif result: responses[agent_name] = result return responses def _perform_handoff( self, from_agent: Any, to_agent_name: str, current_response: str, message: str, chat_history: List ) -> str: """ Perform smooth handoff between agents Args: from_agent: Current agent to_agent_name: Name of agent to hand off to current_response: Current agent's response message: User's message chat_history: Conversation history Returns: str: Combined response with handoff """ # Create handoff message handoff_msg = from_agent.create_handoff_message(to_agent_name, current_response) # Update memory self.memory.set_current_agent(to_agent_name) return handoff_msg def _needs_multi_agent(self, message: str) -> bool: """ Determine if query needs multiple agents Args: message: User's message Returns: bool: True if multiple agents needed """ agents_needed = self._detect_required_agents(message) return len(agents_needed) > 1 def _detect_required_agents(self, message: str) -> List[str]: """ Detect which agents are needed for this query Args: message: User's message Returns: List[str]: List of agent names needed """ agents_needed = [] message_lower = message.lower() # PRIORITY 1: Symptom keywords (highest priority - health emergencies) symptom_keywords = ['đau', 'sốt', 'ho', 'buồn nôn', 'chóng mặt', 'triệu chứng', 'khó tiêu', 'đầy bụng', 'ợ hơi'] has_symptoms = any(kw in message_lower for kw in symptom_keywords) # PRIORITY 2: Nutrition keywords (but NOT if it's a symptom context) nutrition_keywords = ['thực đơn', 'calo', 'giảm cân', 'tăng cân', 'dinh dưỡng', 'rau củ', 'thực phẩm'] # Special handling: 'ăn' only counts as nutrition if NOT in symptom context has_nutrition = any(kw in message_lower for kw in nutrition_keywords) if not has_symptoms and 'ăn' in message_lower: has_nutrition = True # PRIORITY 3: Exercise keywords exercise_keywords = ['tập', 'gym', 'cardio', 'yoga', 'chạy bộ', 'exercise', 'workout'] has_exercise = any(kw in message_lower for kw in exercise_keywords) # PRIORITY 4: Mental health keywords mental_keywords = ['stress', 'lo âu', 'trầm cảm', 'mất ngủ', 'burnout', 'mental'] has_mental = any(kw in message_lower for kw in mental_keywords) # IMPORTANT: Only trigger multi-agent if CLEARLY needs multiple domains # Example: "Tôi bị đau bụng, nên ăn gì?" -> symptom + nutrition # But: "WHO khuyến nghị ăn bao nhiêu rau củ?" -> ONLY nutrition # Count how many domains are triggered domain_count = sum([has_symptoms, has_nutrition, has_exercise, has_mental]) # If only 1 domain -> single agent (no multi-agent) if domain_count <= 1: if has_symptoms: agents_needed.append('symptom_agent') elif has_nutrition: agents_needed.append('nutrition_agent') elif has_exercise: agents_needed.append('exercise_agent') elif has_mental: agents_needed.append('mental_health_agent') else: # Multiple domains detected # Check if it's a REAL multi-domain question or false positive # False positive patterns (should be single agent) false_positives = [ 'who khuyến nghị', # WHO recommendations -> single domain 'bao nhiêu', # Quantitative questions -> single domain 'khó tiêu', # Digestive issues -> symptom only 'đầy bụng', # Bloating -> symptom only 'đau bụng', # Stomach pain -> symptom only 'ợ hơi', # Burping -> symptom only ] is_false_positive = any(pattern in message_lower for pattern in false_positives) if is_false_positive: # Use primary domain only if has_nutrition: agents_needed.append('nutrition_agent') elif has_exercise: agents_needed.append('exercise_agent') elif has_symptoms: agents_needed.append('symptom_agent') elif has_mental: agents_needed.append('mental_health_agent') else: # Real multi-domain question if has_symptoms: agents_needed.append('symptom_agent') if has_nutrition: agents_needed.append('nutrition_agent') if has_exercise: agents_needed.append('exercise_agent') if has_mental: agents_needed.append('mental_health_agent') return agents_needed def _combine_responses(self, responses: Dict[str, str], agents_order: List[str]) -> str: """ Combine responses from multiple agents Args: responses: Dict of agent_name -> response agents_order: Order of agents Returns: str: Combined response """ # For natural flow, just combine responses without headers # Make it feel like ONE person giving comprehensive advice responses_list = [responses[agent] for agent in agents_order if agent in responses] if len(responses_list) == 1: # Single agent - return as is return responses_list[0] # Multiple agents - combine naturally combined = "" # First response (usually symptom assessment) combined += responses_list[0] # Add other responses with smooth transitions for i in range(1, len(responses_list)): # Natural transition phrases transitions = [ "\n\nNgoài ra, ", "\n\nBên cạnh đó, ", "\n\nĐồng thời, ", "\n\nVề mặt khác, " ] transition = transitions[min(i-1, len(transitions)-1)] combined += transition + responses_list[i] # Natural closing (not too formal) combined += "\n\nBạn thử làm theo xem có đỡ không nhé. Có gì thắc mắc cứ hỏi mình!" return combined def _update_memory_from_history(self, chat_history: List) -> None: """Extract and update SHARED memory from chat history to prevent duplicate questions""" if not chat_history: return # Extract user info from ALL conversations (not just current agent) user_info = self._extract_user_info_from_all_history(chat_history) # Update SHARED memory that ALL agents can access if user_info: for key, value in user_info.items(): self.memory.update_profile(key, value) def _extract_user_info_from_all_history(self, chat_history: List) -> Dict: """Extract user information from entire conversation history""" user_info = {} # Common patterns to extract patterns = { 'age': [r'(\d+)\s*tuổi', r'tôi\s*(\d+)', r'(\d+)\s*years?\s*old'], 'gender': [r'tôi là (nam|nữ)', r'giới tính[:\s]*(nam|nữ)', r'(male|female|nam|nữ)'], 'weight': [r'(\d+)\s*kg', r'nặng\s*(\d+)', r'cân nặng[:\s]*(\d+)'], 'height': [r'(\d+)\s*cm', r'cao\s*(\d+)', r'chiều cao[:\s]*(\d+)'], 'goal': [r'muốn\s*(giảm cân|tăng cân|tăng cơ|khỏe mạnh)', r'mục tiêu[:\s]*(.+)'] } # Search through all user messages import re for user_msg, _ in chat_history: if not user_msg: continue for field, field_patterns in patterns.items(): if field not in user_info: # Only extract if not already found for pattern in field_patterns: match = re.search(pattern, user_msg.lower()) if match: user_info[field] = match.group(1) break return user_info # Extract gender if not self.memory.get_profile('gender'): if re.search(r'\bnam\b|male', all_messages.lower()): self.memory.update_profile('gender', 'male') elif re.search(r'\bnữ\b|female', all_messages.lower()): self.memory.update_profile('gender', 'female') # Extract weight if not self.memory.get_profile('weight'): weight_match = re.search(r'(\d+)\s*kg|nặng\s*(\d+)', all_messages.lower()) if weight_match: weight = float([g for g in weight_match.groups() if g][0]) self.memory.update_profile('weight', weight) # Extract height if not self.memory.get_profile('height'): height_match = re.search(r'(\d+)\s*cm|cao\s*(\d+)', all_messages.lower()) if height_match: height = float([g for g in height_match.groups() if g][0]) self.memory.update_profile('height', height) def _summarize_if_needed(self, chat_history: List) -> List: """ Summarize conversation if it's too long Args: chat_history: Full conversation history Returns: Compressed history with summary """ user_profile = self.memory.get_full_profile() compressed = self.summarizer.compress_history( chat_history, target_turns=10 # Keep last 10 turns + summary ) # print(f"📝 Summarized {len(chat_history)} turns → {len(compressed)} turns") return compressed def get_conversation_stats(self, chat_history: List) -> Dict[str, Any]: """Get statistics about current conversation""" return self.summarizer.get_summary_stats(chat_history) def get_memory_summary(self) -> str: """Get summary of current memory state""" return self.memory.get_context_summary() def _clean_user_data_for_training(self, user_data: Dict[str, Any]) -> Dict[str, Any]: """ Clean user data before logging for training Ensures only valid, corrected data is used for fine-tuning This prevents the model from learning bad patterns like: - "cao 200m" (should be 200cm) - "nặng 75g" (should be 75kg) - Invalid BMI values """ cleaned = user_data.copy() # Validate and clean height (should be 50-300 cm) if 'height' in cleaned and cleaned['height'] is not None: height = float(cleaned['height']) if not (50 <= height <= 300): # Invalid height - don't log it cleaned['height'] = None # Validate and clean weight (should be 20-300 kg) if 'weight' in cleaned and cleaned['weight'] is not None: weight = float(cleaned['weight']) if not (20 <= weight <= 300): # Invalid weight - don't log it cleaned['weight'] = None # Validate and clean age (should be 1-120) if 'age' in cleaned and cleaned['age'] is not None: age = int(cleaned['age']) if not (1 <= age <= 120): # Invalid age - don't log it cleaned['age'] = None # Validate and clean body fat (should be 3-60%) if 'body_fat_percentage' in cleaned and cleaned['body_fat_percentage'] is not None: bf = float(cleaned['body_fat_percentage']) if not (3 <= bf <= 60): # Invalid body fat - don't log it cleaned['body_fat_percentage'] = None # Remove any None values to keep training data clean cleaned = {k: v for k, v in cleaned.items() if v is not None} return cleaned def clear_memory(self) -> None: """Clear all memory (start fresh)""" self.memory.clear() def __repr__(self) -> str: return f""