chat / utils /retrieval.py
hcy276's picture
Create utils/retrieval.py
0ea9daa verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class ChatQABot:
def __init__(self, model_name='Qwen/Qwen1.5-1.8B-Chat'):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else "cpu"
)
def generate_answer(self, query: str, context: str) -> str:
"""基于检索到的上下文生成答案"""
prompt = f"""基于以下聊天记录回答问题:
聊天记录:
{context}
问题:{query}
请根据聊天记录准确回答,如果聊天记录中没有相关信息,请说"根据现有聊天记录无法回答这个问题"。回答要简洁准确。"""
messages = [
{"role": "system", "content": "你是一个专业的聊天记录分析助手。"},
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=200,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response