Spaces:
Sleeping
Sleeping
raoqu
commited on
Commit
·
4159dd0
1
Parent(s):
16ef8ef
提交
Browse files- app.py +314 -0
- mindsearch/__init__.py +0 -0
- mindsearch/agent/__init__.py +82 -0
- mindsearch/agent/graph.py +307 -0
- mindsearch/agent/mindsearch_agent.py +210 -0
- mindsearch/agent/mindsearch_prompt.py +326 -0
- mindsearch/agent/models.py +115 -0
- mindsearch/agent/streaming.py +203 -0
- mindsearch/app.py +176 -0
- mindsearch/terminal.py +66 -0
- requirements.txt +18 -0
app.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import tempfile
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
import streamlit as st
|
| 6 |
+
from lagent.schema import AgentStatusCode
|
| 7 |
+
from pyvis.network import Network
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Function to create the network graph
|
| 11 |
+
def create_network_graph(nodes, adjacency_list):
|
| 12 |
+
net = Network(height="500px", width="60%", bgcolor="white", font_color="black")
|
| 13 |
+
for node_id, node_content in nodes.items():
|
| 14 |
+
net.add_node(node_id, label=node_id, title=node_content, color="#FF5733", size=25)
|
| 15 |
+
for node_id, neighbors in adjacency_list.items():
|
| 16 |
+
for neighbor in neighbors:
|
| 17 |
+
if neighbor["name"] in nodes:
|
| 18 |
+
net.add_edge(node_id, neighbor["name"])
|
| 19 |
+
net.show_buttons(filter_=["physics"])
|
| 20 |
+
return net
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Function to draw the graph and return the HTML file path
|
| 24 |
+
def draw_graph(net):
|
| 25 |
+
path = tempfile.mktemp(suffix=".html")
|
| 26 |
+
net.save_graph(path)
|
| 27 |
+
return path
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def streaming(raw_response):
|
| 31 |
+
for chunk in raw_response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\n"):
|
| 32 |
+
if chunk:
|
| 33 |
+
decoded = chunk.decode("utf-8")
|
| 34 |
+
if decoded == "\r":
|
| 35 |
+
continue
|
| 36 |
+
if decoded[:6] == "data: ":
|
| 37 |
+
decoded = decoded[6:]
|
| 38 |
+
elif decoded.startswith(": ping - "):
|
| 39 |
+
continue
|
| 40 |
+
response = json.loads(decoded)
|
| 41 |
+
yield (
|
| 42 |
+
response["current_node"],
|
| 43 |
+
(
|
| 44 |
+
response["response"]["formatted"]["node"][response["current_node"]]["response"]
|
| 45 |
+
if response["current_node"]
|
| 46 |
+
else response["response"]
|
| 47 |
+
),
|
| 48 |
+
response["response"]["formatted"]["adjacency_list"],
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Initialize Streamlit session state
|
| 53 |
+
if "queries" not in st.session_state:
|
| 54 |
+
st.session_state["queries"] = []
|
| 55 |
+
st.session_state["responses"] = []
|
| 56 |
+
st.session_state["graphs_html"] = []
|
| 57 |
+
st.session_state["nodes_list"] = []
|
| 58 |
+
st.session_state["adjacency_list_list"] = []
|
| 59 |
+
st.session_state["history"] = []
|
| 60 |
+
st.session_state["already_used_keys"] = list()
|
| 61 |
+
|
| 62 |
+
# Set up page layout
|
| 63 |
+
st.set_page_config(layout="wide")
|
| 64 |
+
st.title("MindSearch-思索")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Function to update chat
|
| 68 |
+
def update_chat(query):
|
| 69 |
+
with st.chat_message("user"):
|
| 70 |
+
st.write(query)
|
| 71 |
+
if query not in st.session_state["queries"]:
|
| 72 |
+
# Mock data to simulate backend response
|
| 73 |
+
# response, history, nodes, adjacency_list
|
| 74 |
+
st.session_state["queries"].append(query)
|
| 75 |
+
st.session_state["responses"].append([])
|
| 76 |
+
history = None
|
| 77 |
+
# 暂不支持多轮
|
| 78 |
+
# message = [dict(role='user', content=query)]
|
| 79 |
+
|
| 80 |
+
url = "http://localhost:8002/solve"
|
| 81 |
+
headers = {"Content-Type": "application/json"}
|
| 82 |
+
data = {"inputs": query}
|
| 83 |
+
raw_response = requests.post(
|
| 84 |
+
url, headers=headers, data=json.dumps(data), timeout=20, stream=True
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
_nodes, _node_cnt = {}, 0
|
| 88 |
+
for resp in streaming(raw_response):
|
| 89 |
+
node_name, response, adjacency_list = resp
|
| 90 |
+
for name in set(adjacency_list) | {
|
| 91 |
+
val["name"] for vals in adjacency_list.values() for val in vals
|
| 92 |
+
}:
|
| 93 |
+
if name not in _nodes:
|
| 94 |
+
_nodes[name] = query if name == "root" else name
|
| 95 |
+
elif response["stream_state"] == 0:
|
| 96 |
+
_nodes[node_name or "response"] = response["formatted"] and response[
|
| 97 |
+
"formatted"
|
| 98 |
+
].get("thought")
|
| 99 |
+
if len(_nodes) != _node_cnt or response["stream_state"] == 0:
|
| 100 |
+
net = create_network_graph(_nodes, adjacency_list)
|
| 101 |
+
graph_html_path = draw_graph(net)
|
| 102 |
+
with open(graph_html_path, encoding="utf-8") as f:
|
| 103 |
+
graph_html = f.read()
|
| 104 |
+
_node_cnt = len(_nodes)
|
| 105 |
+
else:
|
| 106 |
+
graph_html = None
|
| 107 |
+
if "graph_placeholder" not in st.session_state:
|
| 108 |
+
st.session_state["graph_placeholder"] = st.empty()
|
| 109 |
+
if "expander_placeholder" not in st.session_state:
|
| 110 |
+
st.session_state["expander_placeholder"] = st.empty()
|
| 111 |
+
if graph_html:
|
| 112 |
+
with st.session_state["expander_placeholder"].expander(
|
| 113 |
+
"Show Graph", expanded=False
|
| 114 |
+
):
|
| 115 |
+
st.session_state["graph_placeholder"]._html(graph_html, height=500)
|
| 116 |
+
if "container_placeholder" not in st.session_state:
|
| 117 |
+
st.session_state["container_placeholder"] = st.empty()
|
| 118 |
+
with st.session_state["container_placeholder"].container():
|
| 119 |
+
if "columns_placeholder" not in st.session_state:
|
| 120 |
+
st.session_state["columns_placeholder"] = st.empty()
|
| 121 |
+
col1, col2 = st.session_state["columns_placeholder"].columns([2, 1])
|
| 122 |
+
with col1:
|
| 123 |
+
if "planner_placeholder" not in st.session_state:
|
| 124 |
+
st.session_state["planner_placeholder"] = st.empty()
|
| 125 |
+
if "session_info_temp" not in st.session_state:
|
| 126 |
+
st.session_state["session_info_temp"] = ""
|
| 127 |
+
if not node_name:
|
| 128 |
+
if response["stream_state"] in [
|
| 129 |
+
AgentStatusCode.STREAM_ING,
|
| 130 |
+
AgentStatusCode.CODING,
|
| 131 |
+
AgentStatusCode.CODE_END,
|
| 132 |
+
]:
|
| 133 |
+
content = response["formatted"]["thought"]
|
| 134 |
+
if response["formatted"]["tool_type"]:
|
| 135 |
+
action = response["formatted"]["action"]
|
| 136 |
+
if isinstance(action, dict):
|
| 137 |
+
action = json.dumps(action, ensure_ascii=False, indent=4)
|
| 138 |
+
content += "\n" + action
|
| 139 |
+
st.session_state["session_info_temp"] = content.replace(
|
| 140 |
+
"<|action_start|><|interpreter|>\n", "\n"
|
| 141 |
+
)
|
| 142 |
+
elif response["stream_state"] == AgentStatusCode.CODE_RETURN:
|
| 143 |
+
# assert history[-1]["role"] == "environment"
|
| 144 |
+
st.session_state["session_info_temp"] += "\n" + response["content"]
|
| 145 |
+
st.session_state["planner_placeholder"].markdown(
|
| 146 |
+
st.session_state["session_info_temp"]
|
| 147 |
+
)
|
| 148 |
+
if response["stream_state"] == AgentStatusCode.CODE_RETURN:
|
| 149 |
+
st.session_state["responses"][-1].append(
|
| 150 |
+
st.session_state["session_info_temp"]
|
| 151 |
+
)
|
| 152 |
+
st.session_state["session_info_temp"] = ""
|
| 153 |
+
else:
|
| 154 |
+
st.session_state["planner_placeholder"].markdown(
|
| 155 |
+
st.session_state["responses"][-1][-1]
|
| 156 |
+
if not st.session_state["session_info_temp"]
|
| 157 |
+
else st.session_state["session_info_temp"]
|
| 158 |
+
)
|
| 159 |
+
with col2:
|
| 160 |
+
if "selectbox_placeholder" not in st.session_state:
|
| 161 |
+
st.session_state["selectbox_placeholder"] = st.empty()
|
| 162 |
+
if "searcher_placeholder" not in st.session_state:
|
| 163 |
+
st.session_state["searcher_placeholder"] = st.empty()
|
| 164 |
+
if node_name:
|
| 165 |
+
selected_node_key = (
|
| 166 |
+
f"selected_node_{len(st.session_state['queries'])}_{node_name}"
|
| 167 |
+
)
|
| 168 |
+
if selected_node_key not in st.session_state:
|
| 169 |
+
st.session_state[selected_node_key] = node_name
|
| 170 |
+
if selected_node_key not in st.session_state["already_used_keys"]:
|
| 171 |
+
selected_node = st.session_state["selectbox_placeholder"].selectbox(
|
| 172 |
+
"Select a node:",
|
| 173 |
+
list(_nodes.keys()),
|
| 174 |
+
key=f"key_{selected_node_key}",
|
| 175 |
+
index=list(_nodes.keys()).index(node_name),
|
| 176 |
+
)
|
| 177 |
+
st.session_state["already_used_keys"].append(selected_node_key)
|
| 178 |
+
else:
|
| 179 |
+
selected_node = node_name
|
| 180 |
+
st.session_state[selected_node_key] = selected_node
|
| 181 |
+
node_info_key = f"{selected_node}_info"
|
| 182 |
+
if node_info_key not in st.session_state:
|
| 183 |
+
st.session_state[node_info_key] = [["thought", ""]]
|
| 184 |
+
if response["stream_state"] in [AgentStatusCode.STREAM_ING]:
|
| 185 |
+
content = response["formatted"]["thought"]
|
| 186 |
+
st.session_state[node_info_key][-1][1] = content.replace(
|
| 187 |
+
"<|action_start|><|plugin|>\n", "\n```json\n"
|
| 188 |
+
)
|
| 189 |
+
elif response["stream_state"] in [
|
| 190 |
+
AgentStatusCode.PLUGIN_START,
|
| 191 |
+
AgentStatusCode.PLUGIN_END,
|
| 192 |
+
]:
|
| 193 |
+
thought = response["formatted"]["thought"]
|
| 194 |
+
action = response["formatted"]["action"]
|
| 195 |
+
if isinstance(action, dict):
|
| 196 |
+
action = json.dumps(action, ensure_ascii=False, indent=4)
|
| 197 |
+
content = thought + "\n```json\n" + action
|
| 198 |
+
if response["stream_state"] == AgentStatusCode.PLUGIN_RETURN:
|
| 199 |
+
content += "\n```"
|
| 200 |
+
st.session_state[node_info_key][-1][1] = content
|
| 201 |
+
elif (
|
| 202 |
+
response["stream_state"] == AgentStatusCode.PLUGIN_RETURN
|
| 203 |
+
and st.session_state[node_info_key][-1][1]
|
| 204 |
+
):
|
| 205 |
+
try:
|
| 206 |
+
content = json.loads(response["content"])
|
| 207 |
+
except json.decoder.JSONDecodeError:
|
| 208 |
+
content = response["content"]
|
| 209 |
+
st.session_state[node_info_key].append(
|
| 210 |
+
[
|
| 211 |
+
"observation",
|
| 212 |
+
(
|
| 213 |
+
content
|
| 214 |
+
if isinstance(content, str)
|
| 215 |
+
else f"```json\n{json.dumps(content, ensure_ascii=False, indent=4)}\n```"
|
| 216 |
+
),
|
| 217 |
+
]
|
| 218 |
+
)
|
| 219 |
+
st.session_state["searcher_placeholder"].markdown(
|
| 220 |
+
st.session_state[node_info_key][-1][1]
|
| 221 |
+
)
|
| 222 |
+
if (
|
| 223 |
+
response["stream_state"] == AgentStatusCode.PLUGIN_RETURN
|
| 224 |
+
and st.session_state[node_info_key][-1][1]
|
| 225 |
+
):
|
| 226 |
+
st.session_state[node_info_key].append(["thought", ""])
|
| 227 |
+
if st.session_state["session_info_temp"]:
|
| 228 |
+
st.session_state["responses"][-1].append(st.session_state["session_info_temp"])
|
| 229 |
+
st.session_state["session_info_temp"] = ""
|
| 230 |
+
# st.session_state['responses'][-1] = '\n'.join(st.session_state['responses'][-1])
|
| 231 |
+
st.session_state["graphs_html"].append(graph_html)
|
| 232 |
+
st.session_state["nodes_list"].append(_nodes)
|
| 233 |
+
st.session_state["adjacency_list_list"].append(adjacency_list)
|
| 234 |
+
st.session_state["history"] = history
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def display_chat_history():
|
| 238 |
+
for i, query in enumerate(st.session_state["queries"][-1:]):
|
| 239 |
+
# with st.chat_message('assistant'):
|
| 240 |
+
if st.session_state["graphs_html"][i]:
|
| 241 |
+
with st.session_state["expander_placeholder"].expander("Show Graph", expanded=False):
|
| 242 |
+
st.session_state["graph_placeholder"]._html(
|
| 243 |
+
st.session_state["graphs_html"][i], height=500
|
| 244 |
+
)
|
| 245 |
+
with st.session_state["container_placeholder"].container():
|
| 246 |
+
col1, col2 = st.session_state["columns_placeholder"].columns([2, 1])
|
| 247 |
+
with col1:
|
| 248 |
+
st.session_state["planner_placeholder"].markdown(
|
| 249 |
+
st.session_state["responses"][-1][-1]
|
| 250 |
+
)
|
| 251 |
+
with col2:
|
| 252 |
+
selected_node_key = st.session_state["already_used_keys"][-1]
|
| 253 |
+
st.session_state["selectbox_placeholder"] = st.empty()
|
| 254 |
+
selected_node = st.session_state["selectbox_placeholder"].selectbox(
|
| 255 |
+
"Select a node:",
|
| 256 |
+
list(st.session_state["nodes_list"][i].keys()),
|
| 257 |
+
key=f"replay_key_{i}",
|
| 258 |
+
index=list(st.session_state["nodes_list"][i].keys()).index(
|
| 259 |
+
st.session_state[selected_node_key]
|
| 260 |
+
),
|
| 261 |
+
)
|
| 262 |
+
st.session_state[selected_node_key] = selected_node
|
| 263 |
+
if (
|
| 264 |
+
selected_node not in ["root", "response"]
|
| 265 |
+
and selected_node in st.session_state["nodes_list"][i]
|
| 266 |
+
):
|
| 267 |
+
node_info_key = f"{selected_node}_info"
|
| 268 |
+
for item in st.session_state[node_info_key]:
|
| 269 |
+
if item[0] in ["thought", "answer"]:
|
| 270 |
+
st.session_state["searcher_placeholder"] = st.empty()
|
| 271 |
+
st.session_state["searcher_placeholder"].markdown(item[1])
|
| 272 |
+
elif item[0] == "observation":
|
| 273 |
+
st.session_state["observation_expander"] = st.empty()
|
| 274 |
+
with st.session_state["observation_expander"].expander("Results"):
|
| 275 |
+
st.write(item[1])
|
| 276 |
+
# st.session_state['searcher_placeholder'].markdown(st.session_state[node_info_key])
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def clean_history():
|
| 280 |
+
st.session_state["queries"] = []
|
| 281 |
+
st.session_state["responses"] = []
|
| 282 |
+
st.session_state["graphs_html"] = []
|
| 283 |
+
st.session_state["nodes_list"] = []
|
| 284 |
+
st.session_state["adjacency_list_list"] = []
|
| 285 |
+
st.session_state["history"] = []
|
| 286 |
+
st.session_state["already_used_keys"] = list()
|
| 287 |
+
for k in st.session_state:
|
| 288 |
+
if k.endswith("placeholder") or k.endswith("_info"):
|
| 289 |
+
del st.session_state[k]
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# Main function to run the Streamlit app
|
| 293 |
+
def main():
|
| 294 |
+
st.sidebar.title("Model Control")
|
| 295 |
+
col1, col2 = st.columns([4, 1])
|
| 296 |
+
with col1:
|
| 297 |
+
user_input = st.chat_input("Enter your query:")
|
| 298 |
+
with col2:
|
| 299 |
+
if st.button("Clear History"):
|
| 300 |
+
clean_history()
|
| 301 |
+
if user_input:
|
| 302 |
+
update_chat(user_input)
|
| 303 |
+
display_chat_history()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
import os
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
os.system("pip show starlette")
|
| 310 |
+
# os.system("pip install -r requirements.txt")
|
| 311 |
+
os.system("pip install tenacity")
|
| 312 |
+
os.system("python -m mindsearch.app --lang en --model_format internlm_silicon --search_engine GoogleSearch &")
|
| 313 |
+
|
| 314 |
+
main()
|
mindsearch/__init__.py
ADDED
|
File without changes
|
mindsearch/agent/__init__.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from lagent.actions import AsyncWebBrowser, WebBrowser
|
| 6 |
+
from lagent.agents.stream import get_plugin_prompt
|
| 7 |
+
from lagent.prompts import InterpreterParser, PluginParser
|
| 8 |
+
from lagent.utils import create_object
|
| 9 |
+
|
| 10 |
+
from . import models as llm_factory
|
| 11 |
+
from .mindsearch_agent import AsyncMindSearchAgent, MindSearchAgent
|
| 12 |
+
from .mindsearch_prompt import (
|
| 13 |
+
FINAL_RESPONSE_CN,
|
| 14 |
+
FINAL_RESPONSE_EN,
|
| 15 |
+
GRAPH_PROMPT_CN,
|
| 16 |
+
GRAPH_PROMPT_EN,
|
| 17 |
+
searcher_context_template_cn,
|
| 18 |
+
searcher_context_template_en,
|
| 19 |
+
searcher_input_template_cn,
|
| 20 |
+
searcher_input_template_en,
|
| 21 |
+
searcher_system_prompt_cn,
|
| 22 |
+
searcher_system_prompt_en,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
LLM = {}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def init_agent(lang="cn",
|
| 29 |
+
model_format="internlm_server",
|
| 30 |
+
search_engine="BingSearch",
|
| 31 |
+
use_async=False):
|
| 32 |
+
mode = "async" if use_async else "sync"
|
| 33 |
+
llm = LLM.get(model_format, {}).get(mode)
|
| 34 |
+
if llm is None:
|
| 35 |
+
llm_cfg = deepcopy(getattr(llm_factory, model_format))
|
| 36 |
+
if llm_cfg is None:
|
| 37 |
+
raise NotImplementedError
|
| 38 |
+
if use_async:
|
| 39 |
+
cls_name = (
|
| 40 |
+
llm_cfg["type"].split(".")[-1] if isinstance(
|
| 41 |
+
llm_cfg["type"], str) else llm_cfg["type"].__name__)
|
| 42 |
+
llm_cfg["type"] = f"lagent.llms.Async{cls_name}"
|
| 43 |
+
llm = create_object(llm_cfg)
|
| 44 |
+
LLM.setdefault(model_format, {}).setdefault(mode, llm)
|
| 45 |
+
|
| 46 |
+
date = datetime.now().strftime("The current date is %Y-%m-%d.")
|
| 47 |
+
plugins = [(dict(
|
| 48 |
+
type=AsyncWebBrowser if use_async else WebBrowser,
|
| 49 |
+
searcher_type=search_engine,
|
| 50 |
+
topk=6,
|
| 51 |
+
secret_id=os.getenv("TENCENT_SEARCH_SECRET_ID"),
|
| 52 |
+
secret_key=os.getenv("TENCENT_SEARCH_SECRET_KEY"),
|
| 53 |
+
) if search_engine == "TencentSearch" else dict(
|
| 54 |
+
type=AsyncWebBrowser if use_async else WebBrowser,
|
| 55 |
+
searcher_type=search_engine,
|
| 56 |
+
topk=6,
|
| 57 |
+
api_key=os.getenv("WEB_SEARCH_API_KEY"),
|
| 58 |
+
))]
|
| 59 |
+
agent = (AsyncMindSearchAgent if use_async else MindSearchAgent)(
|
| 60 |
+
llm=llm,
|
| 61 |
+
template=date,
|
| 62 |
+
output_format=InterpreterParser(
|
| 63 |
+
template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
|
| 64 |
+
searcher_cfg=dict(
|
| 65 |
+
llm=llm,
|
| 66 |
+
plugins=plugins,
|
| 67 |
+
template=date,
|
| 68 |
+
output_format=PluginParser(
|
| 69 |
+
template=searcher_system_prompt_cn
|
| 70 |
+
if lang == "cn" else searcher_system_prompt_en,
|
| 71 |
+
tool_info=get_plugin_prompt(plugins),
|
| 72 |
+
),
|
| 73 |
+
user_input_template=(searcher_input_template_cn if lang == "cn"
|
| 74 |
+
else searcher_input_template_en),
|
| 75 |
+
user_context_template=(searcher_context_template_cn if lang == "cn"
|
| 76 |
+
else searcher_context_template_en),
|
| 77 |
+
),
|
| 78 |
+
summary_prompt=FINAL_RESPONSE_CN
|
| 79 |
+
if lang == "cn" else FINAL_RESPONSE_EN,
|
| 80 |
+
max_turn=10,
|
| 81 |
+
)
|
| 82 |
+
return agent
|
mindsearch/agent/graph.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import queue
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
import uuid
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
from copy import deepcopy
|
| 9 |
+
from threading import Thread
|
| 10 |
+
from typing import Dict, List
|
| 11 |
+
|
| 12 |
+
from lagent.actions import BaseAction
|
| 13 |
+
from lagent.schema import AgentMessage, AgentStatusCode
|
| 14 |
+
|
| 15 |
+
from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SearcherAgent(StreamingAgentForInternLM):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
user_input_template: str = "{question}",
|
| 22 |
+
user_context_template: str = None,
|
| 23 |
+
**kwargs,
|
| 24 |
+
):
|
| 25 |
+
self.user_input_template = user_input_template
|
| 26 |
+
self.user_context_template = user_context_template
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
|
| 29 |
+
def forward(
|
| 30 |
+
self,
|
| 31 |
+
question: str,
|
| 32 |
+
topic: str,
|
| 33 |
+
history: List[dict] = None,
|
| 34 |
+
session_id=0,
|
| 35 |
+
**kwargs,
|
| 36 |
+
):
|
| 37 |
+
message = [self.user_input_template.format(question=question, topic=topic)]
|
| 38 |
+
if history and self.user_context_template:
|
| 39 |
+
message = [self.user_context_template.format_map(item) for item in history] + message
|
| 40 |
+
message = "\n".join(message)
|
| 41 |
+
return super().forward(message, session_id=session_id, **kwargs)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class AsyncSearcherAgent(AsyncStreamingAgentForInternLM):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
user_input_template: str = "{question}",
|
| 48 |
+
user_context_template: str = None,
|
| 49 |
+
**kwargs,
|
| 50 |
+
):
|
| 51 |
+
self.user_input_template = user_input_template
|
| 52 |
+
self.user_context_template = user_context_template
|
| 53 |
+
super().__init__(**kwargs)
|
| 54 |
+
|
| 55 |
+
async def forward(
|
| 56 |
+
self,
|
| 57 |
+
question: str,
|
| 58 |
+
topic: str,
|
| 59 |
+
history: List[dict] = None,
|
| 60 |
+
session_id=0,
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
message = [self.user_input_template.format(question=question, topic=topic)]
|
| 64 |
+
if history and self.user_context_template:
|
| 65 |
+
message = [self.user_context_template.format_map(item) for item in history] + message
|
| 66 |
+
message = "\n".join(message)
|
| 67 |
+
async for message in super().forward(message, session_id=session_id, **kwargs):
|
| 68 |
+
yield message
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class WebSearchGraph:
|
| 72 |
+
is_async = False
|
| 73 |
+
SEARCHER_CONFIG = {}
|
| 74 |
+
_SEARCHER_LOOP = []
|
| 75 |
+
_SEARCHER_THREAD = []
|
| 76 |
+
|
| 77 |
+
def __init__(self):
|
| 78 |
+
self.nodes: Dict[str, Dict[str, str]] = {}
|
| 79 |
+
self.adjacency_list: Dict[str, List[dict]] = defaultdict(list)
|
| 80 |
+
self.future_to_query = dict()
|
| 81 |
+
self.searcher_resp_queue = queue.Queue()
|
| 82 |
+
self.executor = ThreadPoolExecutor(max_workers=10)
|
| 83 |
+
self.n_active_tasks = 0
|
| 84 |
+
|
| 85 |
+
def add_root_node(
|
| 86 |
+
self,
|
| 87 |
+
node_content: str,
|
| 88 |
+
node_name: str = "root",
|
| 89 |
+
):
|
| 90 |
+
"""添加起始节点
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
node_content (str): 节点内容
|
| 94 |
+
node_name (str, optional): 节点名称. Defaults to 'root'.
|
| 95 |
+
|
| 96 |
+
"""
|
| 97 |
+
self.nodes[node_name] = dict(content=node_content, type="root")
|
| 98 |
+
self.adjacency_list[node_name] = []
|
| 99 |
+
|
| 100 |
+
def add_node(
|
| 101 |
+
self,
|
| 102 |
+
node_name: str,
|
| 103 |
+
node_content: str,
|
| 104 |
+
):
|
| 105 |
+
"""添加搜索子问题节点
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
node_name (str): 节点名称
|
| 109 |
+
node_content (str): 子问题内容
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
str: 返回搜索结果
|
| 113 |
+
"""
|
| 114 |
+
self.nodes[node_name] = dict(content=node_content, type="searcher")
|
| 115 |
+
self.adjacency_list[node_name] = []
|
| 116 |
+
|
| 117 |
+
parent_nodes = []
|
| 118 |
+
for start_node, adj in self.adjacency_list.items():
|
| 119 |
+
for neighbor in adj:
|
| 120 |
+
if (
|
| 121 |
+
node_name == neighbor
|
| 122 |
+
and start_node in self.nodes
|
| 123 |
+
and "response" in self.nodes[start_node]
|
| 124 |
+
):
|
| 125 |
+
parent_nodes.append(self.nodes[start_node])
|
| 126 |
+
parent_response = [
|
| 127 |
+
dict(question=node["content"], answer=node["response"]) for node in parent_nodes
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
if self.is_async:
|
| 131 |
+
|
| 132 |
+
async def _async_search_node_stream():
|
| 133 |
+
cfg = {
|
| 134 |
+
**self.SEARCHER_CONFIG,
|
| 135 |
+
"plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")),
|
| 136 |
+
}
|
| 137 |
+
agent, session_id = AsyncSearcherAgent(**cfg), random.randint(0, 999999)
|
| 138 |
+
searcher_message = AgentMessage(sender="SearcherAgent", content="")
|
| 139 |
+
try:
|
| 140 |
+
async for searcher_message in agent(
|
| 141 |
+
question=node_content,
|
| 142 |
+
topic=self.nodes["root"]["content"],
|
| 143 |
+
history=parent_response,
|
| 144 |
+
session_id=session_id,
|
| 145 |
+
):
|
| 146 |
+
self.nodes[node_name]["response"] = searcher_message.model_dump()
|
| 147 |
+
self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id)
|
| 148 |
+
self.nodes[node_name]["session_id"] = session_id
|
| 149 |
+
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
|
| 150 |
+
self.searcher_resp_queue.put((None, None, None))
|
| 151 |
+
except Exception as exc:
|
| 152 |
+
self.searcher_resp_queue.put((exc, None, None))
|
| 153 |
+
|
| 154 |
+
self.future_to_query[
|
| 155 |
+
asyncio.run_coroutine_threadsafe(
|
| 156 |
+
_async_search_node_stream(), random.choice(self._SEARCHER_LOOP)
|
| 157 |
+
)
|
| 158 |
+
] = f"{node_name}-{node_content}"
|
| 159 |
+
# self.future_to_query[
|
| 160 |
+
# self.executor.submit(asyncio.run, _async_search_node_stream())
|
| 161 |
+
# ] = f"{node_name}-{node_content}"
|
| 162 |
+
else:
|
| 163 |
+
|
| 164 |
+
def _search_node_stream():
|
| 165 |
+
cfg = {
|
| 166 |
+
**self.SEARCHER_CONFIG,
|
| 167 |
+
"plugins": deepcopy(self.SEARCHER_CONFIG.get("plugins")),
|
| 168 |
+
}
|
| 169 |
+
agent, session_id = SearcherAgent(**cfg), random.randint(0, 999999)
|
| 170 |
+
searcher_message = AgentMessage(sender="SearcherAgent", content="")
|
| 171 |
+
try:
|
| 172 |
+
for searcher_message in agent(
|
| 173 |
+
question=node_content,
|
| 174 |
+
topic=self.nodes["root"]["content"],
|
| 175 |
+
history=parent_response,
|
| 176 |
+
session_id=session_id,
|
| 177 |
+
):
|
| 178 |
+
self.nodes[node_name]["response"] = searcher_message.model_dump()
|
| 179 |
+
self.nodes[node_name]["memory"] = agent.state_dict(session_id=session_id)
|
| 180 |
+
self.nodes[node_name]["session_id"] = session_id
|
| 181 |
+
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
|
| 182 |
+
self.searcher_resp_queue.put((None, None, None))
|
| 183 |
+
except Exception as exc:
|
| 184 |
+
self.searcher_resp_queue.put((exc, None, None))
|
| 185 |
+
|
| 186 |
+
self.future_to_query[
|
| 187 |
+
self.executor.submit(_search_node_stream)
|
| 188 |
+
] = f"{node_name}-{node_content}"
|
| 189 |
+
|
| 190 |
+
self.n_active_tasks += 1
|
| 191 |
+
|
| 192 |
+
def add_response_node(self, node_name="response"):
|
| 193 |
+
"""添加回复节点
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
thought (str): 思考过程
|
| 197 |
+
node_name (str, optional): 节点名称. Defaults to 'response'.
|
| 198 |
+
|
| 199 |
+
"""
|
| 200 |
+
self.nodes[node_name] = dict(type="end")
|
| 201 |
+
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
|
| 202 |
+
|
| 203 |
+
def add_edge(self, start_node: str, end_node: str):
|
| 204 |
+
"""添加边
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
start_node (str): 起始节点名称
|
| 208 |
+
end_node (str): 结束节点名称
|
| 209 |
+
"""
|
| 210 |
+
self.adjacency_list[start_node].append(dict(id=str(uuid.uuid4()), name=end_node, state=2))
|
| 211 |
+
self.searcher_resp_queue.put(
|
| 212 |
+
(start_node, self.nodes[start_node], self.adjacency_list[start_node])
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def reset(self):
|
| 216 |
+
self.nodes = {}
|
| 217 |
+
self.adjacency_list = defaultdict(list)
|
| 218 |
+
|
| 219 |
+
def node(self, node_name: str) -> str:
|
| 220 |
+
return self.nodes[node_name].copy()
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def start_loop(cls, n: int = 32):
|
| 224 |
+
if not cls.is_async:
|
| 225 |
+
raise RuntimeError("Event loop cannot be launched as `is_async` is disabled")
|
| 226 |
+
|
| 227 |
+
assert len(cls._SEARCHER_LOOP) == len(cls._SEARCHER_THREAD)
|
| 228 |
+
for i, (loop, thread) in enumerate(
|
| 229 |
+
zip(cls._SEARCHER_LOOP.copy(), cls._SEARCHER_THREAD.copy())
|
| 230 |
+
):
|
| 231 |
+
if not (loop.is_running() and thread.is_alive()):
|
| 232 |
+
cls._SEARCHER_LOOP.pop(i)
|
| 233 |
+
cls._SEARCHER_THREAD.pop(i)
|
| 234 |
+
|
| 235 |
+
while len(cls._SEARCHER_THREAD) < n:
|
| 236 |
+
|
| 237 |
+
def _start_loop():
|
| 238 |
+
loop = asyncio.new_event_loop()
|
| 239 |
+
asyncio.set_event_loop(loop)
|
| 240 |
+
cls._SEARCHER_LOOP.append(loop)
|
| 241 |
+
loop.run_forever()
|
| 242 |
+
|
| 243 |
+
thread = Thread(target=_start_loop, daemon=True)
|
| 244 |
+
thread.start()
|
| 245 |
+
cls._SEARCHER_THREAD.append(thread)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class ExecutionAction(BaseAction):
|
| 249 |
+
"""Tool used by MindSearch planner to execute graph node query."""
|
| 250 |
+
|
| 251 |
+
def run(self, command, local_dict, global_dict, stream_graph=False):
|
| 252 |
+
def extract_code(text: str) -> str:
|
| 253 |
+
text = re.sub(r"from ([\w.]+) import WebSearchGraph", "", text)
|
| 254 |
+
triple_match = re.search(r"```[^\n]*\n(.+?)```", text, re.DOTALL)
|
| 255 |
+
single_match = re.search(r"`([^`]*)`", text, re.DOTALL)
|
| 256 |
+
if triple_match:
|
| 257 |
+
return triple_match.group(1)
|
| 258 |
+
elif single_match:
|
| 259 |
+
return single_match.group(1)
|
| 260 |
+
return text
|
| 261 |
+
|
| 262 |
+
command = extract_code(command)
|
| 263 |
+
exec(command, global_dict, local_dict)
|
| 264 |
+
|
| 265 |
+
# 匹配所有 graph.node 中的内容
|
| 266 |
+
node_list = re.findall(r"graph.node\((.*?)\)", command)
|
| 267 |
+
graph: WebSearchGraph = local_dict["graph"]
|
| 268 |
+
while graph.n_active_tasks:
|
| 269 |
+
while not graph.searcher_resp_queue.empty():
|
| 270 |
+
node_name, _, _ = graph.searcher_resp_queue.get(timeout=60)
|
| 271 |
+
if isinstance(node_name, Exception):
|
| 272 |
+
raise node_name
|
| 273 |
+
if node_name is None:
|
| 274 |
+
graph.n_active_tasks -= 1
|
| 275 |
+
continue
|
| 276 |
+
if stream_graph:
|
| 277 |
+
for neighbors in graph.adjacency_list.values():
|
| 278 |
+
for neighbor in neighbors:
|
| 279 |
+
# state 1进行中,2未开始,3已结束
|
| 280 |
+
if not (
|
| 281 |
+
neighbor["name"] in graph.nodes
|
| 282 |
+
and "response" in graph.nodes[neighbor["name"]]
|
| 283 |
+
):
|
| 284 |
+
neighbor["state"] = 2
|
| 285 |
+
elif (
|
| 286 |
+
graph.nodes[neighbor["name"]]["response"]["stream_state"]
|
| 287 |
+
== AgentStatusCode.END
|
| 288 |
+
):
|
| 289 |
+
neighbor["state"] = 3
|
| 290 |
+
else:
|
| 291 |
+
neighbor["state"] = 1
|
| 292 |
+
if all(
|
| 293 |
+
"response" in node
|
| 294 |
+
for name, node in graph.nodes.items()
|
| 295 |
+
if name not in ["root", "response"]
|
| 296 |
+
):
|
| 297 |
+
yield AgentMessage(
|
| 298 |
+
sender=self.name,
|
| 299 |
+
content=dict(current_node=node_name),
|
| 300 |
+
formatted=dict(
|
| 301 |
+
node=deepcopy(graph.nodes),
|
| 302 |
+
adjacency_list=deepcopy(graph.adjacency_list),
|
| 303 |
+
),
|
| 304 |
+
stream_state=AgentStatusCode.STREAM_ING,
|
| 305 |
+
)
|
| 306 |
+
res = [graph.nodes[node.strip().strip('"').strip("'")] for node in node_list]
|
| 307 |
+
return res, graph.nodes, graph.adjacency_list
|
mindsearch/agent/mindsearch_agent.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import re
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import Dict, Tuple
|
| 6 |
+
|
| 7 |
+
from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode
|
| 8 |
+
from lagent.utils import GeneratorWithReturn
|
| 9 |
+
|
| 10 |
+
from .graph import ExecutionAction, WebSearchGraph
|
| 11 |
+
from .streaming import AsyncStreamingAgentForInternLM, StreamingAgentForInternLM
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _update_ref(ref: str, ref2url: Dict[str, str], ptr: int) -> str:
|
| 15 |
+
numbers = list({int(n) for n in re.findall(r"\[\[(\d+)\]\]", ref)})
|
| 16 |
+
numbers = {n: idx + 1 for idx, n in enumerate(numbers)}
|
| 17 |
+
updated_ref = re.sub(
|
| 18 |
+
r"\[\[(\d+)\]\]",
|
| 19 |
+
lambda match: f"[[{numbers[int(match.group(1))] + ptr}]]",
|
| 20 |
+
ref,
|
| 21 |
+
)
|
| 22 |
+
updated_ref2url = {}
|
| 23 |
+
if numbers:
|
| 24 |
+
try:
|
| 25 |
+
assert all(elem in ref2url for elem in numbers)
|
| 26 |
+
except Exception as exc:
|
| 27 |
+
logging.info(f"Illegal reference id: {str(exc)}")
|
| 28 |
+
if ref2url:
|
| 29 |
+
updated_ref2url = {
|
| 30 |
+
numbers[idx] + ptr: ref2url[idx] for idx in numbers if idx in ref2url
|
| 31 |
+
}
|
| 32 |
+
return updated_ref, updated_ref2url, len(numbers) + 1
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _generate_references_from_graph(graph: Dict[str, dict]) -> Tuple[str, Dict[int, dict]]:
|
| 36 |
+
ptr, references, references_url = 0, [], {}
|
| 37 |
+
for name, data_item in graph.items():
|
| 38 |
+
if name in ["root", "response"]:
|
| 39 |
+
continue
|
| 40 |
+
# only search once at each node, thus the result offset is 2
|
| 41 |
+
assert data_item["memory"]["agent.memory"][2]["sender"].endswith("ActionExecutor")
|
| 42 |
+
ref2url = {
|
| 43 |
+
int(k): v
|
| 44 |
+
for k, v in json.loads(data_item["memory"]["agent.memory"][2]["content"]).items()
|
| 45 |
+
}
|
| 46 |
+
updata_ref, ref2url, added_ptr = _update_ref(
|
| 47 |
+
data_item["response"]["content"], ref2url, ptr
|
| 48 |
+
)
|
| 49 |
+
ptr += added_ptr
|
| 50 |
+
references.append(f'## {data_item["content"]}\n\n{updata_ref}')
|
| 51 |
+
references_url.update(ref2url)
|
| 52 |
+
return "\n\n".join(references), references_url
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class MindSearchAgent(StreamingAgentForInternLM):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
searcher_cfg: dict,
|
| 59 |
+
summary_prompt: str,
|
| 60 |
+
finish_condition=lambda m: "add_response_node" in m.content,
|
| 61 |
+
max_turn: int = 10,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
WebSearchGraph.SEARCHER_CONFIG = searcher_cfg
|
| 65 |
+
super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs)
|
| 66 |
+
self.summary_prompt = summary_prompt
|
| 67 |
+
self.action = ExecutionAction()
|
| 68 |
+
|
| 69 |
+
def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
| 70 |
+
if isinstance(message, str):
|
| 71 |
+
message = AgentMessage(sender="user", content=message)
|
| 72 |
+
_graph_state = dict(node={}, adjacency_list={}, ref2url={})
|
| 73 |
+
local_dict, global_dict = {}, globals()
|
| 74 |
+
for _ in range(self.max_turn):
|
| 75 |
+
last_agent_state = AgentStatusCode.SESSION_READY
|
| 76 |
+
for message in self.agent(message, session_id=session_id, **kwargs):
|
| 77 |
+
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
|
| 78 |
+
if message.stream_state == ModelStatusCode.END:
|
| 79 |
+
message.stream_state = last_agent_state + int(
|
| 80 |
+
last_agent_state
|
| 81 |
+
in [
|
| 82 |
+
AgentStatusCode.CODING,
|
| 83 |
+
AgentStatusCode.PLUGIN_START,
|
| 84 |
+
]
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
message.stream_state = (
|
| 88 |
+
AgentStatusCode.PLUGIN_START
|
| 89 |
+
if message.formatted["tool_type"] == "plugin"
|
| 90 |
+
else AgentStatusCode.CODING
|
| 91 |
+
)
|
| 92 |
+
else:
|
| 93 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 94 |
+
message.formatted.update(deepcopy(_graph_state))
|
| 95 |
+
yield message
|
| 96 |
+
last_agent_state = message.stream_state
|
| 97 |
+
if not message.formatted["tool_type"]:
|
| 98 |
+
message.stream_state = AgentStatusCode.END
|
| 99 |
+
yield message
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
gen = GeneratorWithReturn(
|
| 103 |
+
self.action.run(message.content, local_dict, global_dict, True)
|
| 104 |
+
)
|
| 105 |
+
for graph_exec in gen:
|
| 106 |
+
graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"])
|
| 107 |
+
yield graph_exec
|
| 108 |
+
|
| 109 |
+
reference, references_url = _generate_references_from_graph(gen.ret[1])
|
| 110 |
+
_graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url)
|
| 111 |
+
if self.finish_condition(message):
|
| 112 |
+
message = AgentMessage(
|
| 113 |
+
sender="ActionExecutor",
|
| 114 |
+
content=self.summary_prompt,
|
| 115 |
+
formatted=deepcopy(_graph_state),
|
| 116 |
+
stream_state=message.stream_state + 1, # plugin or code return
|
| 117 |
+
)
|
| 118 |
+
yield message
|
| 119 |
+
# summarize the references to generate the final answer
|
| 120 |
+
for message in self.agent(message, session_id=session_id, **kwargs):
|
| 121 |
+
message.formatted.update(deepcopy(_graph_state))
|
| 122 |
+
yield message
|
| 123 |
+
return
|
| 124 |
+
message = AgentMessage(
|
| 125 |
+
sender="ActionExecutor",
|
| 126 |
+
content=reference,
|
| 127 |
+
formatted=deepcopy(_graph_state),
|
| 128 |
+
stream_state=message.stream_state + 1, # plugin or code return
|
| 129 |
+
)
|
| 130 |
+
yield message
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class AsyncMindSearchAgent(AsyncStreamingAgentForInternLM):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
searcher_cfg: dict,
|
| 137 |
+
summary_prompt: str,
|
| 138 |
+
finish_condition=lambda m: "add_response_node" in m.content,
|
| 139 |
+
max_turn: int = 10,
|
| 140 |
+
**kwargs,
|
| 141 |
+
):
|
| 142 |
+
WebSearchGraph.SEARCHER_CONFIG = searcher_cfg
|
| 143 |
+
WebSearchGraph.is_async = True
|
| 144 |
+
WebSearchGraph.start_loop()
|
| 145 |
+
super().__init__(finish_condition=finish_condition, max_turn=max_turn, **kwargs)
|
| 146 |
+
self.summary_prompt = summary_prompt
|
| 147 |
+
self.action = ExecutionAction()
|
| 148 |
+
|
| 149 |
+
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
| 150 |
+
if isinstance(message, str):
|
| 151 |
+
message = AgentMessage(sender="user", content=message)
|
| 152 |
+
_graph_state = dict(node={}, adjacency_list={}, ref2url={})
|
| 153 |
+
local_dict, global_dict = {}, globals()
|
| 154 |
+
for _ in range(self.max_turn):
|
| 155 |
+
last_agent_state = AgentStatusCode.SESSION_READY
|
| 156 |
+
async for message in self.agent(message, session_id=session_id, **kwargs):
|
| 157 |
+
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
|
| 158 |
+
if message.stream_state == ModelStatusCode.END:
|
| 159 |
+
message.stream_state = last_agent_state + int(
|
| 160 |
+
last_agent_state
|
| 161 |
+
in [
|
| 162 |
+
AgentStatusCode.CODING,
|
| 163 |
+
AgentStatusCode.PLUGIN_START,
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
else:
|
| 167 |
+
message.stream_state = (
|
| 168 |
+
AgentStatusCode.PLUGIN_START
|
| 169 |
+
if message.formatted["tool_type"] == "plugin"
|
| 170 |
+
else AgentStatusCode.CODING
|
| 171 |
+
)
|
| 172 |
+
else:
|
| 173 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 174 |
+
message.formatted.update(deepcopy(_graph_state))
|
| 175 |
+
yield message
|
| 176 |
+
last_agent_state = message.stream_state
|
| 177 |
+
if not message.formatted["tool_type"]:
|
| 178 |
+
message.stream_state = AgentStatusCode.END
|
| 179 |
+
yield message
|
| 180 |
+
return
|
| 181 |
+
|
| 182 |
+
gen = GeneratorWithReturn(
|
| 183 |
+
self.action.run(message.content, local_dict, global_dict, True)
|
| 184 |
+
)
|
| 185 |
+
for graph_exec in gen:
|
| 186 |
+
graph_exec.formatted["ref2url"] = deepcopy(_graph_state["ref2url"])
|
| 187 |
+
yield graph_exec
|
| 188 |
+
|
| 189 |
+
reference, references_url = _generate_references_from_graph(gen.ret[1])
|
| 190 |
+
_graph_state.update(node=gen.ret[1], adjacency_list=gen.ret[2], ref2url=references_url)
|
| 191 |
+
if self.finish_condition(message):
|
| 192 |
+
message = AgentMessage(
|
| 193 |
+
sender="ActionExecutor",
|
| 194 |
+
content=self.summary_prompt,
|
| 195 |
+
formatted=deepcopy(_graph_state),
|
| 196 |
+
stream_state=message.stream_state + 1, # plugin or code return
|
| 197 |
+
)
|
| 198 |
+
yield message
|
| 199 |
+
# summarize the references to generate the final answer
|
| 200 |
+
async for message in self.agent(message, session_id=session_id, **kwargs):
|
| 201 |
+
message.formatted.update(deepcopy(_graph_state))
|
| 202 |
+
yield message
|
| 203 |
+
return
|
| 204 |
+
message = AgentMessage(
|
| 205 |
+
sender="ActionExecutor",
|
| 206 |
+
content=reference,
|
| 207 |
+
formatted=deepcopy(_graph_state),
|
| 208 |
+
stream_state=message.stream_state + 1, # plugin or code return
|
| 209 |
+
)
|
| 210 |
+
yield message
|
mindsearch/agent/mindsearch_prompt.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa
|
| 2 |
+
|
| 3 |
+
searcher_system_prompt_cn = """## 人物简介
|
| 4 |
+
你是一个可以调用网络搜索工具的智能助手。请根据"当前问题",调用搜索工具收集信息并回复问题。你能够调用如下工具:
|
| 5 |
+
{tool_info}
|
| 6 |
+
## 回复格式
|
| 7 |
+
|
| 8 |
+
调用工具时,请按照以下格式:
|
| 9 |
+
```
|
| 10 |
+
你的思考过程...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
## 要求
|
| 14 |
+
|
| 15 |
+
- 回答中每个关键点需标注引用的搜索结果来源,以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
|
| 16 |
+
- 基于"当前问题"的搜索结果,撰写详细完备的回复,优先回答"当前问题"。
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
searcher_system_prompt_en = """## Character Introduction
|
| 21 |
+
You are an intelligent assistant that can call web search tools. Please collect information and reply to the question based on the current problem. You can use the following tools:
|
| 22 |
+
{tool_info}
|
| 23 |
+
## Reply Format
|
| 24 |
+
|
| 25 |
+
When calling the tool, please follow the format below:
|
| 26 |
+
```
|
| 27 |
+
Your thought process...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
## Requirements
|
| 31 |
+
|
| 32 |
+
- Each key point in the response should be marked with the source of the search results to ensure the credibility of the information. The citation format is `[[int]]`. If there are multiple citations, use multiple [[]] to provide the index, such as `[[id_1]][[id_2]]`.
|
| 33 |
+
- Based on the search results of the "current problem", write a detailed and complete reply to answer the "current problem".
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
fewshot_example_cn = """
|
| 37 |
+
## 样例
|
| 38 |
+
|
| 39 |
+
### search
|
| 40 |
+
当我希望搜索"王者荣耀现在是什么赛季"时,我会按照以下格式进行操作:
|
| 41 |
+
现在是2024年,因此我应该搜索王者荣耀赛季关键词<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["王者荣耀 赛季", "2024年王者荣耀赛季"]}}}}<|action_end|>
|
| 42 |
+
|
| 43 |
+
### select
|
| 44 |
+
为了找到王者荣耀s36赛季最强射手,我需要寻找提及王者荣耀s36射手的网页。初步浏览网页后,发现网页0提到王者荣耀s36赛季的信息,但没有具体提及射手的相关信息。网页3提到“s36最强射手出现?”,有可能包含最强射手信息。网页13提到“四大T0英雄崛起,射手荣耀降临”,可能包含最强射手的信息。因此,我选择了网页3和网页13进行进一步阅读。<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
fewshot_example_en = """
|
| 48 |
+
## Example
|
| 49 |
+
|
| 50 |
+
### search
|
| 51 |
+
When I want to search for "What season is Honor of Kings now", I will operate in the following format:
|
| 52 |
+
Now it is 2024, so I should search for the keyword of the Honor of Kings<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["Honor of Kings Season", "season for Honor of Kings in 2024"]}}}}<|action_end|>
|
| 53 |
+
|
| 54 |
+
### select
|
| 55 |
+
In order to find the strongest shooters in Honor of Kings in season s36, I needed to look for web pages that mentioned shooters in Honor of Kings in season s36. After an initial browse of the web pages, I found that web page 0 mentions information about Honor of Kings in s36 season, but there is no specific mention of information about the shooter. Webpage 3 mentions that “the strongest shooter in s36 has appeared?”, which may contain information about the strongest shooter. Webpage 13 mentions “Four T0 heroes rise, archer's glory”, which may contain information about the strongest archer. Therefore, I chose webpages 3 and 13 for further reading.<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
searcher_input_template_en = """## Final Problem
|
| 59 |
+
{topic}
|
| 60 |
+
## Current Problem
|
| 61 |
+
{question}
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
searcher_input_template_cn = """## 主问题
|
| 65 |
+
{topic}
|
| 66 |
+
## 当前问题
|
| 67 |
+
{question}
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
searcher_context_template_en = """## Historical Problem
|
| 71 |
+
{question}
|
| 72 |
+
Answer: {answer}
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
searcher_context_template_cn = """## 历史问题
|
| 76 |
+
{question}
|
| 77 |
+
回答:{answer}
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
search_template_cn = "## {query}\n\n{result}\n"
|
| 81 |
+
search_template_en = "## {query}\n\n{result}\n"
|
| 82 |
+
|
| 83 |
+
GRAPH_PROMPT_CN = """## 人物简介
|
| 84 |
+
你是一个可以利用 Jupyter 环境 Python 编程的程序员。你可以利用提供的 API 来构建 Web 搜索图,最终生成代码并执行。
|
| 85 |
+
|
| 86 |
+
## API 介绍
|
| 87 |
+
|
| 88 |
+
下面是包含属性详细说明的 `WebSearchGraph` 类的 API 文档:
|
| 89 |
+
|
| 90 |
+
### 类:`WebSearchGraph`
|
| 91 |
+
|
| 92 |
+
此类用于管理网络搜索图的节点和边,并通过网络代理进行搜索。
|
| 93 |
+
|
| 94 |
+
#### 初始化方法
|
| 95 |
+
|
| 96 |
+
初始化 `WebSearchGraph` 实例。
|
| 97 |
+
|
| 98 |
+
**属性:**
|
| 99 |
+
|
| 100 |
+
- `nodes` (Dict[str, Dict[str, str]]): 存储图中所有节点的字典。每个节点由其名称索引,并包含内容、类型以及其他相关信息。
|
| 101 |
+
- `adjacency_list` (Dict[str, List[str]]): 存储图中所有节点之间连接关系的邻接表。每个节点由其名称索引,并包含一个相邻节点名称的列表。
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
#### 方法:`add_root_node`
|
| 105 |
+
|
| 106 |
+
添加原始问题作为根节点。
|
| 107 |
+
**参数:**
|
| 108 |
+
|
| 109 |
+
- `node_content` (str): 用户提出的问题。
|
| 110 |
+
- `node_name` (str, 可选): 节点名称,默认为 'root'。
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
#### 方法:`add_node`
|
| 114 |
+
|
| 115 |
+
添加搜索子问题节点并返回搜索结果。
|
| 116 |
+
**参数:
|
| 117 |
+
|
| 118 |
+
- `node_name` (str): 节点名称。
|
| 119 |
+
- `node_content` (str): 子问题内容。
|
| 120 |
+
|
| 121 |
+
**返回:**
|
| 122 |
+
|
| 123 |
+
- `str`: 返回搜索结果。
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
#### 方法:`add_response_node`
|
| 127 |
+
|
| 128 |
+
当前获取的信息已经满足问题需求,添加回复节点。
|
| 129 |
+
|
| 130 |
+
**参数:**
|
| 131 |
+
|
| 132 |
+
- `node_name` (str, 可选): 节点名称,默认为 'response'。
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
#### 方法:`add_edge`
|
| 136 |
+
|
| 137 |
+
添加边。
|
| 138 |
+
|
| 139 |
+
**参数:**
|
| 140 |
+
|
| 141 |
+
- `start_node` (str): 起始节点名称。
|
| 142 |
+
- `end_node` (str): 结束节点名称。
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
#### 方法:`reset`
|
| 146 |
+
|
| 147 |
+
重置节点和边。
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
#### 方法:`node`
|
| 151 |
+
|
| 152 |
+
获取节点信息。
|
| 153 |
+
|
| 154 |
+
```python
|
| 155 |
+
def node(self, node_name: str) -> str
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
**参数:**
|
| 159 |
+
|
| 160 |
+
- `node_name` (str): 节点名称。
|
| 161 |
+
|
| 162 |
+
**返回:**
|
| 163 |
+
|
| 164 |
+
- `str`: 返回包含节点信息的字典,包含节点的内容、类型、思考过程(如果有)和前驱节点列表。
|
| 165 |
+
|
| 166 |
+
## 任务介绍
|
| 167 |
+
通过将一个问题拆分成能够通过搜索回答的子问题(没有关联的问题可以同步并列搜索),每个搜索的问题应该是一个单一问题,即单个具体人、事、物、具体时间点、地点或知识点的问题,不是一个复合问题(比如某个时间段), 一步步构建搜索图,最终回答问题。
|
| 168 |
+
|
| 169 |
+
## 注意事项
|
| 170 |
+
|
| 171 |
+
1. 注意,每个搜索节点的内容必须单个问题,不要包含多个问题(比如同时问多个知识点的问题或者多个事物的比较加筛选,类似 A, B, C 有什么区别,那个价格在哪个区间 -> 分别查询)
|
| 172 |
+
2. 不要杜撰搜索结果,要等待代码返回结果
|
| 173 |
+
3. 同样的问题不要重复提问,可以在已有问题的基础上继续提问
|
| 174 |
+
4. 添加 response 节点的时候,要单独添加,不要和其他节点一起添加,不能同时添加 response 节点和其他节点
|
| 175 |
+
5. 一次输出中,不要包含多个代码块,每次只能有一个代码块
|
| 176 |
+
6. 每个代码块应该放置在一个代码块标记中,同时生成完代码后添加一个<|action_end|>标志,如下所示:
|
| 177 |
+
<|action_start|><|interpreter|>```python
|
| 178 |
+
# 你的代码块
|
| 179 |
+
```<|action_end|>
|
| 180 |
+
7. 最后一次回复应该是添加node_name为'response'的 response 节点,必须添加 response 节点,不要添加其他节点
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
GRAPH_PROMPT_EN = """## Character Profile
|
| 184 |
+
You are a programmer capable of Python programming in a Jupyter environment. You can utilize the provided API to construct a Web Search Graph, ultimately generating and executing code.
|
| 185 |
+
|
| 186 |
+
## API Description
|
| 187 |
+
|
| 188 |
+
Below is the API documentation for the WebSearchGraph class, including detailed attribute descriptions:
|
| 189 |
+
|
| 190 |
+
### Class: WebSearchGraph
|
| 191 |
+
|
| 192 |
+
This class manages nodes and edges of a web search graph and conducts searches via a web proxy.
|
| 193 |
+
|
| 194 |
+
#### Initialization Method
|
| 195 |
+
|
| 196 |
+
Initializes an instance of WebSearchGraph.
|
| 197 |
+
|
| 198 |
+
**Attributes:**
|
| 199 |
+
|
| 200 |
+
- nodes (Dict[str, Dict[str, str]]): A dictionary storing all nodes in the graph. Each node is indexed by its name and contains content, type, and other related information.
|
| 201 |
+
- adjacency_list (Dict[str, List[str]]): An adjacency list storing the connections between all nodes in the graph. Each node is indexed by its name and contains a list of adjacent node names.
|
| 202 |
+
|
| 203 |
+
#### Method: add_root_node
|
| 204 |
+
|
| 205 |
+
Adds the initial question as the root node.
|
| 206 |
+
**Parameters:**
|
| 207 |
+
|
| 208 |
+
- node_content (str): The user's question.
|
| 209 |
+
- node_name (str, optional): The node name, default is 'root'.
|
| 210 |
+
|
| 211 |
+
#### Method: add_node
|
| 212 |
+
|
| 213 |
+
Adds a sub-question node and returns search results.
|
| 214 |
+
**Parameters:**
|
| 215 |
+
|
| 216 |
+
- node_name (str): The node name.
|
| 217 |
+
- node_content (str): The sub-question content.
|
| 218 |
+
|
| 219 |
+
**Returns:**
|
| 220 |
+
|
| 221 |
+
- str: Returns the search results.
|
| 222 |
+
|
| 223 |
+
#### Method: add_response_node
|
| 224 |
+
|
| 225 |
+
Adds a response node when the current information satisfies the question's requirements.
|
| 226 |
+
|
| 227 |
+
**Parameters:**
|
| 228 |
+
|
| 229 |
+
- node_name (str, optional): The node name, default is 'response'.
|
| 230 |
+
|
| 231 |
+
#### Method: add_edge
|
| 232 |
+
|
| 233 |
+
Adds an edge.
|
| 234 |
+
|
| 235 |
+
**Parameters:**
|
| 236 |
+
|
| 237 |
+
- start_node (str): The starting node name.
|
| 238 |
+
- end_node (str): The ending node name.
|
| 239 |
+
|
| 240 |
+
#### Method: reset
|
| 241 |
+
|
| 242 |
+
Resets nodes and edges.
|
| 243 |
+
|
| 244 |
+
#### Method: node
|
| 245 |
+
|
| 246 |
+
Get node information.
|
| 247 |
+
|
| 248 |
+
python
|
| 249 |
+
def node(self, node_name: str) -> str
|
| 250 |
+
|
| 251 |
+
**Parameters:**
|
| 252 |
+
|
| 253 |
+
- node_name (str): The node name.
|
| 254 |
+
|
| 255 |
+
**Returns:**
|
| 256 |
+
|
| 257 |
+
- str: Returns a dictionary containing the node's information, including content, type, thought process (if any), and list of predecessor nodes.
|
| 258 |
+
|
| 259 |
+
## Task Description
|
| 260 |
+
By breaking down a question into sub-questions that can be answered through searches (unrelated questions can be searched concurrently), each search query should be a single question focusing on a specific person, event, object, specific time point, location, or knowledge point. It should not be a compound question (e.g., a time period). Step by step, build the search graph to finally answer the question.
|
| 261 |
+
|
| 262 |
+
## Considerations
|
| 263 |
+
|
| 264 |
+
1. Each search node's content must be a single question; do not include multiple questions (e.g., do not ask multiple knowledge points or compare and filter multiple things simultaneously, like asking for differences between A, B, and C, or price ranges -> query each separately).
|
| 265 |
+
2. Do not fabricate search results; wait for the code to return results.
|
| 266 |
+
3. Do not repeat the same question; continue asking based on existing questions.
|
| 267 |
+
4. When adding a response node, add it separately; do not add a response node and other nodes simultaneously.
|
| 268 |
+
5. In a single output, do not include multiple code blocks; only one code block per output.
|
| 269 |
+
6. Each code block should be placed within a code block marker, and after generating the code, add an <|action_end|> tag as shown below:
|
| 270 |
+
<|action_start|><|interpreter|>
|
| 271 |
+
```python
|
| 272 |
+
# Your code block (Note that the 'Get new added node information' logic must be added at the end of the code block, such as 'graph.node('...')')
|
| 273 |
+
```<|action_end|>
|
| 274 |
+
7. The final response should add a response node with node_name 'response', and no other nodes should be added.
|
| 275 |
+
"""
|
| 276 |
+
|
| 277 |
+
graph_fewshot_example_cn = """
|
| 278 |
+
## 返回格式示例
|
| 279 |
+
<|action_start|><|interpreter|>```python
|
| 280 |
+
graph = WebSearchGraph()
|
| 281 |
+
graph.add_root_node(node_content="哪家大模型API最便宜?", node_name="root") # 添加原始问题作为根节点
|
| 282 |
+
graph.add_node(
|
| 283 |
+
node_name="大模型API提供商", # 节点名称最好有意义
|
| 284 |
+
node_content="目前有哪些主要的大模型API提供商?")
|
| 285 |
+
graph.add_node(
|
| 286 |
+
node_name="sub_name_2", # 节点名称最好有意义
|
| 287 |
+
node_content="content of sub_name_2")
|
| 288 |
+
...
|
| 289 |
+
graph.add_edge(start_node="root", end_node="sub_name_1")
|
| 290 |
+
...
|
| 291 |
+
graph.node("大模型API提供商"), graph.node("sub_name_2"), ...
|
| 292 |
+
```<|action_end|>
|
| 293 |
+
"""
|
| 294 |
+
|
| 295 |
+
graph_fewshot_example_en = """
|
| 296 |
+
## Response Format
|
| 297 |
+
<|action_start|><|interpreter|>```python
|
| 298 |
+
graph = WebSearchGraph()
|
| 299 |
+
graph.add_root_node(node_content="Which large model API is the cheapest?", node_name="root") # Add the original question as the root node
|
| 300 |
+
graph.add_node(
|
| 301 |
+
node_name="Large Model API Providers", # The node name should be meaningful
|
| 302 |
+
node_content="Who are the main large model API providers currently?")
|
| 303 |
+
graph.add_node(
|
| 304 |
+
node_name="sub_name_2", # The node name should be meaningful
|
| 305 |
+
node_content="content of sub_name_2")
|
| 306 |
+
...
|
| 307 |
+
graph.add_edge(start_node="root", end_node="sub_name_1")
|
| 308 |
+
...
|
| 309 |
+
# Get node info
|
| 310 |
+
graph.node("Large Model API Providers"), graph.node("sub_name_2"), ...
|
| 311 |
+
```<|action_end|>
|
| 312 |
+
"""
|
| 313 |
+
|
| 314 |
+
FINAL_RESPONSE_CN = """基于提供的问答对,撰写一篇详细完备的最终回答。
|
| 315 |
+
- 回答内容需要逻辑清晰,层次分明,确保读者易于理解。
|
| 316 |
+
- 回答中每个关键点需标注引用的搜索结果来源(保持跟问答对中的索引一致),以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
|
| 317 |
+
- 回答部分需要全面且完备,不要出现"基于上述内容"等模糊表达,最终呈现的回答不包括提供给你的问答对。
|
| 318 |
+
- 语言风格需要专业、严谨,避免口语化表达。
|
| 319 |
+
- 保持统一的语法和词汇使用,确保整体文档的一致性和连贯性。"""
|
| 320 |
+
|
| 321 |
+
FINAL_RESPONSE_EN = """Based on the provided Q&A pairs, write a detailed and comprehensive final response.
|
| 322 |
+
- The response content should be logically clear and well-structured to ensure reader understanding.
|
| 323 |
+
- Each key point in the response should be marked with the source of the search results (consistent with the indices in the Q&A pairs) to ensure information credibility. The index is in the form of `[[int]]`, and if there are multiple indices, use multiple `[[]]`, such as `[[id_1]][[id_2]]`.
|
| 324 |
+
- The response should be comprehensive and complete, without vague expressions like "based on the above content". The final response should not include the Q&A pairs provided to you.
|
| 325 |
+
- The language style should be professional and rigorous, avoiding colloquial expressions.
|
| 326 |
+
- Maintain consistent grammar and vocabulary usage to ensure overall document consistency and coherence."""
|
mindsearch/agent/models.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from lagent.llms import (
|
| 5 |
+
GPTAPI,
|
| 6 |
+
INTERNLM2_META,
|
| 7 |
+
HFTransformerCasualLM,
|
| 8 |
+
LMDeployClient,
|
| 9 |
+
LMDeployServer,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
internlm_server = dict(
|
| 13 |
+
type=LMDeployServer,
|
| 14 |
+
path="internlm/internlm2_5-7b-chat",
|
| 15 |
+
model_name="internlm2_5-7b-chat",
|
| 16 |
+
meta_template=INTERNLM2_META,
|
| 17 |
+
top_p=0.8,
|
| 18 |
+
top_k=1,
|
| 19 |
+
temperature=0,
|
| 20 |
+
max_new_tokens=8192,
|
| 21 |
+
repetition_penalty=1.02,
|
| 22 |
+
stop_words=["<|im_end|>"],
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
internlm_client = dict(
|
| 26 |
+
type=LMDeployClient,
|
| 27 |
+
model_name="internlm2_5-7b-chat",
|
| 28 |
+
url="http://127.0.0.1:23333",
|
| 29 |
+
meta_template=INTERNLM2_META,
|
| 30 |
+
top_p=0.8,
|
| 31 |
+
top_k=1,
|
| 32 |
+
temperature=0,
|
| 33 |
+
max_new_tokens=8192,
|
| 34 |
+
repetition_penalty=1.02,
|
| 35 |
+
stop_words=["<|im_end|>"],
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
internlm_hf = dict(
|
| 39 |
+
type=HFTransformerCasualLM,
|
| 40 |
+
path="internlm/internlm2_5-7b-chat",
|
| 41 |
+
meta_template=INTERNLM2_META,
|
| 42 |
+
top_p=0.8,
|
| 43 |
+
top_k=None,
|
| 44 |
+
temperature=1e-6,
|
| 45 |
+
max_new_tokens=8192,
|
| 46 |
+
repetition_penalty=1.02,
|
| 47 |
+
stop_words=["<|im_end|>"],
|
| 48 |
+
)
|
| 49 |
+
# openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions
|
| 50 |
+
gpt4 = dict(
|
| 51 |
+
type=GPTAPI,
|
| 52 |
+
model_type="gpt-4-turbo",
|
| 53 |
+
key=os.environ.get("OPENAI_API_KEY", "YOUR OPENAI API KEY"),
|
| 54 |
+
api_base=os.environ.get("OPENAI_API_BASE",
|
| 55 |
+
"https://api.openai.com/v1/chat/completions"),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
| 59 |
+
qwen = dict(
|
| 60 |
+
type=GPTAPI,
|
| 61 |
+
model_type="qwen-max-longcontext",
|
| 62 |
+
key=os.environ.get("QWEN_API_KEY", "YOUR QWEN API KEY"),
|
| 63 |
+
api_base=url,
|
| 64 |
+
meta_template=[
|
| 65 |
+
dict(role="system", api_role="system"),
|
| 66 |
+
dict(role="user", api_role="user"),
|
| 67 |
+
dict(role="assistant", api_role="assistant"),
|
| 68 |
+
dict(role="environment", api_role="system"),
|
| 69 |
+
],
|
| 70 |
+
top_p=0.8,
|
| 71 |
+
top_k=1,
|
| 72 |
+
temperature=0,
|
| 73 |
+
max_new_tokens=4096,
|
| 74 |
+
repetition_penalty=1.02,
|
| 75 |
+
stop_words=["<|im_end|>"],
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
internlm_silicon = dict(
|
| 79 |
+
type=GPTAPI,
|
| 80 |
+
model_type="internlm/internlm2_5-7b-chat",
|
| 81 |
+
key=os.environ.get("SILICON_API_KEY", "YOUR SILICON API KEY"),
|
| 82 |
+
api_base="https://api.siliconflow.cn/v1/chat/completions",
|
| 83 |
+
meta_template=[
|
| 84 |
+
dict(role="system", api_role="system"),
|
| 85 |
+
dict(role="user", api_role="user"),
|
| 86 |
+
dict(role="assistant", api_role="assistant"),
|
| 87 |
+
dict(role="environment", api_role="system"),
|
| 88 |
+
],
|
| 89 |
+
top_p=0.8,
|
| 90 |
+
top_k=1,
|
| 91 |
+
temperature=0,
|
| 92 |
+
max_new_tokens=8192,
|
| 93 |
+
repetition_penalty=1.02,
|
| 94 |
+
stop_words=["<|im_end|>"],
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
internlm_api = dict(
|
| 99 |
+
type=GPTAPI,
|
| 100 |
+
model_type="internlm2.5-latest",
|
| 101 |
+
key=os.environ.get("InternLM_API_KEY", "YOUR InternLM API KEY https://internlm.intern-ai.org.cn/api/document"),
|
| 102 |
+
api_base="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions",
|
| 103 |
+
meta_template=[
|
| 104 |
+
dict(role="system", api_role="system"),
|
| 105 |
+
dict(role="user", api_role="user"),
|
| 106 |
+
dict(role="assistant", api_role="assistant"),
|
| 107 |
+
dict(role="environment", api_role="system"),
|
| 108 |
+
],
|
| 109 |
+
top_p=0.8,
|
| 110 |
+
top_k=1,
|
| 111 |
+
temperature=0,
|
| 112 |
+
max_new_tokens=8192,
|
| 113 |
+
repetition_penalty=1.02,
|
| 114 |
+
stop_words=["<|im_end|>"],
|
| 115 |
+
)
|
mindsearch/agent/streaming.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from typing import List, Union
|
| 3 |
+
|
| 4 |
+
from lagent.agents import Agent, AgentForInternLM, AsyncAgent, AsyncAgentForInternLM
|
| 5 |
+
from lagent.schema import AgentMessage, AgentStatusCode, ModelStatusCode
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class StreamingAgentMixin:
|
| 9 |
+
"""Make agent calling output a streaming response."""
|
| 10 |
+
|
| 11 |
+
def __call__(self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs):
|
| 12 |
+
for hook in self._hooks.values():
|
| 13 |
+
message = copy.deepcopy(message)
|
| 14 |
+
result = hook.before_agent(self, message, session_id)
|
| 15 |
+
if result:
|
| 16 |
+
message = result
|
| 17 |
+
self.update_memory(message, session_id=session_id)
|
| 18 |
+
response_message = AgentMessage(sender=self.name, content="")
|
| 19 |
+
for response_message in self.forward(*message, session_id=session_id, **kwargs):
|
| 20 |
+
if not isinstance(response_message, AgentMessage):
|
| 21 |
+
model_state, response = response_message
|
| 22 |
+
response_message = AgentMessage(
|
| 23 |
+
sender=self.name,
|
| 24 |
+
content=response,
|
| 25 |
+
stream_state=model_state,
|
| 26 |
+
)
|
| 27 |
+
yield response_message.model_copy()
|
| 28 |
+
self.update_memory(response_message, session_id=session_id)
|
| 29 |
+
for hook in self._hooks.values():
|
| 30 |
+
response_message = response_message.model_copy(deep=True)
|
| 31 |
+
result = hook.after_agent(self, response_message, session_id)
|
| 32 |
+
if result:
|
| 33 |
+
response_message = result
|
| 34 |
+
yield response_message
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AsyncStreamingAgentMixin:
|
| 38 |
+
"""Make asynchronous agent calling output a streaming response."""
|
| 39 |
+
|
| 40 |
+
async def __call__(
|
| 41 |
+
self, *message: Union[AgentMessage, List[AgentMessage]], session_id=0, **kwargs
|
| 42 |
+
):
|
| 43 |
+
for hook in self._hooks.values():
|
| 44 |
+
message = copy.deepcopy(message)
|
| 45 |
+
result = hook.before_agent(self, message, session_id)
|
| 46 |
+
if result:
|
| 47 |
+
message = result
|
| 48 |
+
self.update_memory(message, session_id=session_id)
|
| 49 |
+
response_message = AgentMessage(sender=self.name, content="")
|
| 50 |
+
async for response_message in self.forward(*message, session_id=session_id, **kwargs):
|
| 51 |
+
if not isinstance(response_message, AgentMessage):
|
| 52 |
+
model_state, response = response_message
|
| 53 |
+
response_message = AgentMessage(
|
| 54 |
+
sender=self.name,
|
| 55 |
+
content=response,
|
| 56 |
+
stream_state=model_state,
|
| 57 |
+
)
|
| 58 |
+
yield response_message.model_copy()
|
| 59 |
+
self.update_memory(response_message, session_id=session_id)
|
| 60 |
+
for hook in self._hooks.values():
|
| 61 |
+
response_message = response_message.model_copy(deep=True)
|
| 62 |
+
result = hook.after_agent(self, response_message, session_id)
|
| 63 |
+
if result:
|
| 64 |
+
response_message = result
|
| 65 |
+
yield response_message
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class StreamingAgent(StreamingAgentMixin, Agent):
|
| 69 |
+
"""Base streaming agent class"""
|
| 70 |
+
|
| 71 |
+
def forward(self, *message: AgentMessage, session_id=0, **kwargs):
|
| 72 |
+
formatted_messages = self.aggregator.aggregate(
|
| 73 |
+
self.memory.get(session_id),
|
| 74 |
+
self.name,
|
| 75 |
+
self.output_format,
|
| 76 |
+
self.template,
|
| 77 |
+
)
|
| 78 |
+
for model_state, response, _ in self.llm.stream_chat(
|
| 79 |
+
formatted_messages, session_id=session_id, **kwargs
|
| 80 |
+
):
|
| 81 |
+
yield AgentMessage(
|
| 82 |
+
sender=self.name,
|
| 83 |
+
content=response,
|
| 84 |
+
formatted=self.output_format.parse_response(response),
|
| 85 |
+
stream_state=model_state,
|
| 86 |
+
) if self.output_format else (model_state, response)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class AsyncStreamingAgent(AsyncStreamingAgentMixin, AsyncAgent):
|
| 90 |
+
"""Base asynchronous streaming agent class"""
|
| 91 |
+
|
| 92 |
+
async def forward(self, *message: AgentMessage, session_id=0, **kwargs):
|
| 93 |
+
formatted_messages = self.aggregator.aggregate(
|
| 94 |
+
self.memory.get(session_id),
|
| 95 |
+
self.name,
|
| 96 |
+
self.output_format,
|
| 97 |
+
self.template,
|
| 98 |
+
)
|
| 99 |
+
async for model_state, response, _ in self.llm.stream_chat(
|
| 100 |
+
formatted_messages, session_id=session_id, **kwargs
|
| 101 |
+
):
|
| 102 |
+
yield AgentMessage(
|
| 103 |
+
sender=self.name,
|
| 104 |
+
content=response,
|
| 105 |
+
formatted=self.output_format.parse_response(response),
|
| 106 |
+
stream_state=model_state,
|
| 107 |
+
) if self.output_format else (model_state, response)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class StreamingAgentForInternLM(StreamingAgentMixin, AgentForInternLM):
|
| 111 |
+
"""Streaming implementation of `lagent.agents.AgentForInternLM`"""
|
| 112 |
+
|
| 113 |
+
_INTERNAL_AGENT_CLS = StreamingAgent
|
| 114 |
+
|
| 115 |
+
def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
| 116 |
+
if isinstance(message, str):
|
| 117 |
+
message = AgentMessage(sender="user", content=message)
|
| 118 |
+
for _ in range(self.max_turn):
|
| 119 |
+
last_agent_state = AgentStatusCode.SESSION_READY
|
| 120 |
+
for message in self.agent(message, session_id=session_id, **kwargs):
|
| 121 |
+
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
|
| 122 |
+
if message.stream_state == ModelStatusCode.END:
|
| 123 |
+
message.stream_state = last_agent_state + int(
|
| 124 |
+
last_agent_state
|
| 125 |
+
in [
|
| 126 |
+
AgentStatusCode.CODING,
|
| 127 |
+
AgentStatusCode.PLUGIN_START,
|
| 128 |
+
]
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
message.stream_state = (
|
| 132 |
+
AgentStatusCode.PLUGIN_START
|
| 133 |
+
if message.formatted["tool_type"] == "plugin"
|
| 134 |
+
else AgentStatusCode.CODING
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 138 |
+
yield message
|
| 139 |
+
last_agent_state = message.stream_state
|
| 140 |
+
if self.finish_condition(message):
|
| 141 |
+
message.stream_state = AgentStatusCode.END
|
| 142 |
+
yield message
|
| 143 |
+
return
|
| 144 |
+
if message.formatted["tool_type"]:
|
| 145 |
+
tool_type = message.formatted["tool_type"]
|
| 146 |
+
executor = getattr(self, f"{tool_type}_executor", None)
|
| 147 |
+
if not executor:
|
| 148 |
+
raise RuntimeError(f"No available {tool_type} executor")
|
| 149 |
+
tool_return = executor(message, session_id=session_id)
|
| 150 |
+
tool_return.stream_state = message.stream_state + 1
|
| 151 |
+
message = tool_return
|
| 152 |
+
yield message
|
| 153 |
+
else:
|
| 154 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 155 |
+
yield message
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class AsyncStreamingAgentForInternLM(AsyncStreamingAgentMixin, AsyncAgentForInternLM):
|
| 159 |
+
"""Streaming implementation of `lagent.agents.AsyncAgentForInternLM`"""
|
| 160 |
+
|
| 161 |
+
_INTERNAL_AGENT_CLS = AsyncStreamingAgent
|
| 162 |
+
|
| 163 |
+
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
| 164 |
+
if isinstance(message, str):
|
| 165 |
+
message = AgentMessage(sender="user", content=message)
|
| 166 |
+
for _ in range(self.max_turn):
|
| 167 |
+
last_agent_state = AgentStatusCode.SESSION_READY
|
| 168 |
+
async for message in self.agent(message, session_id=session_id, **kwargs):
|
| 169 |
+
if isinstance(message.formatted, dict) and message.formatted.get("tool_type"):
|
| 170 |
+
if message.stream_state == ModelStatusCode.END:
|
| 171 |
+
message.stream_state = last_agent_state + int(
|
| 172 |
+
last_agent_state
|
| 173 |
+
in [
|
| 174 |
+
AgentStatusCode.CODING,
|
| 175 |
+
AgentStatusCode.PLUGIN_START,
|
| 176 |
+
]
|
| 177 |
+
)
|
| 178 |
+
else:
|
| 179 |
+
message.stream_state = (
|
| 180 |
+
AgentStatusCode.PLUGIN_START
|
| 181 |
+
if message.formatted["tool_type"] == "plugin"
|
| 182 |
+
else AgentStatusCode.CODING
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 186 |
+
yield message
|
| 187 |
+
last_agent_state = message.stream_state
|
| 188 |
+
if self.finish_condition(message):
|
| 189 |
+
message.stream_state = AgentStatusCode.END
|
| 190 |
+
yield message
|
| 191 |
+
return
|
| 192 |
+
if message.formatted["tool_type"]:
|
| 193 |
+
tool_type = message.formatted["tool_type"]
|
| 194 |
+
executor = getattr(self, f"{tool_type}_executor", None)
|
| 195 |
+
if not executor:
|
| 196 |
+
raise RuntimeError(f"No available {tool_type} executor")
|
| 197 |
+
tool_return = await executor(message, session_id=session_id)
|
| 198 |
+
tool_return.stream_state = message.stream_state + 1
|
| 199 |
+
message = tool_return
|
| 200 |
+
yield message
|
| 201 |
+
else:
|
| 202 |
+
message.stream_state = AgentStatusCode.STREAM_ING
|
| 203 |
+
yield message
|
mindsearch/app.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import random
|
| 5 |
+
from typing import Dict, List, Union
|
| 6 |
+
|
| 7 |
+
import janus
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.requests import Request
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
from sse_starlette.sse import EventSourceResponse
|
| 13 |
+
|
| 14 |
+
from mindsearch.agent import init_agent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def parse_arguments():
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
parser = argparse.ArgumentParser(description="MindSearch API")
|
| 21 |
+
parser.add_argument("--host", default="0.0.0.0", type=str, help="Service host")
|
| 22 |
+
parser.add_argument("--port", default=8002, type=int, help="Service port")
|
| 23 |
+
parser.add_argument("--lang", default="cn", type=str, help="Language")
|
| 24 |
+
parser.add_argument("--model_format", default="internlm_server", type=str, help="Model format")
|
| 25 |
+
parser.add_argument("--search_engine", default="BingSearch", type=str, help="Search engine")
|
| 26 |
+
parser.add_argument("--asy", default=False, action="store_true", help="Agent mode")
|
| 27 |
+
return parser.parse_args()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
args = parse_arguments()
|
| 31 |
+
app = FastAPI(docs_url="/")
|
| 32 |
+
app.add_middleware(
|
| 33 |
+
CORSMiddleware,
|
| 34 |
+
allow_origins=["*"],
|
| 35 |
+
allow_credentials=True,
|
| 36 |
+
allow_methods=["*"],
|
| 37 |
+
allow_headers=["*"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class GenerationParams(BaseModel):
|
| 42 |
+
inputs: Union[str, List[Dict]]
|
| 43 |
+
session_id: int = Field(default_factory=lambda: random.randint(0, 999999))
|
| 44 |
+
agent_cfg: Dict = dict()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _postprocess_agent_message(message: dict) -> dict:
|
| 48 |
+
content, fmt = message["content"], message["formatted"]
|
| 49 |
+
current_node = content["current_node"] if isinstance(content, dict) else None
|
| 50 |
+
if current_node:
|
| 51 |
+
message["content"] = None
|
| 52 |
+
for key in ["ref2url"]:
|
| 53 |
+
fmt.pop(key, None)
|
| 54 |
+
graph = fmt["node"]
|
| 55 |
+
for key in graph.copy():
|
| 56 |
+
if key != current_node:
|
| 57 |
+
graph.pop(key)
|
| 58 |
+
if current_node not in ["root", "response"]:
|
| 59 |
+
node = graph[current_node]
|
| 60 |
+
for key in ["memory", "session_id"]:
|
| 61 |
+
node.pop(key, None)
|
| 62 |
+
node_fmt = node["response"]["formatted"]
|
| 63 |
+
if isinstance(node_fmt, dict) and "thought" in node_fmt and "action" in node_fmt:
|
| 64 |
+
node["response"]["content"] = None
|
| 65 |
+
node_fmt["thought"] = (
|
| 66 |
+
node_fmt["thought"] and node_fmt["thought"].split("<|action_start|>")[0]
|
| 67 |
+
)
|
| 68 |
+
if isinstance(node_fmt["action"], str):
|
| 69 |
+
node_fmt["action"] = node_fmt["action"].split("<|action_end|>")[0]
|
| 70 |
+
else:
|
| 71 |
+
if isinstance(fmt, dict) and "thought" in fmt and "action" in fmt:
|
| 72 |
+
message["content"] = None
|
| 73 |
+
fmt["thought"] = fmt["thought"] and fmt["thought"].split("<|action_start|>")[0]
|
| 74 |
+
if isinstance(fmt["action"], str):
|
| 75 |
+
fmt["action"] = fmt["action"].split("<|action_end|>")[0]
|
| 76 |
+
for key in ["node"]:
|
| 77 |
+
fmt.pop(key, None)
|
| 78 |
+
return dict(current_node=current_node, response=message)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
async def run(request: GenerationParams, _request: Request):
|
| 82 |
+
async def generate():
|
| 83 |
+
try:
|
| 84 |
+
queue = janus.Queue()
|
| 85 |
+
stop_event = asyncio.Event()
|
| 86 |
+
|
| 87 |
+
# Wrapping a sync generator as an async generator using run_in_executor
|
| 88 |
+
def sync_generator_wrapper():
|
| 89 |
+
try:
|
| 90 |
+
for response in agent(inputs, session_id=session_id):
|
| 91 |
+
queue.sync_q.put(response)
|
| 92 |
+
except Exception as e:
|
| 93 |
+
logging.exception(f"Exception in sync_generator_wrapper: {e}")
|
| 94 |
+
finally:
|
| 95 |
+
# Notify async_generator_wrapper that the data generation is complete.
|
| 96 |
+
queue.sync_q.put(None)
|
| 97 |
+
|
| 98 |
+
async def async_generator_wrapper():
|
| 99 |
+
loop = asyncio.get_event_loop()
|
| 100 |
+
loop.run_in_executor(None, sync_generator_wrapper)
|
| 101 |
+
while True:
|
| 102 |
+
response = await queue.async_q.get()
|
| 103 |
+
if response is None: # Ensure that all elements are consumed
|
| 104 |
+
break
|
| 105 |
+
yield response
|
| 106 |
+
stop_event.set() # Inform sync_generator_wrapper to stop
|
| 107 |
+
|
| 108 |
+
async for message in async_generator_wrapper():
|
| 109 |
+
response_json = json.dumps(
|
| 110 |
+
_postprocess_agent_message(message.model_dump()),
|
| 111 |
+
ensure_ascii=False,
|
| 112 |
+
)
|
| 113 |
+
yield {"data": response_json}
|
| 114 |
+
if await _request.is_disconnected():
|
| 115 |
+
break
|
| 116 |
+
except Exception as exc:
|
| 117 |
+
msg = "An error occurred while generating the response."
|
| 118 |
+
logging.exception(msg)
|
| 119 |
+
response_json = json.dumps(
|
| 120 |
+
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
|
| 121 |
+
)
|
| 122 |
+
yield {"data": response_json}
|
| 123 |
+
finally:
|
| 124 |
+
await stop_event.wait() # Waiting for async_generator_wrapper to stop
|
| 125 |
+
queue.close()
|
| 126 |
+
await queue.wait_closed()
|
| 127 |
+
agent.agent.memory.memory_map.pop(session_id, None)
|
| 128 |
+
|
| 129 |
+
inputs = request.inputs
|
| 130 |
+
session_id = request.session_id
|
| 131 |
+
agent = init_agent(
|
| 132 |
+
lang=args.lang,
|
| 133 |
+
model_format=args.model_format,
|
| 134 |
+
search_engine=args.search_engine,
|
| 135 |
+
)
|
| 136 |
+
return EventSourceResponse(generate(), ping=300)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
async def run_async(request: GenerationParams, _request: Request):
|
| 140 |
+
async def generate():
|
| 141 |
+
try:
|
| 142 |
+
async for message in agent(inputs, session_id=session_id):
|
| 143 |
+
response_json = json.dumps(
|
| 144 |
+
_postprocess_agent_message(message.model_dump()),
|
| 145 |
+
ensure_ascii=False,
|
| 146 |
+
)
|
| 147 |
+
yield {"data": response_json}
|
| 148 |
+
if await _request.is_disconnected():
|
| 149 |
+
break
|
| 150 |
+
except Exception as exc:
|
| 151 |
+
msg = "An error occurred while generating the response."
|
| 152 |
+
logging.exception(msg)
|
| 153 |
+
response_json = json.dumps(
|
| 154 |
+
dict(error=dict(msg=msg, details=str(exc))), ensure_ascii=False
|
| 155 |
+
)
|
| 156 |
+
yield {"data": response_json}
|
| 157 |
+
finally:
|
| 158 |
+
agent.agent.memory.memory_map.pop(session_id, None)
|
| 159 |
+
|
| 160 |
+
inputs = request.inputs
|
| 161 |
+
session_id = request.session_id
|
| 162 |
+
agent = init_agent(
|
| 163 |
+
lang=args.lang,
|
| 164 |
+
model_format=args.model_format,
|
| 165 |
+
search_engine=args.search_engine,
|
| 166 |
+
use_async=True,
|
| 167 |
+
)
|
| 168 |
+
return EventSourceResponse(generate(), ping=300)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
app.add_api_route("/solve", run_async if args.asy else run, methods=["POST"])
|
| 172 |
+
|
| 173 |
+
if __name__ == "__main__":
|
| 174 |
+
import uvicorn
|
| 175 |
+
|
| 176 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
mindsearch/terminal.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
|
| 5 |
+
from lagent.actions import WebBrowser
|
| 6 |
+
from lagent.agents.stream import get_plugin_prompt
|
| 7 |
+
from lagent.llms import INTERNLM2_META, LMDeployServer
|
| 8 |
+
from lagent.prompts import InterpreterParser, PluginParser
|
| 9 |
+
|
| 10 |
+
from mindsearch.agent.mindsearch_agent import MindSearchAgent
|
| 11 |
+
from mindsearch.agent.mindsearch_prompt import (
|
| 12 |
+
FINAL_RESPONSE_CN,
|
| 13 |
+
FINAL_RESPONSE_EN,
|
| 14 |
+
GRAPH_PROMPT_CN,
|
| 15 |
+
GRAPH_PROMPT_EN,
|
| 16 |
+
searcher_context_template_cn,
|
| 17 |
+
searcher_context_template_en,
|
| 18 |
+
searcher_input_template_cn,
|
| 19 |
+
searcher_input_template_en,
|
| 20 |
+
searcher_system_prompt_cn,
|
| 21 |
+
searcher_system_prompt_en,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
lang = "cn"
|
| 25 |
+
date = datetime.now().strftime("The current date is %Y-%m-%d.")
|
| 26 |
+
llm = LMDeployServer(
|
| 27 |
+
path="internlm/internlm2_5-7b-chat",
|
| 28 |
+
model_name="internlm2",
|
| 29 |
+
meta_template=INTERNLM2_META,
|
| 30 |
+
top_p=0.8,
|
| 31 |
+
top_k=1,
|
| 32 |
+
temperature=1.0,
|
| 33 |
+
max_new_tokens=8192,
|
| 34 |
+
repetition_penalty=1.02,
|
| 35 |
+
stop_words=["<|im_end|>", "<|action_end|>"],
|
| 36 |
+
)
|
| 37 |
+
plugins = [WebBrowser(searcher_type="BingSearch", topk=6)]
|
| 38 |
+
agent = MindSearchAgent(
|
| 39 |
+
llm=llm,
|
| 40 |
+
template=date,
|
| 41 |
+
output_format=InterpreterParser(template=GRAPH_PROMPT_CN if lang == "cn" else GRAPH_PROMPT_EN),
|
| 42 |
+
searcher_cfg=dict(
|
| 43 |
+
llm=llm,
|
| 44 |
+
plugins=plugins,
|
| 45 |
+
template=date,
|
| 46 |
+
output_format=PluginParser(
|
| 47 |
+
template=searcher_system_prompt_cn if lang == "cn" else searcher_system_prompt_en,
|
| 48 |
+
tool_info=get_plugin_prompt(plugins),
|
| 49 |
+
),
|
| 50 |
+
user_input_template=searcher_input_template_cn
|
| 51 |
+
if lang == "cn"
|
| 52 |
+
else searcher_input_template_en,
|
| 53 |
+
user_context_template=searcher_context_template_cn
|
| 54 |
+
if lang == "cn"
|
| 55 |
+
else searcher_context_template_en,
|
| 56 |
+
),
|
| 57 |
+
summary_prompt=FINAL_RESPONSE_CN if lang == "cn" else FINAL_RESPONSE_EN,
|
| 58 |
+
max_turn=10,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
for agent_return in agent("上海今天适合穿什么衣服"):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
print(agent_return.sender)
|
| 65 |
+
print(agent_return.content)
|
| 66 |
+
print(agent_return.formatted["ref2url"])
|
requirements.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
duckduckgo_search==5.3.1b1
|
| 2 |
+
einops
|
| 3 |
+
fastapi
|
| 4 |
+
gradio==5.7.1
|
| 5 |
+
janus
|
| 6 |
+
lagent==0.5.0rc2
|
| 7 |
+
matplotlib
|
| 8 |
+
pydantic==2.6.4
|
| 9 |
+
python-dotenv
|
| 10 |
+
pyvis
|
| 11 |
+
schemdraw
|
| 12 |
+
sse-starlette
|
| 13 |
+
termcolor
|
| 14 |
+
transformers==4.41.0
|
| 15 |
+
uvicorn
|
| 16 |
+
tenacity
|
| 17 |
+
streamlit
|
| 18 |
+
git+https://github.com/vansin/ms_gradio_agentchatbot.git
|