ConvoBot / src /llm.py
ashish-ninehertz
changes
e0fb2f6
import logging
import ollama
from typing import List, Dict, Optional
from src.config import Config
import os
class OllamaMistral:
"""
A class to interact with the Ollama API for Mistral model.
Handles both chat completions and embeddings generation.
"""
def __init__(self):
"""Initialize the Ollama Mistral client with default settings."""
self.logger = logging.getLogger(__name__)
# Initialize Ollama client with default host
self.client = ollama.Client(host='http://localhost:11434')
self.model = 'mistral' # Default model name
async def generate_response(self, prompt: str) -> str:
"""
Asynchronously generate a text response from Mistral model.
Args:
prompt: The input text prompt for the model
Returns:
Generated response text or error message if failed
"""
try:
print(f"[Ollama] Sending prompt:\n{prompt}\n")
# Send chat request to Ollama API
response = self.client.chat(
model=self.model,
messages=[{
'role': 'user',
'content': prompt
}]
)
print(f"[Ollama] Received response:\n{response}\n")
# Handle different response formats from Ollama
if isinstance(response, dict):
if 'message' in response and 'content' in response['message']:
return response['message']['content']
elif hasattr(response, 'message') and hasattr(response.message, 'content'):
return response.message.content
# Fallback: try to convert to string
return str(response)
except Exception as e:
self.logger.error(f"[OllamaMistral] Error generating response: {str(e)}", exc_info=True)
return f"Error generating response: {str(e)}"
def generate_embedding(self, text: str, model: str = Config.OLLAMA_MODEL) -> Optional[List[float]]:
"""
Generate embeddings for the input text using specified model.
Args:
text: Input text to generate embeddings for
model: Model name to use for embeddings (default from Config)
Returns:
List of embeddings or None if failed
"""
try:
print(f"[Ollama] Generating embedding for: {text[:60]}...")
# Request embeddings from Ollama API
response = self.client.embeddings(
model=model,
prompts=[text] # prompts must be a list of strings
)
print(f"[Ollama] Embedding response: {response}")
# Handle different response formats
if isinstance(response, dict) and 'embeddings' in response:
return response['embeddings'][0]
elif isinstance(response, dict) and 'embedding' in response:
return response['embedding']
else:
self.logger.warning(f"Unexpected embedding response format: {response}")
return None
except Exception as e:
self.logger.error(f"[OllamaMistral] Error generating embedding: {str(e)}", exc_info=True)
return None
def generate(self, prompt: str) -> str:
"""
Synchronous wrapper for generate_response.
Args:
prompt: Input text prompt
Returns:
Generated response text
"""
import asyncio
try:
return asyncio.run(self.generate_response(prompt))
except Exception as e:
self.logger.error(f"Error in synchronous generate: {e}")
return f"Error generating response: {str(e)}"
class GeminiProvider:
"""
A class to interact with Google's Gemini API.
Requires GEMINI_API_KEY environment variable.
"""
def __init__(self):
"""Initialize Gemini provider with API key."""
self.logger = logging.getLogger(__name__)
self.api_key = os.getenv('GEMINI_API_KEY')
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is required for Gemini provider")
try:
import google.generativeai as genai
# Configure Gemini API
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel('gemini-1.5-flash')
except ImportError:
raise ImportError("google-generativeai package is required for Gemini provider")
def generate(self, prompt: str) -> str:
"""
Generate text response using Gemini model.
Args:
prompt: Input text prompt
Returns:
Generated response text or error message
"""
try:
response = self.model.generate_content(prompt)
return response.text
except Exception as e:
self.logger.error(f"[Gemini] Error generating response: {str(e)}")
return f"Error generating response: {str(e)}"
class OpenChatProvider:
"""
A class to use OpenChat models locally via transformers.
Requires transformers package to be installed.
"""
def __init__(self):
"""Initialize OpenChat model and tokenizer."""
self.logger = logging.getLogger(__name__)
try:
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load pretrained OpenChat model
self.tokenizer = AutoTokenizer.from_pretrained("openchat/openchat-3.5-0106")
self.model = AutoModelForCausalLM.from_pretrained("openchat/openchat-3.5-0106")
except ImportError:
raise ImportError("transformers package is required for OpenChat provider")
def generate(self, prompt: str) -> str:
"""
Generate text response using OpenChat model.
Args:
prompt: Input text prompt
Returns:
Generated response text
"""
try:
# Tokenize input and generate response
inputs = self.tokenizer(prompt, return_tensors="pt")
outputs = self.model.generate(**inputs, max_length=512, temperature=0.7)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
except Exception as e:
self.logger.error(f"[OpenChat] Error generating response: {str(e)}")
return f"Error generating response: {str(e)}"
class LLMFactory:
"""
Factory class to create and manage different LLM providers.
Implements the Factory design pattern for LLM provider instantiation.
"""
@staticmethod
def get_provider(model_name: Optional[str] = None) -> any:
"""
Get appropriate LLM provider based on model name.
Args:
model_name: Name of the model ('mistral', 'gemini', 'openchat')
Defaults to 'mistral' if None or unknown
Returns:
Instance of the requested LLM provider
Raises:
ValueError: If required dependencies are missing for the provider
"""
if model_name is None:
model_name = "mistral" # Default to mistral
model_name = model_name.lower()
# Return appropriate provider based on model name
if model_name == "mistral":
return OllamaMistral()
elif model_name == "gemini":
return GeminiProvider()
elif model_name == "openchat":
return OpenChatProvider()
else:
# Default to mistral if unknown model is specified
logging.warning(f"Unknown model '{model_name}', defaulting to mistral")
return OllamaMistral()