Spaces:
Runtime error
Runtime error
| """ | |
| Agent Coordinator - Manages agent collaboration and handoffs | |
| Enables multi-agent responses and smooth transitions | |
| """ | |
| from typing import Dict, List, Optional, Any | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| from utils.memory import ConversationMemory | |
| from utils.session_store import get_session_store | |
| from utils.conversation_summarizer import get_summarizer | |
| from agents.core.router import route_to_agent, get_router | |
| from fine_tuning import get_data_collector | |
| from health_data import HealthContext, HealthDataStore | |
| import hashlib | |
| import json | |
| class AgentCoordinator: | |
| """ | |
| Coordinates multiple agents and manages handoffs | |
| Provides multi-agent collaboration capabilities | |
| """ | |
| def __init__(self, user_id: Optional[str] = None, use_embedding_router=True, enable_cache=True, enable_data_collection=True, enable_session_persistence=True): | |
| """ | |
| Initialize coordinator with shared memory and data store | |
| Args: | |
| user_id: Unique user identifier for session persistence | |
| use_embedding_router: Use embedding-based routing (faster) | |
| enable_cache: Enable response caching | |
| enable_data_collection: Enable conversation logging for fine-tuning | |
| enable_session_persistence: Enable session persistence across restarts | |
| """ | |
| # Session persistence | |
| self.user_id = user_id | |
| self.session_store = get_session_store() if enable_session_persistence else None | |
| # Initialize memory with session persistence | |
| self.memory = ConversationMemory( | |
| user_id=user_id, | |
| session_store=self.session_store | |
| ) | |
| self.data_store = HealthDataStore() | |
| self.health_context = None | |
| self.agents = {} | |
| # Enable embedding router (faster than LLM routing) | |
| self.use_embedding_router = use_embedding_router | |
| if use_embedding_router: | |
| self.router = get_router(use_embeddings=True) | |
| else: | |
| self.router = None | |
| # Enable response cache | |
| self.enable_cache = enable_cache | |
| self.response_cache = {} if enable_cache else None | |
| # Enable data collection for fine-tuning | |
| self.enable_data_collection = enable_data_collection | |
| if enable_data_collection: | |
| self.data_collector = get_data_collector() | |
| else: | |
| self.data_collector = None | |
| # Conversation summarizer | |
| self.summarizer = get_summarizer() | |
| self._initialize_agents() | |
| def _initialize_agents(self) -> None: | |
| """Initialize all agents with shared memory""" | |
| # Import agents (lazy import to avoid circular dependencies) | |
| from agents.specialized.nutrition_agent import NutritionAgent | |
| from agents.specialized.exercise_agent import ExerciseAgent | |
| from agents.specialized.symptom_agent import SymptomAgent | |
| from agents.specialized.mental_health_agent import MentalHealthAgent | |
| from agents.specialized.general_health_agent import GeneralHealthAgent | |
| # Create agents with shared memory | |
| self.agents = { | |
| 'nutrition_agent': NutritionAgent(memory=self.memory), | |
| 'exercise_agent': ExerciseAgent(memory=self.memory), | |
| 'symptom_agent': SymptomAgent(memory=self.memory), | |
| 'mental_health_agent': MentalHealthAgent(memory=self.memory), | |
| 'general_health_agent': GeneralHealthAgent(memory=self.memory) | |
| } | |
| def handle_query(self, message: str, chat_history: Optional[List] = None, user_id: Optional[str] = None) -> str: | |
| """ | |
| Main entry point - handles user query with coordination | |
| Args: | |
| message: User's message | |
| chat_history: Conversation history | |
| user_id: User ID for data persistence | |
| Returns: | |
| str: Response (possibly from multiple agents) | |
| """ | |
| chat_history = chat_history or [] | |
| # Create or update health context for user | |
| if user_id: | |
| self.health_context = HealthContext(user_id, self.data_store) | |
| # Inject health context into all agents | |
| for agent in self.agents.values(): | |
| if hasattr(agent, 'set_health_context'): | |
| agent.set_health_context(self.health_context) | |
| # Update memory from chat history | |
| self._update_memory_from_history(chat_history) | |
| # Summarize if conversation is too long | |
| if self.summarizer.should_summarize(chat_history): | |
| chat_history = self._summarize_if_needed(chat_history) | |
| # Check if multi-agent collaboration is needed | |
| if self._needs_multi_agent(message): | |
| return self._handle_multi_agent_query(message, chat_history) | |
| # Single agent routing | |
| return self._handle_single_agent_query(message, chat_history) | |
| def _get_cache_key(self, message: str, chat_history: List) -> str: | |
| """Generate cache key from message and recent history""" | |
| # Include last 2 exchanges for context | |
| recent_history = chat_history[-4:] if len(chat_history) > 4 else chat_history | |
| cache_data = { | |
| "message": message.lower().strip(), | |
| "history": [(h[0].lower().strip() if h[0] else "", h[1][:50] if len(h) > 1 else "") for h in recent_history] | |
| } | |
| cache_str = json.dumps(cache_data, sort_keys=True) | |
| return hashlib.md5(cache_str.encode()).hexdigest() | |
| def _handle_single_agent_query(self, message: str, chat_history: List, file_data: Optional[Dict] = None) -> str: | |
| """Handle query with single agent (with potential handoff)""" | |
| # Check cache first | |
| if self.enable_cache: | |
| cache_key = self._get_cache_key(message, chat_history) | |
| if cache_key in self.response_cache: | |
| # print("[CACHE HIT] Returning cached response") | |
| return self.response_cache[cache_key] | |
| # Route to appropriate agent (use embedding router if available) | |
| if self.router: | |
| routing_result = self.router.route(message, chat_history) | |
| else: | |
| routing_result = route_to_agent(message, chat_history) | |
| agent_name = routing_result['agent'] | |
| parameters = routing_result['parameters'] | |
| # Update current agent in memory | |
| self.memory.set_current_agent(agent_name) | |
| # Get agent | |
| agent = self.agents.get(agent_name) | |
| if not agent: | |
| return "Xin lỗi, không tìm thấy agent phù hợp." | |
| # Let agent handle the request | |
| response = agent.handle(parameters, chat_history) | |
| # Log conversation for fine-tuning (with cleaned data) | |
| if self.enable_data_collection and self.data_collector: | |
| user_data = self.memory.get_full_profile() | |
| # Clean user data before logging to prevent learning from errors | |
| cleaned_user_data = self._clean_user_data_for_training(user_data) | |
| self.data_collector.log_conversation( | |
| agent_name=agent_name, | |
| user_message=message, | |
| agent_response=response, | |
| user_data=cleaned_user_data, | |
| metadata={'data_cleaned': True} # Flag that data was cleaned | |
| ) | |
| # Cache the response | |
| if self.enable_cache: | |
| cache_key = self._get_cache_key(message, chat_history) | |
| self.response_cache[cache_key] = response | |
| # Limit cache size to 100 entries | |
| if len(self.response_cache) > 100: | |
| # Remove oldest entry (simple FIFO) | |
| self.response_cache.pop(next(iter(self.response_cache))) | |
| # Check if handoff is needed | |
| if hasattr(agent, 'should_handoff') and agent.should_handoff(message, chat_history): | |
| next_agent_name = agent.suggest_next_agent(message) | |
| if next_agent_name and next_agent_name in self.agents: | |
| return self._perform_handoff(agent, next_agent_name, response, message, chat_history) | |
| return response | |
| def _handle_multi_agent_query(self, message: str, chat_history: List) -> str: | |
| """Handle query that needs multiple agents (with parallel execution)""" | |
| # Detect which agents are needed | |
| agents_needed = self._detect_required_agents(message) | |
| if len(agents_needed) <= 1: | |
| # Fallback to single agent | |
| return self._handle_single_agent_query(message, chat_history) | |
| # Use async for parallel execution (faster!) | |
| try: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| responses = loop.run_until_complete( | |
| self._handle_multi_agent_async(message, chat_history, agents_needed) | |
| ) | |
| loop.close() | |
| except Exception as e: | |
| print(f"Async multi-agent failed, falling back to sequential: {e}") | |
| # Fallback to sequential if async fails | |
| responses = {} | |
| for agent_name in agents_needed: | |
| agent = self.agents.get(agent_name) | |
| if agent: | |
| parameters = {'user_query': message} | |
| responses[agent_name] = agent.handle(parameters, chat_history) | |
| # Combine responses | |
| return self._combine_responses(responses, agents_needed) | |
| async def _handle_multi_agent_async(self, message: str, chat_history: List, agents_needed: List[str]) -> Dict[str, str]: | |
| """Execute multiple agents in parallel using asyncio""" | |
| async def call_agent(agent_name: str): | |
| """Async wrapper for agent.handle()""" | |
| agent = self.agents.get(agent_name) | |
| if not agent: | |
| return None | |
| # Run in thread pool (since agent.handle is sync) | |
| loop = asyncio.get_event_loop() | |
| with ThreadPoolExecutor() as pool: | |
| parameters = {'user_query': message} | |
| response = await loop.run_in_executor( | |
| pool, | |
| agent.handle, | |
| parameters, | |
| chat_history | |
| ) | |
| return response | |
| # Create tasks for all agents | |
| tasks = {agent_name: call_agent(agent_name) for agent_name in agents_needed} | |
| # Execute in parallel | |
| results = await asyncio.gather(*tasks.values(), return_exceptions=True) | |
| # Map results back to agent names | |
| responses = {} | |
| for agent_name, result in zip(tasks.keys(), results): | |
| if isinstance(result, Exception): | |
| print(f"Agent {agent_name} failed: {result}") | |
| responses[agent_name] = f"Xin lỗi, {agent_name} gặp lỗi." | |
| elif result: | |
| responses[agent_name] = result | |
| return responses | |
| def _perform_handoff( | |
| self, | |
| from_agent: Any, | |
| to_agent_name: str, | |
| current_response: str, | |
| message: str, | |
| chat_history: List | |
| ) -> str: | |
| """ | |
| Perform smooth handoff between agents | |
| Args: | |
| from_agent: Current agent | |
| to_agent_name: Name of agent to hand off to | |
| current_response: Current agent's response | |
| message: User's message | |
| chat_history: Conversation history | |
| Returns: | |
| str: Combined response with handoff | |
| """ | |
| # Create handoff message | |
| handoff_msg = from_agent.create_handoff_message(to_agent_name, current_response) | |
| # Update memory | |
| self.memory.set_current_agent(to_agent_name) | |
| return handoff_msg | |
| def _needs_multi_agent(self, message: str) -> bool: | |
| """ | |
| Determine if query needs multiple agents | |
| Args: | |
| message: User's message | |
| Returns: | |
| bool: True if multiple agents needed | |
| """ | |
| agents_needed = self._detect_required_agents(message) | |
| return len(agents_needed) > 1 | |
| def _detect_required_agents(self, message: str) -> List[str]: | |
| """ | |
| Detect which agents are needed for this query | |
| Args: | |
| message: User's message | |
| Returns: | |
| List[str]: List of agent names needed | |
| """ | |
| agents_needed = [] | |
| message_lower = message.lower() | |
| # PRIORITY 1: Symptom keywords (highest priority - health emergencies) | |
| symptom_keywords = ['đau', 'sốt', 'ho', 'buồn nôn', 'chóng mặt', 'triệu chứng', 'khó tiêu', 'đầy bụng', 'ợ hơi'] | |
| has_symptoms = any(kw in message_lower for kw in symptom_keywords) | |
| # PRIORITY 2: Nutrition keywords (but NOT if it's a symptom context) | |
| nutrition_keywords = ['thực đơn', 'calo', 'giảm cân', 'tăng cân', 'dinh dưỡng', 'rau củ', 'thực phẩm'] | |
| # Special handling: 'ăn' only counts as nutrition if NOT in symptom context | |
| has_nutrition = any(kw in message_lower for kw in nutrition_keywords) | |
| if not has_symptoms and 'ăn' in message_lower: | |
| has_nutrition = True | |
| # PRIORITY 3: Exercise keywords | |
| exercise_keywords = ['tập', 'gym', 'cardio', 'yoga', 'chạy bộ', 'exercise', 'workout'] | |
| has_exercise = any(kw in message_lower for kw in exercise_keywords) | |
| # PRIORITY 4: Mental health keywords | |
| mental_keywords = ['stress', 'lo âu', 'trầm cảm', 'mất ngủ', 'burnout', 'mental'] | |
| has_mental = any(kw in message_lower for kw in mental_keywords) | |
| # IMPORTANT: Only trigger multi-agent if CLEARLY needs multiple domains | |
| # Example: "Tôi bị đau bụng, nên ăn gì?" -> symptom + nutrition | |
| # But: "WHO khuyến nghị ăn bao nhiêu rau củ?" -> ONLY nutrition | |
| # Count how many domains are triggered | |
| domain_count = sum([has_symptoms, has_nutrition, has_exercise, has_mental]) | |
| # If only 1 domain -> single agent (no multi-agent) | |
| if domain_count <= 1: | |
| if has_symptoms: | |
| agents_needed.append('symptom_agent') | |
| elif has_nutrition: | |
| agents_needed.append('nutrition_agent') | |
| elif has_exercise: | |
| agents_needed.append('exercise_agent') | |
| elif has_mental: | |
| agents_needed.append('mental_health_agent') | |
| else: | |
| # Multiple domains detected | |
| # Check if it's a REAL multi-domain question or false positive | |
| # False positive patterns (should be single agent) | |
| false_positives = [ | |
| 'who khuyến nghị', # WHO recommendations -> single domain | |
| 'bao nhiêu', # Quantitative questions -> single domain | |
| 'khó tiêu', # Digestive issues -> symptom only | |
| 'đầy bụng', # Bloating -> symptom only | |
| 'đau bụng', # Stomach pain -> symptom only | |
| 'ợ hơi', # Burping -> symptom only | |
| ] | |
| is_false_positive = any(pattern in message_lower for pattern in false_positives) | |
| if is_false_positive: | |
| # Use primary domain only | |
| if has_nutrition: | |
| agents_needed.append('nutrition_agent') | |
| elif has_exercise: | |
| agents_needed.append('exercise_agent') | |
| elif has_symptoms: | |
| agents_needed.append('symptom_agent') | |
| elif has_mental: | |
| agents_needed.append('mental_health_agent') | |
| else: | |
| # Real multi-domain question | |
| if has_symptoms: | |
| agents_needed.append('symptom_agent') | |
| if has_nutrition: | |
| agents_needed.append('nutrition_agent') | |
| if has_exercise: | |
| agents_needed.append('exercise_agent') | |
| if has_mental: | |
| agents_needed.append('mental_health_agent') | |
| return agents_needed | |
| def _combine_responses(self, responses: Dict[str, str], agents_order: List[str]) -> str: | |
| """ | |
| Combine responses from multiple agents | |
| Args: | |
| responses: Dict of agent_name -> response | |
| agents_order: Order of agents | |
| Returns: | |
| str: Combined response | |
| """ | |
| # For natural flow, just combine responses without headers | |
| # Make it feel like ONE person giving comprehensive advice | |
| responses_list = [responses[agent] for agent in agents_order if agent in responses] | |
| if len(responses_list) == 1: | |
| # Single agent - return as is | |
| return responses_list[0] | |
| # Multiple agents - combine naturally | |
| combined = "" | |
| # First response (usually symptom assessment) | |
| combined += responses_list[0] | |
| # Add other responses with smooth transitions | |
| for i in range(1, len(responses_list)): | |
| # Natural transition phrases | |
| transitions = [ | |
| "\n\nNgoài ra, ", | |
| "\n\nBên cạnh đó, ", | |
| "\n\nĐồng thời, ", | |
| "\n\nVề mặt khác, " | |
| ] | |
| transition = transitions[min(i-1, len(transitions)-1)] | |
| combined += transition + responses_list[i] | |
| # Natural closing (not too formal) | |
| combined += "\n\nBạn thử làm theo xem có đỡ không nhé. Có gì thắc mắc cứ hỏi mình!" | |
| return combined | |
| def _update_memory_from_history(self, chat_history: List) -> None: | |
| """Extract and update SHARED memory from chat history to prevent duplicate questions""" | |
| if not chat_history: | |
| return | |
| # Extract user info from ALL conversations (not just current agent) | |
| user_info = self._extract_user_info_from_all_history(chat_history) | |
| # Update SHARED memory that ALL agents can access | |
| if user_info: | |
| for key, value in user_info.items(): | |
| self.memory.update_profile(key, value) | |
| def _extract_user_info_from_all_history(self, chat_history: List) -> Dict: | |
| """Extract user information from entire conversation history""" | |
| user_info = {} | |
| # Common patterns to extract | |
| patterns = { | |
| 'age': [r'(\d+)\s*tuổi', r'tôi\s*(\d+)', r'(\d+)\s*years?\s*old'], | |
| 'gender': [r'tôi là (nam|nữ)', r'giới tính[:\s]*(nam|nữ)', r'(male|female|nam|nữ)'], | |
| 'weight': [r'(\d+)\s*kg', r'nặng\s*(\d+)', r'cân nặng[:\s]*(\d+)'], | |
| 'height': [r'(\d+)\s*cm', r'cao\s*(\d+)', r'chiều cao[:\s]*(\d+)'], | |
| 'goal': [r'muốn\s*(giảm cân|tăng cân|tăng cơ|khỏe mạnh)', r'mục tiêu[:\s]*(.+)'] | |
| } | |
| # Search through all user messages | |
| import re | |
| for user_msg, _ in chat_history: | |
| if not user_msg: | |
| continue | |
| for field, field_patterns in patterns.items(): | |
| if field not in user_info: # Only extract if not already found | |
| for pattern in field_patterns: | |
| match = re.search(pattern, user_msg.lower()) | |
| if match: | |
| user_info[field] = match.group(1) | |
| break | |
| return user_info | |
| # Extract gender | |
| if not self.memory.get_profile('gender'): | |
| if re.search(r'\bnam\b|male', all_messages.lower()): | |
| self.memory.update_profile('gender', 'male') | |
| elif re.search(r'\bnữ\b|female', all_messages.lower()): | |
| self.memory.update_profile('gender', 'female') | |
| # Extract weight | |
| if not self.memory.get_profile('weight'): | |
| weight_match = re.search(r'(\d+)\s*kg|nặng\s*(\d+)', all_messages.lower()) | |
| if weight_match: | |
| weight = float([g for g in weight_match.groups() if g][0]) | |
| self.memory.update_profile('weight', weight) | |
| # Extract height | |
| if not self.memory.get_profile('height'): | |
| height_match = re.search(r'(\d+)\s*cm|cao\s*(\d+)', all_messages.lower()) | |
| if height_match: | |
| height = float([g for g in height_match.groups() if g][0]) | |
| self.memory.update_profile('height', height) | |
| def _summarize_if_needed(self, chat_history: List) -> List: | |
| """ | |
| Summarize conversation if it's too long | |
| Args: | |
| chat_history: Full conversation history | |
| Returns: | |
| Compressed history with summary | |
| """ | |
| user_profile = self.memory.get_full_profile() | |
| compressed = self.summarizer.compress_history( | |
| chat_history, | |
| target_turns=10 # Keep last 10 turns + summary | |
| ) | |
| # print(f"📝 Summarized {len(chat_history)} turns → {len(compressed)} turns") | |
| return compressed | |
| def get_conversation_stats(self, chat_history: List) -> Dict[str, Any]: | |
| """Get statistics about current conversation""" | |
| return self.summarizer.get_summary_stats(chat_history) | |
| def get_memory_summary(self) -> str: | |
| """Get summary of current memory state""" | |
| return self.memory.get_context_summary() | |
| def _clean_user_data_for_training(self, user_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Clean user data before logging for training | |
| Ensures only valid, corrected data is used for fine-tuning | |
| This prevents the model from learning bad patterns like: | |
| - "cao 200m" (should be 200cm) | |
| - "nặng 75g" (should be 75kg) | |
| - Invalid BMI values | |
| """ | |
| cleaned = user_data.copy() | |
| # Validate and clean height (should be 50-300 cm) | |
| if 'height' in cleaned and cleaned['height'] is not None: | |
| height = float(cleaned['height']) | |
| if not (50 <= height <= 300): | |
| # Invalid height - don't log it | |
| cleaned['height'] = None | |
| # Validate and clean weight (should be 20-300 kg) | |
| if 'weight' in cleaned and cleaned['weight'] is not None: | |
| weight = float(cleaned['weight']) | |
| if not (20 <= weight <= 300): | |
| # Invalid weight - don't log it | |
| cleaned['weight'] = None | |
| # Validate and clean age (should be 1-120) | |
| if 'age' in cleaned and cleaned['age'] is not None: | |
| age = int(cleaned['age']) | |
| if not (1 <= age <= 120): | |
| # Invalid age - don't log it | |
| cleaned['age'] = None | |
| # Validate and clean body fat (should be 3-60%) | |
| if 'body_fat_percentage' in cleaned and cleaned['body_fat_percentage'] is not None: | |
| bf = float(cleaned['body_fat_percentage']) | |
| if not (3 <= bf <= 60): | |
| # Invalid body fat - don't log it | |
| cleaned['body_fat_percentage'] = None | |
| # Remove any None values to keep training data clean | |
| cleaned = {k: v for k, v in cleaned.items() if v is not None} | |
| return cleaned | |
| def clear_memory(self) -> None: | |
| """Clear all memory (start fresh)""" | |
| self.memory.clear() | |
| def __repr__(self) -> str: | |
| return f"<AgentCoordinator: {self.get_memory_summary()}>" | |