Gilvaa commited on
Commit
17a1603
·
verified ·
1 Parent(s): 85700de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -21
app.py CHANGED
@@ -16,6 +16,17 @@ TOP_P = float(os.getenv("TOP_P", "0.9"))
16
  REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.08"))
17
  SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" # 1=开启基础过滤;想关就设为 0
18
 
 
 
 
 
 
 
 
 
 
 
 
19
  print(f"[boot] MODEL_ID={MODEL_ID}")
20
  print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
21
 
@@ -53,7 +64,6 @@ if torch.cuda.is_available():
53
  trust_remote_code=True,
54
  )
55
  else:
56
- # 没 GPU 时仅用于链路自测:建议把 MODEL_ID 换成 1.5B 基座以免过慢
57
  print("[boot] No GPU detected. Running on CPU is very slow for 7B. "
58
  "Consider setting MODEL_ID=Qwen/Qwen2.5-1.5B-Instruct for smoke test.")
59
  model = AutoModelForCausalLM.from_pretrained(
@@ -89,6 +99,46 @@ def violates(text: str) -> bool:
89
  return True
90
  return False
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # ======================
93
  # 动态长度:根据输入长短调 max_new_tokens
94
  # ======================
@@ -100,15 +150,64 @@ def choose_max_new_tokens(user_text: str) -> int:
100
  return min(384, MAX_NEW_TOKENS + 128)
101
 
102
  # ======================
103
- # 构建 Qwen 模板 Prompt(messages 形式 → chat_template)
104
  # ======================
105
- SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are a helpful, concise chat assistant. Avoid unsafe content.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- def build_prompt(history_msgs, user_msg: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  """
109
  history_msgs: Chatbot(type='messages') 的历史 [{role, content}, ...]
110
  """
111
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
112
  tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
113
  tail = tail[-8:] if len(tail) > 8 else tail
114
  messages.extend(tail)
@@ -132,21 +231,22 @@ BASE_GEN_KW = dict(
132
  )
133
 
134
  # ======================
135
- # 主推理:流式输出
136
  # ======================
137
- def stream_chat(history_msgs, user_msg):
138
  try:
139
  if not user_msg or not user_msg.strip():
140
  yield history_msgs; return
141
 
 
142
  if violates(user_msg):
143
  yield history_msgs + [
144
- {"role":"user","content": user_msg},
145
- {"role":"assistant","content": SAFE_REPLACEMENT},
146
  ]
147
  return
148
 
149
- prompt = build_prompt(history_msgs, user_msg)
150
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
151
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
152
 
@@ -155,6 +255,9 @@ def stream_chat(history_msgs, user_msg):
155
  max_new_tokens=choose_max_new_tokens(user_msg),
156
  **BASE_GEN_KW
157
  )
 
 
 
158
 
159
  print("[gen] start")
160
  th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
@@ -163,15 +266,19 @@ def stream_chat(history_msgs, user_msg):
163
  reply = ""
164
  for chunk in streamer:
165
  reply += chunk
166
- if violates(reply):
 
 
 
167
  yield history_msgs + [
168
- {"role":"user","content": user_msg},
169
- {"role":"assistant","content": SAFE_REPLACEMENT},
170
  ]
171
  return
 
172
  yield history_msgs + [
173
- {"role":"user","content": user_msg},
174
- {"role":"assistant","content": reply},
175
  ]
176
  print("[gen] done, len:", len(reply))
177
 
@@ -179,20 +286,42 @@ def stream_chat(history_msgs, user_msg):
179
  traceback.print_exc()
180
  err = f"【运行异常】{type(e).__name__}: {e}"
181
  yield history_msgs + [
182
- {"role":"user","content": user_msg},
183
- {"role":"assistant","content": err},
184
  ]
185
 
186
  # ======================
187
- # Gradio UI(移动端友好)
188
  # ======================
189
  CSS = """
190
  .gradio-container{ max-width:640px; margin:auto; }
191
  footer{ display:none !important; }
192
  """
193
 
 
 
 
194
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
195
- gr.Markdown("### 🤖 Ins-v3 · Mobile Web Chat\n(happzy2633 / qwen2.5-7b-ins-v3 · 4bit 流式)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
197
  with gr.Row():
198
  msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True)
@@ -200,8 +329,9 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
200
  clear = gr.Button("清空对话")
201
 
202
  clear.click(lambda: [], outputs=[chat])
203
- msg.submit(stream_chat, [chat, msg], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg)
204
- send.click(stream_chat, [chat, msg], [chat], concurrency_limit=4); send.click(lambda:"", None, msg)
 
205
 
206
  # 在 Spaces 上无需 share=True
207
  demo.queue().launch(ssr_mode=False, show_api=False)
 
16
  REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY", "1.08"))
17
  SAFE_MODE = os.getenv("SAFE_MODE", "1") != "0" # 1=开启基础过滤;想关就设为 0
18
 
19
+ # ——系统基础提示 + 人设默认(强化禁止泄露思考过程)——
20
+ BASE_SYSTEM_PROMPT = os.getenv(
21
+ "SYSTEM_PROMPT",
22
+ """
23
+ You are a helpful, concise chat assistant.
24
+ Do NOT reveal chain-of-thought, analysis, inner reasoning, or <Thought> sections.
25
+ If asked to explain reasoning, provide a brief, high-level summary of steps only.
26
+ """
27
+ ).strip()
28
+ DEFAULT_PERSONA = os.getenv("PERSONA", "").strip()
29
+
30
  print(f"[boot] MODEL_ID={MODEL_ID}")
31
  print(f"[boot] torch.cuda.is_available={torch.cuda.is_available()}")
32
 
 
64
  trust_remote_code=True,
65
  )
