cafe3310 commited on
Commit
9602bb7
·
1 Parent(s): 491422c

Refactor: Implement streaming response and simplify architecture

Browse files

- Replace LangGraph with a direct Gradio implementation for simplicity and performance.
- Implement streaming responses using `TextIteratorStreamer` for a better user experience.
- Use `tokenizer.apply_chat_template` for robust prompt formatting.
- Remove obsolete `graph.py.

Files changed (4) hide show
  1. GEMINI.md +4 -5
  2. app.py +19 -42
  3. comp.py +46 -46
  4. graph.py +0 -42
GEMINI.md CHANGED
@@ -19,21 +19,20 @@
19
 
20
  # 子目标
21
  ## 未完成
22
- - [ ] **(进行中)** 解决模型体积过大导致部署失败的问题。
23
- - [ ] (已暂停) 实现自动化部署和验证流程。
24
 
25
  ## 已完成
 
26
  - [x] 使用 LangGraph 实现一个可以路由两个模型的聊天网页应用。
27
 
28
  ---
29
 
30
  # Todolist
31
  ## 未完成
32
- - [ ] **当前任务**: 修改 `app.py`,移除 `Ling-flash-2.0` 模型,只保留 `Ring-mini-2.0`。
33
- - [ ] (待定) 根据用户找到的量化模型,更新 `app.py` 中的模型路径。
34
  - [ ] (已暂停) 搜索 `huggingface_hub` 文档,确认是否存在用于重启 Space 的 API。
35
 
36
  ## 已完成
 
37
  - [x] **(用户决策)** 确认 `Ling-flash-2.0` 模型过大,暂时移除,仅使用 `Ring-mini-2.0`。
38
  - [x] 搭建 LangGraph 基础架构并重构 `app.py`。
39
  - [x] 实现基于用户输入的模型路由逻辑。
@@ -65,4 +64,4 @@
65
  - **平台:** HuggingFace Spaces
66
  - **订阅:** HuggingFace Pro
67
  - **推理资源:** 可以使用 ZeroGPU
68
- - **文档参考:** 在必要的时候,主动搜索 HuggingFace 以及 Gradio 的在线 API 文档。
 
19
 
20
  # 子目标
21
  ## 未完成
22
+ - [ ] **(进行中)** 实现自动化部署和验证流程。
 
23
 
24
  ## 已完成
25
+ - [x] 解决模型体积过大导致部署失败的问题。
26
  - [x] 使用 LangGraph 实现一个可以路由两个模型的聊天网页应用。
27
 
28
  ---
29
 
30
  # Todolist
31
  ## 未完成
 
 
32
  - [ ] (已暂停) 搜索 `huggingface_hub` 文档,确认是否存在用于重启 Space 的 API。
33
 
34
  ## 已完成
35
+ - [x] 修改 `app.py`,移除 `Ling-flash-2.0` 模型,只保留 `Ring-mini-2.0`。
36
  - [x] **(用户决策)** 确认 `Ling-flash-2.0` 模型过大,暂时移除,仅使用 `Ring-mini-2.0`。
37
  - [x] 搭建 LangGraph 基础架构并重构 `app.py`。
38
  - [x] 实现基于用户输入的模型路由逻辑。
 
64
  - **平台:** HuggingFace Spaces
65
  - **订阅:** HuggingFace Pro
66
  - **推理资源:** 可以使用 ZeroGPU
67
+ - **文档参考:** 在必要的时候,主动搜索 HuggingFace 以及 Gradio 的在线 API 文档。
app.py CHANGED
@@ -1,51 +1,28 @@
1
  import gradio as gr
2
- import spaces
3
- from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
4
 
5
- # 导入已编译的 LangGraph 应用
6
- from graph import app
7
 
8
- @spaces.GPU
9
- def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
10
- """Gradio 接口的响应函数,调用 LangGraph 应用"""
11
-
12
- # Gradio 的 history 格式转换为 LangChain 消息格式
13
- messages = []
14
- if system_message:
15
- messages.append(SystemMessage(content=system_message))
16
-
17
- for chat_message in history:
18
- if chat_message['role'] == "user":
19
- messages.append(HumanMessage(content=chat_message['content']))
20
- elif chat_message['role'] == "assistant":
21
- messages.append(AIMessage(content=chat_message['content']))
22
-
23
- messages.append(HumanMessage(content=message))
24
-
25
- # 使用 invoke 方法进行一次性调用
26
- inputs = {"messages": messages}
27
- final_state = app.invoke(inputs)
28
-
29
- # 从最终状态中提取最后一条消息
30
- final_response = final_state["messages"][-1].content
31
-
32
- return final_response
33
 
34
- # 重新定义 ChatInterface
35
- chatbot = gr.ChatInterface(
36
- respond,
37
- type="messages", # 改为 messages 类型以更好地匹配 LangChain
38
- additional_inputs=[
39
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
40
- ],
41
- )
42
 
43
- with gr.Blocks() as demo:
44
- gr.Markdown("# HuggingFace Running")
45
- with gr.Sidebar():
46
- gr.LoginButton()
47
- chatbot.render()
 
48
 
 
 
 
 
49
 
50
  if __name__ == "__main__":
51
  demo.launch()
 
1
  import gradio as gr
2
+ from comp import generate_response
 
3
 
4
+ # --- Gradio UI ---
 
5
 
6
+ with gr.Blocks() as demo:
7
+ gr.Markdown("# Ling Playground")
8
+ chatbot = gr.Chatbot()
9
+ msg = gr.Textbox()
10
+ clear = gr.ClearButton([msg, chatbot])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ def user(user_message, history):
13
+ return "", history + [[user_message, None]]
 
 
 
 
 
 
14
 
15
+ def bot(history):
16
+ user_message = history[-1][0]
17
+ history[-1][1] = ""
18
+ for response in generate_response(user_message, history[:-1]):
19
+ history[-1][1] = response
20
+ yield history
21
 
22
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
23
+ bot, chatbot, chatbot
24
+ )
25
+ clear.click(lambda: None, None, chatbot, queue=False)
26
 
27
  if __name__ == "__main__":
28
  demo.launch()
comp.py CHANGED
@@ -1,12 +1,7 @@
1
  import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
- from langchain_core.messages import AIMessage
4
- from typing import TypedDict, Annotated, List
5
- import operator
6
-
7
- # 定义此组件操作的图状态的子集
8
- class GraphState(TypedDict):
9
- messages: Annotated[List[AIMessage], operator.add]
10
 
11
  # --- 模型加载 ---
12
  # 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度
@@ -20,44 +15,49 @@ model = AutoModelForCausalLM.from_pretrained(
20
  trust_remote_code=True
21
  )
22
 
23
-
24
- def completion_node(state: GraphState) -> dict:
25
- """
26
- 一个调用语言模型以获取响应的节点。
27
-
28
- Args:
29
- state (GraphState): 图的当前状态,包含消息历史。
30
-
31
- Returns:
32
- dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。
33
- """
34
- messages = state["messages"]
35
-
36
- # --- 提示工程 ---
37
- # 从消息历史中组装提示。
38
- prompt = ""
39
- for msg in messages:
40
- if msg.type == "system":
41
- prompt += f"{msg.content}\n"
42
- elif msg.type == "human":
43
- prompt += f"User: {msg.content}\n"
44
- elif msg.type == "ai":
45
- prompt += f"Assistant: {msg.content}\n"
46
- prompt += "Assistant:"
47
-
48
- # --- 模型调用 ---
49
- # 调用 tokenizer 时获取 input_ids 和 attention_mask
50
- inputs = tokenizer(prompt, return_tensors="pt")
51
-
52
- # 将 attention_mask 和 input_ids 一起传递给 model.generate
53
- output_ids = model.generate(
54
- inputs.input_ids.to(model.device),
55
- attention_mask=inputs.attention_mask.to(model.device),
56
- max_new_tokens=512, # 暂时硬编码
57
  do_sample=True,
58
- pad_token_id=tokenizer.eos_token_id,
59
  )
60
- output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
61
 
62
- # AIMessage 的形式返回响应,以添加到图的状态中。
63
- return {"messages": [AIMessage(content=output)]}
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ from threading import Thread
4
+ import spaces
 
 
 
 
 
5
 
6
  # --- 模型加载 ---
7
  # 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度
 
15
  trust_remote_code=True
16
  )
17
 
18
+ @spaces.GPU(duration=120)
19
+ def generate_response(message, history):
20
+ # Convert history to messages format
21
+ messages = [
22
+ {"role": "system", "content": "You are Ring, an assistant created by inclusionAI"}
23
+ ]
24
+
25
+ # Add conversation history
26
+ for human, assistant in history:
27
+ messages.append({"role": "user", "content": human})
28
+ messages.append({"role": "assistant", "content": assistant})
29
+
30
+ # Add current message
31
+ messages.append({"role": "user", "content": message})
32
+
33
+ # Apply chat template
34
+ text = tokenizer.apply_chat_template(
35
+ messages,
36
+ tokenize=False,
37
+ add_generation_prompt=True
38
+ )
39
+
40
+ # Tokenize input
41
+ model_inputs = tokenizer([text], return_tensors="pt", return_token_type_ids=False).to(model.device)
42
+
43
+ # Generate response with streaming
44
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)
45
+
46
+ generation_kwargs = dict(
47
+ **model_inputs,
48
+ max_new_tokens=8192,
49
+ temperature=0.7,
 
 
50
  do_sample=True,
51
+ streamer=streamer,
52
  )
 
53
 
54
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
55
+ thread.start()
56
+
57
+ # Stream the response
58
+ response = ""
59
+ for new_text in streamer:
60
+ response += new_text
61
+ yield response
62
+
63
+ thread.join()
graph.py DELETED
@@ -1,42 +0,0 @@
1
- import operator
2
- from typing import Annotated, List
3
- from typing_extensions import TypedDict
4
-
5
- from langchain_core.messages import AnyMessage
6
- from langgraph.graph import StateGraph, END
7
-
8
- # 从我们的组件文件中导入模型补全节点
9
- from comp import completion_node
10
-
11
- # --- 图状态定义 ---
12
- # 状态是我们图的内存或上下文。它是一个字典,
13
- # 保存了对话过程中交换的所有消息。
14
- class GraphState(TypedDict):
15
- """
16
- 表示我们图的状态。
17
-
18
- Attributes:
19
- messages: 一个随时间自动累积的消息列表。
20
- `operator.add` 注解告诉 LangGraph 将新消息附加到此列表,
21
- 而不是覆盖它。这就是图如何维护对话历史(上下文)的方式。
22
- """
23
- messages: Annotated[List[AnyMessage], operator.add]
24
-
25
-
26
- # --- 图工作流构建 ---
27
- # 使用我们定义的状态创建一个新的状态图
28
- workflow = StateGraph(GraphState)
29
-
30
- # 将补全节点添加到图中。我们将其命名为 “llm”。
31
- # 这个节点负责调用语言模型。
32
- workflow.add_node("llm", completion_node)
33
-
34
- # 设置图的入口点。第一个被调用的节点是 “llm”。
35
- workflow.set_entry_point("llm")
36
-
37
- # 从 “llm” 节点到 END 添加一条简单的边。
38
- # 这意味着在调用 LLM 后,图的执行就完成了。
39
- workflow.add_edge("llm", END)
40
-
41
- # 将工作流编译成一个可运行的应用。
42
- app = workflow.compile()