File size: 14,907 Bytes
927854c 8f4d405 927854c 8f4d405 927854c b3aba24 927854c b3aba24 927854c b3aba24 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 5787d0a 927854c 5787d0a 927854c 5787d0a 927854c 5787d0a 927854c 5787d0a 8f4d405 927854c 5787d0a 927854c 5787d0a 927854c 5787d0a 8f4d405 927854c 8f4d405 927854c 8d4bf4a 927854c 8f4d405 927854c 79ea999 8f4d405 927854c 5787d0a 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 5787d0a 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c 8f4d405 927854c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 |
# llm_router.py - NOVITA AI API ONLY
import logging
import asyncio
from typing import Dict, Optional
from .models_config import LLM_CONFIG
from .config import get_settings
# Import OpenAI client for Novita AI API
try:
from openai import OpenAI
OPENAI_AVAILABLE = True
except ImportError:
OPENAI_AVAILABLE = False
logger = logging.getLogger(__name__)
logger.error("openai package not available - Novita AI API requires openai package")
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token=None, use_local_models: bool = False):
"""
Initialize LLM Router with Novita AI API only.
Args:
hf_token: Not used (kept for backward compatibility)
use_local_models: Must be False (local models disabled)
"""
if use_local_models:
raise ValueError("Local models are disabled. Only Novita AI API is supported.")
self.settings = get_settings()
self.novita_client = None
# Validate OpenAI package
if not OPENAI_AVAILABLE:
raise ImportError(
"openai package is required for Novita AI API. "
"Install it with: pip install openai>=1.0.0"
)
# Validate API key
if not self.settings.novita_api_key:
raise ValueError(
"NOVITA_API_KEY is required. "
"Set it in environment variables or .env file"
)
# Initialize Novita AI client
try:
self.novita_client = OpenAI(
base_url=self.settings.novita_base_url,
api_key=self.settings.novita_api_key,
)
logger.info(f"✓ Novita AI API client initialized")
logger.info(f" Base URL: {self.settings.novita_base_url}")
logger.info(f" Model: {self.settings.novita_model}")
except Exception as e:
logger.error(f"Failed to initialize Novita AI client: {e}")
raise RuntimeError(f"Could not initialize Novita AI API client: {e}") from e
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Route inference to Novita AI API.
Args:
task_type: Type of task (general_reasoning, intent_classification, etc.)
prompt: Input prompt
**kwargs: Additional parameters (max_tokens, temperature, etc.)
Returns:
Generated text response
"""
logger.info(f"Routing inference to Novita AI API for task: {task_type}")
if not self.novita_client:
raise RuntimeError("Novita AI client not initialized")
try:
# Handle embedding generation (may need special handling)
if task_type == "embedding_generation":
logger.warning("Embedding generation via Novita API may require special implementation")
# For now, use chat completion (may need adjustment based on Novita API capabilities)
result = await self._call_novita_api(task_type, prompt, **kwargs)
else:
result = await self._call_novita_api(task_type, prompt, **kwargs)
if result is None:
logger.error(f"Novita AI API returned None for task: {task_type}")
raise RuntimeError(f"Inference failed for task: {task_type}")
logger.info(f"Inference complete for {task_type} (Novita AI API)")
return result
except Exception as e:
logger.error(f"Novita AI API inference failed: {e}", exc_info=True)
raise RuntimeError(
f"Inference failed for task: {task_type}. "
f"Novita AI API error: {e}"
) from e
async def _call_novita_api(self, task_type: str, prompt: str, **kwargs) -> Optional[str]:
"""Call Novita AI API for inference."""
if not self.novita_client:
return None
# Get model config
model_config = self._select_model(task_type)
model_name = kwargs.get('model', self.settings.novita_model)
# Get optimized parameters
max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 4096))
temperature = kwargs.get('temperature',
model_config.get('temperature', self.settings.deepseek_r1_temperature))
top_p = kwargs.get('top_p', model_config.get('top_p', 0.95))
stream = kwargs.get('stream', False)
# Format prompt according to DeepSeek-R1 best practices
formatted_prompt = self._format_deepseek_r1_prompt(prompt, task_type, model_config)
# IMPORTANT: No system prompt - all instructions in user prompt
messages = [{"role": "user", "content": formatted_prompt}]
# Build request parameters
request_params = {
"model": model_name,
"messages": messages,
"stream": stream,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
try:
if stream:
# Handle streaming response
response_text = ""
stream_response = self.novita_client.chat.completions.create(**request_params)
for chunk in stream_response:
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and delta.content:
response_text += delta.content
# Clean up reasoning tags if present
response_text = self._clean_reasoning_tags(response_text)
logger.info(f"Novita AI API generated response (length: {len(response_text)})")
return response_text
else:
# Handle non-streaming response
response = self.novita_client.chat.completions.create(**request_params)
if response.choices and len(response.choices) > 0:
result = response.choices[0].message.content
# Clean up reasoning tags if present
result = self._clean_reasoning_tags(result)
logger.info(f"Novita AI API generated response (length: {len(result)})")
return result
else:
logger.error("Novita AI API returned empty response")
return None
except Exception as e:
logger.error(f"Error calling Novita AI API: {e}", exc_info=True)
raise
def _format_deepseek_r1_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
"""
Format prompt according to DeepSeek-R1 best practices:
- No system prompt (all instructions in user prompt)
- Force reasoning trigger for reasoning tasks
- Add math directive for mathematical problems
"""
formatted_prompt = prompt
# Check if we should force reasoning prefix
force_reasoning = (
self.settings.deepseek_r1_force_reasoning and
model_config.get("force_reasoning_prefix", False)
)
if force_reasoning:
# Force model to start with reasoning trigger
formatted_prompt = f"`<think>`\n\n{formatted_prompt}"
# Add math directive for mathematical problems
if self._is_math_query(prompt):
math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
return formatted_prompt
def _is_math_query(self, prompt: str) -> bool:
"""Detect if query is mathematical"""
math_keywords = [
"solve", "calculate", "compute", "equation", "formula",
"mathematical", "algebra", "geometry", "calculus", "integral",
"derivative", "theorem", "proof", "problem"
]
prompt_lower = prompt.lower()
return any(keyword in prompt_lower for keyword in math_keywords)
def _clean_reasoning_tags(self, text: str) -> str:
"""Clean up reasoning tags from response"""
text = text.replace("`<think>`", "").replace("`</think>`", "")
text = text.strip()
return text
def _select_model(self, task_type: str) -> dict:
"""Select model configuration based on task type"""
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def get_available_models(self):
"""Get list of available models (Novita AI only)"""
return ["Novita AI API - DeepSeek-R1-Distill-Qwen-7B"]
async def health_check(self):
"""Perform health check on Novita AI API"""
try:
# Test API with a simple request
test_response = self.novita_client.chat.completions.create(
model=self.settings.novita_model,
messages=[{"role": "user", "content": "test"}],
max_tokens=5
)
return {
"provider": "novita_api",
"status": "healthy",
"model": self.settings.novita_model,
"base_url": self.settings.novita_base_url
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"provider": "novita_api",
"status": "unhealthy",
"error": str(e)
}
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
user_input: Optional[str] = None) -> str:
"""
Smart context windowing with user input priority.
User input is NEVER truncated - context is reduced to fit.
Args:
raw_context: Context dictionary
max_tokens: Optional override (uses config default if None)
user_input: Optional explicit user input (takes priority over raw_context['user_input'])
"""
# Use config budget if not provided
if max_tokens is None:
max_tokens = self.settings.context_preparation_budget
# Get user input (explicit parameter takes priority)
actual_user_input = user_input or raw_context.get('user_input', '')
# Calculate user input tokens (simple estimation: 1 token ≈ 4 chars)
user_input_tokens = len(actual_user_input) // 4
# Ensure user input fits within dedicated budget
user_input_max = self.settings.user_input_max_tokens
if user_input_tokens > user_input_max:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating")
max_chars = user_input_max * 4
actual_user_input = actual_user_input[:max_chars - 3] + "..."
user_input_tokens = user_input_max
# Reserve space for user input (it has highest priority)
remaining_tokens = max_tokens - user_input_tokens
if remaining_tokens < 0:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})")
remaining_tokens = 0
logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}")
# Priority order for context elements (user input already handled)
priority_elements = [
('recent_interactions', 0.8),
('user_preferences', 0.6),
('session_summary', 0.4),
('historical_context', 0.2)
]
formatted_context = []
total_tokens = user_input_tokens # Start with user input tokens
# Add user input first (unconditionally, never truncated)
if actual_user_input:
formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}")
# Now add context elements within remaining budget
for element, priority in priority_elements:
element_key_map = {
'recent_interactions': raw_context.get('interaction_contexts', []),
'user_preferences': raw_context.get('preferences', {}),
'session_summary': raw_context.get('session_context', {}),
'historical_context': raw_context.get('user_context', '')
}
content = element_key_map.get(element, '')
# Convert to string if needed
if isinstance(content, dict):
content = str(content)
elif isinstance(content, list):
content = "\n".join([str(item) for item in content[:10]])
if not content:
continue
# Estimate tokens (simple: 1 token ≈ 4 chars)
tokens = len(content) // 4
if total_tokens + tokens <= max_tokens:
formatted_context.append(f"=== {element.upper()} ===\n{content}")
total_tokens += tokens
elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed
available = max_tokens - total_tokens
if available > 100: # Only truncate if we have meaningful space
truncated = self._truncate_to_tokens(content, available)
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
total_tokens += available
break
logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})")
return "\n\n".join(formatted_context)
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
"""Truncate content to fit within token limit"""
# Simple character-based truncation (1 token ≈ 4 chars)
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars - 3] + "..."
|