raoqu commited on
Commit
4159dd0
·
1 Parent(s): 16ef8ef
app.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import tempfile
3
+
4
+ import requests
5
+ import streamlit as st
6
+ from lagent.schema import AgentStatusCode
7
+ from pyvis.network import Network
8
+
9
+
10
+ # Function to create the network graph
11
+ def create_network_graph(nodes, adjacency_list):
12
+ net = Network(height="500px", width="60%", bgcolor="white", font_color="black")
13
+ for node_id, node_content in nodes.items():
14
+ net.add_node(node_id, label=node_id, title=node_content, color="#FF5733", size=25)
15
+ for node_id, neighbors in adjacency_list.items():
16
+ for neighbor in neighbors:
17
+ if neighbor["name"] in nodes:
18
+ net.add_edge(node_id, neighbor["name"])
19
+ net.show_buttons(filter_=["physics"])
20
+ return net
21
+
22
+
23
+ # Function to draw the graph and return the HTML file path
24
+ def draw_graph(net):
25
+ path = tempfile.mktemp(suffix=".html")
26
+ net.save_graph(path)
27
+ return path
28
+
29
+
30
+ def streaming(raw_response):
31
+ for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\n"):
32
+ if chunk:
33
+ decoded = chunk.decode("utf-8")
34
+ if decoded == "\r":
35
+ continue
36
+ if decoded[:6] == "data: ":
37
+ decoded = decoded[6:]
38
+ elif decoded.startswith(": ping - "):
39
+ continue
40
+ response = json.loads(decoded)
41
+ yield (
42
+ response["current_node"],
43
+ (
44
+ response["response"]["formatted"]["node"][response["current_node"]]["response"]
45
+ if response["current_node"]
46
+ else response["response"]
47
+ ),
48
+ response["response"]["formatted"]["adjacency_list"],
49
+ )
50
+
51
+
52
+ # Initialize Streamlit session state
53
+ if "queries" not in st.session_state:
54
+ st.session_state["queries"] = []
55
+ st.session_state["responses"] = []
56
+ st.session_state["graphs_html"] = []
57
+ st.session_state["nodes_list"] = []
58
+ st.session_state["adjacency_list_list"] = []
59
+ st.session_state["history"] = []
60
+ st.session_state["already_used_keys"] = list()
61
+
62
+ # Set up page layout
63
+ st.set_page_config(layout="wide")
64
+ st.title("MindSearch-思索")
65
+
66
+
67
+ # Function to update chat
68
+ def update_chat(query):
69
+ with st.chat_message("user"):
70
+ st.write(query)
71
+ if query not in st.session_state["queries"]:
72
+ # Mock data to simulate backend response
73
+ # response, history, nodes, adjacency_list
74
+ st.session_state["queries"].append(query)
75
+ st.session_state["responses"].append([])
76
+ history = None
77
+ # 暂不支持多轮
78
+ # message = [dict(role='user', content=query)]
79
+
80
+ url = "http://localhost:8002/solve"
81
+ headers = {"Content-Type": "application/json"}
82
+ data = {"inputs": query}
83
+ raw_response = requests.post(
84
+ url, headers=headers, data=json.dumps(data), timeout=20, stream=True
85
+ )
86
+
87
+ _nodes, _node_cnt = {}, 0
88
+ for resp in streaming(raw_response):
89
+ node_name, response, adjacency_list = resp
90
+ for name in set(adjacency_list) | {
91
+ val["name"] for vals in adjacency_list.values() for val in vals
92
+ }:
93
+ if name not in _nodes:
94
+ _nodes[name] = query if name == "root" else name
95
+ elif response["stream_state"] == 0:
96
+ _nodes[node_name or "response"] = response["formatted"] and response[
97
+ "formatted"
98
+ ].get("thought")
99
+ if len(_nodes) != _node_cnt or response["stream_state"] == 0:
100
+ net = create_network_graph(_nodes, adjacency_list)
101
+ graph_html_path = draw_graph(net)
102
+ with open(graph_html_path, encoding="utf-8") as f:
103
+ graph_html = f.read()
104
+ _node_cnt = len(_nodes)
105
+ else:
106
+ graph_html = None
107
+ if "graph_placeholder" not in st.session_state:
108
+ st.session_state["graph_placeholder"] = st.empty()
109
+ if "expander_placeholder" not in st.session_state:
110
+ st.session_state["expander_placeholder"] = st.empty()
111
+ if graph_html:
112
+ with st.session_state["expander_placeholder"].expander(
113
+ "Show Graph", expanded=False
114
+ ):
115
+ st.session_state["graph_placeholder"]._html(graph_html, height=500)
116
+ if "container_placeholder" not in st.session_state:
117
+ st.session_state["container_placeholder"] = st.empty()
118
+ with st.session_state["container_placeholder"].container():
119
+ if "columns_placeholder" not in st.session_state:
120
+ st.session_state["columns_placeholder"] = st.empty()
121
+ col1, col2 = st.session_state["columns_placeholder"].columns([2, 1])
122
+ with col1:
123
+ if "planner_placeholder" not in st.session_state:
124
+ st.session_state["planner_placeholder"] = st.empty()
125
+ if "session_info_temp" not in st.session_state:
126
+ st.session_state["session_info_temp"] = ""
127
+ if not node_name:
128
+ if response["stream_state"] in [
129
+ AgentStatusCode.STREAM_ING,
130
+ AgentStatusCode.CODING,
131
+ AgentStatusCode.CODE_END,
132
+ ]:
133
+ content = response["formatted"]["thought"]
134
+ if response["formatted"]["tool_type"]:
135
+ action = response["formatted"]["action"]
136
+ if isinstance(action, dict):
137
+ action = json.dumps(action, ensure_ascii=False, indent=4)
138
+ content += "\n" + action
139
+ st.session_state["session_info_temp"] = content.replace(
140
+ "<|action_start|><|interpreter|>\n", "\n"
141
+ )
142
+ elif response["stream_state"] == AgentStatusCode.CODE_RETURN:
143
+ # assert history[-1]["role"] == "environment"
144
+ st.session_state["session_info_temp"] += "\n" + response["content"]
145
+ st.session_state["planner_placeholder"].markdown(
146
+ st.session_state["session_info_temp"]
147
+ )
148
+ if response["stream_state"] == AgentStatusCode.CODE_RETURN:
149
+ st.session_state["responses"][-1].append(
150
+ st.session_state["session_info_temp"]
151
+ )
152
+ st.session_state["session_info_temp"] = ""
153
+ else:
154
+ st.session_state["planner_placeholder"].markdown(
155
+ st.session_state["responses"][-1][-1]
156
+ if not st.session_state["session_info_temp"]
157
+ else st.session_state["session_info_temp"]
158
+ )
159
+ with col2:
160
+ if "selectbox_placeholder" not in st.session_state:
161
+ st.session_state["selectbox_placeholder"] = st.empty()
162
+ if "searcher_placeholder" not in st.session_state:
163
+ st.session_state["searcher_placeholder"] = st.empty()
164
+ if node_name:
165
+ selected_node_key = (
166
+ f"selected_node_{len(st.session_state['queries'])}_{node_name}"
167
+ )
168
+ if selected_node_key not in st.session_state:
169
+ st.session_state[selected_node_key] = node_name
170
+ if selected_node_key not in st.session_state["already_used_keys"]:
171
+ selected_node = st.session_state["selectbox_placeholder"].selectbox(
172
+ "Select a node:",
173
+ list(_nodes.keys()),
174
+ key=f"key_{selected_node_key}",
175
+ index=list(_nodes.keys()).index(node_name),
176
+ )
177
+ st.session_state["already_used_keys"].append(selected_node_key)
178
+ else:
179
+ selected_node = node_name
180
+ st.session_state[selected_node_key] = selected_node
181
+ node_info_key = f"{selected_node}_info"
182
+ if node_info_key not in st.session_state:
183
+ st.session_state[node_info_key] = [["thought", ""]]
184
+ if response["stream_state"] in [AgentStatusCode.STREAM_ING]:
185
+ content = response["formatted"]["thought"]
186
+ st.session_state[node_info_key][-1][1] = content.replace(
187
+ "<|action_start|><|plugin|>\n", "\n```json\n"
188
+ )
189
+ elif response["stream_state"] in [
190
+ AgentStatusCode.PLUGIN_START,
191
+ AgentStatusCode.PLUGIN_END,
192
+ ]:
193
+ thought = response["formatted"]["thought"]
194
+ action = response["formatted"]["action"]
195
+ if isinstance(action, dict):
196
+ action = json.dumps(action, ensure_ascii=False, indent=4)
197
+ content = thought + "\n```json\n" + action
198
+ if response["stream_state"] == AgentStatusCode.PLUGIN_RETURN:
199
+ content += "\n```"
200
+ st.session_state[node_info_key][-1][1] = content
201
+ elif (
202
+ response["stream_state"] == AgentStatusCode.PLUGIN_RETURN
203
+ and st.session_state[node_info_key][-1][1]
204
+ ):
205
+ try:
206
+ content = json.loads(response["content"])
207
+ except json.decoder.JSONDecodeError:
208
+ content = response["content"]
209
+ st.session_state[node_info_key].append(
210
+ [
211
+ "observation",
212
+ (
213
+ content
214
+ if isinstance(content, str)
215
+ else f"```json\n{json.dumps(content, ensure_ascii=False, indent=4)}\n```"
216
+ ),
217
+ ]
218
+ )
219
+ st.session_state["searcher_placeholder"].markdown(
220
+ st.session_state[node_info_key][-1][1]
221
+ )
222
+ if (
223
+ response["stream_state"] == AgentStatusCode.PLUGIN_RETURN
224
+ and st.session_state[node_info_key][-1][1]
225
+ ):
226
+ st.session_state[node_info_key].append(["thought", ""])
227
+ if st.session_state["session_info_temp"]:
228
+ st.session_state["responses"][-1].append(st.session_state["session_info_temp"])
229
+ st.session_state["session_info_temp"] = ""
230
+ # st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1])
231
+ st.session_state["graphs_html"].append(graph_html)
232
+ st.session_state["nodes_list"].append(_nodes)
233
+ st.session_state["adjacency_list_list"].append(adjacency_list)
234
+ st.session_state["history"] = history
235
+
236
+
237
+ def display_chat_history():
238
+ for i, query in enumerate(st.session_state["queries"][-1:]):
239
+ # with st.chat_message('assistant'):
240
+ if st.session_state["graphs_html"][i]:
241
+ with st.session_state["expander_placeholder"].expander("Show Graph", expanded=False):
242
+ st.session_state["graph_placeholder"]._html(
243
+ st.session_state["graphs_html"][i], height=500
244
+ )
245
+ with st.session_state["container_placeholder"].container():
246
+ col1, col2 = st.session_state["columns_placeholder"].columns([2, 1])
247
+ with col1:
248
+ st.session_state["planner_placeholder"].markdown(
249
+ st.session_state["responses"][-1][-1]
250
+ )
251
+ with col2:
252
+ selected_node_key = st.session_state["already_used_keys"][-1]
253
+ st.session_state["selectbox_placeholder"] = st.empty()
254
+ selected_node = st.session_state["selectbox_placeholder"].selectbox(
255
+ "Select a node:",
256
+ list(st.session_state["nodes_list"][i].keys()),
257
+ key=f"replay_key_{i}",
258
+ index=list(st.session_state["nodes_list"][i].keys()).index(
259
+ st.session_state[selected_node_key]
260
+ ),
261
+ )
262
+ st.session_state[selected_node_key] = selected_node
263
+ if (
264
+ selected_node not in ["root", "response"]
265
+ and selected_node in st.session_state["nodes_list"][i]
266
+ ):
267
+ node_info_key = f"{selected_node}_info"
268
+ for item in st.session_state[node_info_key]:
269
+ if item[0] in ["thought", "answer"]:
270
+ st.session_state["searcher_placeholder"] = st.empty()
271
+ st.session_state["searcher_placeholder"].markdown(item[1])
272
+ elif item[0] == "observation":
273
+ st.session_state["observation_expander"] = st.empty()
274
+ with st.session_state["observation_expander"].expander("Results"):
275
+ st.write(item[1])
276
+ # st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key])
277
+
278
+
279
+ def clean_history():
280
+ st.session_state["queries"] = []
281
+ st.session_state["responses"] = []
282
+ st.session_state["graphs_html"] = []
283
+ st.session_state["nodes_list"] = []
284
+ st.session_state["adjacency_list_list"] = []
285
+ st.session_state["history"] = []
286
+ st.session_state["already_used_keys"] = list()
287
+ for k in st.session_state:
288
+ if k.endswith("placeholder") or k.endswith("_info"):
289
+ del st.session_state[k]
290
+
291
+
292
+ # Main function to run the Streamlit app
293
+ def main():
294
+ st.sidebar.title("Model Control")
295
+ col1, col2 = st.columns([4, 1])
296
+ with col1:
297
+ user_input = st.chat_input("Enter your query:")
298
+ with col2:
299
+ if st.button("Clear History"):
300
+ clean_history()
301
+ if user_input:
302
+ update_chat(user_input)
303
+ display_chat_history()
304
+
305
+
306
+ import os
307
+
308
+ if __name__ == "__main__":
309
+ os.system("pip show starlette")
310
+ # os.system("pip install -r requirements.txt")
311
+ os.system("pip install tenacity")
312
+ os.system("python -m mindsearch.app --lang en --model_format internlm_silicon --search_engine GoogleSearch &")
313
+
314
+ main()
mindsearch/__init__.py ADDED
File without changes
mindsearch/agent/__init__.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from copy import deepcopy
3
+ from datetime import datetime
4
+
5
+ from lagent.actions import AsyncWebBrowser, WebBrowser
6
+ from lagent.agents.stream import get_plugin_prompt
7
+ from lagent.prompts import InterpreterParser, PluginParser
8
+ from lagent.utils import create_object
9
+
10
+ from . import models as llm_factory
11
+ from .mindsearch_agent import AsyncMindSearchAgent, MindSearchAgent
12
+ from .mindsearch_prompt import (
13
+ FINAL_RESPONSE_CN,
14
+ FINAL_RESPONSE_EN,
15
+ GRAPH_PROMPT_CN,
16
+ GRAPH_PROMPT_EN,
17
+ searcher_context_template_cn,
18
+ searcher_context_template_en,
19
+ searcher_input_template_cn,
20
+ searcher_input_template_en,
21
+ searcher_system_prompt_cn,
22
+ searcher_system_prompt_en,
23
+ )
24
+
25
+ LLM = {}
26
+
27
+
28
+ def init_agent(lang="cn",
29
+ model_format="internlm_server",
30
+ search_engine="BingSearch",
31
+ use_async=False):
32
+ mode = "async" if use_async else "sync"
33
+ llm = LLM.get(model_format, {}).get(mode)
34
+ if llm is None:
35
+ llm_cfg = deepcopy(getattr(llm_factory, model_format))
36
+ if llm_cfg is None:
37
+ raise NotImplementedError
38
+ if use_async:
39
+ cls_name = (
40
+ llm_cfg["type"].split(".")[-1] if isinstance(
41
+ llm_cfg["type"], str) else llm_cfg["type"].__name__)
42
+ llm_cfg["type"] = f"lagent.llms.Async{cls_name}"
43
+ llm = create_object(llm_cfg)
44
+ LLM.setdefault(model_format, {}).setdefault(mode, llm)
45
+
46
+ date = datetime.now().strftime("The current date is %Y-%m-%d.")
47
+ plugins = [(dict(
48
+ type=AsyncWebBrowser if use_async else WebBrowser,
49
+ searcher_type=search_engine,
50
+ topk=6,
51
+ secret_id=os.getenv("TENCENT_SEARCH_SECRET_ID"),
52
+ secret_key=os.getenv("TENCENT_SEARCH_SECRET_KEY"),
53
+ ) if search_engine == "TencentSearch" else dict(
54
+ type=AsyncWebBrowser if use_async else WebBrowser,
55
+ searcher_type=search_engine,
56
+ topk=6,
57
+ api_key=os.getenv("WEB_SEARCH_API_KEY"),
58
+ ))]
59
+ agent = (AsyncMindSearchAgent if use_async else MindSearchAgent)(
60
+ llm=llm,
61
+ template=date,
62
+ output_format=InterpreterParser(
63
+ template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
64
+ searcher_cfg=dict(
65
+ llm=llm,
66
+ plugins=plugins,
67
+ template=date,
68
+ output_format=PluginParser(
69
+ template=searcher_system_prompt_cn
70
+ if lang == "cn" else searcher_system_prompt_en,
71
+ tool_info=get_plugin_prompt(plugins),
72
+ ),
73
+ user_input_template=(searcher_input_template_cn if lang == "cn"
74
+ else searcher_input_template_en),
75
+ user_context_template=(searcher_context_template_cn if lang == "cn"
76
+ else searcher_context_template_en),
77
+ ),
78
+ summary_prompt=FINAL_RESPONSE_CN
79
+ if lang == "cn" else FINAL_RESPONSE_EN,
80
+ max_turn=10,
81
+ )
82
+ return agent
mindsearch/agent/graph.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import queue
3
+ import random
4
+ import re
5
+ import uuid
6
+ from collections import defaultdict
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from copy import deepcopy
9
+ from threading import Thread
10
+ from typing import Dict, List
11
+
12
+ from lagent.actions import BaseAction
13
+ from lagent.schema import AgentMessage, AgentStatusCode
14
+
15
+ from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM
16
+
17
+
18
+ class SearcherAgent(StreamingAgentForInternLM):
19
+ def __init__(
20
+ self,
21
+ user_input_template: str = "{question}",
22
+ user_context_template: str = None,
23
+ **kwargs,
24
+ ):
25
+ self.user_input_template = user_input_template
26
+ self.user_context_template = user_context_template
27
+ super().__init__(**kwargs)
28
+
29
+ def forward(
30
+ self,
31
+ question: str,
32
+ topic: str,
33
+ history: List[dict] = None,
34
+ session_id=0,
35
+ **kwargs,
36
+ ):
37
+ message = [self.user_input_template.format(question=question, topic=topic)]
38
+ if history and self.user_context_template:
39
+ message = [self.user_context_template.format_map(item) for item in history] + message
40
+ message = "\n".join(message)
41
+ return super().forward(message, session_id=session_id, **kwargs)
42
+
43
+
44
+ class AsyncSearcherAgent(AsyncStreamingAgentForInternLM):
45
+ def __init__(
46
+ self,
47
+ user_input_template: str = "{question}",
48
+ user_context_template: str = None,
49
+ **kwargs,
50
+ ):
51
+ self.user_input_template = user_input_template
52
+ self.user_context_template = user_context_template
53
+ super().__init__(**kwargs)
54
+
55
+ async def forward(
56
+ self,
57
+ question: str,
58
+ topic: str,
59
+ history: List[dict] = None,
60
+ session_id=0,
61
+ **kwargs,
62
+ ):
63
+ message = [self.user_input_template.format(question=question, topic=topic)]
64
+ if history and self.user_context_template:
65
+ message = [self.user_context_template.format_map(item) for item in history] + message
66
+ message = "\n".join(message)
67
+ async for message in super().forward(message, session_id=session_id, **kwargs):
68
+ yield message
69
+
70
+
71
+ class WebSearchGraph:
72
+ is_async = False
73
+ SEARCHER_CONFIG = {}
74
+ _SEARCHER_LOOP = []
75
+ _SEARCHER_THREAD = []
76
+
77
+ def __init__(self):
78
+ self.nodes: Dict[str, Dict[str, str]] = {}
79
+ self.adjacency_list: Dict[str, List[dict]] = defaultdict(list)
80
+ self.future_to_query = dict()
81
+ self.searcher_resp_queue = queue.Queue()
82
+ self.executor = ThreadPoolExecutor(max_workers=10)
83
+ self.n_active_tasks = 0
84
+
85
+ def add_root_node(
86
+ self,
87
+ node_content: str,
88
+ node_name: str = "root",
89
+ ):
90
+ """添加起始节点
91
+
92
+ Args:
93
+ node_content (str): 节点内容
94
+ node_name (str, optional): 节点名称. Defaults to 'root'.
95
+
96
+ """
97
+ self.nodes[node_name] = dict(content=node_content, type="root")
98
+ self.adjacency_list[node_name] = []
99
+
100
+ def add_node(
101
+ self,
102
+ node_name: str,
103
+ node_content: str,
104
+ ):
105
+ """添加搜索子问题节点
106
+
107
+ Args:
108
+ node_name (str): 节点名称
109
+ node_content (str): 子问题内容
110
+
111
+ Returns:
112
+ str: 返回搜索结果
113
+ """
114
+ self.nodes[node_name] = dict(content=node_content, type="searcher")
115
+ self.adjacency_list[node_name] = []
116
+
117
+ parent_nodes = []
118
+ for start_node, adj in self.adjacency_list.items():
119
+ for neighbor in adj:
120
+ if (
121
+ node_name == neighbor
122
+ and start_node in self.nodes
123
+ and "response" in self.nodes[start_node]
124
+ ):
125
+ parent_nodes.append(self.nodes[start_node])
126
+ parent_response = [
127
+ dict(question=node["content"], answer=node["response"]) for node in parent_nodes
128
+ ]
129
+
130
+ if self.is_async:
131
+
132
+ async def _async_search_node_stream():
133
+ cfg = {
134
+ **self.SEARCHER_CONFIG,
135
+ "plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")),
136
+ }
137
+ agent, session_id = AsyncSearcherAgent(**cfg), random.randint(0, 999999)
138
+ searcher_message = AgentMessage(sender="SearcherAgent", content="")
139
+ try:
140
+ async for searcher_message in agent(
141
+ question=node_content,
142
+ topic=self.nodes["root"]["content"],
143
+ history=parent_response,
144
+ session_id=session_id,
145
+ ):
146
+ self.nodes[node_name]["response"] = searcher_message.model_dump()
147
+ self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id)
148
+ self.nodes[node_name]["session_id"] = session_id
149
+ self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
150
+ self.searcher_resp_queue.put((None, None, None))
151
+ except Exception as exc:
152
+ self.searcher_resp_queue.put((exc, None, None))
153
+
154
+ self.future_to_query[
155
+ asyncio.run_coroutine_threadsafe(
156
+ _async_search_node_stream(), random.choice(self._SEARCHER_LOOP)
157
+ )
158
+ ] = f"{node_name}-{node_content}"
159
+ # self.future_to_query[
160
+ # self.executor.submit(asyncio.run, _async_search_node_stream())
161
+ # ] = f"{node_name}-{node_content}"
162
+ else:
163
+
164
+ def _search_node_stream():
165
+ cfg = {
166
+ **self.SEARCHER_CONFIG,
167
+ "plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")),
168
+ }
169
+ agent, session_id = SearcherAgent(**cfg), random.randint(0, 999999)
170
+ searcher_message = AgentMessage(sender="SearcherAgent", content="")
171
+ try:
172
+ for searcher_message in agent(
173
+ question=node_content,
174
+ topic=self.nodes["root"]["content"],
175
+ history=parent_response,
176
+ session_id=session_id,
177
+ ):
178
+ self.nodes[node_name]["response"] = searcher_message.model_dump()
179
+ self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id)
180
+ self.nodes[node_name]["session_id"] = session_id
181
+ self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
182
+ self.searcher_resp_queue.put((None, None, None))
183
+ except Exception as exc:
184
+ self.searcher_resp_queue.put((exc, None, None))
185
+
186
+ self.future_to_query[
187
+ self.executor.submit(_search_node_stream)
188
+ ] = f"{node_name}-{node_content}"
189
+
190
+ self.n_active_tasks += 1
191
+
192
+ def add_response_node(self, node_name="response"):
193
+ """添加回复节点
194
+
195
+ Args:
196
+ thought (str): 思考过程
197
+ node_name (str, optional): 节点名称. Defaults to 'response'.
198
+
199
+ """
200
+ self.nodes[node_name] = dict(type="end")
201
+ self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
202
+
203
+ def add_edge(self, start_node: str, end_node: str):
204
+ """添加边
205
+
206
+ Args:
207
+ start_node (str): 起始节点名称
208
+ end_node (str): 结束节点名称
209
+ """
210
+ self.adjacency_list[start_node].append(dict(id=str(uuid.uuid4()), name=end_node, state=2))
211
+ self.searcher_resp_queue.put(
212
+ (start_node, self.nodes[start_node], self.adjacency_list[start_node])
213
+ )
214
+
215
+ def reset(self):
216
+ self.nodes = {}
217
+ self.adjacency_list = defaultdict(list)
218
+
219
+ def node(self, node_name: str) -> str:
220
+ return self.nodes[node_name].copy()
221
+
222
+ @classmethod
223
+ def start_loop(cls, n: int = 32):
224
+ if not cls.is_async:
225
+ raise RuntimeError("Event loop cannot be launched as `is_async` is disabled")
226
+
227
+ assert len(cls._SEARCHER_LOOP) == len(cls._SEARCHER_THREAD)
228
+ for i, (loop, thread) in enumerate(
229
+ zip(cls._SEARCHER_LOOP.copy(), cls._SEARCHER_THREAD.copy())
230
+ ):
231
+ if not (loop.is_running() and thread.is_alive()):
232
+ cls._SEARCHER_LOOP.pop(i)
233
+ cls._SEARCHER_THREAD.pop(i)
234
+
235
+ while len(cls._SEARCHER_THREAD) < n:
236
+
237
+ def _start_loop():
238
+ loop = asyncio.new_event_loop()
239
+ asyncio.set_event_loop(loop)
240
+ cls._SEARCHER_LOOP.append(loop)
241
+ loop.run_forever()
242
+
243
+ thread = Thread(target=_start_loop, daemon=True)
244
+ thread.start()
245
+ cls._SEARCHER_THREAD.append(thread)
246
+
247
+
248
+ class ExecutionAction(BaseAction):
249
+ """Tool used by MindSearch planner to execute graph node query."""
250
+
251
+ def run(self, command, local_dict, global_dict, stream_graph=False):
252
+ def extract_code(text: str) -> str:
253
+ text = re.sub(r"from ([\w.]+) import WebSearchGraph", "", text)
254
+ triple_match = re.search(r"```[^\n]*\n(.+?)```", text, re.DOTALL)
255
+ single_match = re.search(r"`([^`]*)`", text, re.DOTALL)
256
+ if triple_match:
257
+ return triple_match.group(1)
258
+ elif single_match:
259
+ return single_match.group(1)
260
+ return text
261
+
262
+ command = extract_code(command)
263
+ exec(command, global_dict, local_dict)
264
+
265
+ # 匹配所有 graph.node 中的内容
266
+ node_list = re.findall(r"graph.node\((.*?)\)", command)
267
+ graph: WebSearchGraph = local_dict["graph"]
268
+ while graph.n_active_tasks:
269
+ while not graph.searcher_resp_queue.empty():
270
+ node_name, _, _ = graph.searcher_resp_queue.get(timeout=60)
271
+ if isinstance(node_name, Exception):
272
+ raise node_name
273
+ if node_name is None:
274
+ graph.n_active_tasks -= 1
275
+ continue
276
+ if stream_graph:
277
+ for neighbors in graph.adjacency_list.values():
278
+ for neighbor in neighbors:
279
+ # state 1进行中,2未开始,3已结束
280
+ if not (
281
+ neighbor["name"] in graph.nodes
282
+ and "response" in graph.nodes[neighbor["name"]]
283
+ ):
284
+ neighbor["state"] = 2
285
+ elif (
286
+ graph.nodes[neighbor["name"]]["response"]["stream_state"]
287
+ == AgentStatusCode.END
288
+ ):
289
+ neighbor["state"] = 3
290
+ else:
291
+ neighbor["state"] = 1
292
+ if all(
293
+ "response" in node
294
+ for name, node in graph.nodes.items()
295
+ if name not in ["root", "response"]
296
+ ):
297
+ yield AgentMessage(
298
+ sender=self.name,
299
+ content=dict(current_node=node_name),
300
+ formatted=dict(
301
+ node=deepcopy(graph.nodes),
302
+ adjacency_list=deepcopy(graph.adjacency_list),
303
+ ),
304
+ stream_state=AgentStatusCode.STREAM_ING,
305
+ )
306
+ res = [graph.nodes[node.strip().strip('"').strip("'")] for node in node_list]
307
+ return res, graph.nodes, graph.adjacency_list
mindsearch/agent/mindsearch_agent.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import re
4
+ from copy import deepcopy
5
+ from typing import Dict, Tuple
6
+
7
+ from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode
8
+ from lagent.utils import GeneratorWithReturn
9
+
10
+ from .graph import ExecutionAction, WebSearchGraph
11
+ from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM
12
+
13
+
14
+ def _update_ref(ref: str, ref2url: Dict[str, str], ptr: int) -> str:
15
+ numbers = list({int(n) for n in re.findall(r"\[\[(\d+)\]\]", ref)})
16
+ numbers = {n: idx + 1 for idx, n in enumerate(numbers)}
17
+ updated_ref = re.sub(
18
+ r"\[\[(\d+)\]\]",
19
+ lambda match: f"[[{numbers[int(match.group(1))] + ptr}]]",
20
+ ref,
21
+ )
22
+ updated_ref2url = {}
23
+ if numbers:
24
+ try:
25
+ assert all(elem in ref2url for elem in numbers)
26
+ except Exception as exc:
27
+ logging.info(f"Illegal reference id: {str(exc)}")
28
+ if ref2url:
29
+ updated_ref2url = {
30
+ numbers[idx] + ptr: ref2url[idx] for idx in numbers if idx in ref2url
31
+ }
32
+ return updated_ref, updated_ref2url, len(numbers) + 1
33
+
34
+
35
+ def _generate_references_from_graph(graph: Dict[str, dict]) -> Tuple[str, Dict[int, dict]]:
36
+ ptr, references, references_url = 0, [], {}
37
+ for name, data_item in graph.items():
38
+ if name in ["root", "response"]:
39
+ continue
40
+ # only search once at each node, thus the result offset is 2
41
+ assert data_item["memory"]["agent.memory"][2]["sender"].endswith("ActionExecutor")
42
+ ref2url = {
43
+ int(k): v
44
+ for k, v in json.loads(data_item["memory"]["agent.memory"][2]["content"]).items()
45
+ }
46
+ updata_ref, ref2url, added_ptr = _update_ref(
47
+ data_item["response"]["content"], ref2url, ptr
48
+ )
49
+ ptr += added_ptr
50
+ references.append(f'## {data_item["content"]}\n\n{updata_ref}')
51
+ references_url.update(ref2url)
52
+ return "\n\n".join(references), references_url
53
+
54
+
55
+ class MindSearchAgent(StreamingAgentForInternLM):
56
+ def __init__(
57
+ self,
58
+ searcher_cfg: dict,
59
+ summary_prompt: str,
60
+ finish_condition=lambda m: "add_response_node" in m.content,
61
+ max_turn: int = 10,
62
+ **kwargs,
63
+ ):
64
+ WebSearchGraph.SEARCHER_CONFIG = searcher_cfg
65
+ super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs)
66
+ self.summary_prompt = summary_prompt
67
+ self.action = ExecutionAction()
68
+
69
+ def forward(self, message: AgentMessage, session_id=0, **kwargs):
70
+ if isinstance(message, str):
71
+ message = AgentMessage(sender="user", content=message)
72
+ _graph_state = dict(node={}, adjacency_list={}, ref2url={})
73
+ local_dict, global_dict = {}, globals()
74
+ for _ in range(self.max_turn):
75
+ last_agent_state = AgentStatusCode.SESSION_READY
76
+ for message in self.agent(message, session_id=session_id, **kwargs):
77
+ if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
78
+ if message.stream_state == ModelStatusCode.END:
79
+ message.stream_state = last_agent_state + int(
80
+ last_agent_state
81
+ in [
82
+ AgentStatusCode.CODING,
83
+ AgentStatusCode.PLUGIN_START,
84
+ ]
85
+ )
86
+ else:
87
+ message.stream_state = (
88
+ AgentStatusCode.PLUGIN_START
89
+ if message.formatted["tool_type"] == "plugin"
90
+ else AgentStatusCode.CODING
91
+ )
92
+ else:
93
+ message.stream_state = AgentStatusCode.STREAM_ING
94
+ message.formatted.update(deepcopy(_graph_state))
95
+ yield message
96
+ last_agent_state = message.stream_state
97
+ if not message.formatted["tool_type"]:
98
+ message.stream_state = AgentStatusCode.END
99
+ yield message
100
+ return
101
+
102
+ gen = GeneratorWithReturn(
103
+ self.action.run(message.content, local_dict, global_dict, True)
104
+ )
105
+ for graph_exec in gen:
106
+ graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"])
107
+ yield graph_exec
108
+
109
+ reference, references_url = _generate_references_from_graph(gen.ret[1])
110
+ _graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url)
111
+ if self.finish_condition(message):
112
+ message = AgentMessage(
113
+ sender="ActionExecutor",
114
+ content=self.summary_prompt,
115
+ formatted=deepcopy(_graph_state),
116
+ stream_state=message.stream_state + 1, # plugin or code return
117
+ )
118
+ yield message
119
+ # summarize the references to generate the final answer
120
+ for message in self.agent(message, session_id=session_id, **kwargs):
121
+ message.formatted.update(deepcopy(_graph_state))
122
+ yield message
123
+ return
124
+ message = AgentMessage(
125
+ sender="ActionExecutor",
126
+ content=reference,
127
+ formatted=deepcopy(_graph_state),
128
+ stream_state=message.stream_state + 1, # plugin or code return
129
+ )
130
+ yield message
131
+
132
+
133
+ class AsyncMindSearchAgent(AsyncStreamingAgentForInternLM):
134
+ def __init__(
135
+ self,
136
+ searcher_cfg: dict,
137
+ summary_prompt: str,
138
+ finish_condition=lambda m: "add_response_node" in m.content,
139
+ max_turn: int = 10,
140
+ **kwargs,
141
+ ):
142
+ WebSearchGraph.SEARCHER_CONFIG = searcher_cfg
143
+ WebSearchGraph.is_async = True
144
+ WebSearchGraph.start_loop()
145
+ super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs)
146
+ self.summary_prompt = summary_prompt
147
+ self.action = ExecutionAction()
148
+
149
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
150
+ if isinstance(message, str):
151
+ message = AgentMessage(sender="user", content=message)
152
+ _graph_state = dict(node={}, adjacency_list={}, ref2url={})
153
+ local_dict, global_dict = {}, globals()
154
+ for _ in range(self.max_turn):
155
+ last_agent_state = AgentStatusCode.SESSION_READY
156
+ async for message in self.agent(message, session_id=session_id, **kwargs):
157
+ if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
158
+ if message.stream_state == ModelStatusCode.END:
159
+ message.stream_state = last_agent_state + int(
160
+ last_agent_state
161
+ in [
162
+ AgentStatusCode.CODING,
163
+ AgentStatusCode.PLUGIN_START,
164
+ ]
165
+ )
166
+ else:
167
+ message.stream_state = (
168
+ AgentStatusCode.PLUGIN_START
169
+ if message.formatted["tool_type"] == "plugin"
170
+ else AgentStatusCode.CODING
171
+ )
172
+ else:
173
+ message.stream_state = AgentStatusCode.STREAM_ING
174
+ message.formatted.update(deepcopy(_graph_state))
175
+ yield message
176
+ last_agent_state = message.stream_state
177
+ if not message.formatted["tool_type"]:
178
+ message.stream_state = AgentStatusCode.END
179
+ yield message
180
+ return
181
+
182
+ gen = GeneratorWithReturn(
183
+ self.action.run(message.content, local_dict, global_dict, True)
184
+ )
185
+ for graph_exec in gen:
186
+ graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"])
187
+ yield graph_exec
188
+
189
+ reference, references_url = _generate_references_from_graph(gen.ret[1])
190
+ _graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url)
191
+ if self.finish_condition(message):
192
+ message = AgentMessage(
193
+ sender="ActionExecutor",
194
+ content=self.summary_prompt,
195
+ formatted=deepcopy(_graph_state),
196
+ stream_state=message.stream_state + 1, # plugin or code return
197
+ )
198
+ yield message
199
+ # summarize the references to generate the final answer
200
+ async for message in self.agent(message, session_id=session_id, **kwargs):
201
+ message.formatted.update(deepcopy(_graph_state))
202
+ yield message
203
+ return
204
+ message = AgentMessage(
205
+ sender="ActionExecutor",
206
+ content=reference,
207
+ formatted=deepcopy(_graph_state),
208
+ stream_state=message.stream_state + 1, # plugin or code return
209
+ )
210
+ yield message
mindsearch/agent/mindsearch_prompt.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa
2
+
3
+ searcher_system_prompt_cn = """## 人物简介
4
+ 你是一个可以调用网络搜索工具的智能助手。请根据"当前问题",调用搜索工具收集信息并回复问题。你能够调用如下工具:
5
+ {tool_info}
6
+ ## 回复格式
7
+
8
+ 调用工具时,请按照以下格式:
9
+ ```
10
+ 你的思考过程...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
11
+ ```
12
+
13
+ ## 要求
14
+
15
+ - 回答中每个关键点需标注引用的搜索结果来源,以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
16
+ - 基于"当前问题"的搜索结果,撰写详细完备的回复,优先回答"当前问题"。
17
+
18
+ """
19
+
20
+ searcher_system_prompt_en = """## Character Introduction
21
+ You are an intelligent assistant that can call web search tools. Please collect information and reply to the question based on the current problem. You can use the following tools:
22
+ {tool_info}
23
+ ## Reply Format
24
+
25
+ When calling the tool, please follow the format below:
26
+ ```
27
+ Your thought process...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
28
+ ```
29
+
30
+ ## Requirements
31
+
32
+ - Each key point in the response should be marked with the source of the search results to ensure the credibility of the information. The citation format is `[[int]]`. If there are multiple citations, use multiple [[]] to provide the index, such as `[[id_1]][[id_2]]`.
33
+ - Based on the search results of the "current problem", write a detailed and complete reply to answer the "current problem".
34
+ """
35
+
36
+ fewshot_example_cn = """
37
+ ## 样例
38
+
39
+ ### search
40
+ 当我希望搜索"王者荣耀现在是什么赛季"时,我会按照以下格式进行操作:
41
+ 现在是2024年,因此我应该搜索王者荣耀赛季关键词<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["王者荣耀 赛季", "2024年王者荣耀赛季"]}}}}<|action_end|>
42
+
43
+ ### select
44
+ 为了找到王者荣耀s36赛季最强射手,我需要寻找提及王者荣耀s36射手的网页。初步浏览网页后,发现网页0提到王者荣耀s36赛季的信息,但没有具体提及射手的相关信息。网页3提到“s36最强射手出现?”,有可能包含最强射手信息。网页13提到“四大T0英雄崛起,射手荣耀降临”,可能包含最强射手的信息。因此,我选择了网页3和网页13进行进一步阅读。<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
45
+ """
46
+
47
+ fewshot_example_en = """
48
+ ## Example
49
+
50
+ ### search
51
+ When I want to search for "What season is Honor of Kings now", I will operate in the following format:
52
+ Now it is 2024, so I should search for the keyword of the Honor of Kings<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["Honor of Kings Season", "season for Honor of Kings in 2024"]}}}}<|action_end|>
53
+
54
+ ### select
55
+ In order to find the strongest shooters in Honor of Kings in season s36, I needed to look for web pages that mentioned shooters in Honor of Kings in season s36. After an initial browse of the web pages, I found that web page 0 mentions information about Honor of Kings in s36 season, but there is no specific mention of information about the shooter. Webpage 3 mentions that “the strongest shooter in s36 has appeared?”, which may contain information about the strongest shooter. Webpage 13 mentions “Four T0 heroes rise, archer's glory”, which may contain information about the strongest archer. Therefore, I chose webpages 3 and 13 for further reading.<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
56
+ """
57
+
58
+ searcher_input_template_en = """## Final Problem
59
+ {topic}
60
+ ## Current Problem
61
+ {question}
62
+ """
63
+
64
+ searcher_input_template_cn = """## 主问题
65
+ {topic}
66
+ ## 当前问题
67
+ {question}
68
+ """
69
+
70
+ searcher_context_template_en = """## Historical Problem
71
+ {question}
72
+ Answer: {answer}
73
+ """
74
+
75
+ searcher_context_template_cn = """## 历史问题
76
+ {question}
77
+ 回答:{answer}
78
+ """
79
+
80
+ search_template_cn = "## {query}\n\n{result}\n"
81
+ search_template_en = "## {query}\n\n{result}\n"
82
+
83
+ GRAPH_PROMPT_CN = """## 人物简介
84
+ 你是一个可以利用 Jupyter 环境 Python 编程的程序员。你可以利用提供的 API 来构建 Web 搜索图,最终生成代码并执行。
85
+
86
+ ## API 介绍
87
+
88
+ 下面是包含属性详细说明的 `WebSearchGraph` 类的 API 文档:
89
+
90
+ ### 类:`WebSearchGraph`
91
+
92
+ 此类用于管理网络搜索图的节点和边,并通过网络代理进行搜索。
93
+
94
+ #### 初始化方法
95
+
96
+ 初始化 `WebSearchGraph` 实例。
97
+
98
+ **属性:**
99
+
100
+ - `nodes` (Dict[str, Dict[str, str]]): 存储图中所有节点的字典。每个节点由其名称索引,并包含内容、类型以及其他相关信息。
101
+ - `adjacency_list` (Dict[str, List[str]]): 存储图中所有节点之间连接关系的邻接表。每个节点由其名称索引,并包含一个相邻节点名称的列表。
102
+
103
+
104
+ #### 方法:`add_root_node`
105
+
106
+ 添加原始问题作为根节点。
107
+ **参数:**
108
+
109
+ - `node_content` (str): 用户提出的问题。
110
+ - `node_name` (str, 可选): 节点名称,默认为 'root'。
111
+
112
+
113
+ #### 方法:`add_node`
114
+
115
+ 添加搜索子问题节点并返回搜索结果。
116
+ **参数:
117
+
118
+ - `node_name` (str): 节点名称。
119
+ - `node_content` (str): 子问题内容。
120
+
121
+ **返回:**
122
+
123
+ - `str`: 返回搜索结果。
124
+
125
+
126
+ #### 方法:`add_response_node`
127
+
128
+ 当前获取的信息已经满足问题需求,添加回复节点。
129
+
130
+ **参数:**
131
+
132
+ - `node_name` (str, 可选): 节点名称,默认为 'response'。
133
+
134
+
135
+ #### 方法:`add_edge`
136
+
137
+ 添加边。
138
+
139
+ **参数:**
140
+
141
+ - `start_node` (str): 起始节点名称。
142
+ - `end_node` (str): 结束节点名称。
143
+
144
+
145
+ #### 方法:`reset`
146
+
147
+ 重置节点和边。
148
+
149
+
150
+ #### 方法:`node`
151
+
152
+ 获取节点信息。
153
+
154
+ ```python
155
+ def node(self, node_name: str) -> str
156
+ ```
157
+
158
+ **参数:**
159
+
160
+ - `node_name` (str): 节点名称。
161
+
162
+ **返回:**
163
+
164
+ - `str`: 返回包含节点信息的字典,包含节点的内容、类型、思考过程(如果有)和前驱节点列表。
165
+
166
+ ## 任务介绍
167
+ 通过将一个问题拆分成能够通过搜索回答的子问题(没有关联的问题可以同步并列搜索),每个搜索的问题应该是一个单一问题,即单个具体人、事、物、具体时间点、地点或知识点的问题,不是一个复合问题(比如某个时间段), 一步步构建搜索图,最终回答问题。
168
+
169
+ ## 注意事项
170
+
171
+ 1. 注意,每个搜索节点的内容必须单个问题,不要包含多个问题(比如同时问多个知识点的问题或者多个事物的比较加筛选,类似 A, B, C 有什么区别,那个价格在哪个区间 -> 分别查询)
172
+ 2. 不要杜撰搜索结果,要等待代码返回结果
173
+ 3. 同样的问题不要重复提问,可以在已有问题的基础上继续提问
174
+ 4. 添加 response 节点的时候,要单独添加,不要和其他节点一起添加,不能同时添加 response 节点和其他节点
175
+ 5. 一次输出中,不要包含多个代码块,每次只能有一个代码块
176
+ 6. 每个代码块应该放置在一个代码块标记中,同时生成完代码后添加一个<|action_end|>标志,如下所示:
177
+ <|action_start|><|interpreter|>```python
178
+ # 你的代码块
179
+ ```<|action_end|>
180
+ 7. 最后一次回复应该是添加node_name为'response'的 response 节点,必须添加 response 节点,不要添加其他节点
181
+ """
182
+
183
+ GRAPH_PROMPT_EN = """## Character Profile
184
+ You are a programmer capable of Python programming in a Jupyter environment. You can utilize the provided API to construct a Web Search Graph, ultimately generating and executing code.
185
+
186
+ ## API Description
187
+
188
+ Below is the API documentation for the WebSearchGraph class, including detailed attribute descriptions:
189
+
190
+ ### Class: WebSearchGraph
191
+
192
+ This class manages nodes and edges of a web search graph and conducts searches via a web proxy.
193
+
194
+ #### Initialization Method
195
+
196
+ Initializes an instance of WebSearchGraph.
197
+
198
+ **Attributes:**
199
+
200
+ - nodes (Dict[str, Dict[str, str]]): A dictionary storing all nodes in the graph. Each node is indexed by its name and contains content, type, and other related information.
201
+ - adjacency_list (Dict[str, List[str]]): An adjacency list storing the connections between all nodes in the graph. Each node is indexed by its name and contains a list of adjacent node names.
202
+
203
+ #### Method: add_root_node
204
+
205
+ Adds the initial question as the root node.
206
+ **Parameters:**
207
+
208
+ - node_content (str): The user's question.
209
+ - node_name (str, optional): The node name, default is 'root'.
210
+
211
+ #### Method: add_node
212
+
213
+ Adds a sub-question node and returns search results.
214
+ **Parameters:**
215
+
216
+ - node_name (str): The node name.
217
+ - node_content (str): The sub-question content.
218
+
219
+ **Returns:**
220
+
221
+ - str: Returns the search results.
222
+
223
+ #### Method: add_response_node
224
+
225
+ Adds a response node when the current information satisfies the question's requirements.
226
+
227
+ **Parameters:**
228
+
229
+ - node_name (str, optional): The node name, default is 'response'.
230
+
231
+ #### Method: add_edge
232
+
233
+ Adds an edge.
234
+
235
+ **Parameters:**
236
+
237
+ - start_node (str): The starting node name.
238
+ - end_node (str): The ending node name.
239
+
240
+ #### Method: reset
241
+
242
+ Resets nodes and edges.
243
+
244
+ #### Method: node
245
+
246
+ Get node information.
247
+
248
+ python
249
+ def node(self, node_name: str) -> str
250
+
251
+ **Parameters:**
252
+
253
+ - node_name (str): The node name.
254
+
255
+ **Returns:**
256
+
257
+ - str: Returns a dictionary containing the node's information, including content, type, thought process (if any), and list of predecessor nodes.
258
+
259
+ ## Task Description
260
+ By breaking down a question into sub-questions that can be answered through searches (unrelated questions can be searched concurrently), each search query should be a single question focusing on a specific person, event, object, specific time point, location, or knowledge point. It should not be a compound question (e.g., a time period). Step by step, build the search graph to finally answer the question.
261
+
262
+ ## Considerations
263
+
264
+ 1. Each search node's content must be a single question; do not include multiple questions (e.g., do not ask multiple knowledge points or compare and filter multiple things simultaneously, like asking for differences between A, B, and C, or price ranges -> query each separately).
265
+ 2. Do not fabricate search results; wait for the code to return results.
266
+ 3. Do not repeat the same question; continue asking based on existing questions.
267
+ 4. When adding a response node, add it separately; do not add a response node and other nodes simultaneously.
268
+ 5. In a single output, do not include multiple code blocks; only one code block per output.
269
+ 6. Each code block should be placed within a code block marker, and after generating the code, add an <|action_end|> tag as shown below:
270
+ <|action_start|><|interpreter|>
271
+ ```python
272
+ # Your code block (Note that the 'Get new added node information' logic must be added at the end of the code block, such as 'graph.node('...')')
273
+ ```<|action_end|>
274
+ 7. The final response should add a response node with node_name 'response', and no other nodes should be added.
275
+ """
276
+
277
+ graph_fewshot_example_cn = """
278
+ ## 返回格式示例
279
+ <|action_start|><|interpreter|>```python
280
+ graph = WebSearchGraph()
281
+ graph.add_root_node(node_content="哪家大模型API最便宜?", node_name="root") # 添加原始问题作为根节点
282
+ graph.add_node(
283
+ node_name="大模型API提供商", # 节点名称最好有意义
284
+ node_content="目前有哪些主要的大模型API提供商?")
285
+ graph.add_node(
286
+ node_name="sub_name_2", # 节点名称最好有意义
287
+ node_content="content of sub_name_2")
288
+ ...
289
+ graph.add_edge(start_node="root", end_node="sub_name_1")
290
+ ...
291
+ graph.node("大模型API提供商"), graph.node("sub_name_2"), ...
292
+ ```<|action_end|>
293
+ """
294
+
295
+ graph_fewshot_example_en = """
296
+ ## Response Format
297
+ <|action_start|><|interpreter|>```python
298
+ graph = WebSearchGraph()
299
+ graph.add_root_node(node_content="Which large model API is the cheapest?", node_name="root") # Add the original question as the root node
300
+ graph.add_node(
301
+ node_name="Large Model API Providers", # The node name should be meaningful
302
+ node_content="Who are the main large model API providers currently?")
303
+ graph.add_node(
304
+ node_name="sub_name_2", # The node name should be meaningful
305
+ node_content="content of sub_name_2")
306
+ ...
307
+ graph.add_edge(start_node="root", end_node="sub_name_1")
308
+ ...
309
+ # Get node info
310
+ graph.node("Large Model API Providers"), graph.node("sub_name_2"), ...
311
+ ```<|action_end|>
312
+ """
313
+
314
+ FINAL_RESPONSE_CN = """基于提供的问答对,撰写一篇详细完备的最终回答。
315
+ - 回答内容需要逻辑清晰,层次分明,确保读者易于理解。
316
+ - 回答中每个关键点需标注引用的搜索结果来源(保持跟问答对中的索引一致),以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
317
+ - 回答部分需要全面且完备,不要出现"基于上述内容"等模糊表达,最终呈现的回答不包括提供给你的问答对。
318
+ - 语言风格需要专业、严谨,避免口语化表达。
319
+ - 保持统一的语法和词汇使用,确保整体文档的一致性和连贯性。"""
320
+
321
+ FINAL_RESPONSE_EN = """Based on the provided Q&A pairs, write a detailed and comprehensive final response.
322
+ - The response content should be logically clear and well-structured to ensure reader understanding.
323
+ - Each key point in the response should be marked with the source of the search results (consistent with the indices in the Q&A pairs) to ensure information credibility. The index is in the form of `[[int]]`, and if there are multiple indices, use multiple `[[]]`, such as `[[id_1]][[id_2]]`.
324
+ - The response should be comprehensive and complete, without vague expressions like "based on the above content". The final response should not include the Q&A pairs provided to you.
325
+ - The language style should be professional and rigorous, avoiding colloquial expressions.
326
+ - Maintain consistent grammar and vocabulary usage to ensure overall document consistency and coherence."""
mindsearch/agent/models.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from lagent.llms import (
5
+ GPTAPI,
6
+ INTERNLM2_META,
7
+ HFTransformerCasualLM,
8
+ LMDeployClient,
9
+ LMDeployServer,
10
+ )
11
+
12
+ internlm_server = dict(
13
+ type=LMDeployServer,
14
+ path="internlm/internlm2_5-7b-chat",
15
+ model_name="internlm2_5-7b-chat",
16
+ meta_template=INTERNLM2_META,
17
+ top_p=0.8,
18
+ top_k=1,
19
+ temperature=0,
20
+ max_new_tokens=8192,
21
+ repetition_penalty=1.02,
22
+ stop_words=["<|im_end|>"],
23
+ )
24
+
25
+ internlm_client = dict(
26
+ type=LMDeployClient,
27
+ model_name="internlm2_5-7b-chat",
28
+ url="http://127.0.0.1:23333",
29
+ meta_template=INTERNLM2_META,
30
+ top_p=0.8,
31
+ top_k=1,
32
+ temperature=0,
33
+ max_new_tokens=8192,
34
+ repetition_penalty=1.02,
35
+ stop_words=["<|im_end|>"],
36
+ )
37
+
38
+ internlm_hf = dict(
39
+ type=HFTransformerCasualLM,
40
+ path="internlm/internlm2_5-7b-chat",
41
+ meta_template=INTERNLM2_META,
42
+ top_p=0.8,
43
+ top_k=None,
44
+ temperature=1e-6,
45
+ max_new_tokens=8192,
46
+ repetition_penalty=1.02,
47
+ stop_words=["<|im_end|>"],
48
+ )
49
+ # openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions
50
+ gpt4 = dict(
51
+ type=GPTAPI,
52
+ model_type="gpt-4-turbo",
53
+ key=os.environ.get("OPENAI_API_KEY", "YOUR OPENAI API KEY"),
54
+ api_base=os.environ.get("OPENAI_API_BASE",
55
+ "https://api.openai.com/v1/chat/completions"),
56
+ )
57
+
58
+ url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
59
+ qwen = dict(
60
+ type=GPTAPI,
61
+ model_type="qwen-max-longcontext",
62
+ key=os.environ.get("QWEN_API_KEY", "YOUR QWEN API KEY"),
63
+ api_base=url,
64
+ meta_template=[
65
+ dict(role="system", api_role="system"),
66
+ dict(role="user", api_role="user"),
67
+ dict(role="assistant", api_role="assistant"),
68
+ dict(role="environment", api_role="system"),
69
+ ],
70
+ top_p=0.8,
71
+ top_k=1,
72
+ temperature=0,
73
+ max_new_tokens=4096,
74
+ repetition_penalty=1.02,
75
+ stop_words=["<|im_end|>"],
76
+ )
77
+
78
+ internlm_silicon = dict(
79
+ type=GPTAPI,
80
+ model_type="internlm/internlm2_5-7b-chat",
81
+ key=os.environ.get("SILICON_API_KEY", "YOUR SILICON API KEY"),
82
+ api_base="https://api.siliconflow.cn/v1/chat/completions",
83
+ meta_template=[
84
+ dict(role="system", api_role="system"),
85
+ dict(role="user", api_role="user"),
86
+ dict(role="assistant", api_role="assistant"),
87
+ dict(role="environment", api_role="system"),
88
+ ],
89
+ top_p=0.8,
90
+ top_k=1,
91
+ temperature=0,
92
+ max_new_tokens=8192,
93
+ repetition_penalty=1.02,
94
+ stop_words=["<|im_end|>"],
95
+ )
96
+
97
+
98
+ internlm_api = dict(
99
+ type=GPTAPI,
100
+ model_type="internlm2.5-latest",
101
+ key=os.environ.get("InternLM_API_KEY", "YOUR InternLM API KEY https://internlm.intern-ai.org.cn/api/document"),
102
+ api_base="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions",
103
+ meta_template=[
104
+ dict(role="system", api_role="system"),
105
+ dict(role="user", api_role="user"),
106
+ dict(role="assistant", api_role="assistant"),
107
+ dict(role="environment", api_role="system"),
108
+ ],
109
+ top_p=0.8,
110
+ top_k=1,
111
+ temperature=0,
112
+ max_new_tokens=8192,
113
+ repetition_penalty=1.02,
114
+ stop_words=["<|im_end|>"],
115
+ )
mindsearch/agent/streaming.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import List, Union
3
+
4
+ from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM
5
+ from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode
6
+
7
+
8
+ class StreamingAgentMixin:
9
+ """Make agent calling output a streaming response."""
10
+
11
+ def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs):
12
+ for hook in self._hooks.values():
13
+ message = copy.deepcopy(message)
14
+ result = hook.before_agent(self, message, session_id)
15
+ if result:
16
+ message = result
17
+ self.update_memory(message, session_id=session_id)
18
+ response_message = AgentMessage(sender=self.name, content="")
19
+ for response_message in self.forward(*message, session_id=session_id, **kwargs):
20
+ if not isinstance(response_message, AgentMessage):
21
+ model_state, response = response_message
22
+ response_message = AgentMessage(
23
+ sender=self.name,
24
+ content=response,
25
+ stream_state=model_state,
26
+ )
27
+ yield response_message.model_copy()
28
+ self.update_memory(response_message, session_id=session_id)
29
+ for hook in self._hooks.values():
30
+ response_message = response_message.model_copy(deep=True)
31
+ result = hook.after_agent(self, response_message, session_id)
32
+ if result:
33
+ response_message = result
34
+ yield response_message
35
+
36
+
37
+ class AsyncStreamingAgentMixin:
38
+ """Make asynchronous agent calling output a streaming response."""
39
+
40
+ async def __call__(
41
+ self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs
42
+ ):
43
+ for hook in self._hooks.values():
44
+ message = copy.deepcopy(message)
45
+ result = hook.before_agent(self, message, session_id)
46
+ if result:
47
+ message = result
48
+ self.update_memory(message, session_id=session_id)
49
+ response_message = AgentMessage(sender=self.name, content="")
50
+ async for response_message in self.forward(*message, session_id=session_id, **kwargs):
51
+ if not isinstance(response_message, AgentMessage):
52
+ model_state, response = response_message
53
+ response_message = AgentMessage(
54
+ sender=self.name,
55
+ content=response,
56
+ stream_state=model_state,
57
+ )
58
+ yield response_message.model_copy()
59
+ self.update_memory(response_message, session_id=session_id)
60
+ for hook in self._hooks.values():
61
+ response_message = response_message.model_copy(deep=True)
62
+ result = hook.after_agent(self, response_message, session_id)
63
+ if result:
64
+ response_message = result
65
+ yield response_message
66
+
67
+
68
+ class StreamingAgent(StreamingAgentMixin, Agent):
69
+ """Base streaming agent class"""
70
+
71
+ def forward(self, *message: AgentMessage, session_id=0, **kwargs):
72
+ formatted_messages = self.aggregator.aggregate(
73
+ self.memory.get(session_id),
74
+ self.name,
75
+ self.output_format,
76
+ self.template,
77
+ )
78
+ for model_state, response, _ in self.llm.stream_chat(
79
+ formatted_messages, session_id=session_id, **kwargs
80
+ ):
81
+ yield AgentMessage(
82
+ sender=self.name,
83
+ content=response,
84
+ formatted=self.output_format.parse_response(response),
85
+ stream_state=model_state,
86
+ ) if self.output_format else (model_state, response)
87
+
88
+
89
+ class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent):
90
+ """Base asynchronous streaming agent class"""
91
+
92
+ async def forward(self, *message: AgentMessage, session_id=0, **kwargs):
93
+ formatted_messages = self.aggregator.aggregate(
94
+ self.memory.get(session_id),
95
+ self.name,
96
+ self.output_format,
97
+ self.template,
98
+ )
99
+ async for model_state, response, _ in self.llm.stream_chat(
100
+ formatted_messages, session_id=session_id, **kwargs
101
+ ):
102
+ yield AgentMessage(
103
+ sender=self.name,
104
+ content=response,
105
+ formatted=self.output_format.parse_response(response),
106
+ stream_state=model_state,
107
+ ) if self.output_format else (model_state, response)
108
+
109
+
110
+ class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM):
111
+ """Streaming implementation of `lagent.agents.AgentForInternLM`"""
112
+
113
+ _INTERNAL_AGENT_CLS = StreamingAgent
114
+
115
+ def forward(self, message: AgentMessage, session_id=0, **kwargs):
116
+ if isinstance(message, str):
117
+ message = AgentMessage(sender="user", content=message)
118
+ for _ in range(self.max_turn):
119
+ last_agent_state = AgentStatusCode.SESSION_READY
120
+ for message in self.agent(message, session_id=session_id, **kwargs):
121
+ if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
122
+ if message.stream_state == ModelStatusCode.END:
123
+ message.stream_state = last_agent_state + int(
124
+ last_agent_state
125
+ in [
126
+ AgentStatusCode.CODING,
127
+ AgentStatusCode.PLUGIN_START,
128
+ ]
129
+ )
130
+ else:
131
+ message.stream_state = (
132
+ AgentStatusCode.PLUGIN_START
133
+ if message.formatted["tool_type"] == "plugin"
134
+ else AgentStatusCode.CODING
135
+ )
136
+ else:
137
+ message.stream_state = AgentStatusCode.STREAM_ING
138
+ yield message
139
+ last_agent_state = message.stream_state
140
+ if self.finish_condition(message):
141
+ message.stream_state = AgentStatusCode.END
142
+ yield message
143
+ return
144
+ if message.formatted["tool_type"]:
145
+ tool_type = message.formatted["tool_type"]
146
+ executor = getattr(self, f"{tool_type}_executor", None)
147
+ if not executor:
148
+ raise RuntimeError(f"No available {tool_type} executor")
149
+ tool_return = executor(message, session_id=session_id)
150
+ tool_return.stream_state = message.stream_state + 1
151
+ message = tool_return
152
+ yield message
153
+ else:
154
+ message.stream_state = AgentStatusCode.STREAM_ING
155
+ yield message
156
+
157
+
158
+ class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM):
159
+ """Streaming implementation of `lagent.agents.AsyncAgentForInternLM`"""
160
+
161
+ _INTERNAL_AGENT_CLS = AsyncStreamingAgent
162
+
163
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
164
+ if isinstance(message, str):
165
+ message = AgentMessage(sender="user", content=message)
166
+ for _ in range(self.max_turn):
167
+ last_agent_state = AgentStatusCode.SESSION_READY
168
+ async for message in self.agent(message, session_id=session_id, **kwargs):
169
+ if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
170
+ if message.stream_state == ModelStatusCode.END:
171
+ message.stream_state = last_agent_state + int(
172
+ last_agent_state
173
+ in [
174
+ AgentStatusCode.CODING,
175
+ AgentStatusCode.PLUGIN_START,
176
+ ]
177
+ )
178
+ else:
179
+ message.stream_state = (
180
+ AgentStatusCode.PLUGIN_START
181
+ if message.formatted["tool_type"] == "plugin"
182
+ else AgentStatusCode.CODING
183
+ )
184
+ else:
185
+ message.stream_state = AgentStatusCode.STREAM_ING
186
+ yield message
187
+ last_agent_state = message.stream_state
188
+ if self.finish_condition(message):
189
+ message.stream_state = AgentStatusCode.END
190
+ yield message
191
+ return
192
+ if message.formatted["tool_type"]:
193
+ tool_type = message.formatted["tool_type"]
194
+ executor = getattr(self, f"{tool_type}_executor", None)
195
+ if not executor:
196
+ raise RuntimeError(f"No available {tool_type} executor")
197
+ tool_return = await executor(message, session_id=session_id)
198
+ tool_return.stream_state = message.stream_state + 1
199
+ message = tool_return
200
+ yield message
201
+ else:
202
+ message.stream_state = AgentStatusCode.STREAM_ING
203
+ yield message
mindsearch/app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import random
5
+ from typing import Dict, List, Union
6
+
7
+ import janus
8
+ from fastapi import FastAPI
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.requests import Request
11
+ from pydantic import BaseModel, Field
12
+ from sse_starlette.sse import EventSourceResponse
13
+
14
+ from mindsearch.agent import init_agent
15
+
16
+
17
+ def parse_arguments():
18
+ import argparse
19
+
20
+ parser = argparse.ArgumentParser(description="MindSearch API")
21
+ parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host")
22
+ parser.add_argument("--port", default=8002, type=int, help="Service port")
23
+ parser.add_argument("--lang", default="cn", type=str, help="Language")
24
+ parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format")
25
+ parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine")
26
+ parser.add_argument("--asy", default=False, action="store_true", help="Agent mode")
27
+ return parser.parse_args()
28
+
29
+
30
+ args = parse_arguments()
31
+ app = FastAPI(docs_url="/")
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"],
35
+ allow_credentials=True,
36
+ allow_methods=["*"],
37
+ allow_headers=["*"],
38
+ )
39
+
40
+
41
+ class GenerationParams(BaseModel):
42
+ inputs: Union[str, List[Dict]]
43
+ session_id: int = Field(default_factory=lambda: random.randint(0, 999999))
44
+ agent_cfg: Dict = dict()
45
+
46
+
47
+ def _postprocess_agent_message(message: dict) -> dict:
48
+ content, fmt = message["content"], message["formatted"]
49
+ current_node = content["current_node"] if isinstance(content, dict) else None
50
+ if current_node:
51
+ message["content"] = None
52
+ for key in ["ref2url"]:
53
+ fmt.pop(key, None)
54
+ graph = fmt["node"]
55
+ for key in graph.copy():
56
+ if key != current_node:
57
+ graph.pop(key)
58
+ if current_node not in ["root", "response"]:
59
+ node = graph[current_node]
60
+ for key in ["memory", "session_id"]:
61
+ node.pop(key, None)
62
+ node_fmt = node["response"]["formatted"]
63
+ if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt:
64
+ node["response"]["content"] = None
65
+ node_fmt["thought"] = (
66
+ node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0]
67
+ )
68
+ if isinstance(node_fmt["action"], str):
69
+ node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0]
70
+ else:
71
+ if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt:
72
+ message["content"] = None
73
+ fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0]
74
+ if isinstance(fmt["action"], str):
75
+ fmt["action"] = fmt["action"].split("<|action_end|>")[0]
76
+ for key in ["node"]:
77
+ fmt.pop(key, None)
78
+ return dict(current_node=current_node, response=message)
79
+
80
+
81
+ async def run(request: GenerationParams, _request: Request):
82
+ async def generate():
83
+ try:
84
+ queue = janus.Queue()
85
+ stop_event = asyncio.Event()
86
+
87
+ # Wrapping a sync generator as an async generator using run_in_executor
88
+ def sync_generator_wrapper():
89
+ try:
90
+ for response in agent(inputs, session_id=session_id):
91
+ queue.sync_q.put(response)
92
+ except Exception as e:
93
+ logging.exception(f"Exception in sync_generator_wrapper: {e}")
94
+ finally:
95
+ # Notify async_generator_wrapper that the data generation is complete.
96
+ queue.sync_q.put(None)
97
+
98
+ async def async_generator_wrapper():
99
+ loop = asyncio.get_event_loop()
100
+ loop.run_in_executor(None, sync_generator_wrapper)
101
+ while True:
102
+ response = await queue.async_q.get()
103
+ if response is None: # Ensure that all elements are consumed
104
+ break
105
+ yield response
106
+ stop_event.set() # Inform sync_generator_wrapper to stop
107
+
108
+ async for message in async_generator_wrapper():
109
+ response_json = json.dumps(
110
+ _postprocess_agent_message(message.model_dump()),
111
+ ensure_ascii=False,
112
+ )
113
+ yield {"data": response_json}
114
+ if await _request.is_disconnected():
115
+ break
116
+ except Exception as exc:
117
+ msg = "An error occurred while generating the response."
118
+ logging.exception(msg)
119
+ response_json = json.dumps(
120
+ dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
121
+ )
122
+ yield {"data": response_json}
123
+ finally:
124
+ await stop_event.wait() # Waiting for async_generator_wrapper to stop
125
+ queue.close()
126
+ await queue.wait_closed()
127
+ agent.agent.memory.memory_map.pop(session_id, None)
128
+
129
+ inputs = request.inputs
130
+ session_id = request.session_id
131
+ agent = init_agent(
132
+ lang=args.lang,
133
+ model_format=args.model_format,
134
+ search_engine=args.search_engine,
135
+ )
136
+ return EventSourceResponse(generate(), ping=300)
137
+
138
+
139
+ async def run_async(request: GenerationParams, _request: Request):
140
+ async def generate():
141
+ try:
142
+ async for message in agent(inputs, session_id=session_id):
143
+ response_json = json.dumps(
144
+ _postprocess_agent_message(message.model_dump()),
145
+ ensure_ascii=False,
146
+ )
147
+ yield {"data": response_json}
148
+ if await _request.is_disconnected():
149
+ break
150
+ except Exception as exc:
151
+ msg = "An error occurred while generating the response."
152
+ logging.exception(msg)
153
+ response_json = json.dumps(
154
+ dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
155
+ )
156
+ yield {"data": response_json}
157
+ finally:
158
+ agent.agent.memory.memory_map.pop(session_id, None)
159
+
160
+ inputs = request.inputs
161
+ session_id = request.session_id
162
+ agent = init_agent(
163
+ lang=args.lang,
164
+ model_format=args.model_format,
165
+ search_engine=args.search_engine,
166
+ use_async=True,
167
+ )
168
+ return EventSourceResponse(generate(), ping=300)
169
+
170
+
171
+ app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"])
172
+
173
+ if __name__ == "__main__":
174
+ import uvicorn
175
+
176
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
mindsearch/terminal.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from datetime import datetime
4
+
5
+ from lagent.actions import WebBrowser
6
+ from lagent.agents.stream import get_plugin_prompt
7
+ from lagent.llms import INTERNLM2_META, LMDeployServer
8
+ from lagent.prompts import InterpreterParser, PluginParser
9
+
10
+ from mindsearch.agent.mindsearch_agent import MindSearchAgent
11
+ from mindsearch.agent.mindsearch_prompt import (
12
+ FINAL_RESPONSE_CN,
13
+ FINAL_RESPONSE_EN,
14
+ GRAPH_PROMPT_CN,
15
+ GRAPH_PROMPT_EN,
16
+ searcher_context_template_cn,
17
+ searcher_context_template_en,
18
+ searcher_input_template_cn,
19
+ searcher_input_template_en,
20
+ searcher_system_prompt_cn,
21
+ searcher_system_prompt_en,
22
+ )
23
+
24
+ lang = "cn"
25
+ date = datetime.now().strftime("The current date is %Y-%m-%d.")
26
+ llm = LMDeployServer(
27
+ path="internlm/internlm2_5-7b-chat",
28
+ model_name="internlm2",
29
+ meta_template=INTERNLM2_META,
30
+ top_p=0.8,
31
+ top_k=1,
32
+ temperature=1.0,
33
+ max_new_tokens=8192,
34
+ repetition_penalty=1.02,
35
+ stop_words=["<|im_end|>", "<|action_end|>"],
36
+ )
37
+ plugins = [WebBrowser(searcher_type="BingSearch", topk=6)]
38
+ agent = MindSearchAgent(
39
+ llm=llm,
40
+ template=date,
41
+ output_format=InterpreterParser(template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
42
+ searcher_cfg=dict(
43
+ llm=llm,
44
+ plugins=plugins,
45
+ template=date,
46
+ output_format=PluginParser(
47
+ template=searcher_system_prompt_cn if lang == "cn" else searcher_system_prompt_en,
48
+ tool_info=get_plugin_prompt(plugins),
49
+ ),
50
+ user_input_template=searcher_input_template_cn
51
+ if lang == "cn"
52
+ else searcher_input_template_en,
53
+ user_context_template=searcher_context_template_cn
54
+ if lang == "cn"
55
+ else searcher_context_template_en,
56
+ ),
57
+ summary_prompt=FINAL_RESPONSE_CN if lang == "cn" else FINAL_RESPONSE_EN,
58
+ max_turn=10,
59
+ )
60
+
61
+ for agent_return in agent("上海今天适合穿什么衣服"):
62
+ pass
63
+
64
+ print(agent_return.sender)
65
+ print(agent_return.content)
66
+ print(agent_return.formatted["ref2url"])
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ duckduckgo_search==5.3.1b1
2
+ einops
3
+ fastapi
4
+ gradio==5.7.1
5
+ janus
6
+ lagent==0.5.0rc2
7
+ matplotlib
8
+ pydantic==2.6.4
9
+ python-dotenv
10
+ pyvis
11
+ schemdraw
12
+ sse-starlette
13
+ termcolor
14
+ transformers==4.41.0
15
+ uvicorn
16
+ tenacity
17
+ streamlit
18
+ git+https://github.com/vansin/ms_gradio_agentchatbot.git