from sentence_transformers import SentenceTransformer import numpy as np import json import faiss from typing import List, Dict class ChatEmbedder: def __init__(self, model_name='BAAI/bge-small-zh-v1.5'): self.model = SentenceTransformer(model_name) self.index = None self.documents = [] def load_chat_data(self, json_path: str): """加载聊天记录并转换为文档""" with open(json_path, 'r', encoding='utf-8') as f: chats = json.load(f) for chat in chats: # 将每条消息作为一个文档,包含上下文信息 doc = f"【{chat['group']}】{chat['sender']}在{chat['timestamp']}说:{chat['message']}" self.documents.append(doc) return self.documents def build_index(self): """构建FAISS向量索引""" embeddings = self.model.encode(self.documents, normalize_embeddings=True) dimension = embeddings.shape[1] # 使用内积相似度(余弦相似度) self.index = faiss.IndexFlatIP(dimension) self.index.add(embeddings) return len(self.documents) def search(self, query: str, top_k: int = 3): """检索相关文档""" query_embedding = self.model.encode([query], normalize_embeddings=True) distances, indices = self.index.search(query_embedding, top_k) results = [] for idx, distance in zip(indices[0], distances[0]): if idx < len(self.documents): results.append({ 'content': self.documents[idx], 'score': float(distance) }) return results