Spaces:
Sleeping
Sleeping
| import logging | |
| import ollama | |
| from typing import List, Dict, Optional | |
| from src.config import Config | |
| import os | |
| class OllamaMistral: | |
| """ | |
| A class to interact with the Ollama API for Mistral model. | |
| Handles both chat completions and embeddings generation. | |
| """ | |
| def __init__(self): | |
| """Initialize the Ollama Mistral client with default settings.""" | |
| self.logger = logging.getLogger(__name__) | |
| # Initialize Ollama client with default host | |
| self.client = ollama.Client(host='http://localhost:11434') | |
| self.model = 'mistral' # Default model name | |
| async def generate_response(self, prompt: str) -> str: | |
| """ | |
| Asynchronously generate a text response from Mistral model. | |
| Args: | |
| prompt: The input text prompt for the model | |
| Returns: | |
| Generated response text or error message if failed | |
| """ | |
| try: | |
| print(f"[Ollama] Sending prompt:\n{prompt}\n") | |
| # Send chat request to Ollama API | |
| response = self.client.chat( | |
| model=self.model, | |
| messages=[{ | |
| 'role': 'user', | |
| 'content': prompt | |
| }] | |
| ) | |
| print(f"[Ollama] Received response:\n{response}\n") | |
| # Handle different response formats from Ollama | |
| if isinstance(response, dict): | |
| if 'message' in response and 'content' in response['message']: | |
| return response['message']['content'] | |
| elif hasattr(response, 'message') and hasattr(response.message, 'content'): | |
| return response.message.content | |
| # Fallback: try to convert to string | |
| return str(response) | |
| except Exception as e: | |
| self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True) | |
| return f"Error generating response: {str(e)}" | |
| def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]: | |
| """ | |
| Generate embeddings for the input text using specified model. | |
| Args: | |
| text: Input text to generate embeddings for | |
| model: Model name to use for embeddings (default from Config) | |
| Returns: | |
| List of embeddings or None if failed | |
| """ | |
| try: | |
| print(f"[Ollama] Generating embedding for: {text[:60]}...") | |
| # Request embeddings from Ollama API | |
| response = self.client.embeddings( | |
| model=model, | |
| prompts=[text] # prompts must be a list of strings | |
| ) | |
| print(f"[Ollama] Embedding response: {response}") | |
| # Handle different response formats | |
| if isinstance(response, dict) and 'embeddings' in response: | |
| return response['embeddings'][0] | |
| elif isinstance(response, dict) and 'embedding' in response: | |
| return response['embedding'] | |
| else: | |
| self.logger.warning(f"Unexpected embedding response format: {response}") | |
| return None | |
| except Exception as e: | |
| self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True) | |
| return None | |
| def generate(self, prompt: str) -> str: | |
| """ | |
| Synchronous wrapper for generate_response. | |
| Args: | |
| prompt: Input text prompt | |
| Returns: | |
| Generated response text | |
| """ | |
| import asyncio | |
| try: | |
| return asyncio.run(self.generate_response(prompt)) | |
| except Exception as e: | |
| self.logger.error(f"Error in synchronous generate: {e}") | |
| return f"Error generating response: {str(e)}" | |
| class GeminiProvider: | |
| """ | |
| A class to interact with Google's Gemini API. | |
| Requires GEMINI_API_KEY environment variable. | |
| """ | |
| def __init__(self): | |
| """Initialize Gemini provider with API key.""" | |
| self.logger = logging.getLogger(__name__) | |
| self.api_key = os.getenv('GEMINI_API_KEY') | |
| if not self.api_key: | |
| raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider") | |
| try: | |
| import google.generativeai as genai | |
| # Configure Gemini API | |
| genai.configure(api_key=self.api_key) | |
| self.model = genai.GenerativeModel('gemini-1.5-flash') | |
| except ImportError: | |
| raise ImportError("google-generativeai package is required for Gemini provider") | |
| def generate(self, prompt: str) -> str: | |
| """ | |
| Generate text response using Gemini model. | |
| Args: | |
| prompt: Input text prompt | |
| Returns: | |
| Generated response text or error message | |
| """ | |
| try: | |
| response = self.model.generate_content(prompt) | |
| return response.text | |
| except Exception as e: | |
| self.logger.error(f"[Gemini] Error generating response: {str(e)}") | |
| return f"Error generating response: {str(e)}" | |
| class OpenChatProvider: | |
| """ | |
| A class to use OpenChat models locally via transformers. | |
| Requires transformers package to be installed. | |
| """ | |
| def __init__(self): | |
| """Initialize OpenChat model and tokenizer.""" | |
| self.logger = logging.getLogger(__name__) | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load pretrained OpenChat model | |
| self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106") | |
| self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106") | |
| except ImportError: | |
| raise ImportError("transformers package is required for OpenChat provider") | |
| def generate(self, prompt: str) -> str: | |
| """ | |
| Generate text response using OpenChat model. | |
| Args: | |
| prompt: Input text prompt | |
| Returns: | |
| Generated response text | |
| """ | |
| try: | |
| # Tokenize input and generate response | |
| inputs = self.tokenizer(prompt, return_tensors="pt") | |
| outputs = self.model.generate(**inputs, max_length=512, temperature=0.7) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return response | |
| except Exception as e: | |
| self.logger.error(f"[OpenChat] Error generating response: {str(e)}") | |
| return f"Error generating response: {str(e)}" | |
| class LLMFactory: | |
| """ | |
| Factory class to create and manage different LLM providers. | |
| Implements the Factory design pattern for LLM provider instantiation. | |
| """ | |
| def get_provider(model_name: Optional[str] = None) -> any: | |
| """ | |
| Get appropriate LLM provider based on model name. | |
| Args: | |
| model_name: Name of the model ('mistral', 'gemini', 'openchat') | |
| Defaults to 'mistral' if None or unknown | |
| Returns: | |
| Instance of the requested LLM provider | |
| Raises: | |
| ValueError: If required dependencies are missing for the provider | |
| """ | |
| if model_name is None: | |
| model_name = "mistral" # Default to mistral | |
| model_name = model_name.lower() | |
| # Return appropriate provider based on model name | |
| if model_name == "mistral": | |
| return OllamaMistral() | |
| elif model_name == "gemini": | |
| return GeminiProvider() | |
| elif model_name == "openchat": | |
| return OpenChatProvider() | |
| else: | |
| # Default to mistral if unknown model is specified | |
| logging.warning(f"Unknown model '{model_name}', defaulting to mistral") | |
| return OllamaMistral() |