cafe3310 commited on
Commit
551e9e2
·
1 Parent(s): 21916d9

refactor: 重构项目结构并优化模型加载方式

Browse files

将应用逻辑拆分为三个独立的模块:

- `app.py`: Gradio 界面及应用入口。

- `graph.py`: LangGraph 状态及工作流定义。

- `comp.py`: 模型加载及推理逻辑。

此次变更还在 `comp.py` 中更新了模型加载方式,使用 `device_map="auto"` 和 `torch_dtype="auto"` 以实现硬件自动优化,提高可移植性。

Files changed (3) hide show
  1. app.py +4 -61
  2. comp.py +60 -0
  3. graph.py +42 -0
app.py CHANGED
@@ -1,67 +1,10 @@
1
  import gradio as gr
2
  import spaces
3
- import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
 
6
- import operator
7
- from typing import Annotated, Literal
8
- from typing_extensions import TypedDict
9
 
10
- from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, HumanMessage, ToolMessage
11
- from langgraph.graph import StateGraph, END
12
-
13
-
14
- # 定义图的状态
15
- class GraphState(TypedDict):
16
- messages: Annotated[list[AnyMessage], operator.add]
17
-
18
-
19
- # 只加载一次模型和分词器
20
- MODEL_NAME = "inclusionAI/Ring-mini-2.0"
21
- device = "cuda" if torch.cuda.is_available() else "cpu"
22
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
23
- model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_NAME,
25
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
26
- trust_remote_code=True
27
- ).to(device)
28
-
29
-
30
- # 定义图的节点
31
- def call_model(state: GraphState):
32
- """模型调用节点"""
33
- messages = state["messages"]
34
-
35
- # 拼接 prompt
36
- prompt = ""
37
- for msg in messages:
38
- if msg.type == "system":
39
- prompt += f"{msg.content}\n"
40
- elif msg.type == "human":
41
- prompt += f"User: {msg.content}\n"
42
- elif msg.type == "ai":
43
- prompt += f"Assistant: {msg.content}\n"
44
- prompt += "Assistant:"
45
-
46
- input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
47
- output_ids = model.generate(
48
- input_ids,
49
- max_new_tokens=512, # 暂时硬编码
50
- do_sample=True,
51
- pad_token_id=tokenizer.eos_token_id,
52
- )
53
- output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
54
-
55
- return {"messages": [AIMessage(content=output)]}
56
-
57
- # 构建图
58
- workflow = StateGraph(GraphState)
59
- workflow.add_node("llm", call_model)
60
- workflow.set_entry_point("llm")
61
- workflow.add_edge("llm", END)
62
-
63
- # 编译图
64
- app = workflow.compile()
65
  @spaces.GPU
66
  def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
67
  """Gradio 接口的响应函数,调用 LangGraph 应用"""
@@ -106,4 +49,4 @@ with gr.Blocks() as demo:
106
 
107
 
108
  if __name__ == "__main__":
109
- demo.launch()
 
1
  import gradio as gr
2
  import spaces
3
+ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
 
4
 
5
+ # 导入已编译的 LangGraph 应用
6
+ from graph import app
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  @spaces.GPU
9
  def respond(message, history, system_message, hf_token: gr.OAuthToken = None):
10
  """Gradio 接口的响应函数,调用 LangGraph 应用"""
 
49
 
50
 
51
  if __name__ == "__main__":
52
+ demo.launch()
comp.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from langchain_core.messages import AIMessage
4
+ from typing import TypedDict, Annotated, List
5
+ import operator
6
+
7
+ # 定义此组件操作的图状态的子集
8
+ class GraphState(TypedDict):
9
+ messages: Annotated[List[AIMessage], operator.add]
10
+
11
+ # --- 模型加载 ---
12
+ # 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度
13
+ MODEL_NAME = "inclusionAI/Ring-mini-2.0"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ MODEL_NAME,
18
+ torch_dtype="auto",
19
+ device_map="auto",
20
+ trust_remote_code=True
21
+ )
22
+
23
+
24
+ def completion_node(state: GraphState) -> dict:
25
+ """
26
+ 一个调用语言模型以获取响应的节点。
27
+
28
+ Args:
29
+ state (GraphState): 图的当前状态,包含消息历史。
30
+
31
+ Returns:
32
+ dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。
33
+ """
34
+ messages = state["messages"]
35
+
36
+ # --- 提示工程 ---
37
+ # 从消息历史中组装提示。
38
+ prompt = ""
39
+ for msg in messages:
40
+ if msg.type == "system":
41
+ prompt += f"{msg.content}\n"
42
+ elif msg.type == "human":
43
+ prompt += f"User: {msg.content}\n"
44
+ elif msg.type == "ai":
45
+ prompt += f"Assistant: {msg.content}\n"
46
+ prompt += "Assistant:"
47
+
48
+ # --- 模型调用 ---
49
+ # 使用 device_map="auto" 时,我们无需手动将张量移动到特定设备
50
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
51
+ output_ids = model.generate(
52
+ input_ids,
53
+ max_new_tokens=512, # 暂时硬编码
54
+ do_sample=True,
55
+ pad_token_id=tokenizer.eos_token_id,
56
+ )
57
+ output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
58
+
59
+ # 以 AIMessage 的形式返回响应,以添加到图的状态中。
60
+ return {"messages": [AIMessage(content=output)]}
graph.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import operator
2
+ from typing import Annotated, List
3
+ from typing_extensions import TypedDict
4
+
5
+ from langchain_core.messages import AnyMessage
6
+ from langgraph.graph import StateGraph, END
7
+
8
+ # 从我们的组件文件中导入模型补全节点
9
+ from comp import completion_node
10
+
11
+ # --- 图状态定义 ---
12
+ # 状态是我们图的内存或上下文。它是一个字典,
13
+ # 保存了对话过程中交换的所有消息。
14
+ class GraphState(TypedDict):
15
+ """
16
+ 表示我们图的状态。
17
+
18
+ Attributes:
19
+ messages: 一个随时间自动累积的消息列表。
20
+ `operator.add` 注解告诉 LangGraph 将新消息附加到此列表,
21
+ 而不是覆盖它。这就是图如何维护对话历史(上下文)的方式。
22
+ """
23
+ messages: Annotated[List[AnyMessage], operator.add]
24
+
25
+
26
+ # --- 图工作流构建 ---
27
+ # 使用我们定义的状态创建一个新的状态图
28
+ workflow = StateGraph(GraphState)
29
+
30
+ # 将补全节点添加到图中。我们将其命名为 “llm”。
31
+ # 这个节点负责调用语言模型。
32
+ workflow.add_node("llm", completion_node)
33
+
34
+ # 设置图的入口点。第一个被调用的节点是 “llm”。
35
+ workflow.set_entry_point("llm")
36
+
37
+ # 从 “llm” 节点到 END 添加一条简单的边。
38
+ # 这意味着在调用 LLM 后,图的执行就完成了。
39
+ workflow.add_edge("llm", END)
40
+
41
+ # 将工作流编译成一个可运行的应用。
42
+ app = workflow.compile()