|
|
from sentence_transformers import SentenceTransformer |
|
|
import numpy as np |
|
|
import json |
|
|
import faiss |
|
|
from typing import List, Dict |
|
|
|
|
|
class ChatEmbedder: |
|
|
def __init__(self, model_name='BAAI/bge-small-zh-v1.5'): |
|
|
self.model = SentenceTransformer(model_name) |
|
|
self.index = None |
|
|
self.documents = [] |
|
|
|
|
|
def load_chat_data(self, json_path: str): |
|
|
"""加载聊天记录并转换为文档""" |
|
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
|
chats = json.load(f) |
|
|
|
|
|
for chat in chats: |
|
|
|
|
|
doc = f"【{chat['group']}】{chat['sender']}在{chat['timestamp']}说:{chat['message']}" |
|
|
self.documents.append(doc) |
|
|
|
|
|
return self.documents |
|
|
|
|
|
def build_index(self): |
|
|
"""构建FAISS向量索引""" |
|
|
embeddings = self.model.encode(self.documents, normalize_embeddings=True) |
|
|
dimension = embeddings.shape[1] |
|
|
|
|
|
|
|
|
self.index = faiss.IndexFlatIP(dimension) |
|
|
self.index.add(embeddings) |
|
|
|
|
|
return len(self.documents) |
|
|
|
|
|
def search(self, query: str, top_k: int = 3): |
|
|
"""检索相关文档""" |
|
|
query_embedding = self.model.encode([query], normalize_embeddings=True) |
|
|
distances, indices = self.index.search(query_embedding, top_k) |
|
|
|
|
|
results = [] |
|
|
for idx, distance in zip(indices[0], distances[0]): |
|
|
if idx < len(self.documents): |
|
|
results.append({ |
|
|
'content': self.documents[idx], |
|
|
'score': float(distance) |
|
|
}) |
|
|
|
|
|
return results |