Spaces:
Sleeping
Sleeping
| from typing import AsyncGenerator, Dict, Any | |
| import importlib | |
| import os | |
| import logging | |
| from core.exceptions import ModelError, ProviderError | |
| from core.key_manager import key_manager | |
| logger = logging.getLogger(__name__) | |
| class TextGenerator: | |
| """Handles text generation across different model providers.""" | |
| async def generate_stream(model: str, prompt: str) -> AsyncGenerator[str, None]: | |
| """ | |
| Generate streaming text responses using the specified model. | |
| Args: | |
| model (str): The name of the model to use | |
| prompt (str): The input prompt for text generation | |
| Yields: | |
| str: Generated text chunks | |
| Raises: | |
| ModelError: If the model is not found or invalid | |
| ProviderError: If there's an error with the provider | |
| """ | |
| try: | |
| # Get provider for the selected model | |
| provider = key_manager.get_provider_for_model(model) | |
| logger.info(f"Using provider {provider} for model {model}") | |
| # Get API key | |
| api_key = key_manager.get_next_key(provider) | |
| logger.info(f"Retrieved API key for provider {provider}") | |
| # Import provider module | |
| try: | |
| # Convert provider name to valid module name | |
| module_name = provider.lower().split()[0] # Get first word in lowercase | |
| logger.info(f"Importing module: models.{module_name}") | |
| # Check if module exists | |
| backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| module_path = os.path.join(backend_dir, 'models', f'{module_name}.py') | |
| if not os.path.exists(module_path): | |
| raise ProviderError(f"Provider module not found at: {module_path}") | |
| provider_module = importlib.import_module(f"models.{module_name}") | |
| # Verify required functions exist | |
| if not hasattr(provider_module, 'run_model_stream'): | |
| raise ProviderError(f"Provider module {module_name} missing required function: run_model_stream") | |
| except ImportError as e: | |
| raise ProviderError(f"Failed to import provider module: {str(e)}") | |
| except Exception as e: | |
| raise ProviderError(f"Error loading provider module: {str(e)}") | |
| # Generate text stream | |
| async for chunk in provider_module.run_model_stream(api_key, model, prompt): | |
| yield chunk | |
| except Exception as e: | |
| logger.error(f"Error in generate_stream: {str(e)}") | |
| raise | |
| def get_available_models() -> Dict[str, Any]: | |
| """ | |
| Get all available models and their providers. | |
| Returns: | |
| Dict[str, Any]: Dictionary containing model information | |
| """ | |
| try: | |
| models = key_manager.get_available_models() | |
| # Return just the models array directly since that's what frontend expects | |
| return models | |
| except Exception as e: | |
| logger.error(f"Error getting available models: {str(e)}") | |
| raise ModelError(f"Failed to get available models: {str(e)}") | |
| # Global instance | |
| text_generator = TextGenerator() |