Spaces:
Runtime error
Runtime error
| """ | |
| Base Agent - Parent class for all specialized agents | |
| Provides shared functionality: memory access, handoff logic, coordination | |
| """ | |
| from typing import Dict, Any, Optional, List | |
| from utils.memory import ConversationMemory | |
| class BaseAgent: | |
| """ | |
| Base class for all agents | |
| Provides common functionality and interface | |
| """ | |
| def __init__(self, memory: Optional[ConversationMemory] = None): | |
| """ | |
| Initialize base agent | |
| Args: | |
| memory: Shared conversation memory (optional) | |
| """ | |
| self.memory = memory or ConversationMemory() | |
| self.agent_name = self.__class__.__name__.replace('Agent', '').lower() | |
| self.system_prompt = "" | |
| # Handoff configuration | |
| self.can_handoff = True | |
| self.handoff_triggers = [] | |
| # ===== Core Interface ===== | |
| def handle(self, parameters: Dict[str, Any], chat_history: Optional[List] = None) -> str: | |
| """ | |
| Handle user request (must be implemented by subclasses) | |
| Args: | |
| parameters: Request parameters from router | |
| chat_history: Conversation history | |
| Returns: | |
| str: Response message | |
| """ | |
| raise NotImplementedError("Subclasses must implement handle()") | |
| # ===== Memory Access Helpers ===== | |
| def get_user_profile(self) -> Dict[str, Any]: | |
| """Get complete user profile from memory""" | |
| return self.memory.get_full_profile() | |
| # ===== Smart RAG Helper ===== | |
| def should_use_rag(self, user_query: str, chat_history: Optional[List] = None) -> bool: | |
| """ | |
| Smart RAG Decision - Skip RAG for simple queries to improve performance | |
| Performance Impact: | |
| - Simple queries: 2-3s (was 8-10s) - 3x faster | |
| - Complex queries: 6-8s (was 8-10s) - 1.3x faster | |
| Args: | |
| user_query: User's message | |
| chat_history: Conversation history | |
| Returns: | |
| bool: True if RAG needed, False for simple conversational queries | |
| """ | |
| query_lower = user_query.lower().strip() | |
| # 1. Greetings & acknowledgments (no RAG needed) | |
| greetings = [ | |
| 'xin chào', 'hello', 'hi', 'chào', 'hey', | |
| 'cảm ơn', 'thanks', 'thank you', 'tks', | |
| 'ok', 'được', 'vâng', 'ừ', 'uhm', 'uh huh', | |
| 'bye', 'tạm biệt', 'hẹn gặp lại' | |
| ] | |
| if any(g in query_lower for g in greetings): | |
| return False | |
| # 2. Very short responses (usually conversational) | |
| if len(query_lower) < 10: | |
| short_responses = ['có', 'không', 'rồi', 'ạ', 'dạ', 'yes', 'no', 'nope', 'yep'] | |
| if any(r == query_lower or query_lower.startswith(r + ' ') for r in short_responses): | |
| return False | |
| # 3. Meta questions about the bot (no RAG needed) | |
| meta_questions = [ | |
| 'bạn là ai', 'bạn tên gì', 'bạn có thể', 'bạn làm gì', | |
| 'who are you', 'what can you', 'what do you' | |
| ] | |
| if any(m in query_lower for m in meta_questions): | |
| return False | |
| # 4. Complex medical/health questions (NEED RAG) | |
| complex_patterns = [ | |
| # Medical terms | |
| 'nguyên nhân', 'tại sao', 'why', 'how', 'làm sao', | |
| 'cách nào', 'phương pháp', 'điều trị', 'chữa', | |
| 'thuốc', 'medicine', 'phòng ngừa', 'prevention', | |
| 'biến chứng', 'complication', 'nghiên cứu', 'research', | |
| # Specific diseases | |
| 'bệnh', 'disease', 'viêm', 'ung thư', 'cancer', | |
| 'tiểu đường', 'diabetes', 'huyết áp', 'blood pressure', | |
| # Detailed questions | |
| 'chi tiết', 'cụ thể', 'specific', 'detail', | |
| 'khoa học', 'scientific', 'evidence', 'hướng dẫn', | |
| 'guideline', 'recommendation', 'chuyên gia', 'expert' | |
| ] | |
| if any(p in query_lower for p in complex_patterns): | |
| return True | |
| # 5. Default: Simple first-turn questions don't need RAG | |
| # Agent can ask clarifying questions first | |
| if not chat_history or len(chat_history) == 0: | |
| # Simple initial statements | |
| simple_starts = [ | |
| 'tôi muốn', 'tôi cần', 'giúp tôi', 'tôi bị', | |
| 'i want', 'i need', 'help me', 'i have', 'i feel' | |
| ] | |
| if any(s in query_lower for s in simple_starts): | |
| # Let agent gather info first, use RAG later | |
| return False | |
| # 6. Default: Use RAG for safety (medical context) | |
| return True | |
| def update_user_profile(self, key: str, value: Any) -> None: | |
| """Update user profile in shared memory""" | |
| self.memory.update_profile(key, value) | |
| def get_missing_profile_fields(self, required_fields: List[str]) -> List[str]: | |
| """Check what profile fields are missing""" | |
| return self.memory.get_missing_fields(required_fields) | |
| def save_agent_data(self, key: str, value: Any) -> None: | |
| """Save agent-specific data to memory""" | |
| self.memory.add_agent_data(self.agent_name, key, value) | |
| def get_agent_data(self, key: str = None) -> Any: | |
| """Get agent-specific data from memory""" | |
| return self.memory.get_agent_data(self.agent_name, key) | |
| def get_other_agent_data(self, agent_name: str, key: str = None) -> Any: | |
| """Get data from another agent""" | |
| return self.memory.get_agent_data(agent_name, key) | |
| # ===== Context Awareness ===== | |
| def get_context_summary(self) -> str: | |
| """Get summary of current conversation context""" | |
| return self.memory.get_context_summary() | |
| def get_previous_agent(self) -> Optional[str]: | |
| """Get name of previous agent""" | |
| return self.memory.get_previous_agent() | |
| def get_current_topic(self) -> Optional[str]: | |
| """Get current conversation topic""" | |
| return self.memory.get_current_topic() | |
| def set_current_topic(self, topic: str) -> None: | |
| """Set current conversation topic""" | |
| self.memory.set_current_topic(topic) | |
| def generate_natural_opening(self, user_query: str, chat_history: Optional[List] = None) -> str: | |
| """ | |
| Generate natural conversation opening based on context | |
| Avoids robotic prefixes like "Thông tin đã tư vấn:" | |
| Args: | |
| user_query: Current user query | |
| chat_history: Conversation history | |
| Returns: | |
| str: Natural opening phrase (empty if not needed) | |
| """ | |
| # Check if this is a topic transition | |
| previous_agent = self.get_previous_agent() | |
| is_new_topic = previous_agent and previous_agent != self.agent_name | |
| # If continuing same topic, no special opening needed | |
| if not is_new_topic: | |
| return "" | |
| # Generate natural transition based on agent type | |
| query_lower = user_query.lower() | |
| # Enthusiastic transitions for new requests | |
| if any(word in query_lower for word in ['muốn', 'cần', 'giúp', 'tư vấn']): | |
| openings = [ | |
| "Ah, bây giờ bạn đang cần", | |
| "Được rồi, để mình", | |
| "Tuyệt! Mình sẽ", | |
| "Ok, cùng", | |
| ] | |
| import random | |
| return random.choice(openings) + " " | |
| # Default: no prefix, just natural response | |
| return "" | |
| # ===== Handoff Logic ===== | |
| def should_handoff(self, user_query: str, chat_history: Optional[List] = None) -> bool: | |
| """ | |
| Determine if this agent should hand off to another agent | |
| Args: | |
| user_query: User's current query | |
| chat_history: Conversation history | |
| Returns: | |
| bool: True if handoff is needed | |
| """ | |
| if not self.can_handoff: | |
| return False | |
| # Check for handoff trigger keywords | |
| query_lower = user_query.lower() | |
| for trigger in self.handoff_triggers: | |
| if trigger in query_lower: | |
| return True | |
| return False | |
| def suggest_next_agent(self, user_query: str) -> Optional[str]: | |
| """ | |
| Suggest which agent to hand off to | |
| Args: | |
| user_query: User's current query | |
| Returns: | |
| str: Name of suggested agent, or None | |
| """ | |
| query_lower = user_query.lower() | |
| # Symptom keywords | |
| symptom_keywords = ['đau', 'sốt', 'ho', 'buồn nôn', 'chóng mặt', 'mệt'] | |
| if any(kw in query_lower for kw in symptom_keywords): | |
| return 'symptom_agent' | |
| # Nutrition keywords | |
| nutrition_keywords = ['ăn', 'thực đơn', 'calo', 'giảm cân', 'tăng cân'] | |
| if any(kw in query_lower for kw in nutrition_keywords): | |
| return 'nutrition_agent' | |
| # Exercise keywords | |
| exercise_keywords = ['tập', 'gym', 'cardio', 'yoga', 'chạy bộ'] | |
| if any(kw in query_lower for kw in exercise_keywords): | |
| return 'exercise_agent' | |
| # Mental health keywords | |
| mental_keywords = ['stress', 'lo âu', 'trầm cảm', 'mất ngủ', 'burnout'] | |
| if any(kw in query_lower for kw in mental_keywords): | |
| return 'mental_health_agent' | |
| return None | |
| def create_handoff_message(self, next_agent: str, context: str = "", user_query: str = "") -> str: | |
| """ | |
| Create a SEAMLESS topic transition (not explicit handoff) | |
| Args: | |
| next_agent: Name of agent to hand off to | |
| context: Additional context for handoff | |
| user_query: User's query to understand intent | |
| Returns: | |
| str: Natural transition message (NOT "chuyển sang chuyên gia") | |
| """ | |
| # Map agents to topic areas | |
| topic_map = { | |
| 'symptom_agent': { | |
| 'topic': 'triệu chứng', | |
| 'action': 'đánh giá', | |
| 'info_needed': ['triệu chứng cụ thể', 'thời gian xuất hiện'] | |
| }, | |
| 'nutrition_agent': { | |
| 'topic': 'dinh dưỡng', | |
| 'action': 'tư vấn chế độ ăn', | |
| 'info_needed': ['mục tiêu', 'cân nặng', 'chiều cao', 'tuổi'] | |
| }, | |
| 'exercise_agent': { | |
| 'topic': 'tập luyện', | |
| 'action': 'lên lịch tập', | |
| 'info_needed': ['mục tiêu', 'thời gian có thể tập', 'thiết bị'] | |
| }, | |
| 'mental_health_agent': { | |
| 'topic': 'sức khỏe tinh thần', | |
| 'action': 'hỗ trợ', | |
| 'info_needed': ['cảm giác hiện tại', 'thời gian kéo dài'] | |
| } | |
| } | |
| topic_info = topic_map.get(next_agent, { | |
| 'topic': 'vấn đề này', | |
| 'action': 'tư vấn', | |
| 'info_needed': [] | |
| }) | |
| # SEAMLESS transition - acknowledge topic change naturally | |
| message = f"{context}\n\n" if context else "" | |
| # Natural acknowledgment based on query | |
| if 'tập' in user_query.lower() or 'gym' in user_query.lower(): | |
| message += f"Ah, bây giờ bạn đang cần về {topic_info['topic']}! " | |
| elif 'ăn' in user_query.lower() or 'thực đơn' in user_query.lower(): | |
| message += f"Okii, giờ chuyển sang {topic_info['topic']} nhé! " | |
| else: | |
| message += f"Được, mình giúp bạn về {topic_info['topic']}! " | |
| # Ask for info if needed (natural, not formal) | |
| if topic_info['info_needed']: | |
| info_list = ', '.join(topic_info['info_needed'][:2]) # Max 2 items | |
| message += f"Để {topic_info['action']} phù hợp, cho mình biết thêm về {info_list} nhé!" | |
| return message | |
| # ===== Multi-Agent Coordination ===== | |
| def needs_collaboration(self, user_query: str) -> List[str]: | |
| """ | |
| Determine if multiple agents are needed | |
| Args: | |
| user_query: User's query | |
| Returns: | |
| List[str]: List of agent names needed | |
| """ | |
| agents_needed = [] | |
| query_lower = user_query.lower() | |
| # Check for each agent's keywords | |
| if any(kw in query_lower for kw in ['đau', 'sốt', 'ho', 'triệu chứng']): | |
| agents_needed.append('symptom_agent') | |
| if any(kw in query_lower for kw in ['ăn', 'thực đơn', 'calo', 'dinh dưỡng']): | |
| agents_needed.append('nutrition_agent') | |
| if any(kw in query_lower for kw in ['tập', 'gym', 'cardio', 'exercise']): | |
| agents_needed.append('exercise_agent') | |
| if any(kw in query_lower for kw in ['stress', 'lo âu', 'trầm cảm', 'mental']): | |
| agents_needed.append('mental_health_agent') | |
| return agents_needed | |
| # ===== Utility Methods ===== | |
| def extract_user_data_from_history(self, chat_history: List) -> Dict[str, Any]: | |
| """ | |
| Extract user data from conversation history | |
| (Can be overridden by subclasses for specific extraction) | |
| Args: | |
| chat_history: List of [user_msg, bot_msg] pairs | |
| Returns: | |
| Dict: Extracted user data | |
| """ | |
| import re | |
| if not chat_history: | |
| return {} | |
| all_messages = " ".join([msg[0] for msg in chat_history if msg[0]]) | |
| extracted = {} | |
| # Extract age | |
| age_match = re.search(r'(\d+)\s*tuổi|tuổi\s*(\d+)|tôi\s*(\d+)', all_messages.lower()) | |
| if age_match: | |
| extracted['age'] = int([g for g in age_match.groups() if g][0]) | |
| # Extract gender | |
| if re.search(r'\bnam\b|male|đàn ông', all_messages.lower()): | |
| extracted['gender'] = 'male' | |
| elif re.search(r'\bnữ\b|female|đàn bà|phụ nữ', all_messages.lower()): | |
| extracted['gender'] = 'female' | |
| # Extract weight | |
| weight_match = re.search(r'(\d+)\s*kg|nặng\s*(\d+)|cân\s*(\d+)', all_messages.lower()) | |
| if weight_match: | |
| extracted['weight'] = float([g for g in weight_match.groups() if g][0]) | |
| # Extract height | |
| height_match = re.search(r'(\d+)\s*cm|cao\s*(\d+)|chiều cao\s*(\d+)', all_messages.lower()) | |
| if height_match: | |
| extracted['height'] = float([g for g in height_match.groups() if g][0]) | |
| return extracted | |
| def update_memory_from_history(self, chat_history: List) -> None: | |
| """Extract and update memory from chat history""" | |
| extracted = self.extract_user_data_from_history(chat_history) | |
| for key, value in extracted.items(): | |
| # Always update with latest info (user may correct themselves) | |
| self.memory.update_profile(key, value) | |
| def extract_and_save_user_info(self, user_message: str) -> Dict[str, Any]: | |
| """ | |
| Extract user info from a single message using LLM (flexible, handles typos) | |
| Saves to memory immediately | |
| Args: | |
| user_message: Single user message (any format, any order) | |
| Returns: | |
| Dict: Extracted data | |
| """ | |
| from config.settings import client, MODEL | |
| import json | |
| # Use LLM to extract structured data (handles typos, any order, extra info) | |
| extraction_prompt = f"""Extract health information from this user message. Handle typos and variations. | |
| User message: "{user_message}" | |
| Extract these fields if present (return null if not found): | |
| - age: integer (tuổi, age, years old) | |
| - gender: "male" or "female" (nam, nữ, male, female, đàn ông, phụ nữ) | |
| - weight: float in kg (nặng, cân, weight, kg) | |
| - height: float in cm (cao, chiều cao, height, cm, m) | |
| IMPORTANT: Height MUST be in cm (50-300 range) | |
| - If user says "1.75m" or "1.78m" → convert to cm (175, 178) | |
| - If user says "175cm" or "178cm" → use as is (175, 178) | |
| - NEVER return values like 1.0, 1.5, 1.75 for height! | |
| - body_fat_percentage: float (tỉ lệ mỡ, body fat, %, optional) | |
| Return ONLY valid JSON with these exact keys. Example: | |
| {{"age": 30, "gender": "male", "weight": 70.5, "height": 175, "body_fat_percentage": 25}} | |
| CRITICAL: Height must be 50-300 (in cm). If user says "1.78m", return 178, not 1.78! | |
| If a field is not found, use null. Be flexible with typos and word order.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": "You are a data extraction assistant. Extract structured health data from user messages. Handle typos and variations. Return only valid JSON."}, | |
| {"role": "user", "content": extraction_prompt} | |
| ], | |
| temperature=0.1, # Low temp for consistent extraction | |
| max_tokens=150 | |
| ) | |
| result_text = response.choices[0].message.content.strip() | |
| # Parse JSON response | |
| # Remove markdown code blocks if present | |
| if "```json" in result_text: | |
| result_text = result_text.split("```json")[1].split("```")[0].strip() | |
| elif "```" in result_text: | |
| result_text = result_text.split("```")[1].split("```")[0].strip() | |
| extracted = json.loads(result_text) | |
| # Auto-correct obvious errors before saving | |
| extracted = self._auto_correct_health_data(extracted) | |
| # Save to memory (only non-null values) | |
| allowed_fields = ['age', 'gender', 'weight', 'height', 'body_fat_percentage'] | |
| for key, value in extracted.items(): | |
| if value is not None and key in allowed_fields: | |
| self.update_user_profile(key, value) | |
| return {k: v for k, v in extracted.items() if v is not None} | |
| except Exception as e: | |
| # Fallback to regex if LLM fails | |
| print(f"LLM extraction failed: {e}, using regex fallback") | |
| return self._extract_with_regex_fallback(user_message) | |
| def _auto_correct_health_data(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Auto-correct obvious errors in health data (typos, wrong units) | |
| Examples: | |
| - height: 200 → 200cm ✅ (likely meant 200cm, not 200m) | |
| - height: 1.75 → 175cm ✅ (convert m to cm) | |
| - weight: 75 → 75kg ✅ (assume kg if reasonable) | |
| - weight: 75000 → 75kg ✅ (likely meant 75kg, not 75000g) | |
| """ | |
| corrected = data.copy() | |
| # Auto-correct height | |
| if 'height' in corrected and corrected['height'] is not None: | |
| height = float(corrected['height']) | |
| # If height is very small (< 10), likely in meters → convert to cm | |
| if 0 < height < 10: | |
| corrected['height'] = height * 100 | |
| print(f"Auto-corrected height: {height}m → {corrected['height']}cm") | |
| # If height is reasonable (50-300), assume cm | |
| elif 50 <= height <= 300: | |
| corrected['height'] = height | |
| # If height is very large (> 1000), likely in mm → convert to cm | |
| elif height > 1000: | |
| corrected['height'] = height / 10 | |
| print(f"Auto-corrected height: {height}mm → {corrected['height']}cm") | |
| # Otherwise invalid, set to None | |
| else: | |
| print(f"Invalid height: {height}, setting to None") | |
| corrected['height'] = None | |
| # Auto-correct weight | |
| if 'weight' in corrected and corrected['weight'] is not None: | |
| weight = float(corrected['weight']) | |
| # If weight is very large (> 500), likely in grams → convert to kg | |
| if weight > 500: | |
| corrected['weight'] = weight / 1000 | |
| print(f"Auto-corrected weight: {weight}g → {corrected['weight']}kg") | |
| # If weight is reasonable (20-300), assume kg | |
| elif 20 <= weight <= 300: | |
| corrected['weight'] = weight | |
| # If weight is very small (< 20), might be wrong unit | |
| elif 0 < weight < 20: | |
| # Could be in different unit or child weight | |
| # Keep as is but flag | |
| corrected['weight'] = weight | |
| # Otherwise invalid | |
| else: | |
| print(f"Invalid weight: {weight}, setting to None") | |
| corrected['weight'] = None | |
| # Auto-correct age | |
| if 'age' in corrected and corrected['age'] is not None: | |
| age = int(corrected['age']) | |
| # Reasonable age range: 1-120 | |
| if not (1 <= age <= 120): | |
| print(f"Invalid age: {age}, setting to None") | |
| corrected['age'] = None | |
| # Auto-correct body fat percentage | |
| if 'body_fat_percentage' in corrected and corrected['body_fat_percentage'] is not None: | |
| bf = float(corrected['body_fat_percentage']) | |
| # Reasonable body fat: 3-60% | |
| if not (3 <= bf <= 60): | |
| print(f"Invalid body fat: {bf}%, setting to None") | |
| corrected['body_fat_percentage'] = None | |
| return corrected | |
| def _extract_with_regex_fallback(self, user_message: str) -> Dict[str, Any]: | |
| """Fallback regex extraction (less flexible but reliable)""" | |
| import re | |
| extracted = {} | |
| msg_lower = user_message.lower() | |
| # Extract age | |
| age_match = re.search(r'(\d+)\s*tuổi|tuổi\s*(\d+)|age\s*(\d+)', msg_lower) | |
| if age_match: | |
| age = int([g for g in age_match.groups() if g][0]) | |
| extracted['age'] = age | |
| self.update_user_profile('age', age) | |
| # Extract gender | |
| if re.search(r'\bnam\b|male|đàn ông', msg_lower): | |
| extracted['gender'] = 'male' | |
| self.update_user_profile('gender', 'male') | |
| elif re.search(r'\bnữ\b|female|đàn bà|phụ nữ', msg_lower): | |
| extracted['gender'] = 'female' | |
| self.update_user_profile('gender', 'female') | |
| # Extract weight | |
| weight_match = re.search(r'(?:nặng|cân|weight)?\s*(\d+(?:\.\d+)?)\s*kg', msg_lower) | |
| if weight_match: | |
| weight = float(weight_match.group(1)) | |
| extracted['weight'] = weight | |
| self.update_user_profile('weight', weight) | |
| # Extract height | |
| height_cm_match = re.search(r'(?:cao|chiều cao|height)?\s*(\d+(?:\.\d+)?)\s*cm', msg_lower) | |
| if height_cm_match: | |
| height = float(height_cm_match.group(1)) | |
| extracted['height'] = height | |
| self.update_user_profile('height', height) | |
| else: | |
| height_m_match = re.search(r'(?:cao|chiều cao|height)?\s*(\d+\.?\d*)\s*m\b', msg_lower) | |
| if height_m_match: | |
| height = float(height_m_match.group(1)) | |
| if height < 3: | |
| height = height * 100 | |
| extracted['height'] = height | |
| self.update_user_profile('height', height) | |
| return extracted | |
| def __repr__(self) -> str: | |
| return f"<{self.__class__.__name__}: {self.get_context_summary()}>" | |