File size: 1,738 Bytes
c16e2d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from sentence_transformers import SentenceTransformer
import numpy as np
import json
import faiss
from typing import List, Dict

class ChatEmbedder:
    def __init__(self, model_name='BAAI/bge-small-zh-v1.5'):
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.documents = []
        
    def load_chat_data(self, json_path: str):
        """加载聊天记录并转换为文档"""
        with open(json_path, 'r', encoding='utf-8') as f:
            chats = json.load(f)
        
        for chat in chats:
            # 将每条消息作为一个文档,包含上下文信息
            doc = f"【{chat['group']}{chat['sender']}{chat['timestamp']}说:{chat['message']}"
            self.documents.append(doc)
        
        return self.documents
    
    def build_index(self):
        """构建FAISS向量索引"""
        embeddings = self.model.encode(self.documents, normalize_embeddings=True)
        dimension = embeddings.shape[1]
        
        # 使用内积相似度(余弦相似度)
        self.index = faiss.IndexFlatIP(dimension)
        self.index.add(embeddings)
        
        return len(self.documents)
    
    def search(self, query: str, top_k: int = 3):
        """检索相关文档"""
        query_embedding = self.model.encode([query], normalize_embeddings=True)
        distances, indices = self.index.search(query_embedding, top_k)
        
        results = []
        for idx, distance in zip(indices[0], distances[0]):
            if idx < len(self.documents):
                results.append({
                    'content': self.documents[idx],
                    'score': float(distance)
                })
        
        return results