|
|
import gradio as gr |
|
|
from utils.embedding import ChatEmbedder |
|
|
from utils.retrieval import ChatQABot |
|
|
import json |
|
|
|
|
|
class ChatAnalyzerApp: |
|
|
def __init__(self): |
|
|
self.embedder = ChatEmbedder() |
|
|
self.qa_bot = ChatQABot() |
|
|
self.setup_data() |
|
|
|
|
|
def setup_data(self): |
|
|
"""初始化数据和索引""" |
|
|
try: |
|
|
count = self.embedder.load_chat_data('chat_log.json') |
|
|
index_size = self.embedder.build_index() |
|
|
print(f"✓ 加载了 {count} 条聊天记录,构建了 {index_size} 维索引") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"× 数据加载失败: {e}") |
|
|
return False |
|
|
|
|
|
def query_chat(self, question: str, top_k: int = 3): |
|
|
"""处理查询""" |
|
|
|
|
|
results = self.embedder.search(question, top_k) |
|
|
|
|
|
if not results: |
|
|
return "未找到相关聊天记录。", "" |
|
|
|
|
|
|
|
|
context = "\n".join([f"{i+1}. {r['content']} (相关度: {r['score']:.3f})" |
|
|
for i, r in enumerate(results)]) |
|
|
|
|
|
|
|
|
answer = self.qa_bot.generate_answer(question, context) |
|
|
|
|
|
return answer, context |
|
|
|
|
|
def main(): |
|
|
app = ChatAnalyzerApp() |
|
|
|
|
|
if not app.setup_data(): |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="群聊记录分析助手") as demo: |
|
|
demo.theme = gr.themes.Soft() |
|
|
gr.Markdown("# 📱 群聊记录分析与问答系统") |
|
|
gr.Markdown("基于RAG技术,智能分析群聊记录并回答问题") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
question_input = gr.Textbox( |
|
|
label="输入您的问题", |
|
|
placeholder="例如:关于新版UI设计,最终的结论是什么?", |
|
|
lines=2 |
|
|
) |
|
|
|
|
|
top_k_slider = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=5, |
|
|
value=3, |
|
|
step=1, |
|
|
label="检索结果数量" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("🔍 搜索", variant="primary") |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
answer_output = gr.Textbox( |
|
|
label="答案", |
|
|
lines=4, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
context_output = gr.Textbox( |
|
|
label="检索到的相关聊天记录", |
|
|
lines=6, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
"关于新版UI设计,最终的结论是什么?", |
|
|
"谁负责联系市场部?", |
|
|
"市场部的联系方式是什么?", |
|
|
"讨论过哪些设计方案?" |
|
|
], |
|
|
inputs=[question_input] |
|
|
) |
|
|
|
|
|
gr.Markdown("### 功能说明") |
|
|
gr.Markdown(""" |
|
|
1. 系统基于模拟的飞书/企微群聊记录构建知识库 |
|
|
2. 使用中文嵌入模型检索相关聊天片段 |
|
|
3. 使用大语言模型生成准确答案 |
|
|
4. 支持调整检索结果数量以优化答案质量 |
|
|
""") |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
fn=app.query_chat, |
|
|
inputs=[question_input, top_k_slider], |
|
|
outputs=[answer_output, context_output] |
|
|
) |
|
|
|
|
|
question_input.submit( |
|
|
fn=app.query_chat, |
|
|
inputs=[question_input, top_k_slider], |
|
|
outputs=[answer_output, context_output] |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |