HonestAI / src /llm_router.py
JatsTheAIGen's picture
Integrate Novita AI as exclusive inference provider - Add Novita AI API integration with DeepSeek-R1-Distill-Qwen-7B model - Remove all local model dependencies - Optimize token allocation for user inputs and context - Add Anaconda environment setup files - Add comprehensive test scripts and documentation
927854c
raw
history blame
14.9 kB
# llm_router.py - NOVITA AI API ONLY
import logging
import asyncio
from typing import Dict, Optional
from .models_config import LLM_CONFIG
from .config import get_settings
# Import OpenAI client for Novita AI API
try:
from openai import OpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.error("openai package not available - Novita AI API requires openai package")
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token=None, use_local_models: bool = False):
"""
Initialize LLM Router with Novita AI API only.
Args:
hf_token: Not used (kept for backward compatibility)
use_local_models: Must be False (local models disabled)
"""
if use_local_models:
raise ValueError("Local models are disabled. Only Novita AI API is supported.")
self.settings = get_settings()
self.novita_client = None
# Validate OpenAI package
if not OPENAI_AVAILABLE:
raise ImportError(
"openai package is required for Novita AI API. "
"Install it with: pip install openai>=1.0.0"
)
# Validate API key
if not self.settings.novita_api_key:
raise ValueError(
"NOVITA_API_KEY is required. "
"Set it in environment variables or .env file"
)
# Initialize Novita AI client
try:
self.novita_client = OpenAI(
base_url=self.settings.novita_base_url,
api_key=self.settings.novita_api_key,
)
logger.info(f"✓ Novita AI API client initialized")
logger.info(f" Base URL: {self.settings.novita_base_url}")
logger.info(f" Model: {self.settings.novita_model}")
except Exception as e:
logger.error(f"Failed to initialize Novita AI client: {e}")
raise RuntimeError(f"Could not initialize Novita AI API client: {e}") from e
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Route inference to Novita AI API.
Args:
task_type: Type of task (general_reasoning, intent_classification, etc.)
prompt: Input prompt
**kwargs: Additional parameters (max_tokens, temperature, etc.)
Returns:
Generated text response
"""
logger.info(f"Routing inference to Novita AI API for task: {task_type}")
if not self.novita_client:
raise RuntimeError("Novita AI client not initialized")
try:
# Handle embedding generation (may need special handling)
if task_type == "embedding_generation":
logger.warning("Embedding generation via Novita API may require special implementation")
# For now, use chat completion (may need adjustment based on Novita API capabilities)
result = await self._call_novita_api(task_type, prompt, **kwargs)
else:
result = await self._call_novita_api(task_type, prompt, **kwargs)
if result is None:
logger.error(f"Novita AI API returned None for task: {task_type}")
raise RuntimeError(f"Inference failed for task: {task_type}")
logger.info(f"Inference complete for {task_type} (Novita AI API)")
return result
except Exception as e:
logger.error(f"Novita AI API inference failed: {e}", exc_info=True)
raise RuntimeError(
f"Inference failed for task: {task_type}. "
f"Novita AI API error: {e}"
) from e
async def _call_novita_api(self, task_type: str, prompt: str, **kwargs) -> Optional[str]:
"""Call Novita AI API for inference."""
if not self.novita_client:
return None
# Get model config
model_config = self._select_model(task_type)
model_name = kwargs.get('model', self.settings.novita_model)
# Get optimized parameters
max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 4096))
temperature = kwargs.get('temperature',
model_config.get('temperature', self.settings.deepseek_r1_temperature))
top_p = kwargs.get('top_p', model_config.get('top_p', 0.95))
stream = kwargs.get('stream', False)
# Format prompt according to DeepSeek-R1 best practices
formatted_prompt = self._format_deepseek_r1_prompt(prompt, task_type, model_config)
# IMPORTANT: No system prompt - all instructions in user prompt
messages = [{"role": "user", "content": formatted_prompt}]
# Build request parameters
request_params = {
"model": model_name,
"messages": messages,
"stream": stream,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
try:
if stream:
# Handle streaming response
response_text = ""
stream_response = self.novita_client.chat.completions.create(**request_params)
for chunk in stream_response:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
response_text += delta.content
# Clean up reasoning tags if present
response_text = self._clean_reasoning_tags(response_text)
logger.info(f"Novita AI API generated response (length: {len(response_text)})")
return response_text
else:
# Handle non-streaming response
response = self.novita_client.chat.completions.create(**request_params)
if response.choices and len(response.choices) > 0:
result = response.choices[0].message.content
# Clean up reasoning tags if present
result = self._clean_reasoning_tags(result)
logger.info(f"Novita AI API generated response (length: {len(result)})")
return result
else:
logger.error("Novita AI API returned empty response")
return None
except Exception as e:
logger.error(f"Error calling Novita AI API: {e}", exc_info=True)
raise
def _format_deepseek_r1_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
"""
Format prompt according to DeepSeek-R1 best practices:
- No system prompt (all instructions in user prompt)
- Force reasoning trigger for reasoning tasks
- Add math directive for mathematical problems
"""
formatted_prompt = prompt
# Check if we should force reasoning prefix
force_reasoning = (
self.settings.deepseek_r1_force_reasoning and
model_config.get("force_reasoning_prefix", False)
)
if force_reasoning:
# Force model to start with reasoning trigger
formatted_prompt = f"`<think>`\n\n{formatted_prompt}"
# Add math directive for mathematical problems
if self._is_math_query(prompt):
math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
return formatted_prompt
def _is_math_query(self, prompt: str) -> bool:
"""Detect if query is mathematical"""
math_keywords = [
"solve", "calculate", "compute", "equation", "formula",
"mathematical", "algebra", "geometry", "calculus", "integral",
"derivative", "theorem", "proof", "problem"
]
prompt_lower = prompt.lower()
return any(keyword in prompt_lower for keyword in math_keywords)
def _clean_reasoning_tags(self, text: str) -> str:
"""Clean up reasoning tags from response"""
text = text.replace("`<think>`", "").replace("`</think>`", "")
text = text.strip()
return text
def _select_model(self, task_type: str) -> dict:
"""Select model configuration based on task type"""
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def get_available_models(self):
"""Get list of available models (Novita AI only)"""
return ["Novita AI API - DeepSeek-R1-Distill-Qwen-7B"]
async def health_check(self):
"""Perform health check on Novita AI API"""
try:
# Test API with a simple request
test_response = self.novita_client.chat.completions.create(
model=self.settings.novita_model,
messages=[{"role": "user", "content": "test"}],
max_tokens=5
)
return {
"provider": "novita_api",
"status": "healthy",
"model": self.settings.novita_model,
"base_url": self.settings.novita_base_url
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"provider": "novita_api",
"status": "unhealthy",
"error": str(e)
}
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
user_input: Optional[str] = None) -> str:
"""
Smart context windowing with user input priority.
User input is NEVER truncated - context is reduced to fit.
Args:
raw_context: Context dictionary
max_tokens: Optional override (uses config default if None)
user_input: Optional explicit user input (takes priority over raw_context['user_input'])
"""
# Use config budget if not provided
if max_tokens is None:
max_tokens = self.settings.context_preparation_budget
# Get user input (explicit parameter takes priority)
actual_user_input = user_input or raw_context.get('user_input', '')
# Calculate user input tokens (simple estimation: 1 token ≈ 4 chars)
user_input_tokens = len(actual_user_input) // 4
# Ensure user input fits within dedicated budget
user_input_max = self.settings.user_input_max_tokens
if user_input_tokens > user_input_max:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating")
max_chars = user_input_max * 4
actual_user_input = actual_user_input[:max_chars - 3] + "..."
user_input_tokens = user_input_max
# Reserve space for user input (it has highest priority)
remaining_tokens = max_tokens - user_input_tokens
if remaining_tokens < 0:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})")
remaining_tokens = 0
logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}")
# Priority order for context elements (user input already handled)
priority_elements = [
('recent_interactions', 0.8),
('user_preferences', 0.6),
('session_summary', 0.4),
('historical_context', 0.2)
]
formatted_context = []
total_tokens = user_input_tokens # Start with user input tokens
# Add user input first (unconditionally, never truncated)
if actual_user_input:
formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}")
# Now add context elements within remaining budget
for element, priority in priority_elements:
element_key_map = {
'recent_interactions': raw_context.get('interaction_contexts', []),
'user_preferences': raw_context.get('preferences', {}),
'session_summary': raw_context.get('session_context', {}),
'historical_context': raw_context.get('user_context', '')
}
content = element_key_map.get(element, '')
# Convert to string if needed
if isinstance(content, dict):
content = str(content)
elif isinstance(content, list):
content = "\n".join([str(item) for item in content[:10]])
if not content:
continue
# Estimate tokens (simple: 1 token ≈ 4 chars)
tokens = len(content) // 4
if total_tokens + tokens <= max_tokens:
formatted_context.append(f"=== {element.upper()} ===\n{content}")
total_tokens += tokens
elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed
available = max_tokens - total_tokens
if available > 100: # Only truncate if we have meaningful space
truncated = self._truncate_to_tokens(content, available)
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
total_tokens += available
break
logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})")
return "\n\n".join(formatted_context)
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
"""Truncate content to fit within token limit"""
# Simple character-based truncation (1 token ≈ 4 chars)
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars - 3] + "..."