66
  else:
 
67
  print("[boot] No GPU detected. Running on CPU is very slow for 7B. "
68
  "Consider setting MODEL_ID=Qwen/Qwen2.5-1.5B-Instruct for smoke test.")
69
  model = AutoModelForCausalLM.from_pretrained(
 
99
  return True
100
  return False
101
 
102
+ # ======================
103
+ # 关闭“思考/推理”可见输出(总开关 + 生成屏蔽 + 流式清洗)
104
+ # ======================
105
+ HIDE_THOUGHT = os.getenv("HIDE_THOUGHT", "1") != "0" # 1=隐藏思考过程
106
+
107
+ BAD_STRINGS = [
108
+ "<Thought>", "</Thought>", "Thought:", "Chain-of-Thought",
109
+ "<analysis>", "</analysis>", "analysis:", "reasoning:",
110
+ "推理过程", "思考过程", "分析:"
111
+ ]
112
+
113
+ # 供 generate() 使用的 bad_words_ids
114
+ BAD_WORDS_IDS = []
115
+ if HIDE_THOUGHT:
116
+ for s in BAD_STRINGS:
117
+ ids = tokenizer(s, add_special_tokens=False).input_ids
118
+ if ids:
119
+ BAD_WORDS_IDS.append(ids)
120
+
121
+ # 正则清洗(即使未闭合标签也会被隐藏)
122
+ def sanitize_visible(text: str) -> str:
123
+ if not HIDE_THOUGHT or not text:
124
+ return text
125
+ # 1) 清除成对标签 <Thought> ... </Thought> / <analysis> ... </analysis>
126
+ text = re.sub(
127
+ r"(?is)<\s*(thought|analysis|chain[_\s-]?of[_\s-]?thought)\s*>.*?</\s*\1\s*>",
128
+ "", text
129
+ )
130
+ # 2) 清除从开标签开始到末尾(处理未闭合情况)
131
+ text = re.sub(
132
+ r"(?is)<\s*(thought|analysis|chain[_\s-]?of[_\s-]?thought)\s*>.*$",
133
+ "", text
134
+ )
135
+ # 3) 清常见行前缀(Thought: / analysis: / reasoning: / 思考过程 / 分析:)
136
+ text = re.sub(
137
+ r"(?im)^\s*(Thought:|analysis:|reasoning:|思考过程|推理过程|分析:).*$",
138
+ "", text
139
+ )
140
+ return text.strip()
141
+
142
  # ======================
143
  # 动态长度:根据输入长短调 max_new_tokens
144
  # ======================
 
150
  return min(384, MAX_NEW_TOKENS + 128)
151
 
152
  # ======================
153
+ # 人设:预设模板 + 组合系统提示
154
  # ======================
155
+ PERSONA_TEMPLATES = {
156
+ "默认(Default)": "",
157
+ "风骚御姐": (
158
+ "风格:风骚、骚话连篇、性感、露骨、挑逗、性欲\n"
159
+ "行为:先共情后建议;适度使用表情符号。\n"
160
+ "边界:不提供医疗/法律结论;出现风险话题时建议求助专业人士。"
161
+ ),
162
+ "温柔体贴 Sweet & Caring": (
163
+ "风格:温柔、共情、耐心,语气轻柔,简洁而温暖。\n"
164
+ "行为:多用复述与情感标注(我理解你感到…),先共情后建议;适度使用表情符号。\n"
165
+ "边界:不提供医疗/法律结论;出现风险话题时建议求助专业人士。"
166
+ ),
167
+ "傲娇 Tsundere": (
168
+ "风格:表面嘴硬、内心关心;语气俏皮、略微傲娇但不贬低用户。\n"
169
+ "行为:简短吐槽 + 实用建议;保持幽默,避免人身攻击。"
170
+ ),
171
+ "风趣段子手 Witty Jokester": (
172
+ "风格:机智幽默、梗多低俗;注重包袱节奏。\n"
173
+ "行为:先给答案,再加简短段子或比喻;避免冷笑话过载。"
174
+ ),
175
+ "严谨导师 Analytical Coach": (
176
+ "风格:结构化、可操作;分点给出步骤与注意事项。\n"
177
+ "行为:先澄清目标,再给里程碑与检查清单;引用通用最佳实践。"
178
+ ),
179
+ "冷淡毒舌 Deadpan Sarcasm": (
180
+ "风格:冷面、反讽、粗鲁、辱骂。\n"
181
+ "行为:直切重点,短句给结论;如用户受挫,稍作暖场。"
182
+ ),
183
+ "心理咨询同伴 Therapeutic Listener": (
184
+ "风格:安全、无评判;以倾听与镜像为主。\n"
185
+ "行为:开放式问题、情绪命名、目标共创;明确非危机干预角色。"
186
+ ),
187
+ }
188
 
189
+ def compose_system_prompt(base_prompt: str, persona_text: str) -> str:
190
+ persona_text = (persona_text or "").strip()
191
+ if not persona_text:
192
+ return base_prompt
193
+ return (
194
+ f"{base_prompt}\n\n"
195
+ f"# Persona\n{persona_text}\n\n"
196
+ f"# Rules\n"
197
+ f"- Stay in persona unless the user explicitly asks to change.\n"
198
+ f"- Be concise unless the user asks for detail.\n"
199
+ f"- Do NOT reveal chain-of-thought or <Thought> sections.\n"
200
+ )
201
+
202
+ # ======================
203
+ # 构建 Qwen 模板 Prompt(messages 形式 → chat_template)
204
+ # ======================
205
+ def build_prompt(history_msgs, user_msg: str, persona_text: str) -> str:
206
  """
207
  history_msgs: Chatbot(type='messages') 的历史 [{role, content}, ...]
208
  """
209
+ system_prompt = compose_system_prompt(BASE_SYSTEM_PROMPT, persona_text)
210
+ messages = [{"role": "system", "content": system_prompt}]
211
  tail = [m for m in history_msgs if m.get("role") in ("user", "assistant")]
212
  tail = tail[-8:] if len(tail) > 8 else tail
213
  messages.extend(tail)
 
231
  )
