|
|
""" |
|
|
Embedding utilities for LifeUnity AI Cognitive Twin System. |
|
|
Provides text embedding functionality using Sentence-BERT. |
|
|
""" |
|
|
|
|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
from typing import List, Union |
|
|
import torch |
|
|
from app.utils.logger import get_logger |
|
|
|
|
|
logger = get_logger("Embedder") |
|
|
|
|
|
|
|
|
class TextEmbedder: |
|
|
"""Text embedding handler using Sentence-BERT.""" |
|
|
|
|
|
def __init__(self, model_name: str = 'all-MiniLM-L6-v2'): |
|
|
""" |
|
|
Initialize the text embedder. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the sentence-transformers model |
|
|
""" |
|
|
self.model_name = model_name |
|
|
self.model = None |
|
|
self.embedding_dim = None |
|
|
logger.info(f"Initializing TextEmbedder with model: {model_name}") |
|
|
|
|
|
def load_model(self): |
|
|
"""Load the sentence transformer model.""" |
|
|
try: |
|
|
if self.model is None: |
|
|
logger.info(f"Loading model: {self.model_name}") |
|
|
self.model = SentenceTransformer(self.model_name) |
|
|
|
|
|
self.embedding_dim = self.model.get_sentence_embedding_dimension() |
|
|
logger.info(f"Model loaded successfully. Embedding dim: {self.embedding_dim}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def embed_text(self, text: Union[str, List[str]]) -> np.ndarray: |
|
|
""" |
|
|
Generate embeddings for text. |
|
|
|
|
|
Args: |
|
|
text: Single text string or list of text strings |
|
|
|
|
|
Returns: |
|
|
Numpy array of embeddings |
|
|
""" |
|
|
if self.model is None: |
|
|
self.load_model() |
|
|
|
|
|
try: |
|
|
|
|
|
if isinstance(text, str): |
|
|
text = [text] |
|
|
|
|
|
|
|
|
embeddings = self.model.encode( |
|
|
text, |
|
|
convert_to_numpy=True, |
|
|
show_progress_bar=False |
|
|
) |
|
|
|
|
|
logger.debug(f"Generated embeddings for {len(text)} texts") |
|
|
return embeddings |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating embeddings: {str(e)}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def compute_similarity(self, text1: str, text2: str) -> float: |
|
|
""" |
|
|
Compute cosine similarity between two texts. |
|
|
|
|
|
Args: |
|
|
text1: First text |
|
|
text2: Second text |
|
|
|
|
|
Returns: |
|
|
Similarity score (0-1) |
|
|
""" |
|
|
try: |
|
|
embeddings = self.embed_text([text1, text2]) |
|
|
|
|
|
|
|
|
similarity = np.dot(embeddings[0], embeddings[1]) / ( |
|
|
np.linalg.norm(embeddings[0]) * np.linalg.norm(embeddings[1]) |
|
|
) |
|
|
|
|
|
return float(similarity) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error computing similarity: {str(e)}", exc_info=True) |
|
|
return 0.0 |
|
|
|
|
|
def find_most_similar( |
|
|
self, |
|
|
query: str, |
|
|
candidates: List[str], |
|
|
top_k: int = 5 |
|
|
) -> List[tuple]: |
|
|
""" |
|
|
Find most similar texts to a query. |
|
|
|
|
|
Args: |
|
|
query: Query text |
|
|
candidates: List of candidate texts |
|
|
top_k: Number of top results to return |
|
|
|
|
|
Returns: |
|
|
List of (index, text, similarity_score) tuples |
|
|
""" |
|
|
try: |
|
|
|
|
|
query_embedding = self.embed_text(query) |
|
|
candidate_embeddings = self.embed_text(candidates) |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
for idx, candidate_emb in enumerate(candidate_embeddings): |
|
|
similarity = np.dot(query_embedding[0], candidate_emb) / ( |
|
|
np.linalg.norm(query_embedding[0]) * np.linalg.norm(candidate_emb) |
|
|
) |
|
|
similarities.append((idx, candidates[idx], float(similarity))) |
|
|
|
|
|
|
|
|
similarities.sort(key=lambda x: x[2], reverse=True) |
|
|
|
|
|
return similarities[:top_k] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error finding similar texts: {str(e)}", exc_info=True) |
|
|
return [] |
|
|
|
|
|
|
|
|
|
|
|
_embedder = None |
|
|
|
|
|
|
|
|
def get_embedder(model_name: str = 'all-MiniLM-L6-v2') -> TextEmbedder: |
|
|
""" |
|
|
Get or create a global embedder instance. |
|
|
|
|
|
Args: |
|
|
model_name: Name of the sentence-transformers model |
|
|
|
|
|
Returns: |
|
|
TextEmbedder instance |
|
|
""" |
|
|
global _embedder |
|
|
if _embedder is None: |
|
|
_embedder = TextEmbedder(model_name) |
|
|
return _embedder |
|
|
|