Integrate Novita AI as exclusive inference provider - Add Novita AI API integration with DeepSeek-R1-Distill-Qwen-7B model - Remove all local model dependencies - Optimize token allocation for user inputs and context - Add Anaconda environment setup files - Add comprehensive test scripts and documentation
927854c
| # llm_router.py - NOVITA AI API ONLY | |
| import logging | |
| import asyncio | |
| from typing import Dict, Optional | |
| from .models_config import LLM_CONFIG | |
| from .config import get_settings | |
| # Import OpenAI client for Novita AI API | |
| try: | |
| from openai import OpenAI | |
| OPENAI_AVAILABLE = True | |
| except ImportError: | |
| OPENAI_AVAILABLE = False | |
| logger = logging.getLogger(__name__) | |
| logger.error("openai package not available - Novita AI API requires openai package") | |
| logger = logging.getLogger(__name__) | |
| class LLMRouter: | |
| def __init__(self, hf_token=None, use_local_models: bool = False): | |
| """ | |
| Initialize LLM Router with Novita AI API only. | |
| Args: | |
| hf_token: Not used (kept for backward compatibility) | |
| use_local_models: Must be False (local models disabled) | |
| """ | |
| if use_local_models: | |
| raise ValueError("Local models are disabled. Only Novita AI API is supported.") | |
| self.settings = get_settings() | |
| self.novita_client = None | |
| # Validate OpenAI package | |
| if not OPENAI_AVAILABLE: | |
| raise ImportError( | |
| "openai package is required for Novita AI API. " | |
| "Install it with: pip install openai>=1.0.0" | |
| ) | |
| # Validate API key | |
| if not self.settings.novita_api_key: | |
| raise ValueError( | |
| "NOVITA_API_KEY is required. " | |
| "Set it in environment variables or .env file" | |
| ) | |
| # Initialize Novita AI client | |
| try: | |
| self.novita_client = OpenAI( | |
| base_url=self.settings.novita_base_url, | |
| api_key=self.settings.novita_api_key, | |
| ) | |
| logger.info(f"✓ Novita AI API client initialized") | |
| logger.info(f" Base URL: {self.settings.novita_base_url}") | |
| logger.info(f" Model: {self.settings.novita_model}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Novita AI client: {e}") | |
| raise RuntimeError(f"Could not initialize Novita AI API client: {e}") from e | |
| async def route_inference(self, task_type: str, prompt: str, **kwargs): | |
| """ | |
| Route inference to Novita AI API. | |
| Args: | |
| task_type: Type of task (general_reasoning, intent_classification, etc.) | |
| prompt: Input prompt | |
| **kwargs: Additional parameters (max_tokens, temperature, etc.) | |
| Returns: | |
| Generated text response | |
| """ | |
| logger.info(f"Routing inference to Novita AI API for task: {task_type}") | |
| if not self.novita_client: | |
| raise RuntimeError("Novita AI client not initialized") | |
| try: | |
| # Handle embedding generation (may need special handling) | |
| if task_type == "embedding_generation": | |
| logger.warning("Embedding generation via Novita API may require special implementation") | |
| # For now, use chat completion (may need adjustment based on Novita API capabilities) | |
| result = await self._call_novita_api(task_type, prompt, **kwargs) | |
| else: | |
| result = await self._call_novita_api(task_type, prompt, **kwargs) | |
| if result is None: | |
| logger.error(f"Novita AI API returned None for task: {task_type}") | |
| raise RuntimeError(f"Inference failed for task: {task_type}") | |
| logger.info(f"Inference complete for {task_type} (Novita AI API)") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Novita AI API inference failed: {e}", exc_info=True) | |
| raise RuntimeError( | |
| f"Inference failed for task: {task_type}. " | |
| f"Novita AI API error: {e}" | |
| ) from e | |
| async def _call_novita_api(self, task_type: str, prompt: str, **kwargs) -> Optional[str]: | |
| """Call Novita AI API for inference.""" | |
| if not self.novita_client: | |
| return None | |
| # Get model config | |
| model_config = self._select_model(task_type) | |
| model_name = kwargs.get('model', self.settings.novita_model) | |
| # Get optimized parameters | |
| max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 4096)) | |
| temperature = kwargs.get('temperature', | |
| model_config.get('temperature', self.settings.deepseek_r1_temperature)) | |
| top_p = kwargs.get('top_p', model_config.get('top_p', 0.95)) | |
| stream = kwargs.get('stream', False) | |
| # Format prompt according to DeepSeek-R1 best practices | |
| formatted_prompt = self._format_deepseek_r1_prompt(prompt, task_type, model_config) | |
| # IMPORTANT: No system prompt - all instructions in user prompt | |
| messages = [{"role": "user", "content": formatted_prompt}] | |
| # Build request parameters | |
| request_params = { | |
| "model": model_name, | |
| "messages": messages, | |
| "stream": stream, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| } | |
| try: | |
| if stream: | |
| # Handle streaming response | |
| response_text = "" | |
| stream_response = self.novita_client.chat.completions.create(**request_params) | |
| for chunk in stream_response: | |
| if chunk.choices and len(chunk.choices) > 0: | |
| delta = chunk.choices[0].delta | |
| if delta and delta.content: | |
| response_text += delta.content | |
| # Clean up reasoning tags if present | |
| response_text = self._clean_reasoning_tags(response_text) | |
| logger.info(f"Novita AI API generated response (length: {len(response_text)})") | |
| return response_text | |
| else: | |
| # Handle non-streaming response | |
| response = self.novita_client.chat.completions.create(**request_params) | |
| if response.choices and len(response.choices) > 0: | |
| result = response.choices[0].message.content | |
| # Clean up reasoning tags if present | |
| result = self._clean_reasoning_tags(result) | |
| logger.info(f"Novita AI API generated response (length: {len(result)})") | |
| return result | |
| else: | |
| logger.error("Novita AI API returned empty response") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Error calling Novita AI API: {e}", exc_info=True) | |
| raise | |
| def _format_deepseek_r1_prompt(self, prompt: str, task_type: str, model_config: dict) -> str: | |
| """ | |
| Format prompt according to DeepSeek-R1 best practices: | |
| - No system prompt (all instructions in user prompt) | |
| - Force reasoning trigger for reasoning tasks | |
| - Add math directive for mathematical problems | |
| """ | |
| formatted_prompt = prompt | |
| # Check if we should force reasoning prefix | |
| force_reasoning = ( | |
| self.settings.deepseek_r1_force_reasoning and | |
| model_config.get("force_reasoning_prefix", False) | |
| ) | |
| if force_reasoning: | |
| # Force model to start with reasoning trigger | |
| formatted_prompt = f"`<think>`\n\n{formatted_prompt}" | |
| # Add math directive for mathematical problems | |
| if self._is_math_query(prompt): | |
| math_directive = "Please reason step by step, and put your final answer within \\boxed{}." | |
| formatted_prompt = f"{formatted_prompt}\n\n{math_directive}" | |
| return formatted_prompt | |
| def _is_math_query(self, prompt: str) -> bool: | |
| """Detect if query is mathematical""" | |
| math_keywords = [ | |
| "solve", "calculate", "compute", "equation", "formula", | |
| "mathematical", "algebra", "geometry", "calculus", "integral", | |
| "derivative", "theorem", "proof", "problem" | |
| ] | |
| prompt_lower = prompt.lower() | |
| return any(keyword in prompt_lower for keyword in math_keywords) | |
| def _clean_reasoning_tags(self, text: str) -> str: | |
| """Clean up reasoning tags from response""" | |
| text = text.replace("`<think>`", "").replace("`</think>`", "") | |
| text = text.strip() | |
| return text | |
| def _select_model(self, task_type: str) -> dict: | |
| """Select model configuration based on task type""" | |
| model_map = { | |
| "intent_classification": LLM_CONFIG["models"]["classification_specialist"], | |
| "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], | |
| "safety_check": LLM_CONFIG["models"]["safety_checker"], | |
| "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], | |
| "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] | |
| } | |
| return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) | |
| async def get_available_models(self): | |
| """Get list of available models (Novita AI only)""" | |
| return ["Novita AI API - DeepSeek-R1-Distill-Qwen-7B"] | |
| async def health_check(self): | |
| """Perform health check on Novita AI API""" | |
| try: | |
| # Test API with a simple request | |
| test_response = self.novita_client.chat.completions.create( | |
| model=self.settings.novita_model, | |
| messages=[{"role": "user", "content": "test"}], | |
| max_tokens=5 | |
| ) | |
| return { | |
| "provider": "novita_api", | |
| "status": "healthy", | |
| "model": self.settings.novita_model, | |
| "base_url": self.settings.novita_base_url | |
| } | |
| except Exception as e: | |
| logger.error(f"Health check failed: {e}") | |
| return { | |
| "provider": "novita_api", | |
| "status": "unhealthy", | |
| "error": str(e) | |
| } | |
| def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None, | |
| user_input: Optional[str] = None) -> str: | |
| """ | |
| Smart context windowing with user input priority. | |
| User input is NEVER truncated - context is reduced to fit. | |
| Args: | |
| raw_context: Context dictionary | |
| max_tokens: Optional override (uses config default if None) | |
| user_input: Optional explicit user input (takes priority over raw_context['user_input']) | |
| """ | |
| # Use config budget if not provided | |
| if max_tokens is None: | |
| max_tokens = self.settings.context_preparation_budget | |
| # Get user input (explicit parameter takes priority) | |
| actual_user_input = user_input or raw_context.get('user_input', '') | |
| # Calculate user input tokens (simple estimation: 1 token ≈ 4 chars) | |
| user_input_tokens = len(actual_user_input) // 4 | |
| # Ensure user input fits within dedicated budget | |
| user_input_max = self.settings.user_input_max_tokens | |
| if user_input_tokens > user_input_max: | |
| logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating") | |
| max_chars = user_input_max * 4 | |
| actual_user_input = actual_user_input[:max_chars - 3] + "..." | |
| user_input_tokens = user_input_max | |
| # Reserve space for user input (it has highest priority) | |
| remaining_tokens = max_tokens - user_input_tokens | |
| if remaining_tokens < 0: | |
| logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})") | |
| remaining_tokens = 0 | |
| logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}") | |
| # Priority order for context elements (user input already handled) | |
| priority_elements = [ | |
| ('recent_interactions', 0.8), | |
| ('user_preferences', 0.6), | |
| ('session_summary', 0.4), | |
| ('historical_context', 0.2) | |
| ] | |
| formatted_context = [] | |
| total_tokens = user_input_tokens # Start with user input tokens | |
| # Add user input first (unconditionally, never truncated) | |
| if actual_user_input: | |
| formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}") | |
| # Now add context elements within remaining budget | |
| for element, priority in priority_elements: | |
| element_key_map = { | |
| 'recent_interactions': raw_context.get('interaction_contexts', []), | |
| 'user_preferences': raw_context.get('preferences', {}), | |
| 'session_summary': raw_context.get('session_context', {}), | |
| 'historical_context': raw_context.get('user_context', '') | |
| } | |
| content = element_key_map.get(element, '') | |
| # Convert to string if needed | |
| if isinstance(content, dict): | |
| content = str(content) | |
| elif isinstance(content, list): | |
| content = "\n".join([str(item) for item in content[:10]]) | |
| if not content: | |
| continue | |
| # Estimate tokens (simple: 1 token ≈ 4 chars) | |
| tokens = len(content) // 4 | |
| if total_tokens + tokens <= max_tokens: | |
| formatted_context.append(f"=== {element.upper()} ===\n{content}") | |
| total_tokens += tokens | |
| elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed | |
| available = max_tokens - total_tokens | |
| if available > 100: # Only truncate if we have meaningful space | |
| truncated = self._truncate_to_tokens(content, available) | |
| formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") | |
| total_tokens += available | |
| break | |
| logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})") | |
| return "\n\n".join(formatted_context) | |
| def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: | |
| """Truncate content to fit within token limit""" | |
| # Simple character-based truncation (1 token ≈ 4 chars) | |
| max_chars = max_tokens * 4 | |
| if len(content) <= max_chars: | |
| return content | |
| return content[:max_chars - 3] + "..." | |