232
 
233
  # ======================
234
+ # 主推理:流式输出(含 persona + 思考清洗)
235
  # ======================
236
+ def stream_chat(history_msgs, user_msg, persona_text):
237
  try:
238
  if not user_msg or not user_msg.strip():
239
  yield history_msgs; return
240
 
241
+ # 先用原始用户输入做安全检测
242
  if violates(user_msg):
243
  yield history_msgs + [
244
+ {"role": "user", "content": user_msg},
245
+ {"role": "assistant", "content": SAFE_REPLACEMENT},
246
  ]
247
  return
248
 
249
+ prompt = build_prompt(history_msgs, user_msg, persona_text)
250
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
251
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
252
 
 
255
  max_new_tokens=choose_max_new_tokens(user_msg),
256
  **BASE_GEN_KW
257
  )
258
+ # 仅在需要时传入 bad_words_ids
259
+ if HIDE_THOUGHT and BAD_WORDS_IDS:
260
+ gen_kwargs["bad_words_ids"] = BAD_WORDS_IDS
261
 
262
  print("[gen] start")
263
  th = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
 
266
  reply = ""
267
  for chunk in streamer:
268
  reply += chunk
269
+ visible = sanitize_visible(reply) # 每步清洗
270
+
271
+ # 用可见文本做安全检测与展示
272
+ if violates(visible):
273
  yield history_msgs + [
274
+ {"role": "user", "content": user_msg},
275
+ {"role": "assistant", "content": SAFE_REPLACEMENT},
276
  ]
