chat / utils /embedding.py
hcy276's picture
Create utils/embedding.py
c16e2d0 verified
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import faiss
from typing import List, Dict
class ChatEmbedder:
def __init__(self, model_name='BAAI/bge-small-zh-v1.5'):
self.model = SentenceTransformer(model_name)
self.index = None
self.documents = []
def load_chat_data(self, json_path: str):
"""加载聊天记录并转换为文档"""
with open(json_path, 'r', encoding='utf-8') as f:
chats = json.load(f)
for chat in chats:
# 将每条消息作为一个文档,包含上下文信息
doc = f"【{chat['group']}{chat['sender']}{chat['timestamp']}说:{chat['message']}"
self.documents.append(doc)
return self.documents
def build_index(self):
"""构建FAISS向量索引"""
embeddings = self.model.encode(self.documents, normalize_embeddings=True)
dimension = embeddings.shape[1]
# 使用内积相似度(余弦相似度)
self.index = faiss.IndexFlatIP(dimension)
self.index.add(embeddings)
return len(self.documents)
def search(self, query: str, top_k: int = 3):
"""检索相关文档"""
query_embedding = self.model.encode([query], normalize_embeddings=True)
distances, indices = self.index.search(query_embedding, top_k)
results = []
for idx, distance in zip(indices[0], distances[0]):
if idx < len(self.documents):
results.append({
'content': self.documents[idx],
'score': float(distance)
})
return results