277
  return
278
+
279
  yield history_msgs + [
280
+ {"role": "user", "content": user_msg},
281
+ {"role": "assistant", "content": visible},
282
  ]
283
  print("[gen] done, len:", len(reply))
284
 
 
286
  traceback.print_exc()
287
  err = f"【运行异常】{type(e).__name__}: {e}"
288
  yield history_msgs + [
289
+ {"role": "user", "content": user_msg},
290
+ {"role": "assistant", "content": err},
291
  ]
292
 
293
  # ======================
294
+ # Gradio UI(移动端友好 + Persona)
295
  # ======================
296
  CSS = """
297
  .gradio-container{ max-width:640px; margin:auto; }
298
  footer{ display:none !important; }
299
  """
300
 
301
+ def pick_persona(name: str) -> str:
302
+ return PERSONA_TEMPLATES.get(name or "默认(Default)", "")
303
+
304
  with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
305
+ gr.Markdown("### 懂你寂寞 · Let's Chat\n ")
306
+
307
+ # Persona 折叠区
308
+ with gr.Accordion("🎭 Persona(人设)", open=False):
309
+ persona_sel = gr.Dropdown(
310
+ choices=list(PERSONA_TEMPLATES.keys()),
311
+ value="默认(Default)" if not DEFAULT_PERSONA else None,
312
+ label="选择预设人设"
313
+ )
314
+ persona_box = gr.Textbox(
315
+ value=DEFAULT_PERSONA if DEFAULT_PERSONA else pick_persona("默认(Default)"),
316
+ placeholder="在这里粘贴 / 编辑你的 Persona 文本。留空则仅使用基础 SYSTEM_PROMPT。",
317
+ lines=8,
318
+ label="Persona 描述(可编辑,发送时以此为准)"
319
+ )
320
+ gr.Markdown(
321
+ "> 提示:下拉选择会把对应模板填入上面的文本框;发送消息时,实际使用的是文本框里的内容。"
322
+ )
323
+ persona_sel.change(fn=pick_persona, inputs=persona_sel, outputs=persona_box)
324
+
325
  chat = gr.Chatbot(type="messages", height=520, show_copy_button=True)
326
  with gr.Row():
327
  msg = gr.Textbox(placeholder="说点什么…(回车发送)", autofocus=True)
 
329
  clear = gr.Button("清空对话")
330
 
331
  clear.click(lambda: [], outputs=[chat])
332
+ # persona_box 作为第三个参数传入流式函数
333
+ msg.submit(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); msg.submit(lambda:"", None, msg)
334
+ send.click(stream_chat, [chat, msg, persona_box], [chat], concurrency_limit=4); send.click(lambda:"", None, msg)
335
 
336
  # 在 Spaces 上无需 share=True
337
  demo.queue().launch(ssr_mode=False, show_api=False)