Spaces:
Running
Running
DESKTOP-P3A1PV5\inShine
commited on
Commit
·
a5ab2ca
1
Parent(s):
e3846de
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +196 -0
- examples/agent_api_web_demo.py +196 -0
- examples/multi_agents_api_web_demo.py +198 -0
- lagent/__init__.py +4 -0
- lagent/actions/__init__.py +26 -0
- lagent/actions/action_executor.py +198 -0
- lagent/actions/arxiv_search.py +79 -0
- lagent/actions/base_action.py +434 -0
- lagent/actions/bing_map.py +268 -0
- lagent/actions/builtin_actions.py +109 -0
- lagent/actions/google_scholar_search.py +438 -0
- lagent/actions/google_search.py +244 -0
- lagent/actions/ipython_interactive.py +273 -0
- lagent/actions/ipython_interpreter.py +584 -0
- lagent/actions/ipython_manager.py +220 -0
- lagent/actions/parser.py +146 -0
- lagent/actions/ppt.py +233 -0
- lagent/actions/python_interpreter.py +176 -0
- lagent/actions/weather_query.py +71 -0
- lagent/actions/web_browser.py +908 -0
- lagent/agents/__init__.py +9 -0
- lagent/agents/agent.py +400 -0
- lagent/agents/aggregator/__init__.py +4 -0
- lagent/agents/aggregator/default_aggregator.py +44 -0
- lagent/agents/aggregator/tool_aggregator.py +106 -0
- lagent/agents/react.py +161 -0
- lagent/agents/stream.py +316 -0
- lagent/distributed/__init__.py +8 -0
- lagent/distributed/http_serve/__init__.py +7 -0
- lagent/distributed/http_serve/api_server.py +123 -0
- lagent/distributed/http_serve/app.py +96 -0
- lagent/distributed/ray_serve/__init__.py +3 -0
- lagent/distributed/ray_serve/ray_warpper.py +48 -0
- lagent/hooks/__init__.py +8 -0
- lagent/hooks/action_preprocessor.py +62 -0
- lagent/hooks/hook.py +50 -0
- lagent/hooks/logger.py +37 -0
- lagent/llms/__init__.py +32 -0
- lagent/llms/base_api.py +175 -0
- lagent/llms/base_llm.py +305 -0
- lagent/llms/huggingface.py +337 -0
- lagent/llms/lmdeploy_wrapper.py +790 -0
- lagent/llms/meta_template.py +40 -0
- lagent/llms/openai.py +924 -0
- lagent/llms/sensenova.py +406 -0
- lagent/llms/vllm_wrapper.py +176 -0
- lagent/memory/__init__.py +4 -0
- lagent/memory/base_memory.py +60 -0
- lagent/memory/manager.py +29 -0
- lagent/prompts/__init__.py +4 -0
app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
import streamlit as st
|
5 |
+
from lagent.actions import ArxivSearch, WeatherQuery
|
6 |
+
from lagent.prompts.parsers import PluginParser
|
7 |
+
from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
|
8 |
+
from lagent.llms import GPTAPI
|
9 |
+
|
10 |
+
class SessionState:
|
11 |
+
"""管理会话状态的类。"""
|
12 |
+
|
13 |
+
def init_state(self):
|
14 |
+
"""初始化会话状态变量。"""
|
15 |
+
st.session_state['assistant'] = [] # 助手消息历史
|
16 |
+
st.session_state['user'] = [] # 用户消息历史
|
17 |
+
# 初始化插件列表
|
18 |
+
action_list = [
|
19 |
+
ArxivSearch(),
|
20 |
+
WeatherQuery()
|
21 |
+
]
|
22 |
+
st.session_state['plugin_map'] = {action.name: action for action in action_list}
|
23 |
+
st.session_state['model_map'] = {} # 存储模型实例
|
24 |
+
st.session_state['model_selected'] = None # 当前选定模型
|
25 |
+
st.session_state['plugin_actions'] = set() # 当前激活插件
|
26 |
+
st.session_state['history'] = [] # 聊天历史
|
27 |
+
st.session_state['api_base'] = None # 初始化API base地址
|
28 |
+
|
29 |
+
def clear_state(self):
|
30 |
+
"""清除当前会话状态。"""
|
31 |
+
st.session_state['assistant'] = []
|
32 |
+
st.session_state['user'] = []
|
33 |
+
st.session_state['model_selected'] = None
|
34 |
+
|
35 |
+
|
36 |
+
class StreamlitUI:
|
37 |
+
"""管理 Streamlit 界面的类。"""
|
38 |
+
|
39 |
+
def __init__(self, session_state: SessionState):
|
40 |
+
self.session_state = session_state
|
41 |
+
self.plugin_action = [] # 当前选定的插件
|
42 |
+
# 初始化提示词
|
43 |
+
self.meta_prompt = META_CN
|
44 |
+
self.plugin_prompt = PLUGIN_CN
|
45 |
+
self.init_streamlit()
|
46 |
+
|
47 |
+
def init_streamlit(self):
|
48 |
+
"""初始化 Streamlit 的 UI 设置。"""
|
49 |
+
st.set_page_config(
|
50 |
+
layout='wide',
|
51 |
+
page_title='lagent-web',
|
52 |
+
page_icon='./docs/imgs/lagent_icon.png'
|
53 |
+
)
|
54 |
+
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
55 |
+
|
56 |
+
def setup_sidebar(self):
|
57 |
+
"""设置侧边栏,选择模型和插件。"""
|
58 |
+
# 模型名称和 API Base 输入框
|
59 |
+
model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
|
60 |
+
|
61 |
+
# ================================== 硅基流动的API ==================================
|
62 |
+
# 注意,如果采用硅基流动API,模型名称需要更改为:internlm/internlm2_5-7b-chat 或者 internlm/internlm2_5-20b-chat
|
63 |
+
# api_base = st.sidebar.text_input(
|
64 |
+
# 'API Base 地址:', value='https://api.siliconflow.cn/v1/chat/completions'
|
65 |
+
# )
|
66 |
+
# ================================== 浦语官方的API ==================================
|
67 |
+
api_base = st.sidebar.text_input(
|
68 |
+
'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
|
69 |
+
)
|
70 |
+
# ==================================================================================
|
71 |
+
# 插件选择
|
72 |
+
plugin_name = st.sidebar.multiselect(
|
73 |
+
'插件选择',
|
74 |
+
options=list(st.session_state['plugin_map'].keys()),
|
75 |
+
default=[],
|
76 |
+
)
|
77 |
+
|
78 |
+
# 根据选择的插件生成插件操作列表
|
79 |
+
self.plugin_action = [st.session_state['plugin_map'][name] for name in plugin_name]
|
80 |
+
|
81 |
+
# 动态生成插件提示
|
82 |
+
if self.plugin_action:
|
83 |
+
self.plugin_prompt = get_plugin_prompt(self.plugin_action)
|
84 |
+
|
85 |
+
# 清空对话按钮
|
86 |
+
if st.sidebar.button('清空对话', key='clear'):
|
87 |
+
self.session_state.clear_state()
|
88 |
+
|
89 |
+
return model_name, api_base, self.plugin_action
|
90 |
+
|
91 |
+
def initialize_chatbot(self, model_name, api_base, plugin_action):
|
92 |
+
"""初始化 GPTAPI 实例作为 chatbot。"""
|
93 |
+
token = os.getenv("INTERNLM_API_KEY")
|
94 |
+
if not token:
|
95 |
+
st.error("未检测到环境变量 `token`,请设置环境变量,例如 `export token='your_token_here'` 后重新运行 X﹏X")
|
96 |
+
st.stop() # 停止运行应用
|
97 |
+
|
98 |
+
# 创建完整的 meta_prompt,保留原始结构并动态插入侧边栏配置
|
99 |
+
meta_prompt = [
|
100 |
+
{"role": "system", "content": self.meta_prompt, "api_role": "system"},
|
101 |
+
{"role": "user", "content": "", "api_role": "user"},
|
102 |
+
{"role": "assistant", "content": self.plugin_prompt, "api_role": "assistant"},
|
103 |
+
{"role": "environment", "content": "", "api_role": "environment"}
|
104 |
+
]
|
105 |
+
|
106 |
+
api_model = GPTAPI(
|
107 |
+
model_type=model_name,
|
108 |
+
api_base=api_base,
|
109 |
+
key=token, # 从环境变量中获取授权令牌
|
110 |
+
meta_template=meta_prompt,
|
111 |
+
max_new_tokens=512,
|
112 |
+
temperature=0.8,
|
113 |
+
top_p=0.9
|
114 |
+
)
|
115 |
+
return api_model
|
116 |
+
|
117 |
+
def render_user(self, prompt: str):
|
118 |
+
"""渲染用户输入内容。"""
|
119 |
+
with st.chat_message('user'):
|
120 |
+
st.markdown(prompt)
|
121 |
+
|
122 |
+
def render_assistant(self, agent_return):
|
123 |
+
"""渲染助手响应内容。"""
|
124 |
+
with st.chat_message('assistant'):
|
125 |
+
content = getattr(agent_return, "content", str(agent_return))
|
126 |
+
st.markdown(content if isinstance(content, str) else str(content))
|
127 |
+
|
128 |
+
|
129 |
+
def main():
|
130 |
+
"""主函数,运行 Streamlit 应用。"""
|
131 |
+
if 'ui' not in st.session_state:
|
132 |
+
session_state = SessionState()
|
133 |
+
session_state.init_state()
|
134 |
+
st.session_state['ui'] = StreamlitUI(session_state)
|
135 |
+
else:
|
136 |
+
st.set_page_config(
|
137 |
+
layout='wide',
|
138 |
+
page_title='lagent-web',
|
139 |
+
page_icon='./docs/imgs/lagent_icon.png'
|
140 |
+
)
|
141 |
+
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
142 |
+
|
143 |
+
# 设置侧边栏并获取模型和插件信息
|
144 |
+
model_name, api_base, plugin_action = st.session_state['ui'].setup_sidebar()
|
145 |
+
plugins = [dict(type=f"lagent.actions.{plugin.__class__.__name__}") for plugin in plugin_action]
|
146 |
+
|
147 |
+
if (
|
148 |
+
'chatbot' not in st.session_state or
|
149 |
+
model_name != st.session_state['chatbot'].model_type or
|
150 |
+
'last_plugin_action' not in st.session_state or
|
151 |
+
plugin_action != st.session_state['last_plugin_action'] or
|
152 |
+
api_base != st.session_state['api_base']
|
153 |
+
):
|
154 |
+
# 更新 Chatbot
|
155 |
+
st.session_state['chatbot'] = st.session_state['ui'].initialize_chatbot(model_name, api_base, plugin_action)
|
156 |
+
st.session_state['last_plugin_action'] = plugin_action # 更新插件状态
|
157 |
+
st.session_state['api_base'] = api_base # 更新 API Base 地址
|
158 |
+
|
159 |
+
# 初始化 AgentForInternLM
|
160 |
+
st.session_state['agent'] = AgentForInternLM(
|
161 |
+
llm=st.session_state['chatbot'],
|
162 |
+
plugins=plugins,
|
163 |
+
output_format=dict(
|
164 |
+
type=PluginParser,
|
165 |
+
template=PLUGIN_CN,
|
166 |
+
prompt=get_plugin_prompt(plugin_action)
|
167 |
+
)
|
168 |
+
)
|
169 |
+
# 清空对话历史
|
170 |
+
st.session_state['session_history'] = []
|
171 |
+
|
172 |
+
if 'agent' not in st.session_state:
|
173 |
+
st.session_state['agent'] = None
|
174 |
+
|
175 |
+
agent = st.session_state['agent']
|
176 |
+
for prompt, agent_return in zip(st.session_state['user'], st.session_state['assistant']):
|
177 |
+
st.session_state['ui'].render_user(prompt)
|
178 |
+
st.session_state['ui'].render_assistant(agent_return)
|
179 |
+
|
180 |
+
# 处理用户输入
|
181 |
+
if user_input := st.chat_input(''):
|
182 |
+
st.session_state['ui'].render_user(user_input)
|
183 |
+
|
184 |
+
# 调用模型时确保侧边栏的系统提示词和插件提示词生效
|
185 |
+
res = agent(user_input, session_id=0)
|
186 |
+
st.session_state['ui'].render_assistant(res)
|
187 |
+
|
188 |
+
# 更新会话状态
|
189 |
+
st.session_state['user'].append(user_input)
|
190 |
+
st.session_state['assistant'].append(copy.deepcopy(res))
|
191 |
+
|
192 |
+
st.session_state['last_status'] = None
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
main()
|
examples/agent_api_web_demo.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import os
|
3 |
+
from typing import List
|
4 |
+
import streamlit as st
|
5 |
+
from lagent.actions import ArxivSearch, WeatherQuery
|
6 |
+
from lagent.prompts.parsers import PluginParser
|
7 |
+
from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
|
8 |
+
from lagent.llms import GPTAPI
|
9 |
+
|
10 |
+
class SessionState:
|
11 |
+
"""管理会话状态的类。"""
|
12 |
+
|
13 |
+
def init_state(self):
|
14 |
+
"""初始化会话状态变量。"""
|
15 |
+
st.session_state['assistant'] = [] # 助手消息历史
|
16 |
+
st.session_state['user'] = [] # 用户消息历史
|
17 |
+
# 初始化插件列表
|
18 |
+
action_list = [
|
19 |
+
ArxivSearch(),
|
20 |
+
WeatherQuery()
|
21 |
+
]
|
22 |
+
st.session_state['plugin_map'] = {action.name: action for action in action_list}
|
23 |
+
st.session_state['model_map'] = {} # 存储模型实例
|
24 |
+
st.session_state['model_selected'] = None # 当前选定模型
|
25 |
+
st.session_state['plugin_actions'] = set() # 当前激活插件
|
26 |
+
st.session_state['history'] = [] # 聊天历史
|
27 |
+
st.session_state['api_base'] = None # 初始化API base地址
|
28 |
+
|
29 |
+
def clear_state(self):
|
30 |
+
"""清除当前会话状态。"""
|
31 |
+
st.session_state['assistant'] = []
|
32 |
+
st.session_state['user'] = []
|
33 |
+
st.session_state['model_selected'] = None
|
34 |
+
|
35 |
+
|
36 |
+
class StreamlitUI:
|
37 |
+
"""管理 Streamlit 界面的类。"""
|
38 |
+
|
39 |
+
def __init__(self, session_state: SessionState):
|
40 |
+
self.session_state = session_state
|
41 |
+
self.plugin_action = [] # 当前选定的插件
|
42 |
+
# 初始化提示词
|
43 |
+
self.meta_prompt = META_CN
|
44 |
+
self.plugin_prompt = PLUGIN_CN
|
45 |
+
self.init_streamlit()
|
46 |
+
|
47 |
+
def init_streamlit(self):
|
48 |
+
"""初始化 Streamlit 的 UI 设置。"""
|
49 |
+
st.set_page_config(
|
50 |
+
layout='wide',
|
51 |
+
page_title='lagent-web',
|
52 |
+
page_icon='./docs/imgs/lagent_icon.png'
|
53 |
+
)
|
54 |
+
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
55 |
+
|
56 |
+
def setup_sidebar(self):
|
57 |
+
"""设置侧边栏,选择模型和插件。"""
|
58 |
+
# 模型名称和 API Base 输入框
|
59 |
+
model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
|
60 |
+
|
61 |
+
# ================================== 硅基流动的API ==================================
|
62 |
+
# 注意,如果采用硅基流动API,模型名称需要更改为:internlm/internlm2_5-7b-chat 或者 internlm/internlm2_5-20b-chat
|
63 |
+
# api_base = st.sidebar.text_input(
|
64 |
+
# 'API Base 地址:', value='https://api.siliconflow.cn/v1/chat/completions'
|
65 |
+
# )
|
66 |
+
# ================================== 浦语官方的API ==================================
|
67 |
+
api_base = st.sidebar.text_input(
|
68 |
+
'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
|
69 |
+
)
|
70 |
+
# ==================================================================================
|
71 |
+
# 插件选择
|
72 |
+
plugin_name = st.sidebar.multiselect(
|
73 |
+
'插件选择',
|
74 |
+
options=list(st.session_state['plugin_map'].keys()),
|
75 |
+
default=[],
|
76 |
+
)
|
77 |
+
|
78 |
+
# 根据选择的插件生成插件操作列表
|
79 |
+
self.plugin_action = [st.session_state['plugin_map'][name] for name in plugin_name]
|
80 |
+
|
81 |
+
# 动态生成插件提示
|
82 |
+
if self.plugin_action:
|
83 |
+
self.plugin_prompt = get_plugin_prompt(self.plugin_action)
|
84 |
+
|
85 |
+
# 清空对话按钮
|
86 |
+
if st.sidebar.button('清空对话', key='clear'):
|
87 |
+
self.session_state.clear_state()
|
88 |
+
|
89 |
+
return model_name, api_base, self.plugin_action
|
90 |
+
|
91 |
+
def initialize_chatbot(self, model_name, api_base, plugin_action):
|
92 |
+
"""初始化 GPTAPI 实例作为 chatbot。"""
|
93 |
+
token = os.getenv("INTERNLM_API_KEY")
|
94 |
+
if not token:
|
95 |
+
st.error("未检测到环境变量 `token`,请设置环境变量,例如 `export token='your_token_here'` 后重新运行 X﹏X")
|
96 |
+
st.stop() # 停止运行应用
|
97 |
+
|
98 |
+
# 创建完整的 meta_prompt,保留原始结构并动态插入侧边栏配置
|
99 |
+
meta_prompt = [
|
100 |
+
{"role": "system", "content": self.meta_prompt, "api_role": "system"},
|
101 |
+
{"role": "user", "content": "", "api_role": "user"},
|
102 |
+
{"role": "assistant", "content": self.plugin_prompt, "api_role": "assistant"},
|
103 |
+
{"role": "environment", "content": "", "api_role": "environment"}
|
104 |
+
]
|
105 |
+
|
106 |
+
api_model = GPTAPI(
|
107 |
+
model_type=model_name,
|
108 |
+
api_base=api_base,
|
109 |
+
key=token, # 从环境变量中获取授权令牌
|
110 |
+
meta_template=meta_prompt,
|
111 |
+
max_new_tokens=512,
|
112 |
+
temperature=0.8,
|
113 |
+
top_p=0.9
|
114 |
+
)
|
115 |
+
return api_model
|
116 |
+
|
117 |
+
def render_user(self, prompt: str):
|
118 |
+
"""渲染用户输入内容。"""
|
119 |
+
with st.chat_message('user'):
|
120 |
+
st.markdown(prompt)
|
121 |
+
|
122 |
+
def render_assistant(self, agent_return):
|
123 |
+
"""渲染助手响应内容。"""
|
124 |
+
with st.chat_message('assistant'):
|
125 |
+
content = getattr(agent_return, "content", str(agent_return))
|
126 |
+
st.markdown(content if isinstance(content, str) else str(content))
|
127 |
+
|
128 |
+
|
129 |
+
def main():
|
130 |
+
"""主函数,运行 Streamlit 应用。"""
|
131 |
+
if 'ui' not in st.session_state:
|
132 |
+
session_state = SessionState()
|
133 |
+
session_state.init_state()
|
134 |
+
st.session_state['ui'] = StreamlitUI(session_state)
|
135 |
+
else:
|
136 |
+
st.set_page_config(
|
137 |
+
layout='wide',
|
138 |
+
page_title='lagent-web',
|
139 |
+
page_icon='./docs/imgs/lagent_icon.png'
|
140 |
+
)
|
141 |
+
st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
|
142 |
+
|
143 |
+
# 设置侧边栏并获取模型和插件信息
|
144 |
+
model_name, api_base, plugin_action = st.session_state['ui'].setup_sidebar()
|
145 |
+
plugins = [dict(type=f"lagent.actions.{plugin.__class__.__name__}") for plugin in plugin_action]
|
146 |
+
|
147 |
+
if (
|
148 |
+
'chatbot' not in st.session_state or
|
149 |
+
model_name != st.session_state['chatbot'].model_type or
|
150 |
+
'last_plugin_action' not in st.session_state or
|
151 |
+
plugin_action != st.session_state['last_plugin_action'] or
|
152 |
+
api_base != st.session_state['api_base']
|
153 |
+
):
|
154 |
+
# 更新 Chatbot
|
155 |
+
st.session_state['chatbot'] = st.session_state['ui'].initialize_chatbot(model_name, api_base, plugin_action)
|
156 |
+
st.session_state['last_plugin_action'] = plugin_action # 更新插件状态
|
157 |
+
st.session_state['api_base'] = api_base # 更新 API Base 地址
|
158 |
+
|
159 |
+
# 初始化 AgentForInternLM
|
160 |
+
st.session_state['agent'] = AgentForInternLM(
|
161 |
+
llm=st.session_state['chatbot'],
|
162 |
+
plugins=plugins,
|
163 |
+
output_format=dict(
|
164 |
+
type=PluginParser,
|
165 |
+
template=PLUGIN_CN,
|
166 |
+
prompt=get_plugin_prompt(plugin_action)
|
167 |
+
)
|
168 |
+
)
|
169 |
+
# 清空对话历史
|
170 |
+
st.session_state['session_history'] = []
|
171 |
+
|
172 |
+
if 'agent' not in st.session_state:
|
173 |
+
st.session_state['agent'] = None
|
174 |
+
|
175 |
+
agent = st.session_state['agent']
|
176 |
+
for prompt, agent_return in zip(st.session_state['user'], st.session_state['assistant']):
|
177 |
+
st.session_state['ui'].render_user(prompt)
|
178 |
+
st.session_state['ui'].render_assistant(agent_return)
|
179 |
+
|
180 |
+
# 处理用户输入
|
181 |
+
if user_input := st.chat_input(''):
|
182 |
+
st.session_state['ui'].render_user(user_input)
|
183 |
+
|
184 |
+
# 调用模型时确保侧边栏的系统提示词和插件提示词生效
|
185 |
+
res = agent(user_input, session_id=0)
|
186 |
+
st.session_state['ui'].render_assistant(res)
|
187 |
+
|
188 |
+
# 更新会话状态
|
189 |
+
st.session_state['user'].append(user_input)
|
190 |
+
st.session_state['assistant'].append(copy.deepcopy(res))
|
191 |
+
|
192 |
+
st.session_state['last_status'] = None
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
main()
|
examples/multi_agents_api_web_demo.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import requests
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
from lagent.agents import Agent
|
9 |
+
from lagent.prompts.parsers import PluginParser
|
10 |
+
from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt
|
11 |
+
from lagent.schema import AgentMessage
|
12 |
+
from lagent.actions import ArxivSearch
|
13 |
+
from lagent.hooks import Hook
|
14 |
+
from lagent.llms import GPTAPI
|
15 |
+
|
16 |
+
YOUR_TOKEN_HERE = os.getenv("INTERNLM_API_KEY")
|
17 |
+
if not YOUR_TOKEN_HERE:
|
18 |
+
raise EnvironmentError("未找到环境变量 'token',请设置后再运行程序。")
|
19 |
+
|
20 |
+
# Hook类,用于对消息添加前缀
|
21 |
+
class PrefixedMessageHook(Hook):
|
22 |
+
def __init__(self, prefix, senders=None):
|
23 |
+
"""
|
24 |
+
初始化Hook
|
25 |
+
:param prefix: 消息前缀
|
26 |
+
:param senders: 指定发送者列表
|
27 |
+
"""
|
28 |
+
self.prefix = prefix
|
29 |
+
self.senders = senders or []
|
30 |
+
|
31 |
+
def before_agent(self, agent, messages, session_id):
|
32 |
+
"""
|
33 |
+
在代理处理消息前修改消息内容
|
34 |
+
:param agent: 当前代理
|
35 |
+
:param messages: 消息列表
|
36 |
+
:param session_id: 会话ID
|
37 |
+
"""
|
38 |
+
for message in messages:
|
39 |
+
if message.sender in self.senders:
|
40 |
+
message.content = self.prefix + message.content
|
41 |
+
|
42 |
+
class AsyncBlogger:
|
43 |
+
"""博客生成类,整合写作者和批评者。"""
|
44 |
+
|
45 |
+
def __init__(self, model_type, api_base, writer_prompt, critic_prompt, critic_prefix='', max_turn=2):
|
46 |
+
"""
|
47 |
+
初始化博客生成器
|
48 |
+
:param model_type: 模型类型
|
49 |
+
:param api_base: API 基地址
|
50 |
+
:param writer_prompt: 写作者提示词
|
51 |
+
:param critic_prompt: 批评者提示词
|
52 |
+
:param critic_prefix: 批评消息前缀
|
53 |
+
:param max_turn: 最大轮次
|
54 |
+
"""
|
55 |
+
self.model_type = model_type
|
56 |
+
self.api_base = api_base
|
57 |
+
self.llm = GPTAPI(
|
58 |
+
model_type=model_type,
|
59 |
+
api_base=api_base,
|
60 |
+
key=YOUR_TOKEN_HERE,
|
61 |
+
max_new_tokens=4096,
|
62 |
+
)
|
63 |
+
self.plugins = [dict(type='lagent.actions.ArxivSearch')]
|
64 |
+
self.writer = Agent(
|
65 |
+
self.llm,
|
66 |
+
writer_prompt,
|
67 |
+
name='写作者',
|
68 |
+
output_format=dict(
|
69 |
+
type=PluginParser,
|
70 |
+
template=PLUGIN_CN,
|
71 |
+
prompt=get_plugin_prompt(self.plugins)
|
72 |
+
)
|
73 |
+
)
|
74 |
+
self.critic = Agent(
|
75 |
+
self.llm,
|
76 |
+
critic_prompt,
|
77 |
+
name='批评者',
|
78 |
+
hooks=[PrefixedMessageHook(critic_prefix, ['写作者'])]
|
79 |
+
)
|
80 |
+
self.max_turn = max_turn
|
81 |
+
|
82 |
+
async def forward(self, message: AgentMessage, update_placeholder):
|
83 |
+
"""
|
84 |
+
执行多阶段博客生成流程
|
85 |
+
:param message: 初始消息
|
86 |
+
:param update_placeholder: Streamlit占位符
|
87 |
+
:return: 最终优化的博客内容
|
88 |
+
"""
|
89 |
+
step1_placeholder = update_placeholder.container()
|
90 |
+
step2_placeholder = update_placeholder.container()
|
91 |
+
step3_placeholder = update_placeholder.container()
|
92 |
+
|
93 |
+
# 第一步:生成初始内容
|
94 |
+
step1_placeholder.markdown("**Step 1: 生成初始内容...**")
|
95 |
+
message = self.writer(message)
|
96 |
+
if message.content:
|
97 |
+
step1_placeholder.markdown(f"**生成的初始内容**:\n\n{message.content}")
|
98 |
+
else:
|
99 |
+
step1_placeholder.markdown("**生成的初始内容为空,请检查生成逻辑。**")
|
100 |
+
|
101 |
+
# 第二步:批评者提供反馈
|
102 |
+
step2_placeholder.markdown("**Step 2: 批评者正在提供反馈和文献推荐...**")
|
103 |
+
message = self.critic(message)
|
104 |
+
if message.content:
|
105 |
+
# 解析批评者反馈
|
106 |
+
suggestions = re.search(r"1\. 批评建议:\n(.*?)2\. 推荐的关键词:", message.content, re.S)
|
107 |
+
keywords = re.search(r"2\. 推荐的关键词:\n- (.*)", message.content)
|
108 |
+
feedback = suggestions.group(1).strip() if suggestions else "未提供批评建议"
|
109 |
+
keywords = keywords.group(1).strip() if keywords else "未提供关键词"
|
110 |
+
|
111 |
+
# Arxiv 文献查询
|
112 |
+
arxiv_search = ArxivSearch()
|
113 |
+
arxiv_results = arxiv_search.get_arxiv_article_information(keywords)
|
114 |
+
|
115 |
+
# 显示批评内容和文献推荐
|
116 |
+
message.content = f"**批评建议**:\n{feedback}\n\n**推荐的文献**:\n{arxiv_results}"
|
117 |
+
step2_placeholder.markdown(f"**批评和文献推荐**:\n\n{message.content}")
|
118 |
+
else:
|
119 |
+
step2_placeholder.markdown("**批评内容为空,请检查批评逻辑。**")
|
120 |
+
|
121 |
+
# 第三步:写作者根据反馈优化内容
|
122 |
+
step3_placeholder.markdown("**Step 3: 根据反馈改进内容...**")
|
123 |
+
improvement_prompt = AgentMessage(
|
124 |
+
sender="critic",
|
125 |
+
content=(
|
126 |
+
f"根据以下批评建议和推荐文献对内容进行改进:\n\n"
|
127 |
+
f"批评建议:\n{feedback}\n\n"
|
128 |
+
f"推荐文献:\n{arxiv_results}\n\n"
|
129 |
+
f"请优化初始内容,使其更加清晰、丰富,并符合专业水准。"
|
130 |
+
),
|
131 |
+
)
|
132 |
+
message = self.writer(improvement_prompt)
|
133 |
+
if message.content:
|
134 |
+
step3_placeholder.markdown(f"**最终优化的博客内容**:\n\n{message.content}")
|
135 |
+
else:
|
136 |
+
step3_placeholder.markdown("**最终优化的博客内容为空,请检查生成逻辑。**")
|
137 |
+
|
138 |
+
return message
|
139 |
+
|
140 |
+
def setup_sidebar():
|
141 |
+
"""设置侧边栏,选择模型。"""
|
142 |
+
model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
|
143 |
+
api_base = st.sidebar.text_input(
|
144 |
+
'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
|
145 |
+
)
|
146 |
+
|
147 |
+
return model_name, api_base
|
148 |
+
|
149 |
+
def main():
|
150 |
+
"""
|
151 |
+
主函数:构建Streamlit界面并处理用户交互
|
152 |
+
"""
|
153 |
+
st.set_page_config(layout='wide', page_title='Lagent Web Demo', page_icon='🤖')
|
154 |
+
st.title("多代理博客优化助手")
|
155 |
+
|
156 |
+
model_type, api_base = setup_sidebar()
|
157 |
+
topic = st.text_input('输入一个话题:', 'Self-Supervised Learning')
|
158 |
+
generate_button = st.button('生成博客内容')
|
159 |
+
|
160 |
+
if (
|
161 |
+
'blogger' not in st.session_state or
|
162 |
+
st.session_state['model_type'] != model_type or
|
163 |
+
st.session_state['api_base'] != api_base
|
164 |
+
):
|
165 |
+
st.session_state['blogger'] = AsyncBlogger(
|
166 |
+
model_type=model_type,
|
167 |
+
api_base=api_base,
|
168 |
+
writer_prompt="你是一位优秀的AI内容写作者,请撰写一篇有吸引力且信息丰富的博客内容。",
|
169 |
+
critic_prompt="""
|
170 |
+
作为一位严谨的批评者,请给出建设性的批评和改进建议,并基于相关主题使用已有的工具推荐一些参考文献,推荐的关键词应该是英语形式,简洁且切题。
|
171 |
+
请按照以下格式提供反馈:
|
172 |
+
1. 批评建议:
|
173 |
+
- (具体建议)
|
174 |
+
2. 推荐的关键词:
|
175 |
+
- (关键词1, 关键词2, ...)
|
176 |
+
""",
|
177 |
+
critic_prefix="请批评以下内容,并提供改进建议:\n\n"
|
178 |
+
)
|
179 |
+
st.session_state['model_type'] = model_type
|
180 |
+
st.session_state['api_base'] = api_base
|
181 |
+
|
182 |
+
if generate_button:
|
183 |
+
update_placeholder = st.empty()
|
184 |
+
|
185 |
+
async def run_async_blogger():
|
186 |
+
message = AgentMessage(
|
187 |
+
sender='user',
|
188 |
+
content=f"请撰写一篇关于{topic}的博客文章,要求表达专业,生动有趣,并且易于理解。"
|
189 |
+
)
|
190 |
+
result = await st.session_state['blogger'].forward(message, update_placeholder)
|
191 |
+
return result
|
192 |
+
|
193 |
+
loop = asyncio.new_event_loop()
|
194 |
+
asyncio.set_event_loop(loop)
|
195 |
+
loop.run_until_complete(run_async_blogger())
|
196 |
+
|
197 |
+
if __name__ == '__main__':
|
198 |
+
main()
|
lagent/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .version import __version__, version_info
|
3 |
+
|
4 |
+
__all__ = ['__version__', 'version_info']
|
lagent/actions/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .action_executor import ActionExecutor, AsyncActionExecutor
|
2 |
+
from .arxiv_search import ArxivSearch, AsyncArxivSearch
|
3 |
+
from .base_action import BaseAction, tool_api
|
4 |
+
from .bing_map import AsyncBINGMap, BINGMap
|
5 |
+
from .builtin_actions import FinishAction, InvalidAction, NoAction
|
6 |
+
from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
|
7 |
+
from .google_search import AsyncGoogleSearch, GoogleSearch
|
8 |
+
from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
|
9 |
+
from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
|
10 |
+
from .ipython_manager import IPythonInteractiveManager
|
11 |
+
from .parser import BaseParser, JsonParser, TupleParser
|
12 |
+
from .ppt import PPT, AsyncPPT
|
13 |
+
from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
|
14 |
+
from .web_browser import AsyncWebBrowser, WebBrowser
|
15 |
+
from .weather_query import WeatherQuery
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
|
19 |
+
'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
|
20 |
+
'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
|
21 |
+
'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
|
22 |
+
'IPythonInteractive', 'AsyncIPythonInteractive',
|
23 |
+
'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
|
24 |
+
'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
|
25 |
+
'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery'
|
26 |
+
]
|
lagent/actions/action_executor.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from collections import OrderedDict
|
3 |
+
from typing import Callable, Dict, List, Union
|
4 |
+
|
5 |
+
from lagent.actions.base_action import BaseAction
|
6 |
+
from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
|
7 |
+
from lagent.hooks import Hook, RemovableHandle
|
8 |
+
from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
|
9 |
+
from lagent.utils import create_object
|
10 |
+
|
11 |
+
|
12 |
+
class ActionExecutor:
|
13 |
+
"""The action executor class.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
actions (Union[BaseAction, List[BaseAction]]): The action or actions.
|
17 |
+
invalid_action (BaseAction, optional): The invalid action. Defaults to
|
18 |
+
InvalidAction().
|
19 |
+
no_action (BaseAction, optional): The no action.
|
20 |
+
Defaults to NoAction().
|
21 |
+
finish_action (BaseAction, optional): The finish action. Defaults to
|
22 |
+
FinishAction().
|
23 |
+
finish_in_action (bool, optional): Whether the finish action is in the
|
24 |
+
action list. Defaults to False.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
|
30 |
+
invalid_action: BaseAction = dict(type=InvalidAction),
|
31 |
+
no_action: BaseAction = dict(type=NoAction),
|
32 |
+
finish_action: BaseAction = dict(type=FinishAction),
|
33 |
+
finish_in_action: bool = False,
|
34 |
+
hooks: List[Dict] = None,
|
35 |
+
):
|
36 |
+
|
37 |
+
if not isinstance(actions, list):
|
38 |
+
actions = [actions]
|
39 |
+
finish_action = create_object(finish_action)
|
40 |
+
if finish_in_action:
|
41 |
+
actions.append(finish_action)
|
42 |
+
for i, action in enumerate(actions):
|
43 |
+
actions[i] = create_object(action)
|
44 |
+
self.actions = {action.name: action for action in actions}
|
45 |
+
|
46 |
+
self.invalid_action = create_object(invalid_action)
|
47 |
+
self.no_action = create_object(no_action)
|
48 |
+
self.finish_action = finish_action
|
49 |
+
self._hooks: Dict[int, Hook] = OrderedDict()
|
50 |
+
if hooks:
|
51 |
+
for hook in hooks:
|
52 |
+
hook = create_object(hook)
|
53 |
+
self.register_hook(hook)
|
54 |
+
|
55 |
+
def description(self) -> List[Dict]:
|
56 |
+
actions = []
|
57 |
+
for action_name, action in self.actions.items():
|
58 |
+
if action.is_toolkit:
|
59 |
+
for api in action.description['api_list']:
|
60 |
+
api_desc = api.copy()
|
61 |
+
api_desc['name'] = f"{action_name}.{api_desc['name']}"
|
62 |
+
actions.append(api_desc)
|
63 |
+
else:
|
64 |
+
action_desc = action.description.copy()
|
65 |
+
actions.append(action_desc)
|
66 |
+
return actions
|
67 |
+
|
68 |
+
def __contains__(self, name: str):
|
69 |
+
return name in self.actions
|
70 |
+
|
71 |
+
def keys(self):
|
72 |
+
return list(self.actions.keys())
|
73 |
+
|
74 |
+
def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
|
75 |
+
action = create_object(action)
|
76 |
+
self.actions[action.name] = action
|
77 |
+
|
78 |
+
def __delitem__(self, name: str):
|
79 |
+
del self.actions[name]
|
80 |
+
|
81 |
+
def forward(self, name, parameters, **kwargs) -> ActionReturn:
|
82 |
+
action_name, api_name = (
|
83 |
+
name.split('.') if '.' in name else (name, 'run'))
|
84 |
+
action_return: ActionReturn = ActionReturn()
|
85 |
+
if action_name not in self:
|
86 |
+
if name == self.no_action.name:
|
87 |
+
action_return = self.no_action(parameters)
|
88 |
+
elif name == self.finish_action.name:
|
89 |
+
action_return = self.finish_action(parameters)
|
90 |
+
else:
|
91 |
+
action_return = self.invalid_action(parameters)
|
92 |
+
else:
|
93 |
+
action_return = self.actions[action_name](parameters, api_name)
|
94 |
+
action_return.valid = ActionValidCode.OPEN
|
95 |
+
return action_return
|
96 |
+
|
97 |
+
def __call__(self,
|
98 |
+
message: AgentMessage,
|
99 |
+
session_id=0,
|
100 |
+
**kwargs) -> AgentMessage:
|
101 |
+
# message.receiver = self.name
|
102 |
+
for hook in self._hooks.values():
|
103 |
+
result = hook.before_action(self, message, session_id)
|
104 |
+
if result:
|
105 |
+
message = result
|
106 |
+
|
107 |
+
assert isinstance(message.content, FunctionCall) or (
|
108 |
+
isinstance(message.content, dict) and 'name' in message.content
|
109 |
+
and 'parameters' in message.content)
|
110 |
+
if isinstance(message.content, dict):
|
111 |
+
name = message.content.get('name')
|
112 |
+
parameters = message.content.get('parameters')
|
113 |
+
else:
|
114 |
+
name = message.content.name
|
115 |
+
parameters = message.content.parameters
|
116 |
+
|
117 |
+
response_message = self.forward(
|
118 |
+
name=name, parameters=parameters, **kwargs)
|
119 |
+
if not isinstance(response_message, AgentMessage):
|
120 |
+
response_message = AgentMessage(
|
121 |
+
sender=self.__class__.__name__,
|
122 |
+
content=response_message,
|
123 |
+
)
|
124 |
+
|
125 |
+
for hook in self._hooks.values():
|
126 |
+
result = hook.after_action(self, response_message, session_id)
|
127 |
+
if result:
|
128 |
+
response_message = result
|
129 |
+
return response_message
|
130 |
+
|
131 |
+
def register_hook(self, hook: Callable):
|
132 |
+
handle = RemovableHandle(self._hooks)
|
133 |
+
self._hooks[handle.id] = hook
|
134 |
+
return handle
|
135 |
+
|
136 |
+
|
137 |
+
class AsyncActionExecutor(ActionExecutor):
|
138 |
+
|
139 |
+
async def forward(self, name, parameters, **kwargs) -> ActionReturn:
|
140 |
+
action_name, api_name = (
|
141 |
+
name.split('.') if '.' in name else (name, 'run'))
|
142 |
+
action_return: ActionReturn = ActionReturn()
|
143 |
+
if action_name not in self:
|
144 |
+
if name == self.no_action.name:
|
145 |
+
action_return = self.no_action(parameters)
|
146 |
+
elif name == self.finish_action.name:
|
147 |
+
action_return = self.finish_action(parameters)
|
148 |
+
else:
|
149 |
+
action_return = self.invalid_action(parameters)
|
150 |
+
else:
|
151 |
+
action = self.actions[action_name]
|
152 |
+
if inspect.iscoroutinefunction(action.__call__):
|
153 |
+
action_return = await action(parameters, api_name)
|
154 |
+
else:
|
155 |
+
action_return = action(parameters, api_name)
|
156 |
+
action_return.valid = ActionValidCode.OPEN
|
157 |
+
return action_return
|
158 |
+
|
159 |
+
async def __call__(self,
|
160 |
+
message: AgentMessage,
|
161 |
+
session_id=0,
|
162 |
+
**kwargs) -> AgentMessage:
|
163 |
+
# message.receiver = self.name
|
164 |
+
for hook in self._hooks.values():
|
165 |
+
if inspect.iscoroutinefunction(hook.before_action):
|
166 |
+
result = await hook.before_action(self, message, session_id)
|
167 |
+
else:
|
168 |
+
result = hook.before_action(self, message, session_id)
|
169 |
+
if result:
|
170 |
+
message = result
|
171 |
+
|
172 |
+
assert isinstance(message.content, FunctionCall) or (
|
173 |
+
isinstance(message.content, dict) and 'name' in message.content
|
174 |
+
and 'parameters' in message.content)
|
175 |
+
if isinstance(message.content, dict):
|
176 |
+
name = message.content.get('name')
|
177 |
+
parameters = message.content.get('parameters')
|
178 |
+
else:
|
179 |
+
name = message.content.name
|
180 |
+
parameters = message.content.parameters
|
181 |
+
|
182 |
+
response_message = await self.forward(
|
183 |
+
name=name, parameters=parameters, **kwargs)
|
184 |
+
if not isinstance(response_message, AgentMessage):
|
185 |
+
response_message = AgentMessage(
|
186 |
+
sender=self.__class__.__name__,
|
187 |
+
content=response_message,
|
188 |
+
)
|
189 |
+
|
190 |
+
for hook in self._hooks.values():
|
191 |
+
if inspect.iscoroutinefunction(hook.after_action):
|
192 |
+
result = await hook.after_action(self, response_message,
|
193 |
+
session_id)
|
194 |
+
else:
|
195 |
+
result = hook.after_action(self, response_message, session_id)
|
196 |
+
if result:
|
197 |
+
response_message = result
|
198 |
+
return response_message
|
lagent/actions/arxiv_search.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Type
|
2 |
+
|
3 |
+
from asyncer import asyncify
|
4 |
+
|
5 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
6 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
7 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
8 |
+
|
9 |
+
|
10 |
+
class ArxivSearch(BaseAction):
|
11 |
+
"""Search information from Arxiv.org. \
|
12 |
+
Useful for when you need to answer questions about Physics, Mathematics, \
|
13 |
+
Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
|
14 |
+
Electrical Engineering, and Economics from scientific articles on arxiv.org.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
top_k_results: int = 3,
|
20 |
+
max_query_len: int = 300,
|
21 |
+
doc_content_chars_max: int = 1500,
|
22 |
+
description: Optional[dict] = None,
|
23 |
+
parser: Type[BaseParser] = JsonParser,
|
24 |
+
):
|
25 |
+
super().__init__(description, parser)
|
26 |
+
self.top_k_results = top_k_results
|
27 |
+
self.max_query_len = max_query_len
|
28 |
+
self.doc_content_chars_max = doc_content_chars_max
|
29 |
+
|
30 |
+
@tool_api(explode_return=True)
|
31 |
+
def get_arxiv_article_information(self, query: str) -> dict:
|
32 |
+
"""Run Arxiv search and get the article meta information.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
query (:class:`str`): the content of search query
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
:class:`dict`: article information
|
39 |
+
* content (str): a list of 3 arxiv search papers
|
40 |
+
"""
|
41 |
+
import arxiv
|
42 |
+
|
43 |
+
try:
|
44 |
+
results = arxiv.Search( # type: ignore
|
45 |
+
query[: self.max_query_len], max_results=self.top_k_results
|
46 |
+
).results()
|
47 |
+
except Exception as exc:
|
48 |
+
return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
|
49 |
+
docs = [
|
50 |
+
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
|
51 |
+
f'Authors: {", ".join(a.name for a in result.authors)}\n'
|
52 |
+
f'Summary: {result.summary[:self.doc_content_chars_max]}'
|
53 |
+
for result in results
|
54 |
+
]
|
55 |
+
if docs:
|
56 |
+
return {'content': '\n\n'.join(docs)}
|
57 |
+
return {'content': 'No good Arxiv Result was found'}
|
58 |
+
|
59 |
+
|
60 |
+
class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
|
61 |
+
"""Search information from Arxiv.org. \
|
62 |
+
Useful for when you need to answer questions about Physics, Mathematics, \
|
63 |
+
Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
|
64 |
+
Electrical Engineering, and Economics from scientific articles on arxiv.org.
|
65 |
+
"""
|
66 |
+
|
67 |
+
@tool_api(explode_return=True)
|
68 |
+
@asyncify
|
69 |
+
def get_arxiv_article_information(self, query: str) -> dict:
|
70 |
+
"""Run Arxiv search and get the article meta information.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
query (:class:`str`): the content of search query
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
:class:`dict`: article information
|
77 |
+
* content (str): a list of 3 arxiv search papers
|
78 |
+
"""
|
79 |
+
return super().get_arxiv_article_information(query)
|
lagent/actions/base_action.py
ADDED
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import logging
|
3 |
+
import re
|
4 |
+
from abc import ABCMeta
|
5 |
+
from copy import deepcopy
|
6 |
+
from functools import wraps
|
7 |
+
from typing import Callable, Optional, Type, get_args, get_origin
|
8 |
+
|
9 |
+
try:
|
10 |
+
from typing import Annotated
|
11 |
+
except ImportError:
|
12 |
+
from typing_extensions import Annotated
|
13 |
+
|
14 |
+
from griffe import Docstring
|
15 |
+
|
16 |
+
try:
|
17 |
+
from griffe import DocstringSectionKind
|
18 |
+
except ImportError:
|
19 |
+
from griffe.enumerations import DocstringSectionKind
|
20 |
+
|
21 |
+
from ..schema import ActionReturn, ActionStatusCode
|
22 |
+
from .parser import BaseParser, JsonParser, ParseError
|
23 |
+
|
24 |
+
logging.getLogger('griffe').setLevel(logging.ERROR)
|
25 |
+
|
26 |
+
|
27 |
+
def tool_api(func: Optional[Callable] = None,
|
28 |
+
*,
|
29 |
+
explode_return: bool = False,
|
30 |
+
returns_named_value: bool = False,
|
31 |
+
**kwargs):
|
32 |
+
"""Turn functions into tools. It will parse typehints as well as docstrings
|
33 |
+
to build the tool description and attach it to functions via an attribute
|
34 |
+
``api_description``.
|
35 |
+
|
36 |
+
Examples:
|
37 |
+
|
38 |
+
.. code-block:: python
|
39 |
+
|
40 |
+
# typehints has higher priority than docstrings
|
41 |
+
from typing import Annotated
|
42 |
+
|
43 |
+
@tool_api
|
44 |
+
def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
|
45 |
+
'''Add operation
|
46 |
+
|
47 |
+
Args:
|
48 |
+
x (int): a
|
49 |
+
y (int): b
|
50 |
+
'''
|
51 |
+
return a + b
|
52 |
+
|
53 |
+
print(add.api_description)
|
54 |
+
|
55 |
+
Args:
|
56 |
+
func (Optional[Callable]): function to decorate. Defaults to ``None``.
|
57 |
+
explode_return (bool): whether to flatten the dictionary or tuple return
|
58 |
+
as the ``return_data`` field. When enabled, it is recommended to
|
59 |
+
annotate the member in docstrings. Defaults to ``False``.
|
60 |
+
|
61 |
+
.. code-block:: python
|
62 |
+
|
63 |
+
@tool_api(explode_return=True)
|
64 |
+
def foo(a, b):
|
65 |
+
'''A simple function
|
66 |
+
|
67 |
+
Args:
|
68 |
+
a (int): a
|
69 |
+
b (int): b
|
70 |
+
|
71 |
+
Returns:
|
72 |
+
dict: information of inputs
|
73 |
+
* x: value of a
|
74 |
+
* y: value of b
|
75 |
+
'''
|
76 |
+
return {'x': a, 'y': b}
|
77 |
+
|
78 |
+
print(foo.api_description)
|
79 |
+
|
80 |
+
returns_named_value (bool): whether to parse ``thing: Description`` in
|
81 |
+
returns sections as a name and description, rather than a type and
|
82 |
+
description. When true, type must be wrapped in parentheses:
|
83 |
+
``(int): Description``. When false, parentheses are optional but
|
84 |
+
the items cannot be named: ``int: Description``. Defaults to ``False``.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
Callable: wrapped function or partial decorator
|
88 |
+
|
89 |
+
Important:
|
90 |
+
``return_data`` field will be added to ``api_description`` only
|
91 |
+
when ``explode_return`` or ``returns_named_value`` is enabled.
|
92 |
+
"""
|
93 |
+
|
94 |
+
def _detect_type(string):
|
95 |
+
field_type = 'STRING'
|
96 |
+
if 'list' in string:
|
97 |
+
field_type = 'Array'
|
98 |
+
elif 'str' not in string:
|
99 |
+
if 'float' in string:
|
100 |
+
field_type = 'FLOAT'
|
101 |
+
elif 'int' in string:
|
102 |
+
field_type = 'NUMBER'
|
103 |
+
elif 'bool' in string:
|
104 |
+
field_type = 'BOOLEAN'
|
105 |
+
return field_type
|
106 |
+
|
107 |
+
def _explode(desc):
|
108 |
+
kvs = []
|
109 |
+
desc = '\nArgs:\n' + '\n'.join([
|
110 |
+
' ' + item.lstrip(' -+*#.')
|
111 |
+
for item in desc.split('\n')[1:] if item.strip()
|
112 |
+
])
|
113 |
+
docs = Docstring(desc).parse('google')
|
114 |
+
if not docs:
|
115 |
+
return kvs
|
116 |
+
if docs[0].kind is DocstringSectionKind.parameters:
|
117 |
+
for d in docs[0].value:
|
118 |
+
d = d.as_dict()
|
119 |
+
if not d['annotation']:
|
120 |
+
d.pop('annotation')
|
121 |
+
else:
|
122 |
+
d['type'] = _detect_type(d.pop('annotation').lower())
|
123 |
+
kvs.append(d)
|
124 |
+
return kvs
|
125 |
+
|
126 |
+
def _parse_tool(function):
|
127 |
+
# remove rst syntax
|
128 |
+
docs = Docstring(
|
129 |
+
re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
|
130 |
+
'google', returns_named_value=returns_named_value, **kwargs)
|
131 |
+
desc = dict(
|
132 |
+
name=function.__name__,
|
133 |
+
description=docs[0].value
|
134 |
+
if docs[0].kind is DocstringSectionKind.text else '',
|
135 |
+
parameters=[],
|
136 |
+
required=[],
|
137 |
+
)
|
138 |
+
args_doc, returns_doc = {}, []
|
139 |
+
for doc in docs:
|
140 |
+
if doc.kind is DocstringSectionKind.parameters:
|
141 |
+
for d in doc.value:
|
142 |
+
d = d.as_dict()
|
143 |
+
d['type'] = _detect_type(d.pop('annotation').lower())
|
144 |
+
args_doc[d['name']] = d
|
145 |
+
if doc.kind is DocstringSectionKind.returns:
|
146 |
+
for d in doc.value:
|
147 |
+
d = d.as_dict()
|
148 |
+
if not d['name']:
|
149 |
+
d.pop('name')
|
150 |
+
if not d['annotation']:
|
151 |
+
d.pop('annotation')
|
152 |
+
else:
|
153 |
+
d['type'] = _detect_type(d.pop('annotation').lower())
|
154 |
+
returns_doc.append(d)
|
155 |
+
|
156 |
+
sig = inspect.signature(function)
|
157 |
+
for name, param in sig.parameters.items():
|
158 |
+
if name == 'self':
|
159 |
+
continue
|
160 |
+
parameter = dict(
|
161 |
+
name=param.name,
|
162 |
+
type='STRING',
|
163 |
+
description=args_doc.get(param.name,
|
164 |
+
{}).get('description', ''))
|
165 |
+
annotation = param.annotation
|
166 |
+
if annotation is inspect.Signature.empty:
|
167 |
+
parameter['type'] = args_doc.get(param.name,
|
168 |
+
{}).get('type', 'STRING')
|
169 |
+
else:
|
170 |
+
if get_origin(annotation) is Annotated:
|
171 |
+
annotation, info = get_args(annotation)
|
172 |
+
if info:
|
173 |
+
parameter['description'] = info
|
174 |
+
while get_origin(annotation):
|
175 |
+
annotation = get_args(annotation)
|
176 |
+
parameter['type'] = _detect_type(str(annotation))
|
177 |
+
desc['parameters'].append(parameter)
|
178 |
+
if param.default is inspect.Signature.empty:
|
179 |
+
desc['required'].append(param.name)
|
180 |
+
|
181 |
+
return_data = []
|
182 |
+
if explode_return:
|
183 |
+
return_data = _explode(returns_doc[0]['description'])
|
184 |
+
elif returns_named_value:
|
185 |
+
return_data = returns_doc
|
186 |
+
if return_data:
|
187 |
+
desc['return_data'] = return_data
|
188 |
+
return desc
|
189 |
+
|
190 |
+
if callable(func):
|
191 |
+
|
192 |
+
if inspect.iscoroutinefunction(func):
|
193 |
+
|
194 |
+
@wraps(func)
|
195 |
+
async def wrapper(self, *args, **kwargs):
|
196 |
+
return await func(self, *args, **kwargs)
|
197 |
+
|
198 |
+
else:
|
199 |
+
|
200 |
+
@wraps(func)
|
201 |
+
def wrapper(self, *args, **kwargs):
|
202 |
+
return func(self, *args, **kwargs)
|
203 |
+
|
204 |
+
wrapper.api_description = _parse_tool(func)
|
205 |
+
return wrapper
|
206 |
+
|
207 |
+
def decorate(func):
|
208 |
+
|
209 |
+
if inspect.iscoroutinefunction(func):
|
210 |
+
|
211 |
+
@wraps(func)
|
212 |
+
async def wrapper(self, *args, **kwargs):
|
213 |
+
return await func(self, *args, **kwargs)
|
214 |
+
|
215 |
+
else:
|
216 |
+
|
217 |
+
@wraps(func)
|
218 |
+
def wrapper(self, *args, **kwargs):
|
219 |
+
return func(self, *args, **kwargs)
|
220 |
+
|
221 |
+
wrapper.api_description = _parse_tool(func)
|
222 |
+
return wrapper
|
223 |
+
|
224 |
+
return decorate
|
225 |
+
|
226 |
+
|
227 |
+
class ToolMeta(ABCMeta):
|
228 |
+
"""Metaclass of tools."""
|
229 |
+
|
230 |
+
def __new__(mcs, name, base, attrs):
|
231 |
+
is_toolkit, tool_desc = True, dict(
|
232 |
+
name=name,
|
233 |
+
description=Docstring(attrs.get('__doc__',
|
234 |
+
'')).parse('google')[0].value)
|
235 |
+
for key, value in attrs.items():
|
236 |
+
if callable(value) and hasattr(value, 'api_description'):
|
237 |
+
api_desc = getattr(value, 'api_description')
|
238 |
+
if key == 'run':
|
239 |
+
tool_desc['parameters'] = api_desc['parameters']
|
240 |
+
tool_desc['required'] = api_desc['required']
|
241 |
+
if api_desc['description']:
|
242 |
+
tool_desc['description'] = api_desc['description']
|
243 |
+
if api_desc.get('return_data'):
|
244 |
+
tool_desc['return_data'] = api_desc['return_data']
|
245 |
+
is_toolkit = False
|
246 |
+
else:
|
247 |
+
tool_desc.setdefault('api_list', []).append(api_desc)
|
248 |
+
if not is_toolkit and 'api_list' in tool_desc:
|
249 |
+
raise KeyError('`run` and other tool APIs can not be implemented '
|
250 |
+
'at the same time')
|
251 |
+
if is_toolkit and 'api_list' not in tool_desc:
|
252 |
+
is_toolkit = False
|
253 |
+
if callable(attrs.get('run')):
|
254 |
+
run_api = tool_api(attrs['run'])
|
255 |
+
api_desc = run_api.api_description
|
256 |
+
tool_desc['parameters'] = api_desc['parameters']
|
257 |
+
tool_desc['required'] = api_desc['required']
|
258 |
+
if api_desc['description']:
|
259 |
+
tool_desc['description'] = api_desc['description']
|
260 |
+
if api_desc.get('return_data'):
|
261 |
+
tool_desc['return_data'] = api_desc['return_data']
|
262 |
+
attrs['run'] = run_api
|
263 |
+
else:
|
264 |
+
tool_desc['parameters'], tool_desc['required'] = [], []
|
265 |
+
attrs['_is_toolkit'] = is_toolkit
|
266 |
+
attrs['__tool_description__'] = tool_desc
|
267 |
+
return super().__new__(mcs, name, base, attrs)
|
268 |
+
|
269 |
+
|
270 |
+
class BaseAction(metaclass=ToolMeta):
|
271 |
+
"""Base class for all actions.
|
272 |
+
|
273 |
+
Args:
|
274 |
+
description (:class:`Optional[dict]`): The description of the action.
|
275 |
+
Defaults to ``None``.
|
276 |
+
parser (:class:`Type[BaseParser]`): The parser class to process the
|
277 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
278 |
+
|
279 |
+
Examples:
|
280 |
+
|
281 |
+
* simple tool
|
282 |
+
|
283 |
+
.. code-block:: python
|
284 |
+
|
285 |
+
class Bold(BaseAction):
|
286 |
+
'''Make text bold'''
|
287 |
+
|
288 |
+
def run(self, text: str):
|
289 |
+
'''
|
290 |
+
Args:
|
291 |
+
text (str): input text
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
str: bold text
|
295 |
+
'''
|
296 |
+
return '**' + text + '**'
|
297 |
+
|
298 |
+
action = Bold()
|
299 |
+
|
300 |
+
* toolkit with multiple APIs
|
301 |
+
|
302 |
+
.. code-block:: python
|
303 |
+
|
304 |
+
class Calculator(BaseAction):
|
305 |
+
'''Calculator'''
|
306 |
+
|
307 |
+
@tool_api
|
308 |
+
def add(self, a, b):
|
309 |
+
'''Add operation
|
310 |
+
|
311 |
+
Args:
|
312 |
+
a (int): augend
|
313 |
+
b (int): addend
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
int: sum
|
317 |
+
'''
|
318 |
+
return a + b
|
319 |
+
|
320 |
+
@tool_api
|
321 |
+
def sub(self, a, b):
|
322 |
+
'''Subtraction operation
|
323 |
+
|
324 |
+
Args:
|
325 |
+
a (int): minuend
|
326 |
+
b (int): subtrahend
|
327 |
+
|
328 |
+
Returns:
|
329 |
+
int: difference
|
330 |
+
'''
|
331 |
+
return a - b
|
332 |
+
|
333 |
+
action = Calculator()
|
334 |
+
"""
|
335 |
+
|
336 |
+
def __init__(
|
337 |
+
self,
|
338 |
+
description: Optional[dict] = None,
|
339 |
+
parser: Type[BaseParser] = JsonParser,
|
340 |
+
):
|
341 |
+
self._description = deepcopy(description or self.__tool_description__)
|
342 |
+
self._name = self._description['name']
|
343 |
+
self._parser = parser(self)
|
344 |
+
|
345 |
+
def __call__(self, inputs: str, name='run') -> ActionReturn:
|
346 |
+
fallback_args = {'inputs': inputs, 'name': name}
|
347 |
+
if not hasattr(self, name):
|
348 |
+
return ActionReturn(
|
349 |
+
fallback_args,
|
350 |
+
type=self.name,
|
351 |
+
errmsg=f'invalid API: {name}',
|
352 |
+
state=ActionStatusCode.API_ERROR)
|
353 |
+
try:
|
354 |
+
inputs = self._parser.parse_inputs(inputs, name)
|
355 |
+
except ParseError as exc:
|
356 |
+
return ActionReturn(
|
357 |
+
fallback_args,
|
358 |
+
type=self.name,
|
359 |
+
errmsg=exc.err_msg,
|
360 |
+
state=ActionStatusCode.ARGS_ERROR)
|
361 |
+
try:
|
362 |
+
outputs = getattr(self, name)(**inputs)
|
363 |
+
except Exception as exc:
|
364 |
+
return ActionReturn(
|
365 |
+
inputs,
|
366 |
+
type=self.name,
|
367 |
+
errmsg=str(exc),
|
368 |
+
state=ActionStatusCode.API_ERROR)
|
369 |
+
if isinstance(outputs, ActionReturn):
|
370 |
+
action_return = outputs
|
371 |
+
if not action_return.args:
|
372 |
+
action_return.args = inputs
|
373 |
+
if not action_return.type:
|
374 |
+
action_return.type = self.name
|
375 |
+
else:
|
376 |
+
result = self._parser.parse_outputs(outputs)
|
377 |
+
action_return = ActionReturn(inputs, type=self.name, result=result)
|
378 |
+
return action_return
|
379 |
+
|
380 |
+
@property
|
381 |
+
def name(self):
|
382 |
+
return self._name
|
383 |
+
|
384 |
+
@property
|
385 |
+
def is_toolkit(self):
|
386 |
+
return self._is_toolkit
|
387 |
+
|
388 |
+
@property
|
389 |
+
def description(self) -> dict:
|
390 |
+
"""Description of the tool."""
|
391 |
+
return self._description
|
392 |
+
|
393 |
+
def __repr__(self):
|
394 |
+
return f'{self.description}'
|
395 |
+
|
396 |
+
__str__ = __repr__
|
397 |
+
|
398 |
+
|
399 |
+
class AsyncActionMixin:
|
400 |
+
|
401 |
+
async def __call__(self, inputs: str, name='run') -> ActionReturn:
|
402 |
+
fallback_args = {'inputs': inputs, 'name': name}
|
403 |
+
if not hasattr(self, name):
|
404 |
+
return ActionReturn(
|
405 |
+
fallback_args,
|
406 |
+
type=self.name,
|
407 |
+
errmsg=f'invalid API: {name}',
|
408 |
+
state=ActionStatusCode.API_ERROR)
|
409 |
+
try:
|
410 |
+
inputs = self._parser.parse_inputs(inputs, name)
|
411 |
+
except ParseError as exc:
|
412 |
+
return ActionReturn(
|
413 |
+
fallback_args,
|
414 |
+
type=self.name,
|
415 |
+
errmsg=exc.err_msg,
|
416 |
+
state=ActionStatusCode.ARGS_ERROR)
|
417 |
+
try:
|
418 |
+
outputs = await getattr(self, name)(**inputs)
|
419 |
+
except Exception as exc:
|
420 |
+
return ActionReturn(
|
421 |
+
inputs,
|
422 |
+
type=self.name,
|
423 |
+
errmsg=str(exc),
|
424 |
+
state=ActionStatusCode.API_ERROR)
|
425 |
+
if isinstance(outputs, ActionReturn):
|
426 |
+
action_return = outputs
|
427 |
+
if not action_return.args:
|
428 |
+
action_return.args = inputs
|
429 |
+
if not action_return.type:
|
430 |
+
action_return.type = self.name
|
431 |
+
else:
|
432 |
+
result = self._parser.parse_outputs(outputs)
|
433 |
+
action_return = ActionReturn(inputs, type=self.name, result=result)
|
434 |
+
return action_return
|
lagent/actions/bing_map.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: E501
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import Optional, Type
|
5 |
+
|
6 |
+
import aiohttp
|
7 |
+
import requests
|
8 |
+
|
9 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
10 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
11 |
+
|
12 |
+
|
13 |
+
class BINGMap(BaseAction):
|
14 |
+
"""BING Map plugin for looking up map information."""
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
key: Optional[str] = None,
|
19 |
+
description: Optional[dict] = None,
|
20 |
+
parser: Type[BaseParser] = JsonParser,
|
21 |
+
) -> None:
|
22 |
+
super().__init__(description, parser)
|
23 |
+
key = os.environ.get('BING_MAP_KEY', key)
|
24 |
+
if key is None:
|
25 |
+
raise ValueError(
|
26 |
+
'Please set BING Map API key either in the environment '
|
27 |
+
'as BING_MAP_KEY or pass it as `key` parameter.')
|
28 |
+
self.key = key
|
29 |
+
self.base_url = 'http://dev.virtualearth.net/REST/V1/'
|
30 |
+
|
31 |
+
@tool_api(explode_return=True)
|
32 |
+
def get_distance(self, start: str, end: str) -> dict:
|
33 |
+
"""Get the distance between two locations in km.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
start (:class:`str`): The start location
|
37 |
+
end (:class:`str`): The end location
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
:class:`dict`: distance information
|
41 |
+
* distance (str): the distance in km.
|
42 |
+
"""
|
43 |
+
# Request URL
|
44 |
+
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
45 |
+
# GET request
|
46 |
+
r = requests.get(url)
|
47 |
+
# TODO check request status?
|
48 |
+
data = json.loads(r.text)
|
49 |
+
# Extract route information
|
50 |
+
route = data['resourceSets'][0]['resources'][0]
|
51 |
+
# Extract distance in miles
|
52 |
+
distance = route['travelDistance']
|
53 |
+
return dict(distance=distance)
|
54 |
+
|
55 |
+
@tool_api(explode_return=True)
|
56 |
+
def get_route(self, start: str, end: str) -> dict:
|
57 |
+
"""Get the route between two locations in km.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
start (:class:`str`): The start location
|
61 |
+
end (:class:`str`): The end location
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
:class:`dict`: route information
|
65 |
+
* route (list): the route, a list of actions.
|
66 |
+
"""
|
67 |
+
# Request URL
|
68 |
+
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
69 |
+
# GET request
|
70 |
+
r = requests.get(url)
|
71 |
+
data = json.loads(r.text)
|
72 |
+
# Extract route information
|
73 |
+
route = data['resourceSets'][0]['resources'][0]
|
74 |
+
itinerary = route['routeLegs'][0]['itineraryItems']
|
75 |
+
# Extract route text information
|
76 |
+
route_text = []
|
77 |
+
for item in itinerary:
|
78 |
+
if 'instruction' in item:
|
79 |
+
route_text.append(item['instruction']['text'])
|
80 |
+
return dict(route=route_text)
|
81 |
+
|
82 |
+
@tool_api(explode_return=True)
|
83 |
+
def get_coordinates(self, location: str) -> dict:
|
84 |
+
"""Get the coordinates of a location.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
location (:class:`str`): the location need to get coordinates.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
:class:`dict`: coordinates information
|
91 |
+
* latitude (float): the latitude of the location.
|
92 |
+
* longitude (float): the longitude of the location.
|
93 |
+
"""
|
94 |
+
url = self.base_url + 'Locations'
|
95 |
+
params = {'query': location, 'key': self.key}
|
96 |
+
response = requests.get(url, params=params)
|
97 |
+
json_data = response.json()
|
98 |
+
coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
|
99 |
+
'coordinates']
|
100 |
+
return dict(latitude=coordinates[0], longitude=coordinates[1])
|
101 |
+
|
102 |
+
@tool_api(explode_return=True)
|
103 |
+
def search_nearby(self,
|
104 |
+
search_term: str,
|
105 |
+
places: str = 'unknown',
|
106 |
+
latitude: float = 0.0,
|
107 |
+
longitude: float = 0.0,
|
108 |
+
radius: int = 5000) -> dict:
|
109 |
+
"""Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
search_term (:class:`str`): the place name.
|
113 |
+
places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
|
114 |
+
latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
|
115 |
+
longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
|
116 |
+
radius (:class:`int`): radius in meters. Defaults to ``5000``.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
:class:`dict`: places information
|
120 |
+
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
|
121 |
+
"""
|
122 |
+
url = self.base_url + 'LocalSearch'
|
123 |
+
if places != 'unknown':
|
124 |
+
pos = self.get_coordinates(**{'location': places})
|
125 |
+
latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
|
126 |
+
# Build the request query string
|
127 |
+
params = {
|
128 |
+
'query': search_term,
|
129 |
+
'userLocation': f'{latitude},{longitude}',
|
130 |
+
'radius': radius,
|
131 |
+
'key': self.key
|
132 |
+
}
|
133 |
+
# Make the request
|
134 |
+
response = requests.get(url, params=params)
|
135 |
+
# Parse the response
|
136 |
+
response_data = json.loads(response.content)
|
137 |
+
# Get the results
|
138 |
+
results = response_data['resourceSets'][0]['resources']
|
139 |
+
addresses = []
|
140 |
+
for result in results:
|
141 |
+
name = result['name']
|
142 |
+
address = result['Address']['formattedAddress']
|
143 |
+
addresses.append(dict(name=name, address=address))
|
144 |
+
if len(addresses) == 5:
|
145 |
+
break
|
146 |
+
return dict(place=addresses)
|
147 |
+
|
148 |
+
|
149 |
+
class AsyncBINGMap(AsyncActionMixin, BINGMap):
|
150 |
+
"""BING Map plugin for looking up map information."""
|
151 |
+
|
152 |
+
@tool_api(explode_return=True)
|
153 |
+
async def get_distance(self, start: str, end: str) -> dict:
|
154 |
+
"""Get the distance between two locations in km.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
start (:class:`str`): The start location
|
158 |
+
end (:class:`str`): The end location
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
:class:`dict`: distance information
|
162 |
+
* distance (str): the distance in km.
|
163 |
+
"""
|
164 |
+
# Request URL
|
165 |
+
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
166 |
+
# GET request
|
167 |
+
async with aiohttp.ClientSession() as session:
|
168 |
+
async with session.get(url) as resp:
|
169 |
+
# TODO check request status?
|
170 |
+
data = await resp.json()
|
171 |
+
# Extract route information
|
172 |
+
route = data['resourceSets'][0]['resources'][0]
|
173 |
+
# Extract distance in miles
|
174 |
+
distance = route['travelDistance']
|
175 |
+
return dict(distance=distance)
|
176 |
+
|
177 |
+
@tool_api(explode_return=True)
|
178 |
+
async def get_route(self, start: str, end: str) -> dict:
|
179 |
+
"""Get the route between two locations in km.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
start (:class:`str`): The start location
|
183 |
+
end (:class:`str`): The end location
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
:class:`dict`: route information
|
187 |
+
* route (list): the route, a list of actions.
|
188 |
+
"""
|
189 |
+
# Request URL
|
190 |
+
url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
|
191 |
+
# GET request
|
192 |
+
async with aiohttp.ClientSession() as session:
|
193 |
+
async with session.get(url) as resp:
|
194 |
+
data = await resp.json()
|
195 |
+
# Extract route information
|
196 |
+
route = data['resourceSets'][0]['resources'][0]
|
197 |
+
itinerary = route['routeLegs'][0]['itineraryItems']
|
198 |
+
# Extract route text information
|
199 |
+
route_text = []
|
200 |
+
for item in itinerary:
|
201 |
+
if 'instruction' in item:
|
202 |
+
route_text.append(item['instruction']['text'])
|
203 |
+
return dict(route=route_text)
|
204 |
+
|
205 |
+
@tool_api(explode_return=True)
|
206 |
+
async def get_coordinates(self, location: str) -> dict:
|
207 |
+
"""Get the coordinates of a location.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
location (:class:`str`): the location need to get coordinates.
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
:class:`dict`: coordinates information
|
214 |
+
* latitude (float): the latitude of the location.
|
215 |
+
* longitude (float): the longitude of the location.
|
216 |
+
"""
|
217 |
+
url = self.base_url + 'Locations'
|
218 |
+
params = {'query': location, 'key': self.key}
|
219 |
+
async with aiohttp.ClientSession() as session:
|
220 |
+
async with session.get(url, params=params) as resp:
|
221 |
+
data = await resp.json()
|
222 |
+
coordinates = data['resourceSets'][0]['resources'][0]['point'][
|
223 |
+
'coordinates']
|
224 |
+
return dict(latitude=coordinates[0], longitude=coordinates[1])
|
225 |
+
|
226 |
+
@tool_api(explode_return=True)
|
227 |
+
async def search_nearby(self,
|
228 |
+
search_term: str,
|
229 |
+
places: str = 'unknown',
|
230 |
+
latitude: float = 0.0,
|
231 |
+
longitude: float = 0.0,
|
232 |
+
radius: int = 5000) -> dict:
|
233 |
+
"""Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
search_term (:class:`str`): the place name.
|
237 |
+
places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
|
238 |
+
latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
|
239 |
+
longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
|
240 |
+
radius (:class:`int`): radius in meters. Defaults to ``5000``.
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
:class:`dict`: places information
|
244 |
+
* places (list): the list of places, each place is a dict with name and address, at most 5 places.
|
245 |
+
"""
|
246 |
+
url = self.base_url + 'LocalSearch'
|
247 |
+
if places != 'unknown':
|
248 |
+
pos = self.get_coordinates(**{'location': places})
|
249 |
+
latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
|
250 |
+
# Build the request query string
|
251 |
+
params = {
|
252 |
+
'query': search_term,
|
253 |
+
'userLocation': f'{latitude},{longitude}',
|
254 |
+
'radius': radius,
|
255 |
+
'key': self.key
|
256 |
+
}
|
257 |
+
async with aiohttp.ClientSession() as session:
|
258 |
+
async with session.get(url, params=params) as resp:
|
259 |
+
data = await resp.json()
|
260 |
+
results = data['resourceSets'][0]['resources']
|
261 |
+
addresses = []
|
262 |
+
for result in results:
|
263 |
+
name = result['name']
|
264 |
+
address = result['Address']['formattedAddress']
|
265 |
+
addresses.append(dict(name=name, address=address))
|
266 |
+
if len(addresses) == 5:
|
267 |
+
break
|
268 |
+
return dict(place=addresses)
|
lagent/actions/builtin_actions.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
from lagent.actions.base_action import BaseAction, tool_api
|
4 |
+
from lagent.actions.parser import BaseParser
|
5 |
+
from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
|
6 |
+
|
7 |
+
|
8 |
+
class InvalidAction(BaseAction):
|
9 |
+
"""This is a invalid action class, which is used to return error message
|
10 |
+
when the action is invalid.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
err_msg (str): The error message. Defaults to 'The action is invalid,
|
14 |
+
please check the action name'.
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
ActionReturn: The action return.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
err_msg:
|
22 |
+
str = 'The action is invalid, please check the action name.',
|
23 |
+
description: Optional[dict] = None,
|
24 |
+
parser=BaseParser) -> None:
|
25 |
+
super().__init__(description, parser)
|
26 |
+
self._err_msg = err_msg
|
27 |
+
|
28 |
+
@tool_api
|
29 |
+
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
30 |
+
"""Return the error message.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
err_msg (str, optional): The error message. If err_msg is not None,
|
34 |
+
it will be returned, otherwise the default error message will
|
35 |
+
be returned. Defaults to None.
|
36 |
+
"""
|
37 |
+
action_return = ActionReturn(
|
38 |
+
url=None,
|
39 |
+
args=dict(text=err_msg),
|
40 |
+
errmsg=err_msg or self._err_msg,
|
41 |
+
type=self.name,
|
42 |
+
valid=ActionValidCode.INVALID,
|
43 |
+
state=ActionStatusCode.API_ERROR)
|
44 |
+
return action_return
|
45 |
+
|
46 |
+
|
47 |
+
class NoAction(BaseAction):
|
48 |
+
"""This is a no action class, which is used to return error message when
|
49 |
+
the response does not follow the format.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
err_msg (str): The error message. Defaults to
|
53 |
+
'Please follow the format'.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self,
|
57 |
+
err_msg: str = 'Please follow the format',
|
58 |
+
description: Optional[dict] = None,
|
59 |
+
parser=BaseParser):
|
60 |
+
super().__init__(description, parser)
|
61 |
+
self._err_msg = err_msg
|
62 |
+
|
63 |
+
@tool_api
|
64 |
+
def run(self, err_msg: Optional[str] = None) -> ActionReturn:
|
65 |
+
"""Return the error message.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
err_msg (str, optional): The error message. If err_msg is not None,
|
69 |
+
it will be returned, otherwise the default error message will
|
70 |
+
be returned. Defaults to None.
|
71 |
+
|
72 |
+
Returns:
|
73 |
+
ActionReturn: The action return.
|
74 |
+
"""
|
75 |
+
action_return = ActionReturn(
|
76 |
+
url=None,
|
77 |
+
args=dict(text=err_msg),
|
78 |
+
type=self.name,
|
79 |
+
errmsg=err_msg or self._err_msg,
|
80 |
+
valid=ActionValidCode.INVALID,
|
81 |
+
state=ActionStatusCode.API_ERROR)
|
82 |
+
return action_return
|
83 |
+
|
84 |
+
|
85 |
+
class FinishAction(BaseAction):
|
86 |
+
"""This is a finish action class, which is used to return the final
|
87 |
+
result."""
|
88 |
+
|
89 |
+
def __init__(self, description: Optional[dict] = None, parser=BaseParser):
|
90 |
+
super().__init__(description, parser)
|
91 |
+
|
92 |
+
@tool_api
|
93 |
+
def run(self, response: str) -> ActionReturn:
|
94 |
+
"""Return the final result.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
response (str): The final result.
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
ActionReturn: The action return.
|
101 |
+
"""
|
102 |
+
action_return = ActionReturn(
|
103 |
+
url=None,
|
104 |
+
args=dict(text=response),
|
105 |
+
result=[dict(type='text', content=response)],
|
106 |
+
type=self.name,
|
107 |
+
valid=ActionValidCode.FINISH,
|
108 |
+
state=ActionStatusCode.SUCCESS)
|
109 |
+
return action_return
|
lagent/actions/google_scholar_search.py
ADDED
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: E501
|
2 |
+
import os
|
3 |
+
from typing import Optional, Type
|
4 |
+
|
5 |
+
from asyncer import asyncify
|
6 |
+
|
7 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
8 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
9 |
+
from .parser import BaseParser, JsonParser
|
10 |
+
|
11 |
+
|
12 |
+
class GoogleScholar(BaseAction):
|
13 |
+
"""Plugin for google scholar search.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
api_key (str): API KEY to use serper google search API,
|
17 |
+
You can create a free API key at https://serper.dev.
|
18 |
+
description (dict): The description of the action. Defaults to ``None``.
|
19 |
+
parser (Type[BaseParser]): The parser class to process the
|
20 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
api_key: Optional[str] = None,
|
26 |
+
description: Optional[dict] = None,
|
27 |
+
parser: Type[BaseParser] = JsonParser,
|
28 |
+
):
|
29 |
+
super().__init__(description, parser)
|
30 |
+
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
31 |
+
if api_key is None:
|
32 |
+
raise ValueError(
|
33 |
+
'Please set Serper API key either in the environment '
|
34 |
+
'as SERPER_API_KEY or pass it as `api_key` parameter.'
|
35 |
+
)
|
36 |
+
self.api_key = api_key
|
37 |
+
|
38 |
+
@tool_api(explode_return=True)
|
39 |
+
def search_google_scholar(
|
40 |
+
self,
|
41 |
+
query: str,
|
42 |
+
cites: Optional[str] = None,
|
43 |
+
as_ylo: Optional[int] = None,
|
44 |
+
as_yhi: Optional[int] = None,
|
45 |
+
scisbd: Optional[int] = None,
|
46 |
+
cluster: Optional[str] = None,
|
47 |
+
hl: Optional[str] = None,
|
48 |
+
lr: Optional[str] = None,
|
49 |
+
start: Optional[int] = None,
|
50 |
+
num: Optional[int] = None,
|
51 |
+
as_sdt: Optional[str] = None,
|
52 |
+
safe: Optional[str] = None,
|
53 |
+
filter: Optional[str] = None,
|
54 |
+
as_vis: Optional[str] = None,
|
55 |
+
) -> dict:
|
56 |
+
"""Search for scholarly articles based on a query according to the google scholar.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
query (str): The query to search for.
|
60 |
+
cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
|
61 |
+
as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
|
62 |
+
as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
|
63 |
+
scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
|
64 |
+
cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
|
65 |
+
hl (Optional[str]): The language to use for the Google Scholar search.
|
66 |
+
lr (Optional[str]): One or multiple languages to limit the search to.
|
67 |
+
start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
|
68 |
+
num (Optional[int]): The maximum number of results to return, limited to 20.
|
69 |
+
as_sdt (Optional[str]): Can be used either as a search type or a filter.
|
70 |
+
safe (Optional[str]): The level of filtering for adult content.
|
71 |
+
filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
|
72 |
+
as_vis (Optional[str]): Defines whether to include citations or not.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
:class:`dict`: article information
|
76 |
+
- title: a list of the titles of the three selected papers
|
77 |
+
- cited_by: a list of the citation numbers of the three selected papers
|
78 |
+
- organic_id: a list of the organic results' ids of the three selected papers
|
79 |
+
- pub_info: publication information of selected papers
|
80 |
+
"""
|
81 |
+
from serpapi import GoogleSearch
|
82 |
+
|
83 |
+
params = {
|
84 |
+
'q': query,
|
85 |
+
'engine': 'google_scholar',
|
86 |
+
'api_key': self.api_key,
|
87 |
+
'cites': cites,
|
88 |
+
'as_ylo': as_ylo,
|
89 |
+
'as_yhi': as_yhi,
|
90 |
+
'scisbd': scisbd,
|
91 |
+
'cluster': cluster,
|
92 |
+
'hl': hl,
|
93 |
+
'lr': lr,
|
94 |
+
'start': start,
|
95 |
+
'num': num,
|
96 |
+
'as_sdt': as_sdt,
|
97 |
+
'safe': safe,
|
98 |
+
'filter': filter,
|
99 |
+
'as_vis': as_vis,
|
100 |
+
}
|
101 |
+
search = GoogleSearch(params)
|
102 |
+
try:
|
103 |
+
r = search.get_dict()
|
104 |
+
results = r['organic_results']
|
105 |
+
title = []
|
106 |
+
snippets = []
|
107 |
+
cited_by = []
|
108 |
+
organic_id = []
|
109 |
+
pub_info = []
|
110 |
+
for item in results[:3]:
|
111 |
+
title.append(item['title'])
|
112 |
+
pub_info.append(item['publication_info']['summary'])
|
113 |
+
citation = item['inline_links'].get('cited_by', {'total': ''})
|
114 |
+
cited_by.append(citation['total'])
|
115 |
+
snippets.append(item['snippet'])
|
116 |
+
organic_id.append(item['result_id'])
|
117 |
+
return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
|
118 |
+
except Exception as e:
|
119 |
+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
120 |
+
|
121 |
+
@tool_api(explode_return=True)
|
122 |
+
def get_author_information(
|
123 |
+
self,
|
124 |
+
author_id: str,
|
125 |
+
hl: Optional[str] = None,
|
126 |
+
view_op: Optional[str] = None,
|
127 |
+
sort: Optional[str] = None,
|
128 |
+
citation_id: Optional[str] = None,
|
129 |
+
start: Optional[int] = None,
|
130 |
+
num: Optional[int] = None,
|
131 |
+
no_cache: Optional[bool] = None,
|
132 |
+
async_req: Optional[bool] = None,
|
133 |
+
output: Optional[str] = None,
|
134 |
+
) -> dict:
|
135 |
+
"""Search for an author's information by author's id provided by get_author_id.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
author_id (str): Required. The ID of an author.
|
139 |
+
hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
|
140 |
+
view_op (Optional[str]): Used for viewing specific parts of a page.
|
141 |
+
sort (Optional[str]): Used for sorting and refining articles.
|
142 |
+
citation_id (Optional[str]): Used for retrieving individual article citation.
|
143 |
+
start (Optional[int]): Defines the result offset. Default is 0.
|
144 |
+
num (Optional[int]): Defines the number of results to return. Default is 20.
|
145 |
+
no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
|
146 |
+
async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
|
147 |
+
output (Optional[str]): Defines the final output you want. Default is 'json'.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
:class:`dict`: author information
|
151 |
+
* name: author's name
|
152 |
+
* affliation: the affliation of the author
|
153 |
+
* articles: at most 3 articles by the author
|
154 |
+
* website: the author's homepage url
|
155 |
+
"""
|
156 |
+
from serpapi import GoogleSearch
|
157 |
+
|
158 |
+
params = {
|
159 |
+
'engine': 'google_scholar_author',
|
160 |
+
'author_id': author_id,
|
161 |
+
'api_key': self.api_key,
|
162 |
+
'hl': hl,
|
163 |
+
'view_op': view_op,
|
164 |
+
'sort': sort,
|
165 |
+
'citation_id': citation_id,
|
166 |
+
'start': start,
|
167 |
+
'num': num,
|
168 |
+
'no_cache': no_cache,
|
169 |
+
'async': async_req,
|
170 |
+
'output': output,
|
171 |
+
}
|
172 |
+
try:
|
173 |
+
search = GoogleSearch(params)
|
174 |
+
results = search.get_dict()
|
175 |
+
author = results['author']
|
176 |
+
articles = results.get('articles', [])
|
177 |
+
return dict(
|
178 |
+
name=author['name'],
|
179 |
+
affiliations=author.get('affiliations', ''),
|
180 |
+
website=author.get('website', ''),
|
181 |
+
articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
|
182 |
+
)
|
183 |
+
except Exception as e:
|
184 |
+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
185 |
+
|
186 |
+
@tool_api(explode_return=True)
|
187 |
+
def get_citation_format(
|
188 |
+
self,
|
189 |
+
q: str,
|
190 |
+
no_cache: Optional[bool] = None,
|
191 |
+
async_: Optional[bool] = None,
|
192 |
+
output: Optional[str] = 'json',
|
193 |
+
) -> dict:
|
194 |
+
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
q (str): ID of an individual Google Scholar organic search result.
|
198 |
+
no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
|
199 |
+
async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
|
200 |
+
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
:class:`dict`: citation format
|
204 |
+
* authors: the authors of the article
|
205 |
+
* citation: the citation format of the article
|
206 |
+
"""
|
207 |
+
from serpapi import GoogleSearch
|
208 |
+
|
209 |
+
params = {
|
210 |
+
'q': q,
|
211 |
+
'engine': 'google_scholar_cite',
|
212 |
+
'api_key': self.api_key,
|
213 |
+
'no_cache': no_cache,
|
214 |
+
'async': async_,
|
215 |
+
'output': output,
|
216 |
+
}
|
217 |
+
try:
|
218 |
+
search = GoogleSearch(params)
|
219 |
+
results = search.get_dict()
|
220 |
+
citation = results['citations']
|
221 |
+
citation_info = citation[0]['snippet']
|
222 |
+
return citation_info
|
223 |
+
except Exception as e:
|
224 |
+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
225 |
+
|
226 |
+
@tool_api(explode_return=True)
|
227 |
+
def get_author_id(
|
228 |
+
self,
|
229 |
+
mauthors: str,
|
230 |
+
hl: Optional[str] = 'en',
|
231 |
+
after_author: Optional[str] = None,
|
232 |
+
before_author: Optional[str] = None,
|
233 |
+
no_cache: Optional[bool] = False,
|
234 |
+
_async: Optional[bool] = False,
|
235 |
+
output: Optional[str] = 'json',
|
236 |
+
) -> dict:
|
237 |
+
"""The getAuthorId function is used to get the author's id by his or her name.
|
238 |
+
|
239 |
+
Args:
|
240 |
+
mauthors (str): Defines the author you want to search for.
|
241 |
+
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
|
242 |
+
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
|
243 |
+
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
|
244 |
+
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
|
245 |
+
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
|
246 |
+
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
:class:`dict`: author id
|
250 |
+
* author_id: the author_id of the author
|
251 |
+
"""
|
252 |
+
from serpapi import GoogleSearch
|
253 |
+
|
254 |
+
params = {
|
255 |
+
'mauthors': mauthors,
|
256 |
+
'engine': 'google_scholar_profiles',
|
257 |
+
'api_key': self.api_key,
|
258 |
+
'hl': hl,
|
259 |
+
'after_author': after_author,
|
260 |
+
'before_author': before_author,
|
261 |
+
'no_cache': no_cache,
|
262 |
+
'async': _async,
|
263 |
+
'output': output,
|
264 |
+
}
|
265 |
+
try:
|
266 |
+
search = GoogleSearch(params)
|
267 |
+
results = search.get_dict()
|
268 |
+
profile = results['profiles']
|
269 |
+
author_info = dict(author_id=profile[0]['author_id'])
|
270 |
+
return author_info
|
271 |
+
except Exception as e:
|
272 |
+
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
|
273 |
+
|
274 |
+
|
275 |
+
class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
|
276 |
+
"""Plugin for google scholar search.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
api_key (str): API KEY to use serper google search API,
|
280 |
+
You can create a free API key at https://serper.dev.
|
281 |
+
description (dict): The description of the action. Defaults to ``None``.
|
282 |
+
parser (Type[BaseParser]): The parser class to process the
|
283 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
284 |
+
"""
|
285 |
+
|
286 |
+
@tool_api(explode_return=True)
|
287 |
+
@asyncify
|
288 |
+
def search_google_scholar(
|
289 |
+
self,
|
290 |
+
query: str,
|
291 |
+
cites: Optional[str] = None,
|
292 |
+
as_ylo: Optional[int] = None,
|
293 |
+
as_yhi: Optional[int] = None,
|
294 |
+
scisbd: Optional[int] = None,
|
295 |
+
cluster: Optional[str] = None,
|
296 |
+
hl: Optional[str] = None,
|
297 |
+
lr: Optional[str] = None,
|
298 |
+
start: Optional[int] = None,
|
299 |
+
num: Optional[int] = None,
|
300 |
+
as_sdt: Optional[str] = None,
|
301 |
+
safe: Optional[str] = None,
|
302 |
+
filter: Optional[str] = None,
|
303 |
+
as_vis: Optional[str] = None,
|
304 |
+
) -> dict:
|
305 |
+
"""Search for scholarly articles based on a query according to the google scholar.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
query (str): The query to search for.
|
309 |
+
cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
|
310 |
+
as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
|
311 |
+
as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
|
312 |
+
scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
|
313 |
+
cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
|
314 |
+
hl (Optional[str]): The language to use for the Google Scholar search.
|
315 |
+
lr (Optional[str]): One or multiple languages to limit the search to.
|
316 |
+
start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
|
317 |
+
num (Optional[int]): The maximum number of results to return, limited to 20.
|
318 |
+
as_sdt (Optional[str]): Can be used either as a search type or a filter.
|
319 |
+
safe (Optional[str]): The level of filtering for adult content.
|
320 |
+
filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
|
321 |
+
as_vis (Optional[str]): Defines whether to include citations or not.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
:class:`dict`: article information
|
325 |
+
- title: a list of the titles of the three selected papers
|
326 |
+
- cited_by: a list of the citation numbers of the three selected papers
|
327 |
+
- organic_id: a list of the organic results' ids of the three selected papers
|
328 |
+
- pub_info: publication information of selected papers
|
329 |
+
"""
|
330 |
+
return super().search_google_scholar(
|
331 |
+
query,
|
332 |
+
cites,
|
333 |
+
as_ylo,
|
334 |
+
as_yhi,
|
335 |
+
scisbd,
|
336 |
+
cluster,
|
337 |
+
hl,
|
338 |
+
lr,
|
339 |
+
start,
|
340 |
+
num,
|
341 |
+
as_sdt,
|
342 |
+
safe,
|
343 |
+
filter,
|
344 |
+
as_vis,
|
345 |
+
)
|
346 |
+
|
347 |
+
@tool_api(explode_return=True)
|
348 |
+
@asyncify
|
349 |
+
def get_author_information(
|
350 |
+
self,
|
351 |
+
author_id: str,
|
352 |
+
hl: Optional[str] = None,
|
353 |
+
view_op: Optional[str] = None,
|
354 |
+
sort: Optional[str] = None,
|
355 |
+
citation_id: Optional[str] = None,
|
356 |
+
start: Optional[int] = None,
|
357 |
+
num: Optional[int] = None,
|
358 |
+
no_cache: Optional[bool] = None,
|
359 |
+
async_req: Optional[bool] = None,
|
360 |
+
output: Optional[str] = None,
|
361 |
+
) -> dict:
|
362 |
+
"""Search for an author's information by author's id provided by get_author_id.
|
363 |
+
|
364 |
+
Args:
|
365 |
+
author_id (str): Required. The ID of an author.
|
366 |
+
hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
|
367 |
+
view_op (Optional[str]): Used for viewing specific parts of a page.
|
368 |
+
sort (Optional[str]): Used for sorting and refining articles.
|
369 |
+
citation_id (Optional[str]): Used for retrieving individual article citation.
|
370 |
+
start (Optional[int]): Defines the result offset. Default is 0.
|
371 |
+
num (Optional[int]): Defines the number of results to return. Default is 20.
|
372 |
+
no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
|
373 |
+
async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
|
374 |
+
output (Optional[str]): Defines the final output you want. Default is 'json'.
|
375 |
+
|
376 |
+
Returns:
|
377 |
+
:class:`dict`: author information
|
378 |
+
* name: author's name
|
379 |
+
* affliation: the affliation of the author
|
380 |
+
* articles: at most 3 articles by the author
|
381 |
+
* website: the author's homepage url
|
382 |
+
"""
|
383 |
+
return super().get_author_information(
|
384 |
+
author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
|
385 |
+
)
|
386 |
+
|
387 |
+
@tool_api(explode_return=True)
|
388 |
+
@asyncify
|
389 |
+
def get_citation_format(
|
390 |
+
self,
|
391 |
+
q: str,
|
392 |
+
no_cache: Optional[bool] = None,
|
393 |
+
async_: Optional[bool] = None,
|
394 |
+
output: Optional[str] = 'json',
|
395 |
+
) -> dict:
|
396 |
+
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
|
397 |
+
|
398 |
+
Args:
|
399 |
+
q (str): ID of an individual Google Scholar organic search result.
|
400 |
+
no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
|
401 |
+
async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
|
402 |
+
output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
:class:`dict`: citation format
|
406 |
+
* authors: the authors of the article
|
407 |
+
* citation: the citation format of the article
|
408 |
+
"""
|
409 |
+
return super().get_citation_format(q, no_cache, async_, output)
|
410 |
+
|
411 |
+
@tool_api(explode_return=True)
|
412 |
+
@asyncify
|
413 |
+
def get_author_id(
|
414 |
+
self,
|
415 |
+
mauthors: str,
|
416 |
+
hl: Optional[str] = 'en',
|
417 |
+
after_author: Optional[str] = None,
|
418 |
+
before_author: Optional[str] = None,
|
419 |
+
no_cache: Optional[bool] = False,
|
420 |
+
_async: Optional[bool] = False,
|
421 |
+
output: Optional[str] = 'json',
|
422 |
+
) -> dict:
|
423 |
+
"""The getAuthorId function is used to get the author's id by his or her name.
|
424 |
+
|
425 |
+
Args:
|
426 |
+
mauthors (str): Defines the author you want to search for.
|
427 |
+
hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
|
428 |
+
after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
|
429 |
+
before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
|
430 |
+
no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
|
431 |
+
_async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
|
432 |
+
output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
|
433 |
+
|
434 |
+
Returns:
|
435 |
+
:class:`dict`: author id
|
436 |
+
* author_id: the author_id of the author
|
437 |
+
"""
|
438 |
+
return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
|
lagent/actions/google_search.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import List, Optional, Tuple, Type, Union
|
3 |
+
|
4 |
+
import aiohttp
|
5 |
+
import requests
|
6 |
+
|
7 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
8 |
+
from .base_action import AsyncActionMixin, BaseAction, tool_api
|
9 |
+
from .parser import BaseParser, JsonParser
|
10 |
+
|
11 |
+
|
12 |
+
class GoogleSearch(BaseAction):
|
13 |
+
"""Wrapper around the Serper.dev Google Search API.
|
14 |
+
|
15 |
+
To use, you should pass your serper API key to the constructor.
|
16 |
+
|
17 |
+
Code is modified from lang-chain GoogleSerperAPIWrapper
|
18 |
+
(https://github.com/langchain-ai/langchain/blob/ba5f
|
19 |
+
baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
|
20 |
+
langchain/utilities/google_serper.py)
|
21 |
+
|
22 |
+
Args:
|
23 |
+
api_key (str): API KEY to use serper google search API,
|
24 |
+
You can create a free API key at https://serper.dev.
|
25 |
+
timeout (int): Upper bound of waiting time for a serper request.
|
26 |
+
search_type (str): Serper API support ['search', 'images', 'news',
|
27 |
+
'places'] types of search, currently we only support 'search'.
|
28 |
+
description (dict): The description of the action. Defaults to ``None``.
|
29 |
+
parser (Type[BaseParser]): The parser class to process the
|
30 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
31 |
+
"""
|
32 |
+
result_key_for_type = {
|
33 |
+
'news': 'news',
|
34 |
+
'places': 'places',
|
35 |
+
'images': 'images',
|
36 |
+
'search': 'organic',
|
37 |
+
}
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
api_key: Optional[str] = None,
|
42 |
+
timeout: int = 5,
|
43 |
+
search_type: str = 'search',
|
44 |
+
description: Optional[dict] = None,
|
45 |
+
parser: Type[BaseParser] = JsonParser,
|
46 |
+
):
|
47 |
+
super().__init__(description, parser)
|
48 |
+
api_key = os.environ.get('SERPER_API_KEY', api_key)
|
49 |
+
if api_key is None:
|
50 |
+
raise ValueError(
|
51 |
+
'Please set Serper API key either in the environment '
|
52 |
+
'as SERPER_API_KEY or pass it as `api_key` parameter.')
|
53 |
+
self.api_key = api_key
|
54 |
+
self.timeout = timeout
|
55 |
+
self.search_type = search_type
|
56 |
+
|
57 |
+
@tool_api
|
58 |
+
def run(self, query: str, k: int = 10) -> ActionReturn:
|
59 |
+
"""一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
|
60 |
+
|
61 |
+
Args:
|
62 |
+
query (str): the search content
|
63 |
+
k (int): select first k results in the search results as response
|
64 |
+
"""
|
65 |
+
tool_return = ActionReturn(type=self.name)
|
66 |
+
status_code, response = self._search(query, k=k)
|
67 |
+
# convert search results to ToolReturn format
|
68 |
+
if status_code == -1:
|
69 |
+
tool_return.errmsg = response
|
70 |
+
tool_return.state = ActionStatusCode.HTTP_ERROR
|
71 |
+
elif status_code == 200:
|
72 |
+
parsed_res = self._parse_results(response, k)
|
73 |
+
tool_return.result = [dict(type='text', content=str(parsed_res))]
|
74 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
75 |
+
else:
|
76 |
+
tool_return.errmsg = str(status_code)
|
77 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
78 |
+
return tool_return
|
79 |
+
|
80 |
+
def _parse_results(self, results: dict, k: int) -> Union[str, List[str]]:
|
81 |
+
"""Parse the search results from Serper API.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
results (dict): The search content from Serper API
|
85 |
+
in json format.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
List[str]: The parsed search results.
|
89 |
+
"""
|
90 |
+
|
91 |
+
snippets = []
|
92 |
+
|
93 |
+
if results.get('answerBox'):
|
94 |
+
answer_box = results.get('answerBox', {})
|
95 |
+
if answer_box.get('answer'):
|
96 |
+
return [answer_box.get('answer')]
|
97 |
+
elif answer_box.get('snippet'):
|
98 |
+
return [answer_box.get('snippet').replace('\n', ' ')]
|
99 |
+
elif answer_box.get('snippetHighlighted'):
|
100 |
+
return answer_box.get('snippetHighlighted')
|
101 |
+
|
102 |
+
if results.get('knowledgeGraph'):
|
103 |
+
kg = results.get('knowledgeGraph', {})
|
104 |
+
title = kg.get('title')
|
105 |
+
entity_type = kg.get('type')
|
106 |
+
if entity_type:
|
107 |
+
snippets.append(f'{title}: {entity_type}.')
|
108 |
+
description = kg.get('description')
|
109 |
+
if description:
|
110 |
+
snippets.append(description)
|
111 |
+
for attribute, value in kg.get('attributes', {}).items():
|
112 |
+
snippets.append(f'{title} {attribute}: {value}.')
|
113 |
+
|
114 |
+
for result in results[self.result_key_for_type[
|
115 |
+
self.search_type]][:k]:
|
116 |
+
if 'snippet' in result:
|
117 |
+
snippets.append(result['snippet'])
|
118 |
+
for attribute, value in result.get('attributes', {}).items():
|
119 |
+
snippets.append(f'{attribute}: {value}.')
|
120 |
+
|
121 |
+
if len(snippets) == 0:
|
122 |
+
return ['No good Google Search Result was found']
|
123 |
+
return snippets
|
124 |
+
|
125 |
+
def _search(self,
|
126 |
+
search_term: str,
|
127 |
+
search_type: Optional[str] = None,
|
128 |
+
**kwargs) -> Tuple[int, Union[dict, str]]:
|
129 |
+
"""HTTP requests to Serper API.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
search_term (str): The search query.
|
133 |
+
search_type (str): search type supported by Serper API,
|
134 |
+
default to 'search'.
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
tuple: the return value is a tuple contains:
|
138 |
+
- status_code (int): HTTP status code from Serper API.
|
139 |
+
- response (dict): response context with json format.
|
140 |
+
"""
|
141 |
+
headers = {
|
142 |
+
'X-API-KEY': self.api_key or '',
|
143 |
+
'Content-Type': 'application/json',
|
144 |
+
}
|
145 |
+
params = {
|
146 |
+
'q': search_term,
|
147 |
+
**{
|
148 |
+
key: value
|
149 |
+
for key, value in kwargs.items() if value is not None
|
150 |
+
},
|
151 |
+
}
|
152 |
+
try:
|
153 |
+
response = requests.post(
|
154 |
+
f'https://google.serper.dev/{search_type or self.search_type}',
|
155 |
+
headers=headers,
|
156 |
+
params=params,
|
157 |
+
timeout=self.timeout)
|
158 |
+
except Exception as e:
|
159 |
+
return -1, str(e)
|
160 |
+
return response.status_code, response.json()
|
161 |
+
|
162 |
+
|
163 |
+
class AsyncGoogleSearch(AsyncActionMixin, GoogleSearch):
|
164 |
+
"""Wrapper around the Serper.dev Google Search API.
|
165 |
+
|
166 |
+
To use, you should pass your serper API key to the constructor.
|
167 |
+
|
168 |
+
Code is modified from lang-chain GoogleSerperAPIWrapper
|
169 |
+
(https://github.com/langchain-ai/langchain/blob/ba5f
|
170 |
+
baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
|
171 |
+
langchain/utilities/google_serper.py)
|
172 |
+
|
173 |
+
Args:
|
174 |
+
api_key (str): API KEY to use serper google search API,
|
175 |
+
You can create a free API key at https://serper.dev.
|
176 |
+
timeout (int): Upper bound of waiting time for a serper request.
|
177 |
+
search_type (str): Serper API support ['search', 'images', 'news',
|
178 |
+
'places'] types of search, currently we only support 'search'.
|
179 |
+
description (dict): The description of the action. Defaults to ``None``.
|
180 |
+
parser (Type[BaseParser]): The parser class to process the
|
181 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
182 |
+
"""
|
183 |
+
|
184 |
+
@tool_api
|
185 |
+
async def run(self, query: str, k: int = 10) -> ActionReturn:
|
186 |
+
"""一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
|
187 |
+
|
188 |
+
Args:
|
189 |
+
query (str): the search content
|
190 |
+
k (int): select first k results in the search results as response
|
191 |
+
"""
|
192 |
+
tool_return = ActionReturn(type=self.name)
|
193 |
+
status_code, response = await self._search(query, k=k)
|
194 |
+
# convert search results to ToolReturn format
|
195 |
+
if status_code == -1:
|
196 |
+
tool_return.errmsg = response
|
197 |
+
tool_return.state = ActionStatusCode.HTTP_ERROR
|
198 |
+
elif status_code == 200:
|
199 |
+
parsed_res = self._parse_results(response)
|
200 |
+
tool_return.result = [dict(type='text', content=str(parsed_res))]
|
201 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
202 |
+
else:
|
203 |
+
tool_return.errmsg = str(status_code)
|
204 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
205 |
+
return tool_return
|
206 |
+
|
207 |
+
async def _search(self,
|
208 |
+
search_term: str,
|
209 |
+
search_type: Optional[str] = None,
|
210 |
+
**kwargs) -> Tuple[int, Union[dict, str]]:
|
211 |
+
"""HTTP requests to Serper API.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
search_term (str): The search query.
|
215 |
+
search_type (str): search type supported by Serper API,
|
216 |
+
default to 'search'.
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
tuple: the return value is a tuple contains:
|
220 |
+
- status_code (int): HTTP status code from Serper API.
|
221 |
+
- response (dict): response context with json format.
|
222 |
+
"""
|
223 |
+
headers = {
|
224 |
+
'X-API-KEY': self.api_key or '',
|
225 |
+
'Content-Type': 'application/json',
|
226 |
+
}
|
227 |
+
params = {
|
228 |
+
'q': search_term,
|
229 |
+
**{
|
230 |
+
key: value
|
231 |
+
for key, value in kwargs.items() if value is not None
|
232 |
+
},
|
233 |
+
}
|
234 |
+
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
235 |
+
async with aiohttp.ClientSession(timeout=timeout) as session:
|
236 |
+
try:
|
237 |
+
async with session.post(
|
238 |
+
f'https://google.serper.dev/{search_type or self.search_type}',
|
239 |
+
headers=headers,
|
240 |
+
params=params) as resp:
|
241 |
+
code, ret = resp.status, await resp.json()
|
242 |
+
except aiohttp.ClientError as e:
|
243 |
+
code, ret = -1, str(e)
|
244 |
+
return code, ret
|
lagent/actions/ipython_interactive.py
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import signal
|
3 |
+
from contextlib import contextmanager, redirect_stdout
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from enum import Enum
|
6 |
+
from io import StringIO
|
7 |
+
from typing import Optional, Type
|
8 |
+
|
9 |
+
from ..schema import ActionReturn, ActionStatusCode
|
10 |
+
from .base_action import AsyncActionMixin, BaseAction, tool_api
|
11 |
+
from .parser import BaseParser, JsonParser
|
12 |
+
|
13 |
+
|
14 |
+
class Status(str, Enum):
|
15 |
+
"""Execution status."""
|
16 |
+
SUCCESS = 'success'
|
17 |
+
FAILURE = 'failure'
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ExecutionResult:
|
22 |
+
"""Execution result."""
|
23 |
+
status: Status
|
24 |
+
value: Optional[str] = None
|
25 |
+
msg: Optional[str] = None
|
26 |
+
|
27 |
+
|
28 |
+
@contextmanager
|
29 |
+
def _raise_timeout(timeout):
|
30 |
+
|
31 |
+
def _handler(signum, frame):
|
32 |
+
raise TimeoutError()
|
33 |
+
|
34 |
+
signal.signal(signal.SIGALRM, _handler)
|
35 |
+
signal.alarm(timeout)
|
36 |
+
|
37 |
+
try:
|
38 |
+
yield
|
39 |
+
finally:
|
40 |
+
signal.alarm(0)
|
41 |
+
|
42 |
+
|
43 |
+
class IPythonInteractive(BaseAction):
|
44 |
+
"""An interactive IPython shell for code execution.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
timeout (int): Upper bound of waiting time for Python script execution.
|
48 |
+
Defaults to ``20``.
|
49 |
+
max_out_len (int): maximum output length. No truncation occurs if negative.
|
50 |
+
Defaults to ``2048``.
|
51 |
+
use_signals (bool): whether signals should be used for timing function out
|
52 |
+
or the multiprocessing. Set to ``False`` when not running in the main
|
53 |
+
thread, e.g. web applications. Defaults to ``True``
|
54 |
+
description (dict): The description of the action. Defaults to ``None``.
|
55 |
+
parser (Type[BaseParser]): The parser class to process the
|
56 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
timeout: int = 30,
|
62 |
+
max_out_len: int = 8192,
|
63 |
+
use_signals: bool = True,
|
64 |
+
description: Optional[dict] = None,
|
65 |
+
parser: Type[BaseParser] = JsonParser,
|
66 |
+
):
|
67 |
+
super().__init__(description, parser)
|
68 |
+
self.timeout = timeout
|
69 |
+
self._executor = self.create_shell()
|
70 |
+
self._highlighting = re.compile(
|
71 |
+
r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
72 |
+
self._max_out_len = max_out_len if max_out_len >= 0 else None
|
73 |
+
self._use_signals = use_signals
|
74 |
+
|
75 |
+
def reset(self):
|
76 |
+
"""Clear the context."""
|
77 |
+
self._executor.reset()
|
78 |
+
|
79 |
+
@tool_api
|
80 |
+
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
|
81 |
+
"""Launch an IPython Interactive Shell to execute code.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
command (:class:`str`): Python code snippet
|
85 |
+
timeout (:class:`Optional[int]`): timeout for execution.
|
86 |
+
This argument only works in the main thread. Defaults to ``None``.
|
87 |
+
"""
|
88 |
+
from timeout_decorator import timeout as timer
|
89 |
+
tool_return = ActionReturn(args={'text': command}, type=self.name)
|
90 |
+
ret = (
|
91 |
+
timer(timeout or self.timeout)(self.exec)(command)
|
92 |
+
if self._use_signals else self.exec(command))
|
93 |
+
if ret.status is Status.SUCCESS:
|
94 |
+
tool_return.result = [{'type': 'text', 'content': ret.value}]
|
95 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
96 |
+
else:
|
97 |
+
tool_return.errmsg = ret.msg
|
98 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
99 |
+
return tool_return
|
100 |
+
|
101 |
+
def exec(self, code: str) -> ExecutionResult:
|
102 |
+
"""Run Python scripts in IPython shell.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
code (:class:`str`): code block
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
:py:class:`ExecutionResult`: execution result
|
109 |
+
"""
|
110 |
+
with StringIO() as io:
|
111 |
+
with redirect_stdout(io):
|
112 |
+
ret = self._executor.run_cell(self.extract_code(code))
|
113 |
+
result = ret.result
|
114 |
+
if result is not None:
|
115 |
+
return ExecutionResult(Status.SUCCESS,
|
116 |
+
str(result)[:self._max_out_len])
|
117 |
+
outs = io.getvalue().strip().split('\n')
|
118 |
+
if not outs:
|
119 |
+
return ExecutionResult(Status.SUCCESS, '')
|
120 |
+
for i, out in enumerate(outs):
|
121 |
+
if re.search('Error|Traceback', out, re.S):
|
122 |
+
if 'TimeoutError' in out:
|
123 |
+
return ExecutionResult(
|
124 |
+
Status.FAILURE,
|
125 |
+
msg=('The code interpreter encountered '
|
126 |
+
'a timeout error.'))
|
127 |
+
err_idx = i
|
128 |
+
break
|
129 |
+
else:
|
130 |
+
return ExecutionResult(Status.SUCCESS,
|
131 |
+
outs[-1].strip()[:self._max_out_len])
|
132 |
+
return ExecutionResult(
|
133 |
+
Status.FAILURE,
|
134 |
+
msg=self._highlighting.sub(
|
135 |
+
'', '\n'.join(outs[err_idx:])[:self._max_out_len]),
|
136 |
+
)
|
137 |
+
|
138 |
+
@staticmethod
|
139 |
+
def create_shell():
|
140 |
+
from IPython import InteractiveShell
|
141 |
+
from traitlets.config import Config
|
142 |
+
|
143 |
+
c = Config()
|
144 |
+
c.HistoryManager.enabled = False
|
145 |
+
c.HistoryManager.hist_file = ':memory:'
|
146 |
+
return InteractiveShell(
|
147 |
+
user_ns={'_raise_timeout': _raise_timeout}, config=c)
|
148 |
+
|
149 |
+
@staticmethod
|
150 |
+
def extract_code(text: str) -> str:
|
151 |
+
"""Extract Python code from markup languages.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
text (:class:`str`): Markdown-formatted text
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
:class:`str`: Python code
|
158 |
+
"""
|
159 |
+
import json5
|
160 |
+
|
161 |
+
# Match triple backtick blocks first
|
162 |
+
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
163 |
+
# Match single backtick blocks second
|
164 |
+
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
165 |
+
if triple_match:
|
166 |
+
text = triple_match.group(1)
|
167 |
+
elif single_match:
|
168 |
+
text = single_match.group(1)
|
169 |
+
else:
|
170 |
+
try:
|
171 |
+
text = json5.loads(text)['code']
|
172 |
+
except Exception:
|
173 |
+
pass
|
174 |
+
# If no code blocks found, return original text
|
175 |
+
return text
|
176 |
+
|
177 |
+
@staticmethod
|
178 |
+
def wrap_code_with_timeout(code: str, timeout: int) -> str:
|
179 |
+
if not code.strip():
|
180 |
+
return code
|
181 |
+
code = code.strip('\n').rstrip()
|
182 |
+
indent = len(code) - len(code.lstrip())
|
183 |
+
handle = ' ' * indent + f'with _raise_timeout({timeout}):\n'
|
184 |
+
block = '\n'.join([' ' + line for line in code.split('\n')])
|
185 |
+
wrapped_code = handle + block
|
186 |
+
last_line = code.split('\n')[-1]
|
187 |
+
is_expression = True
|
188 |
+
try:
|
189 |
+
compile(last_line.lstrip(), '<stdin>', 'eval')
|
190 |
+
except SyntaxError:
|
191 |
+
is_expression = False
|
192 |
+
if is_expression:
|
193 |
+
wrapped_code += '\n' * 5 + last_line
|
194 |
+
return wrapped_code
|
195 |
+
|
196 |
+
|
197 |
+
class AsyncIPythonInteractive(AsyncActionMixin, IPythonInteractive):
|
198 |
+
"""An interactive IPython shell for code execution.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
timeout (int): Upper bound of waiting time for Python script execution.
|
202 |
+
Defaults to ``20``.
|
203 |
+
max_out_len (int): maximum output length. No truncation occurs if negative.
|
204 |
+
Defaults to ``2048``.
|
205 |
+
use_signals (bool): whether signals should be used for timing function out
|
206 |
+
or the multiprocessing. Set to ``False`` when not running in the main
|
207 |
+
thread, e.g. web applications. Defaults to ``True``
|
208 |
+
description (dict): The description of the action. Defaults to ``None``.
|
209 |
+
parser (Type[BaseParser]): The parser class to process the
|
210 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
211 |
+
"""
|
212 |
+
|
213 |
+
@tool_api
|
214 |
+
async def run(self,
|
215 |
+
command: str,
|
216 |
+
timeout: Optional[int] = None) -> ActionReturn:
|
217 |
+
"""Launch an IPython Interactive Shell to execute code.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
command (:class:`str`): Python code snippet
|
221 |
+
timeout (:class:`Optional[int]`): timeout for execution.
|
222 |
+
This argument only works in the main thread. Defaults to ``None``.
|
223 |
+
"""
|
224 |
+
tool_return = ActionReturn(args={'text': command}, type=self.name)
|
225 |
+
ret = await self.exec(command, timeout)
|
226 |
+
if ret.status is Status.SUCCESS:
|
227 |
+
tool_return.result = [{'type': 'text', 'content': ret.value}]
|
228 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
229 |
+
else:
|
230 |
+
tool_return.errmsg = ret.msg
|
231 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
232 |
+
return tool_return
|
233 |
+
|
234 |
+
async def exec(self, code: str, timeout: int = None) -> ExecutionResult:
|
235 |
+
"""Asynchronously run Python scripts in IPython shell.
|
236 |
+
|
237 |
+
Args:
|
238 |
+
code (:class:`str`): code block
|
239 |
+
timeout (:class:`int`): max waiting time for code execution
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
:py:class:`ExecutionResult`: execution result
|
243 |
+
"""
|
244 |
+
with StringIO() as io:
|
245 |
+
with redirect_stdout(io):
|
246 |
+
ret = await self._executor.run_cell_async(
|
247 |
+
# ret = await self.create_shell().run_cell_async(
|
248 |
+
self.wrap_code_with_timeout(
|
249 |
+
self.extract_code(code), timeout or self.timeout))
|
250 |
+
result = ret.result
|
251 |
+
if result is not None:
|
252 |
+
return ExecutionResult(Status.SUCCESS,
|
253 |
+
str(result)[:self._max_out_len])
|
254 |
+
outs = io.getvalue().strip().split('\n')
|
255 |
+
if not outs:
|
256 |
+
return ExecutionResult(Status.SUCCESS, '')
|
257 |
+
for i, out in enumerate(outs):
|
258 |
+
if re.search('Error|Traceback', out, re.S):
|
259 |
+
if 'TimeoutError' in out:
|
260 |
+
return ExecutionResult(
|
261 |
+
Status.FAILURE,
|
262 |
+
msg=('The code interpreter encountered a '
|
263 |
+
'timeout error.'))
|
264 |
+
err_idx = i
|
265 |
+
break
|
266 |
+
else:
|
267 |
+
return ExecutionResult(Status.SUCCESS,
|
268 |
+
outs[-1].strip()[:self._max_out_len])
|
269 |
+
return ExecutionResult(
|
270 |
+
Status.FAILURE,
|
271 |
+
msg=self._highlighting.sub(
|
272 |
+
'', '\n'.join(outs[err_idx:])[:self._max_out_len]),
|
273 |
+
)
|
lagent/actions/ipython_interpreter.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: E501
|
2 |
+
import asyncio
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import queue
|
9 |
+
import re
|
10 |
+
import signal
|
11 |
+
import sys
|
12 |
+
import tempfile
|
13 |
+
import traceback
|
14 |
+
import uuid
|
15 |
+
from typing import Optional, Tuple, Type
|
16 |
+
|
17 |
+
from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager
|
18 |
+
from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
|
19 |
+
|
20 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
21 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
22 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
START_CODE = """
|
27 |
+
def input(*args, **kwargs):
|
28 |
+
raise NotImplementedError('Python input() function is disabled.')
|
29 |
+
|
30 |
+
get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
|
31 |
+
{}
|
32 |
+
""" # noqa
|
33 |
+
|
34 |
+
|
35 |
+
class TimeoutError(Exception):
|
36 |
+
pass
|
37 |
+
|
38 |
+
|
39 |
+
class KernelDeath(Exception):
|
40 |
+
pass
|
41 |
+
|
42 |
+
|
43 |
+
async def async_run_code(
|
44 |
+
km: AsyncKernelManager,
|
45 |
+
code,
|
46 |
+
*,
|
47 |
+
interrupt_after=30,
|
48 |
+
iopub_timeout=40,
|
49 |
+
wait_for_ready_timeout=60,
|
50 |
+
shutdown_kernel=True,
|
51 |
+
):
|
52 |
+
assert iopub_timeout > interrupt_after
|
53 |
+
try:
|
54 |
+
|
55 |
+
async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
|
56 |
+
*,
|
57 |
+
timeout=None):
|
58 |
+
loop = asyncio.get_running_loop()
|
59 |
+
dead_fut = loop.create_future()
|
60 |
+
|
61 |
+
def restarting():
|
62 |
+
assert (
|
63 |
+
False
|
64 |
+
), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0"
|
65 |
+
|
66 |
+
def dead():
|
67 |
+
logger.info("Kernel has died, will NOT restart")
|
68 |
+
dead_fut.set_result(None)
|
69 |
+
|
70 |
+
msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout))
|
71 |
+
km.add_restart_callback(restarting, "restart")
|
72 |
+
km.add_restart_callback(dead, "dead")
|
73 |
+
try:
|
74 |
+
done, _ = await asyncio.wait(
|
75 |
+
[dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
|
76 |
+
if dead_fut in done:
|
77 |
+
raise KernelDeath()
|
78 |
+
assert msg_task in done
|
79 |
+
return await msg_task
|
80 |
+
finally:
|
81 |
+
msg_task.cancel()
|
82 |
+
km.remove_restart_callback(restarting, "restart")
|
83 |
+
km.remove_restart_callback(dead, "dead")
|
84 |
+
|
85 |
+
async def send_interrupt():
|
86 |
+
await asyncio.sleep(interrupt_after)
|
87 |
+
logger.info("Sending interrupt to kernel")
|
88 |
+
await km.interrupt_kernel()
|
89 |
+
|
90 |
+
@retry(
|
91 |
+
retry=retry_if_result(lambda ret: ret[-1].strip() in [
|
92 |
+
'KeyboardInterrupt',
|
93 |
+
f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
|
94 |
+
] if isinstance(ret, tuple) else False),
|
95 |
+
stop=stop_after_attempt(3),
|
96 |
+
wait=wait_fixed(1),
|
97 |
+
retry_error_callback=lambda state: state.outcome.result())
|
98 |
+
async def run():
|
99 |
+
execute_result = None
|
100 |
+
error_traceback = None
|
101 |
+
stream_text_list = []
|
102 |
+
kc = km.client()
|
103 |
+
assert isinstance(kc, AsyncKernelClient)
|
104 |
+
kc.start_channels()
|
105 |
+
try:
|
106 |
+
await kc.wait_for_ready(timeout=wait_for_ready_timeout)
|
107 |
+
msg_id = kc.execute(code)
|
108 |
+
while True:
|
109 |
+
message = await get_iopub_msg_with_death_detection(
|
110 |
+
kc, timeout=iopub_timeout)
|
111 |
+
if logger.isEnabledFor(logging.DEBUG):
|
112 |
+
logger.debug(
|
113 |
+
json.dumps(message, indent=2, default=str))
|
114 |
+
assert message["parent_header"]["msg_id"] == msg_id
|
115 |
+
msg_type = message["msg_type"]
|
116 |
+
if msg_type == "status":
|
117 |
+
if message["content"]["execution_state"] == "idle":
|
118 |
+
break
|
119 |
+
elif msg_type == "stream":
|
120 |
+
stream_name = message["content"]["name"]
|
121 |
+
stream_text = message["content"]["text"]
|
122 |
+
stream_text_list.append(stream_text)
|
123 |
+
elif msg_type == "execute_result":
|
124 |
+
execute_result = message["content"]["data"]
|
125 |
+
elif msg_type == "error":
|
126 |
+
error_traceback_lines = message["content"]["traceback"]
|
127 |
+
error_traceback = "\n".join(error_traceback_lines)
|
128 |
+
elif msg_type == "execute_input":
|
129 |
+
pass
|
130 |
+
else:
|
131 |
+
assert False, f"Unknown message_type: {msg_type}"
|
132 |
+
finally:
|
133 |
+
kc.stop_channels()
|
134 |
+
return execute_result, error_traceback, "".join(stream_text_list)
|
135 |
+
|
136 |
+
if interrupt_after:
|
137 |
+
run_task = asyncio.create_task(run())
|
138 |
+
send_interrupt_task = asyncio.create_task(send_interrupt())
|
139 |
+
done, _ = await asyncio.wait([run_task, send_interrupt_task],
|
140 |
+
return_when=asyncio.FIRST_COMPLETED)
|
141 |
+
if run_task in done:
|
142 |
+
send_interrupt_task.cancel()
|
143 |
+
else:
|
144 |
+
assert send_interrupt_task in done
|
145 |
+
result = await run_task
|
146 |
+
else:
|
147 |
+
result = await run()
|
148 |
+
return result
|
149 |
+
finally:
|
150 |
+
if shutdown_kernel:
|
151 |
+
await km.shutdown_kernel()
|
152 |
+
|
153 |
+
|
154 |
+
class IPythonInterpreter(BaseAction):
|
155 |
+
"""A IPython executor that can execute Python scripts in a jupyter manner.
|
156 |
+
|
157 |
+
Args:
|
158 |
+
timeout (int): Upper bound of waiting time for Python script execution.
|
159 |
+
Defaults to 20.
|
160 |
+
user_data_dir (str, optional): Specified the user data directory for files
|
161 |
+
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
|
162 |
+
Defaults to `ENV`.
|
163 |
+
work_dir (str, optional): Specify which directory to save output images to.
|
164 |
+
Defaults to ``'./work_dir/tmp_dir'``.
|
165 |
+
description (dict): The description of the action. Defaults to ``None``.
|
166 |
+
parser (Type[BaseParser]): The parser class to process the
|
167 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
168 |
+
"""
|
169 |
+
|
170 |
+
_KERNEL_CLIENTS = {}
|
171 |
+
|
172 |
+
def __init__(
|
173 |
+
self,
|
174 |
+
timeout: int = 20,
|
175 |
+
user_data_dir: str = 'ENV',
|
176 |
+
work_dir='./work_dir/tmp_dir',
|
177 |
+
description: Optional[dict] = None,
|
178 |
+
parser: Type[BaseParser] = JsonParser,
|
179 |
+
):
|
180 |
+
super().__init__(description, parser)
|
181 |
+
|
182 |
+
self.timeout = timeout
|
183 |
+
if user_data_dir == 'ENV':
|
184 |
+
user_data_dir = os.environ.get('USER_DATA_DIR', '')
|
185 |
+
|
186 |
+
if user_data_dir:
|
187 |
+
user_data_dir = os.path.dirname(user_data_dir)
|
188 |
+
user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
|
189 |
+
self.user_data_dir = user_data_dir
|
190 |
+
self._initialized = False
|
191 |
+
self.work_dir = work_dir
|
192 |
+
if not os.path.exists(self.work_dir):
|
193 |
+
os.makedirs(self.work_dir, exist_ok=True)
|
194 |
+
|
195 |
+
@staticmethod
|
196 |
+
def start_kernel():
|
197 |
+
from jupyter_client import KernelManager
|
198 |
+
|
199 |
+
# start the kernel and manager
|
200 |
+
km = KernelManager()
|
201 |
+
km.start_kernel()
|
202 |
+
kc = km.client()
|
203 |
+
return km, kc
|
204 |
+
|
205 |
+
def initialize(self):
|
206 |
+
if self._initialized:
|
207 |
+
return
|
208 |
+
pid = os.getpid()
|
209 |
+
if pid not in self._KERNEL_CLIENTS:
|
210 |
+
self._KERNEL_CLIENTS[pid] = self.start_kernel()
|
211 |
+
self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
|
212 |
+
self._initialized = True
|
213 |
+
self._call(START_CODE.format(self.user_data_dir), None)
|
214 |
+
|
215 |
+
def reset(self):
|
216 |
+
if not self._initialized:
|
217 |
+
self.initialize()
|
218 |
+
else:
|
219 |
+
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
|
220 |
+
START_CODE.format(self.user_data_dir)
|
221 |
+
self._call(code, None)
|
222 |
+
|
223 |
+
def _call(self,
|
224 |
+
command: str,
|
225 |
+
timeout: Optional[int] = None) -> Tuple[str, bool]:
|
226 |
+
self.initialize()
|
227 |
+
command = extract_code(command)
|
228 |
+
|
229 |
+
# check previous remaining result
|
230 |
+
while True:
|
231 |
+
try:
|
232 |
+
msg = self.kernel_client.get_iopub_msg(timeout=5)
|
233 |
+
msg_type = msg['msg_type']
|
234 |
+
if msg_type == 'status':
|
235 |
+
if msg['content'].get('execution_state') == 'idle':
|
236 |
+
break
|
237 |
+
except queue.Empty:
|
238 |
+
# assume no result
|
239 |
+
break
|
240 |
+
|
241 |
+
self.kernel_client.execute(command)
|
242 |
+
|
243 |
+
def _inner_call():
|
244 |
+
result = ''
|
245 |
+
images = []
|
246 |
+
succeed = True
|
247 |
+
image_idx = 0
|
248 |
+
|
249 |
+
while True:
|
250 |
+
text = ''
|
251 |
+
image = ''
|
252 |
+
finished = False
|
253 |
+
msg_type = 'error'
|
254 |
+
try:
|
255 |
+
msg = self.kernel_client.get_iopub_msg(timeout=20)
|
256 |
+
msg_type = msg['msg_type']
|
257 |
+
if msg_type == 'status':
|
258 |
+
if msg['content'].get('execution_state') == 'idle':
|
259 |
+
finished = True
|
260 |
+
elif msg_type == 'execute_result':
|
261 |
+
text = msg['content']['data'].get('text/plain', '')
|
262 |
+
if 'image/png' in msg['content']['data']:
|
263 |
+
image_b64 = msg['content']['data']['image/png']
|
264 |
+
image_url = publish_image_to_local(
|
265 |
+
image_b64, self.work_dir)
|
266 |
+
image_idx += 1
|
267 |
+
image = '![fig-%03d](%s)' % (image_idx, image_url)
|
268 |
+
|
269 |
+
elif msg_type == 'display_data':
|
270 |
+
if 'image/png' in msg['content']['data']:
|
271 |
+
image_b64 = msg['content']['data']['image/png']
|
272 |
+
image_url = publish_image_to_local(
|
273 |
+
image_b64, self.work_dir)
|
274 |
+
image_idx += 1
|
275 |
+
image = '![fig-%03d](%s)' % (image_idx, image_url)
|
276 |
+
|
277 |
+
else:
|
278 |
+
text = msg['content']['data'].get('text/plain', '')
|
279 |
+
elif msg_type == 'stream':
|
280 |
+
msg_type = msg['content']['name'] # stdout, stderr
|
281 |
+
text = msg['content']['text']
|
282 |
+
elif msg_type == 'error':
|
283 |
+
succeed = False
|
284 |
+
text = escape_ansi('\n'.join(
|
285 |
+
msg['content']['traceback']))
|
286 |
+
if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
|
287 |
+
text = f'Timeout. No response after {timeout} seconds.' # noqa
|
288 |
+
except queue.Empty:
|
289 |
+
# stop current task in case break next input.
|
290 |
+
self.kernel_manager.interrupt_kernel()
|
291 |
+
succeed = False
|
292 |
+
text = f'Timeout. No response after {timeout} seconds.'
|
293 |
+
finished = True
|
294 |
+
except Exception:
|
295 |
+
succeed = False
|
296 |
+
msg = ''.join(traceback.format_exception(*sys.exc_info()))
|
297 |
+
# text = 'The code interpreter encountered an unexpected error.' # noqa
|
298 |
+
text = msg
|
299 |
+
logging.warning(msg)
|
300 |
+
finished = True
|
301 |
+
if text:
|
302 |
+
# result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
|
303 |
+
result += f'{text}'
|
304 |
+
|
305 |
+
if image:
|
306 |
+
images.append(image_url)
|
307 |
+
if finished:
|
308 |
+
return succeed, dict(text=result, image=images)
|
309 |
+
|
310 |
+
try:
|
311 |
+
if timeout:
|
312 |
+
|
313 |
+
def handler(signum, frame):
|
314 |
+
raise TimeoutError()
|
315 |
+
|
316 |
+
signal.signal(signal.SIGALRM, handler)
|
317 |
+
signal.alarm(timeout)
|
318 |
+
succeed, result = _inner_call()
|
319 |
+
except TimeoutError:
|
320 |
+
succeed = False
|
321 |
+
text = 'The code interpreter encountered an unexpected error.'
|
322 |
+
result = f'\n\nerror:\n\n```\n{text}\n```'
|
323 |
+
finally:
|
324 |
+
if timeout:
|
325 |
+
signal.alarm(0)
|
326 |
+
|
327 |
+
# result = result.strip('\n')
|
328 |
+
return succeed, result
|
329 |
+
|
330 |
+
@tool_api
|
331 |
+
def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
|
332 |
+
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
command (:class:`str`): Python code
|
336 |
+
timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
|
337 |
+
"""
|
338 |
+
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
339 |
+
tool_return.args = dict(text=command)
|
340 |
+
succeed, result = self._call(command, timeout)
|
341 |
+
if succeed:
|
342 |
+
text = result['text']
|
343 |
+
image = result.get('image', [])
|
344 |
+
resp = [dict(type='text', content=text)]
|
345 |
+
if image:
|
346 |
+
resp.extend([dict(type='image', content=im) for im in image])
|
347 |
+
tool_return.result = resp
|
348 |
+
# tool_return.result = dict(
|
349 |
+
# text=result['text'], image=result.get('image', [])[0])
|
350 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
351 |
+
else:
|
352 |
+
tool_return.errmsg = result.get('text', '') if isinstance(
|
353 |
+
result, dict) else result
|
354 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
355 |
+
return tool_return
|
356 |
+
|
357 |
+
|
358 |
+
class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
|
359 |
+
"""A IPython executor that can execute Python scripts in a jupyter manner.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
timeout (int): Upper bound of waiting time for Python script execution.
|
363 |
+
Defaults to 20.
|
364 |
+
user_data_dir (str, optional): Specified the user data directory for files
|
365 |
+
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
|
366 |
+
Defaults to `ENV`.
|
367 |
+
work_dir (str, optional): Specify which directory to save output images to.
|
368 |
+
Defaults to ``'./work_dir/tmp_dir'``.
|
369 |
+
description (dict): The description of the action. Defaults to ``None``.
|
370 |
+
parser (Type[BaseParser]): The parser class to process the
|
371 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
372 |
+
"""
|
373 |
+
|
374 |
+
_UNBOUND_KERNEL_CLIENTS = asyncio.Queue()
|
375 |
+
|
376 |
+
def __init__(
|
377 |
+
self,
|
378 |
+
timeout: int = 20,
|
379 |
+
user_data_dir: str = 'ENV',
|
380 |
+
work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'),
|
381 |
+
max_kernels: Optional[int] = None,
|
382 |
+
reuse_kernel: bool = True,
|
383 |
+
startup_rate: bool = 32,
|
384 |
+
connection_dir: str = tempfile.gettempdir(),
|
385 |
+
description: Optional[dict] = None,
|
386 |
+
parser: Type[BaseParser] = JsonParser,
|
387 |
+
):
|
388 |
+
super().__init__(timeout, user_data_dir, work_dir, description, parser)
|
389 |
+
from traitlets.config import Config
|
390 |
+
|
391 |
+
c = Config()
|
392 |
+
c.KernelManager.transport = 'ipc'
|
393 |
+
self._amkm = AsyncMultiKernelManager(
|
394 |
+
config=c, connection_dir=connection_dir)
|
395 |
+
self._max_kernels = max_kernels
|
396 |
+
self._reuse_kernel = reuse_kernel
|
397 |
+
self._sem = asyncio.Semaphore(startup_rate)
|
398 |
+
self._lock = asyncio.Lock()
|
399 |
+
|
400 |
+
async def initialize(self, session_id: str):
|
401 |
+
session_id = str(session_id)
|
402 |
+
while True:
|
403 |
+
if session_id in self._KERNEL_CLIENTS:
|
404 |
+
return self._KERNEL_CLIENTS[session_id]
|
405 |
+
if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
|
406 |
+
self._KERNEL_CLIENTS[
|
407 |
+
session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
|
408 |
+
return self._KERNEL_CLIENTS[session_id]
|
409 |
+
async with self._sem:
|
410 |
+
if self._max_kernels is None or len(
|
411 |
+
self._KERNEL_CLIENTS
|
412 |
+
) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
|
413 |
+
kernel_id = None
|
414 |
+
try:
|
415 |
+
kernel_id = await self._amkm.start_kernel()
|
416 |
+
kernel = self._amkm.get_kernel(kernel_id)
|
417 |
+
client = kernel.client()
|
418 |
+
_, error_stacktrace, stream_text = await async_run_code(
|
419 |
+
kernel,
|
420 |
+
START_CODE.format(self.user_data_dir),
|
421 |
+
shutdown_kernel=False)
|
422 |
+
# check if the output of START_CODE meets expectations
|
423 |
+
if not (error_stacktrace is None
|
424 |
+
and stream_text == ''):
|
425 |
+
raise RuntimeError
|
426 |
+
except Exception as e:
|
427 |
+
print(f'Starting kernel error: {e}')
|
428 |
+
if kernel_id:
|
429 |
+
await self._amkm.shutdown_kernel(kernel_id)
|
430 |
+
self._amkm.remove_kernel(kernel_id)
|
431 |
+
await asyncio.sleep(1)
|
432 |
+
continue
|
433 |
+
if self._max_kernels is None:
|
434 |
+
self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
|
435 |
+
client)
|
436 |
+
return kernel_id, kernel, client
|
437 |
+
async with self._lock:
|
438 |
+
if len(self._KERNEL_CLIENTS
|
439 |
+
) + self._UNBOUND_KERNEL_CLIENTS.qsize(
|
440 |
+
) < self._max_kernels:
|
441 |
+
self._KERNEL_CLIENTS[session_id] = (kernel_id,
|
442 |
+
kernel, client)
|
443 |
+
return kernel_id, kernel, client
|
444 |
+
await self._amkm.shutdown_kernel(kernel_id)
|
445 |
+
self._amkm.remove_kernel(kernel_id)
|
446 |
+
await asyncio.sleep(1)
|
447 |
+
|
448 |
+
async def reset(self, session_id: str):
|
449 |
+
session_id = str(session_id)
|
450 |
+
if session_id not in self._KERNEL_CLIENTS:
|
451 |
+
return
|
452 |
+
_, kernel, _ = self._KERNEL_CLIENTS[session_id]
|
453 |
+
code = "get_ipython().run_line_magic('reset', '-f')\n" + \
|
454 |
+
START_CODE.format(self.user_data_dir)
|
455 |
+
await async_run_code(kernel, code, shutdown_kernel=False)
|
456 |
+
|
457 |
+
async def shutdown(self, session_id: str):
|
458 |
+
session_id = str(session_id)
|
459 |
+
if session_id in self._KERNEL_CLIENTS:
|
460 |
+
kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id)
|
461 |
+
await self._amkm.shutdown_kernel(kernel_id)
|
462 |
+
self._amkm.remove_kernel(kernel_id)
|
463 |
+
del self._KERNEL_CLIENTS[session_id]
|
464 |
+
|
465 |
+
async def close_session(self, session_id: str):
|
466 |
+
session_id = str(session_id)
|
467 |
+
if self._reuse_kernel:
|
468 |
+
if session_id in self._KERNEL_CLIENTS:
|
469 |
+
await self.reset(session_id)
|
470 |
+
await self._UNBOUND_KERNEL_CLIENTS.put(
|
471 |
+
self._KERNEL_CLIENTS.pop(session_id))
|
472 |
+
else:
|
473 |
+
await self.shutdown(session_id)
|
474 |
+
|
475 |
+
async def _call(self, command, timeout=None, session_id=None):
|
476 |
+
_, kernel, _ = await self.initialize(str(session_id))
|
477 |
+
result = await async_run_code(
|
478 |
+
kernel,
|
479 |
+
extract_code(command),
|
480 |
+
interrupt_after=timeout or self.timeout,
|
481 |
+
shutdown_kernel=False)
|
482 |
+
execute_result, error_stacktrace, stream_text = result
|
483 |
+
if error_stacktrace is not None:
|
484 |
+
ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
|
485 |
+
if ret.endswith('KeyboardInterrupt: '):
|
486 |
+
ret = 'The code interpreter encountered a timeout error.'
|
487 |
+
status, ret = False, ret.strip()
|
488 |
+
elif execute_result is not None:
|
489 |
+
status, ret = True, dict(text=execute_result.get('text/plain', ''))
|
490 |
+
else:
|
491 |
+
status, ret = True, dict(text=stream_text.strip())
|
492 |
+
return status, ret
|
493 |
+
|
494 |
+
@tool_api
|
495 |
+
async def run(self,
|
496 |
+
command: str,
|
497 |
+
timeout: Optional[int] = None,
|
498 |
+
session_id: Optional[str] = None) -> ActionReturn:
|
499 |
+
r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
command (:class:`str`): Python code
|
503 |
+
timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
|
504 |
+
"""
|
505 |
+
tool_return = ActionReturn(url=None, args=None, type=self.name)
|
506 |
+
tool_return.args = dict(text=command)
|
507 |
+
succeed, result = await self._call(command, timeout, session_id)
|
508 |
+
if succeed:
|
509 |
+
text = result['text']
|
510 |
+
image = result.get('image', [])
|
511 |
+
resp = [dict(type='text', content=text)]
|
512 |
+
if image:
|
513 |
+
resp.extend([dict(type='image', content=im) for im in image])
|
514 |
+
tool_return.result = resp
|
515 |
+
# tool_return.result = dict(
|
516 |
+
# text=result['text'], image=result.get('image', [])[0])
|
517 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
518 |
+
else:
|
519 |
+
tool_return.errmsg = result.get('text', '') if isinstance(
|
520 |
+
result, dict) else result
|
521 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
522 |
+
return tool_return
|
523 |
+
|
524 |
+
|
525 |
+
def extract_code(text):
|
526 |
+
import json5
|
527 |
+
|
528 |
+
# Match triple backtick blocks first
|
529 |
+
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
530 |
+
# Match single backtick blocks second
|
531 |
+
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
532 |
+
if triple_match:
|
533 |
+
text = triple_match.group(1)
|
534 |
+
elif single_match:
|
535 |
+
text = single_match.group(1)
|
536 |
+
else:
|
537 |
+
try:
|
538 |
+
text = json5.loads(text)['code']
|
539 |
+
except Exception:
|
540 |
+
pass
|
541 |
+
# If no code blocks found, return original text
|
542 |
+
return text
|
543 |
+
|
544 |
+
|
545 |
+
def escape_ansi(line):
|
546 |
+
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
|
547 |
+
return ansi_escape.sub('', line)
|
548 |
+
|
549 |
+
|
550 |
+
def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
|
551 |
+
import PIL.Image
|
552 |
+
image_file = str(uuid.uuid4()) + '.png'
|
553 |
+
local_image_file = os.path.join(work_dir, image_file)
|
554 |
+
|
555 |
+
png_bytes = base64.b64decode(image_base64)
|
556 |
+
assert isinstance(png_bytes, bytes)
|
557 |
+
bytes_io = io.BytesIO(png_bytes)
|
558 |
+
PIL.Image.open(bytes_io).save(local_image_file, 'png')
|
559 |
+
|
560 |
+
return local_image_file
|
561 |
+
|
562 |
+
|
563 |
+
# local test for code interpreter
|
564 |
+
def get_multiline_input(hint):
|
565 |
+
print(hint)
|
566 |
+
print('// Press ENTER to make a new line. Press CTRL-D to end input.')
|
567 |
+
lines = []
|
568 |
+
while True:
|
569 |
+
try:
|
570 |
+
line = input()
|
571 |
+
except EOFError: # CTRL-D
|
572 |
+
break
|
573 |
+
lines.append(line)
|
574 |
+
print('// Input received.')
|
575 |
+
if lines:
|
576 |
+
return '\n'.join(lines)
|
577 |
+
else:
|
578 |
+
return ''
|
579 |
+
|
580 |
+
|
581 |
+
if __name__ == '__main__':
|
582 |
+
code_interpreter = IPythonInterpreter()
|
583 |
+
while True:
|
584 |
+
print(code_interpreter(get_multiline_input('Enter python code:')))
|
lagent/actions/ipython_manager.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import sys
|
3 |
+
from collections import defaultdict
|
4 |
+
from contextlib import nullcontext
|
5 |
+
from io import StringIO
|
6 |
+
from multiprocessing import Process, Queue
|
7 |
+
from typing import List, Optional, Type, Union
|
8 |
+
|
9 |
+
from filelock import FileLock
|
10 |
+
from timeout_decorator import timeout as tm
|
11 |
+
|
12 |
+
from ..schema import ActionReturn, ActionStatusCode
|
13 |
+
from .base_action import BaseAction
|
14 |
+
from .parser import BaseParser, JsonParser
|
15 |
+
|
16 |
+
|
17 |
+
class IPythonProcess(Process):
|
18 |
+
|
19 |
+
def __init__(self,
|
20 |
+
in_q: Queue,
|
21 |
+
out_q: Queue,
|
22 |
+
timeout: int = 20,
|
23 |
+
ci_lock: str = None,
|
24 |
+
daemon: bool = True):
|
25 |
+
super().__init__(daemon=daemon)
|
26 |
+
self.in_q = in_q
|
27 |
+
self.out_q = out_q
|
28 |
+
self.timeout = timeout
|
29 |
+
self.session_id2shell = defaultdict(self.create_shell)
|
30 |
+
self.ci_lock = FileLock(
|
31 |
+
ci_lock) if ci_lock else nullcontext() # avoid core corruption
|
32 |
+
self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
|
33 |
+
|
34 |
+
def run(self):
|
35 |
+
while True:
|
36 |
+
msg = self.in_q.get()
|
37 |
+
if msg == 'reset':
|
38 |
+
for session_id, shell in self.session_id2shell.items():
|
39 |
+
with self.ci_lock:
|
40 |
+
try:
|
41 |
+
shell.reset(new_session=False)
|
42 |
+
# shell.run_line_magic('reset', '-sf')
|
43 |
+
except Exception:
|
44 |
+
self.session_id2shell[
|
45 |
+
session_id] = self.create_shell()
|
46 |
+
self.out_q.put('ok')
|
47 |
+
elif isinstance(msg, tuple) and len(msg) == 3:
|
48 |
+
i, session_id, code = msg
|
49 |
+
res = self.exec(session_id, code)
|
50 |
+
self.out_q.put((i, session_id, res))
|
51 |
+
|
52 |
+
def exec(self, session_id, code):
|
53 |
+
try:
|
54 |
+
shell = self.session_id2shell[session_id]
|
55 |
+
with StringIO() as io:
|
56 |
+
old_stdout = sys.stdout
|
57 |
+
sys.stdout = io
|
58 |
+
if self.timeout is False or self.timeout < 0:
|
59 |
+
shell.run_cell(self.extract_code(code))
|
60 |
+
else:
|
61 |
+
tm(self.timeout)(shell.run_cell)(self.extract_code(code))
|
62 |
+
sys.stdout = old_stdout
|
63 |
+
output = self._highlighting.sub('', io.getvalue().strip())
|
64 |
+
output = re.sub(r'^Out\[\d+\]: ', '', output)
|
65 |
+
if 'Error' in output or 'Traceback' in output:
|
66 |
+
output = output.lstrip('-').strip()
|
67 |
+
if output.startswith('TimeoutError'):
|
68 |
+
output = 'The code interpreter encountered a timeout error.'
|
69 |
+
return {'status': 'FAILURE', 'msg': output, 'code': code}
|
70 |
+
return {'status': 'SUCCESS', 'value': output, 'code': code}
|
71 |
+
except Exception as e:
|
72 |
+
return {'status': 'FAILURE', 'msg': str(e), 'code': code}
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def create_shell(enable_history: bool = False, in_memory: bool = True):
|
76 |
+
from IPython import InteractiveShell
|
77 |
+
from traitlets.config import Config
|
78 |
+
|
79 |
+
c = Config()
|
80 |
+
c.HistoryManager.enabled = enable_history
|
81 |
+
if in_memory:
|
82 |
+
c.HistoryManager.hist_file = ':memory:'
|
83 |
+
shell = InteractiveShell(config=c)
|
84 |
+
return shell
|
85 |
+
|
86 |
+
@staticmethod
|
87 |
+
def extract_code(text: str) -> str:
|
88 |
+
"""Extract Python code from markup languages.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
text (:class:`str`): Markdown-formatted text
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
:class:`str`: Python code
|
95 |
+
"""
|
96 |
+
import json5
|
97 |
+
|
98 |
+
# Match triple backtick blocks first
|
99 |
+
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
100 |
+
# Match single backtick blocks second
|
101 |
+
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
102 |
+
if triple_match:
|
103 |
+
text = triple_match.group(1)
|
104 |
+
elif single_match:
|
105 |
+
text = single_match.group(1)
|
106 |
+
else:
|
107 |
+
try:
|
108 |
+
text = json5.loads(text)['code']
|
109 |
+
except Exception:
|
110 |
+
pass
|
111 |
+
# If no code blocks found, return original text
|
112 |
+
return text
|
113 |
+
|
114 |
+
|
115 |
+
class IPythonInteractiveManager(BaseAction):
|
116 |
+
"""An interactive IPython shell manager for code execution"""
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
max_workers: int = 50,
|
121 |
+
timeout: int = 20,
|
122 |
+
ci_lock: str = None,
|
123 |
+
description: Optional[dict] = None,
|
124 |
+
parser: Type[BaseParser] = JsonParser,
|
125 |
+
):
|
126 |
+
super().__init__(description, parser)
|
127 |
+
self.max_workers = max_workers
|
128 |
+
self.timeout = timeout
|
129 |
+
self.ci_lock = ci_lock
|
130 |
+
self.id2queue = defaultdict(Queue)
|
131 |
+
self.id2process = {}
|
132 |
+
self.out_queue = Queue()
|
133 |
+
|
134 |
+
def __call__(self,
|
135 |
+
commands: Union[str, List[str]],
|
136 |
+
session_ids: Union[int, List[int]] = None):
|
137 |
+
if isinstance(commands, list):
|
138 |
+
batch_size = len(commands)
|
139 |
+
is_batch = True
|
140 |
+
else:
|
141 |
+
batch_size = 1
|
142 |
+
commands = [commands]
|
143 |
+
is_batch = False
|
144 |
+
if session_ids is None:
|
145 |
+
session_ids = range(batch_size)
|
146 |
+
elif isinstance(session_ids, int):
|
147 |
+
session_ids = [session_ids]
|
148 |
+
if len(session_ids) != batch_size or len(session_ids) != len(
|
149 |
+
set(session_ids)):
|
150 |
+
raise ValueError(
|
151 |
+
'the size of `session_ids` must equal that of `commands`')
|
152 |
+
try:
|
153 |
+
exec_results = self.run_code_blocks([
|
154 |
+
(session_id, command)
|
155 |
+
for session_id, command in zip(session_ids, commands)
|
156 |
+
])
|
157 |
+
except KeyboardInterrupt:
|
158 |
+
self.clear()
|
159 |
+
exit(1)
|
160 |
+
action_returns = []
|
161 |
+
for result, code in zip(exec_results, commands):
|
162 |
+
action_return = ActionReturn({'command': code}, type=self.name)
|
163 |
+
if result['status'] == 'SUCCESS':
|
164 |
+
action_return.result = [
|
165 |
+
dict(type='text', content=result['value'])
|
166 |
+
]
|
167 |
+
action_return.state = ActionStatusCode.SUCCESS
|
168 |
+
else:
|
169 |
+
action_return.errmsg = result['msg']
|
170 |
+
action_return.state = ActionStatusCode.API_ERROR
|
171 |
+
action_returns.append(action_return)
|
172 |
+
if not is_batch:
|
173 |
+
return action_returns[0]
|
174 |
+
return action_returns
|
175 |
+
|
176 |
+
def process_code(self, index, session_id, code):
|
177 |
+
ipy_id = session_id % self.max_workers
|
178 |
+
input_queue = self.id2queue[ipy_id]
|
179 |
+
proc = self.id2process.setdefault(
|
180 |
+
ipy_id,
|
181 |
+
IPythonProcess(
|
182 |
+
input_queue,
|
183 |
+
self.out_queue,
|
184 |
+
self.timeout,
|
185 |
+
self.ci_lock,
|
186 |
+
daemon=True))
|
187 |
+
if not proc.is_alive():
|
188 |
+
proc.start()
|
189 |
+
input_queue.put((index, session_id, code))
|
190 |
+
|
191 |
+
def run_code_blocks(self, session_code_pairs):
|
192 |
+
size = len(session_code_pairs)
|
193 |
+
for index, (session_id, code) in enumerate(session_code_pairs):
|
194 |
+
self.process_code(index, session_id, code)
|
195 |
+
results = []
|
196 |
+
while len(results) < size:
|
197 |
+
msg = self.out_queue.get()
|
198 |
+
if isinstance(msg, tuple) and len(msg) == 3:
|
199 |
+
index, _, result = msg
|
200 |
+
results.append((index, result))
|
201 |
+
results.sort()
|
202 |
+
return [item[1] for item in results]
|
203 |
+
|
204 |
+
def clear(self):
|
205 |
+
self.id2queue.clear()
|
206 |
+
for proc in self.id2process.values():
|
207 |
+
proc.terminate()
|
208 |
+
self.id2process.clear()
|
209 |
+
while not self.out_queue.empty():
|
210 |
+
self.out_queue.get()
|
211 |
+
|
212 |
+
def reset(self):
|
213 |
+
cnt = 0
|
214 |
+
for q in self.id2queue.values():
|
215 |
+
q.put('reset')
|
216 |
+
cnt += 1
|
217 |
+
while cnt > 0:
|
218 |
+
msg = self.out_queue.get()
|
219 |
+
if msg == 'ok':
|
220 |
+
cnt -= 1
|
lagent/actions/parser.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
from ast import literal_eval
|
4 |
+
from typing import Any, List, Union
|
5 |
+
|
6 |
+
|
7 |
+
class ParseError(Exception):
|
8 |
+
"""Parsing exception class."""
|
9 |
+
|
10 |
+
def __init__(self, err_msg: str):
|
11 |
+
self.err_msg = err_msg
|
12 |
+
|
13 |
+
|
14 |
+
class BaseParser:
|
15 |
+
"""Base parser to process inputs and outputs of actions.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
action (:class:`BaseAction`): action to validate
|
19 |
+
|
20 |
+
Attributes:
|
21 |
+
PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
|
22 |
+
LLMs should follow when generating arguments for decided tools.
|
23 |
+
"""
|
24 |
+
|
25 |
+
PARAMETER_DESCRIPTION: str = ''
|
26 |
+
|
27 |
+
def __init__(self, action):
|
28 |
+
self.action = action
|
29 |
+
self._api2param = {}
|
30 |
+
self._api2required = {}
|
31 |
+
# perform basic argument validation
|
32 |
+
if action.description:
|
33 |
+
for api in action.description.get('api_list',
|
34 |
+
[action.description]):
|
35 |
+
name = (f'{action.name}.{api["name"]}'
|
36 |
+
if self.action.is_toolkit else api['name'])
|
37 |
+
required_parameters = set(api['required'])
|
38 |
+
all_parameters = {j['name'] for j in api['parameters']}
|
39 |
+
if not required_parameters.issubset(all_parameters):
|
40 |
+
raise ValueError(
|
41 |
+
f'unknown parameters for function "{name}": '
|
42 |
+
f'{required_parameters - all_parameters}')
|
43 |
+
if self.PARAMETER_DESCRIPTION:
|
44 |
+
api['parameter_description'] = self.PARAMETER_DESCRIPTION
|
45 |
+
api_name = api['name'] if self.action.is_toolkit else 'run'
|
46 |
+
self._api2param[api_name] = api['parameters']
|
47 |
+
self._api2required[api_name] = api['required']
|
48 |
+
|
49 |
+
def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
|
50 |
+
"""Parse inputs LLMs generate for the action.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
inputs (:class:`str`): input string extracted from responses
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
:class:`dict`: processed input
|
57 |
+
"""
|
58 |
+
inputs = {self._api2param[name][0]['name']: inputs}
|
59 |
+
return inputs
|
60 |
+
|
61 |
+
def parse_outputs(self, outputs: Any) -> List[dict]:
|
62 |
+
"""Parser outputs returned by the action.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
outputs (:class:`Any`): raw output of the action
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
:class:`List[dict]`: processed output of which each member is a
|
69 |
+
dictionary with two keys - 'type' and 'content'.
|
70 |
+
"""
|
71 |
+
if isinstance(outputs, dict):
|
72 |
+
outputs = json.dumps(outputs, ensure_ascii=False)
|
73 |
+
elif not isinstance(outputs, str):
|
74 |
+
outputs = str(outputs)
|
75 |
+
return [{
|
76 |
+
'type': 'text',
|
77 |
+
'content': outputs.encode('gbk', 'ignore').decode('gbk')
|
78 |
+
}]
|
79 |
+
|
80 |
+
|
81 |
+
class JsonParser(BaseParser):
|
82 |
+
"""Json parser to convert input string into a dictionary.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
action (:class:`BaseAction`): action to validate
|
86 |
+
"""
|
87 |
+
|
88 |
+
PARAMETER_DESCRIPTION = (
|
89 |
+
'If you call this tool, you must pass arguments in '
|
90 |
+
'the JSON format {key: value}, where the key is the parameter name.')
|
91 |
+
|
92 |
+
def parse_inputs(self,
|
93 |
+
inputs: Union[str, dict],
|
94 |
+
name: str = 'run') -> dict:
|
95 |
+
if not isinstance(inputs, dict):
|
96 |
+
try:
|
97 |
+
match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
|
98 |
+
re.S)
|
99 |
+
if match:
|
100 |
+
inputs = match.group(2).strip()
|
101 |
+
inputs = json.loads(inputs)
|
102 |
+
except json.JSONDecodeError as exc:
|
103 |
+
raise ParseError(f'invalid json format: {inputs}') from exc
|
104 |
+
input_keys = set(inputs)
|
105 |
+
all_keys = {param['name'] for param in self._api2param[name]}
|
106 |
+
if not input_keys.issubset(all_keys):
|
107 |
+
raise ParseError(f'unknown arguments: {input_keys - all_keys}')
|
108 |
+
required_keys = set(self._api2required[name])
|
109 |
+
if not input_keys.issuperset(required_keys):
|
110 |
+
raise ParseError(
|
111 |
+
f'missing required arguments: {required_keys - input_keys}')
|
112 |
+
return inputs
|
113 |
+
|
114 |
+
|
115 |
+
class TupleParser(BaseParser):
|
116 |
+
"""Tuple parser to convert input string into a tuple.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
action (:class:`BaseAction`): action to validate
|
120 |
+
"""
|
121 |
+
|
122 |
+
PARAMETER_DESCRIPTION = (
|
123 |
+
'If you call this tool, you must pass arguments in the tuple format '
|
124 |
+
'like (arg1, arg2, arg3), and the arguments are ordered.')
|
125 |
+
|
126 |
+
def parse_inputs(self,
|
127 |
+
inputs: Union[str, tuple],
|
128 |
+
name: str = 'run') -> dict:
|
129 |
+
if not isinstance(inputs, tuple):
|
130 |
+
try:
|
131 |
+
inputs = literal_eval(inputs)
|
132 |
+
except Exception as exc:
|
133 |
+
raise ParseError(f'invalid tuple format: {inputs}') from exc
|
134 |
+
if len(inputs) < len(self._api2required[name]):
|
135 |
+
raise ParseError(
|
136 |
+
f'API takes {len(self._api2required[name])} required positional '
|
137 |
+
f'arguments but {len(inputs)} were given')
|
138 |
+
if len(inputs) > len(self._api2param[name]):
|
139 |
+
raise ParseError(
|
140 |
+
f'API takes {len(self._api2param[name])} positional arguments '
|
141 |
+
f'but {len(inputs)} were given')
|
142 |
+
inputs = {
|
143 |
+
self._api2param[name][i]['name']: item
|
144 |
+
for i, item in enumerate(inputs)
|
145 |
+
}
|
146 |
+
return inputs
|
lagent/actions/ppt.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Optional, Type
|
2 |
+
|
3 |
+
from asyncer import asyncify
|
4 |
+
|
5 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
6 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
7 |
+
|
8 |
+
THEME_MAPPING = {
|
9 |
+
'Default': {
|
10 |
+
'template': None,
|
11 |
+
'title': 'Title Slide',
|
12 |
+
'single': 'Title and Content',
|
13 |
+
'two': 'Two Content',
|
14 |
+
}
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
class PPT(BaseAction):
|
19 |
+
"""Plugin to create ppt slides with text, paragraph, images in good looking styles."""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
theme_mapping: Optional[Dict[str, dict]] = None,
|
24 |
+
description: Optional[dict] = None,
|
25 |
+
parser: Type[BaseParser] = JsonParser,
|
26 |
+
):
|
27 |
+
super().__init__(description, parser)
|
28 |
+
self.theme_mapping = theme_mapping or THEME_MAPPING
|
29 |
+
self.pointer = None
|
30 |
+
self.location = None
|
31 |
+
|
32 |
+
@tool_api(explode_return=True)
|
33 |
+
def create_file(self, theme: str, abs_location: str) -> dict:
|
34 |
+
"""Create a pptx file with specific themes.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
theme (:class:`str`): the theme used. The value should be one of ['Default'].
|
38 |
+
abs_location (:class:`str`): the ppt file's absolute location
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
:class:`dict`: operation status
|
42 |
+
* status: the result of the execution
|
43 |
+
"""
|
44 |
+
from pptx import Presentation
|
45 |
+
|
46 |
+
self.location = abs_location
|
47 |
+
try:
|
48 |
+
self.pointer = Presentation(self.theme_mapping[theme]['template'])
|
49 |
+
self.pointer.slide_master.name = theme
|
50 |
+
# print('created')
|
51 |
+
except Exception as e:
|
52 |
+
print(e)
|
53 |
+
return dict(status='created a ppt file.')
|
54 |
+
|
55 |
+
@tool_api(explode_return=True)
|
56 |
+
def add_first_page(self, title: str, subtitle: str) -> dict:
|
57 |
+
"""Add the first page of ppt.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
title (:class:`str`): the title of ppt
|
61 |
+
subtitle (:class:`str`): the subtitle of ppt
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
:class:`dict`: operation status
|
65 |
+
* status: the result of the execution
|
66 |
+
"""
|
67 |
+
layout_name = self.theme_mapping[self.pointer.slide_master.name]['title']
|
68 |
+
layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
|
69 |
+
slide = self.pointer.slides.add_slide(layout)
|
70 |
+
ph_title, ph_subtitle = slide.placeholders
|
71 |
+
ph_title.text = title
|
72 |
+
if subtitle:
|
73 |
+
ph_subtitle.text = subtitle
|
74 |
+
return dict(status='added page')
|
75 |
+
|
76 |
+
@tool_api(explode_return=True)
|
77 |
+
def add_text_page(self, title: str, bullet_items: str) -> dict:
|
78 |
+
"""Add text page of ppt.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
title (:class:`str`): the title of the page
|
82 |
+
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
:class:`dict`: operation status
|
86 |
+
* status: the result of the execution
|
87 |
+
""" # noqa: E501
|
88 |
+
layout_name = self.theme_mapping[self.pointer.slide_master.name]['single']
|
89 |
+
layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
|
90 |
+
slide = self.pointer.slides.add_slide(layout)
|
91 |
+
ph_title, ph_body = slide.placeholders
|
92 |
+
ph_title.text = title
|
93 |
+
ph = ph_body
|
94 |
+
tf = ph.text_frame
|
95 |
+
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
96 |
+
if i == 0:
|
97 |
+
p = tf.paragraphs[0]
|
98 |
+
else:
|
99 |
+
p = tf.add_paragraph()
|
100 |
+
p.text = item.strip()
|
101 |
+
p.level = 0
|
102 |
+
return dict(status='added page')
|
103 |
+
|
104 |
+
@tool_api(explode_return=True)
|
105 |
+
def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
|
106 |
+
"""Add a text page with one image. Image should be a path.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
title (:class:`str`): the title of the page
|
110 |
+
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
111 |
+
image (:class:`str`): the path of the image
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
:class:`dict`: operation status
|
115 |
+
* status: the result of the execution
|
116 |
+
""" # noqa: E501
|
117 |
+
from PIL import Image
|
118 |
+
|
119 |
+
layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
|
120 |
+
layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
|
121 |
+
slide = self.pointer.slides.add_slide(layout)
|
122 |
+
ph_title, ph_body1, ph_body2 = slide.placeholders
|
123 |
+
ph_title.text = title
|
124 |
+
ph = ph_body2
|
125 |
+
image = Image.open(image)
|
126 |
+
image_pil = image.to_pil()
|
127 |
+
left = ph.left
|
128 |
+
width = ph.width
|
129 |
+
height = int(width / image_pil.width * image_pil.height)
|
130 |
+
top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
|
131 |
+
slide.shapes.add_picture(image.to_path(), left, top, width, height)
|
132 |
+
|
133 |
+
ph = ph_body1
|
134 |
+
tf = ph.text_frame
|
135 |
+
for i, item in enumerate(bullet_items.split('[SPAN]')):
|
136 |
+
if i == 0:
|
137 |
+
p = tf.paragraphs[0]
|
138 |
+
else:
|
139 |
+
p = tf.add_paragraph()
|
140 |
+
p.text = item.strip()
|
141 |
+
p.level = 0
|
142 |
+
|
143 |
+
return dict(status='added page')
|
144 |
+
|
145 |
+
@tool_api(explode_return=True)
|
146 |
+
def submit_file(self) -> dict:
|
147 |
+
"""When all steps done, YOU MUST use submit_file() to submit your work.
|
148 |
+
|
149 |
+
Returns:
|
150 |
+
:class:`dict`: operation status
|
151 |
+
* status: the result of the execution
|
152 |
+
"""
|
153 |
+
# file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
|
154 |
+
# self.pointer.save(file_path)
|
155 |
+
# retreival_url = upload_file(file_path)
|
156 |
+
self.pointer.save(self.location)
|
157 |
+
return dict(status=f'submitted. view ppt at {self.location}')
|
158 |
+
|
159 |
+
|
160 |
+
class AsyncPPT(AsyncActionMixin, PPT):
|
161 |
+
"""Plugin to create ppt slides with text, paragraph, images in good looking styles."""
|
162 |
+
|
163 |
+
@tool_api(explode_return=True)
|
164 |
+
@asyncify
|
165 |
+
def create_file(self, theme: str, abs_location: str) -> dict:
|
166 |
+
"""Create a pptx file with specific themes.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
theme (:class:`str`): the theme used. The value should be one of ['Default'].
|
170 |
+
abs_location (:class:`str`): the ppt file's absolute location
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
:class:`dict`: operation status
|
174 |
+
* status: the result of the execution
|
175 |
+
"""
|
176 |
+
return super().create_file(theme, abs_location)
|
177 |
+
|
178 |
+
@tool_api(explode_return=True)
|
179 |
+
@asyncify
|
180 |
+
def add_first_page(self, title: str, subtitle: str) -> dict:
|
181 |
+
"""Add the first page of ppt.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
title (:class:`str`): the title of ppt
|
185 |
+
subtitle (:class:`str`): the subtitle of ppt
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
:class:`dict`: operation status
|
189 |
+
* status: the result of the execution
|
190 |
+
"""
|
191 |
+
return super().add_first_page(title, subtitle)
|
192 |
+
|
193 |
+
@tool_api(explode_return=True)
|
194 |
+
@asyncify
|
195 |
+
def add_text_page(self, title: str, bullet_items: str) -> dict:
|
196 |
+
"""Add text page of ppt.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
title (:class:`str`): the title of the page
|
200 |
+
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
:class:`dict`: operation status
|
204 |
+
* status: the result of the execution
|
205 |
+
""" # noqa: E501
|
206 |
+
return super().add_text_page(title, bullet_items)
|
207 |
+
|
208 |
+
@tool_api(explode_return=True)
|
209 |
+
@asyncify
|
210 |
+
def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
|
211 |
+
"""Add a text page with one image. Image should be a path.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
title (:class:`str`): the title of the page
|
215 |
+
bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
|
216 |
+
image (:class:`str`): the path of the image
|
217 |
+
|
218 |
+
Returns:
|
219 |
+
:class:`dict`: operation status
|
220 |
+
* status: the result of the execution
|
221 |
+
""" # noqa: E501
|
222 |
+
return super().add_text_image_page(title, bullet_items, image)
|
223 |
+
|
224 |
+
@tool_api(explode_return=True)
|
225 |
+
@asyncify
|
226 |
+
def submit_file(self) -> dict:
|
227 |
+
"""When all steps done, YOU MUST use submit_file() to submit your work.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
:class:`dict`: operation status
|
231 |
+
* status: the result of the execution
|
232 |
+
"""
|
233 |
+
return super().submit_file()
|
lagent/actions/python_interpreter.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa: E501
|
2 |
+
import copy
|
3 |
+
import io
|
4 |
+
from contextlib import redirect_stdout
|
5 |
+
from typing import Any, Optional, Type
|
6 |
+
|
7 |
+
from asyncer import asyncify
|
8 |
+
|
9 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
10 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
11 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
12 |
+
|
13 |
+
|
14 |
+
class GenericRuntime:
|
15 |
+
GLOBAL_DICT = {}
|
16 |
+
LOCAL_DICT = None
|
17 |
+
HEADERS = []
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
self._global_vars = copy.copy(self.GLOBAL_DICT)
|
21 |
+
self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
|
22 |
+
|
23 |
+
for c in self.HEADERS:
|
24 |
+
self.exec_code(c)
|
25 |
+
|
26 |
+
def exec_code(self, code_piece: str) -> None:
|
27 |
+
exec(code_piece, self._global_vars)
|
28 |
+
|
29 |
+
def eval_code(self, expr: str) -> Any:
|
30 |
+
return eval(expr, self._global_vars)
|
31 |
+
|
32 |
+
|
33 |
+
class PythonInterpreter(BaseAction):
|
34 |
+
"""A Python executor that can execute Python scripts.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
|
38 |
+
answer_expr (str, Optional): the answer function name of the Python
|
39 |
+
script. Defaults to ``'solution()'``.
|
40 |
+
answer_from_stdout (boolean, Optional): whether the execution results is from
|
41 |
+
stdout. Defaults to ``False``.
|
42 |
+
timeout (int, Optional): Upper bound of waiting time for Python script execution.
|
43 |
+
Defaults to ``20``.
|
44 |
+
description (dict, Optional): The description of the action. Defaults to ``None``.
|
45 |
+
parser (Type[BaseParser]): The parser class to process the
|
46 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
answer_symbol: Optional[str] = None,
|
52 |
+
answer_expr: Optional[str] = 'solution()',
|
53 |
+
answer_from_stdout: bool = False,
|
54 |
+
timeout: int = 20,
|
55 |
+
description: Optional[dict] = None,
|
56 |
+
parser: Type[BaseParser] = JsonParser,
|
57 |
+
) -> None:
|
58 |
+
super().__init__(description, parser)
|
59 |
+
self.answer_symbol = answer_symbol
|
60 |
+
self.answer_expr = answer_expr
|
61 |
+
self.answer_from_stdout = answer_from_stdout
|
62 |
+
self.timeout = timeout
|
63 |
+
|
64 |
+
@tool_api
|
65 |
+
def run(self, command: str) -> ActionReturn:
|
66 |
+
"""用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
67 |
+
|
68 |
+
```python
|
69 |
+
# import 依赖包
|
70 |
+
import xxx
|
71 |
+
def solution():
|
72 |
+
# 初始化一些变量
|
73 |
+
variable_names_with_real_meaning = xxx
|
74 |
+
# 步骤一
|
75 |
+
mid_variable = func(variable_names_with_real_meaning)
|
76 |
+
# 步骤 x
|
77 |
+
mid_variable = func(mid_variable)
|
78 |
+
# 最后结果
|
79 |
+
final_answer = func(mid_variable)
|
80 |
+
return final_answer
|
81 |
+
```
|
82 |
+
|
83 |
+
Args:
|
84 |
+
command (:class:`str`): Python code snippet
|
85 |
+
"""
|
86 |
+
from func_timeout import FunctionTimedOut, func_set_timeout
|
87 |
+
|
88 |
+
self.runtime = GenericRuntime()
|
89 |
+
try:
|
90 |
+
tool_return = func_set_timeout(self.timeout)(self._call)(command)
|
91 |
+
except FunctionTimedOut as e:
|
92 |
+
tool_return = ActionReturn(type=self.name)
|
93 |
+
tool_return.errmsg = repr(e)
|
94 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
95 |
+
return tool_return
|
96 |
+
|
97 |
+
def _call(self, command: str) -> ActionReturn:
|
98 |
+
tool_return = ActionReturn(type=self.name)
|
99 |
+
try:
|
100 |
+
if '```python' in command:
|
101 |
+
command = command.split('```python')[1].split('```')[0]
|
102 |
+
elif '```' in command:
|
103 |
+
command = command.split('```')[1].split('```')[0]
|
104 |
+
tool_return.args = dict(text='```python\n' + command + '\n```')
|
105 |
+
command = command.split('\n')
|
106 |
+
|
107 |
+
if self.answer_from_stdout:
|
108 |
+
program_io = io.StringIO()
|
109 |
+
with redirect_stdout(program_io):
|
110 |
+
self.runtime.exec_code('\n'.join(command))
|
111 |
+
program_io.seek(0)
|
112 |
+
res = program_io.readlines()[-1]
|
113 |
+
elif self.answer_symbol:
|
114 |
+
self.runtime.exec_code('\n'.join(command))
|
115 |
+
res = self.runtime._global_vars[self.answer_symbol]
|
116 |
+
elif self.answer_expr:
|
117 |
+
self.runtime.exec_code('\n'.join(command))
|
118 |
+
res = self.runtime.eval_code(self.answer_expr)
|
119 |
+
else:
|
120 |
+
self.runtime.exec_code('\n'.join(command[:-1]))
|
121 |
+
res = self.runtime.eval_code(command[-1])
|
122 |
+
except Exception as e:
|
123 |
+
tool_return.errmsg = repr(e)
|
124 |
+
tool_return.type = self.name
|
125 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
126 |
+
return tool_return
|
127 |
+
try:
|
128 |
+
tool_return.result = [dict(type='text', content=str(res))]
|
129 |
+
tool_return.state = ActionStatusCode.SUCCESS
|
130 |
+
except Exception as e:
|
131 |
+
tool_return.errmsg = repr(e)
|
132 |
+
tool_return.type = self.name
|
133 |
+
tool_return.state = ActionStatusCode.API_ERROR
|
134 |
+
return tool_return
|
135 |
+
|
136 |
+
|
137 |
+
class AsyncPythonInterpreter(AsyncActionMixin, PythonInterpreter):
|
138 |
+
"""A Python executor that can execute Python scripts.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
|
142 |
+
answer_expr (str, Optional): the answer function name of the Python
|
143 |
+
script. Defaults to ``'solution()'``.
|
144 |
+
answer_from_stdout (boolean, Optional): whether the execution results is from
|
145 |
+
stdout. Defaults to ``False``.
|
146 |
+
timeout (int, Optional): Upper bound of waiting time for Python script execution.
|
147 |
+
Defaults to ``20``.
|
148 |
+
description (dict, Optional): The description of the action. Defaults to ``None``.
|
149 |
+
parser (Type[BaseParser]): The parser class to process the
|
150 |
+
action's inputs and outputs. Defaults to :class:`JsonParser`.
|
151 |
+
"""
|
152 |
+
|
153 |
+
@tool_api
|
154 |
+
@asyncify
|
155 |
+
def run(self, command: str) -> ActionReturn:
|
156 |
+
"""用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
|
157 |
+
|
158 |
+
```python
|
159 |
+
# import 依赖包
|
160 |
+
import xxx
|
161 |
+
def solution():
|
162 |
+
# 初始化一些变量
|
163 |
+
variable_names_with_real_meaning = xxx
|
164 |
+
# 步骤一
|
165 |
+
mid_variable = func(variable_names_with_real_meaning)
|
166 |
+
# 步骤 x
|
167 |
+
mid_variable = func(mid_variable)
|
168 |
+
# 最后结果
|
169 |
+
final_answer = func(mid_variable)
|
170 |
+
return final_answer
|
171 |
+
```
|
172 |
+
|
173 |
+
Args:
|
174 |
+
command (:class:`str`): Python code snippet
|
175 |
+
"""
|
176 |
+
return super().run(command)
|
lagent/actions/weather_query.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from lagent.actions.base_action import BaseAction, tool_api
|
4 |
+
from lagent.schema import ActionReturn, ActionStatusCode
|
5 |
+
|
6 |
+
class WeatherQuery(BaseAction):
|
7 |
+
def __init__(self):
|
8 |
+
super().__init__()
|
9 |
+
self.api_key = os.getenv("weather_token")
|
10 |
+
print(self.api_key)
|
11 |
+
if not self.api_key:
|
12 |
+
raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
|
13 |
+
|
14 |
+
@tool_api
|
15 |
+
def run(self, location: str) -> dict:
|
16 |
+
"""
|
17 |
+
查询实时天气信息。
|
18 |
+
|
19 |
+
Args:
|
20 |
+
location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
dict: 包含天气信息的字典
|
24 |
+
* location: 地点名称
|
25 |
+
* weather: 天气状况
|
26 |
+
* temperature: 当前温度
|
27 |
+
* wind_direction: 风向
|
28 |
+
* wind_speed: 风速(公里/小时)
|
29 |
+
* humidity: 相对湿度(%)
|
30 |
+
* report_time: 数据报告时间
|
31 |
+
"""
|
32 |
+
try:
|
33 |
+
# 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
|
34 |
+
if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
|
35 |
+
# 使用 GeoAPI 获取 LocationID
|
36 |
+
geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
|
37 |
+
geo_response = requests.get(geo_url)
|
38 |
+
geo_data = geo_response.json()
|
39 |
+
|
40 |
+
if geo_data.get("code") != "200" or not geo_data.get("location"):
|
41 |
+
raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
|
42 |
+
|
43 |
+
location = geo_data["location"][0]["id"]
|
44 |
+
|
45 |
+
# 构建天气查询的 API 请求 URL
|
46 |
+
weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
|
47 |
+
response = requests.get(weather_url)
|
48 |
+
data = response.json()
|
49 |
+
|
50 |
+
# 检查 API 响应码
|
51 |
+
if data.get("code") != "200":
|
52 |
+
raise Exception(f"Weather API 返回错误码:{data.get('code')}")
|
53 |
+
|
54 |
+
# 解析和组织天气信息
|
55 |
+
weather_info = {
|
56 |
+
"location": location,
|
57 |
+
"weather": data["now"]["text"],
|
58 |
+
"temperature": data["now"]["temp"] + "°C",
|
59 |
+
"wind_direction": data["now"]["windDir"],
|
60 |
+
"wind_speed": data["now"]["windSpeed"] + " km/h",
|
61 |
+
"humidity": data["now"]["humidity"] + "%",
|
62 |
+
"report_time": data["updateTime"]
|
63 |
+
}
|
64 |
+
|
65 |
+
return {"result": weather_info}
|
66 |
+
|
67 |
+
except Exception as exc:
|
68 |
+
return ActionReturn(
|
69 |
+
errmsg=f"WeatherQuery 异常:{exc}",
|
70 |
+
state=ActionStatusCode.HTTP_ERROR
|
71 |
+
)
|
lagent/actions/web_browser.py
ADDED
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import hashlib
|
3 |
+
import hmac
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import random
|
7 |
+
import re
|
8 |
+
import time
|
9 |
+
import warnings
|
10 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
11 |
+
from datetime import datetime
|
12 |
+
from http.client import HTTPSConnection
|
13 |
+
from typing import List, Optional, Tuple, Type, Union
|
14 |
+
|
15 |
+
import aiohttp
|
16 |
+
import aiohttp.client_exceptions
|
17 |
+
import requests
|
18 |
+
from asyncache import cached as acached
|
19 |
+
from bs4 import BeautifulSoup
|
20 |
+
from cachetools import TTLCache, cached
|
21 |
+
from duckduckgo_search import DDGS, AsyncDDGS
|
22 |
+
|
23 |
+
from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
|
24 |
+
from lagent.actions.parser import BaseParser, JsonParser
|
25 |
+
from lagent.utils import async_as_completed
|
26 |
+
|
27 |
+
|
28 |
+
class BaseSearch:
|
29 |
+
|
30 |
+
def __init__(self, topk: int = 3, black_list: List[str] = None):
|
31 |
+
self.topk = topk
|
32 |
+
self.black_list = black_list
|
33 |
+
|
34 |
+
def _filter_results(self, results: List[tuple]) -> dict:
|
35 |
+
filtered_results = {}
|
36 |
+
count = 0
|
37 |
+
for url, snippet, title in results:
|
38 |
+
if all(domain not in url
|
39 |
+
for domain in self.black_list) and not url.endswith('.pdf'):
|
40 |
+
filtered_results[count] = {
|
41 |
+
'url': url,
|
42 |
+
'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
|
43 |
+
'title': title
|
44 |
+
}
|
45 |
+
count += 1
|
46 |
+
if count >= self.topk:
|
47 |
+
break
|
48 |
+
return filtered_results
|
49 |
+
|
50 |
+
|
51 |
+
class DuckDuckGoSearch(BaseSearch):
|
52 |
+
|
53 |
+
def __init__(self,
|
54 |
+
topk: int = 3,
|
55 |
+
black_list: List[str] = [
|
56 |
+
'enoN',
|
57 |
+
'youtube.com',
|
58 |
+
'bilibili.com',
|
59 |
+
'researchgate.net',
|
60 |
+
],
|
61 |
+
**kwargs):
|
62 |
+
self.proxy = kwargs.get('proxy')
|
63 |
+
self.timeout = kwargs.get('timeout', 30)
|
64 |
+
super().__init__(topk, black_list)
|
65 |
+
|
66 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
67 |
+
def search(self, query: str, max_retry: int = 3) -> dict:
|
68 |
+
for attempt in range(max_retry):
|
69 |
+
try:
|
70 |
+
response = self._call_ddgs(
|
71 |
+
query, timeout=self.timeout, proxy=self.proxy)
|
72 |
+
return self._parse_response(response)
|
73 |
+
except Exception as e:
|
74 |
+
logging.exception(str(e))
|
75 |
+
warnings.warn(
|
76 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
77 |
+
time.sleep(random.randint(2, 5))
|
78 |
+
raise Exception(
|
79 |
+
'Failed to get search results from DuckDuckGo after retries.')
|
80 |
+
|
81 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
82 |
+
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
83 |
+
for attempt in range(max_retry):
|
84 |
+
try:
|
85 |
+
ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
|
86 |
+
response = await ddgs.atext(query.strip("'"), max_results=10)
|
87 |
+
return self._parse_response(response)
|
88 |
+
except Exception as e:
|
89 |
+
if isinstance(e, asyncio.TimeoutError):
|
90 |
+
logging.exception('Request to DDGS timed out.')
|
91 |
+
logging.exception(str(e))
|
92 |
+
warnings.warn(
|
93 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
94 |
+
await asyncio.sleep(random.randint(2, 5))
|
95 |
+
raise Exception(
|
96 |
+
'Failed to get search results from DuckDuckGo after retries.')
|
97 |
+
|
98 |
+
async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
|
99 |
+
ddgs = DDGS(**kwargs)
|
100 |
+
try:
|
101 |
+
response = await asyncio.wait_for(
|
102 |
+
asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
|
103 |
+
timeout=self.timeout)
|
104 |
+
return response
|
105 |
+
except asyncio.TimeoutError:
|
106 |
+
logging.exception('Request to DDGS timed out.')
|
107 |
+
raise
|
108 |
+
|
109 |
+
def _call_ddgs(self, query: str, **kwargs) -> dict:
|
110 |
+
loop = asyncio.new_event_loop()
|
111 |
+
asyncio.set_event_loop(loop)
|
112 |
+
try:
|
113 |
+
response = loop.run_until_complete(
|
114 |
+
self._async_call_ddgs(query, **kwargs))
|
115 |
+
return response
|
116 |
+
finally:
|
117 |
+
loop.close()
|
118 |
+
|
119 |
+
def _parse_response(self, response: dict) -> dict:
|
120 |
+
raw_results = []
|
121 |
+
for item in response:
|
122 |
+
raw_results.append(
|
123 |
+
(item['href'], item['description']
|
124 |
+
if 'description' in item else item['body'], item['title']))
|
125 |
+
return self._filter_results(raw_results)
|
126 |
+
|
127 |
+
|
128 |
+
class BingSearch(BaseSearch):
|
129 |
+
|
130 |
+
def __init__(self,
|
131 |
+
api_key: str,
|
132 |
+
region: str = 'zh-CN',
|
133 |
+
topk: int = 3,
|
134 |
+
black_list: List[str] = [
|
135 |
+
'enoN',
|
136 |
+
'youtube.com',
|
137 |
+
'bilibili.com',
|
138 |
+
'researchgate.net',
|
139 |
+
],
|
140 |
+
**kwargs):
|
141 |
+
self.api_key = api_key
|
142 |
+
self.market = region
|
143 |
+
self.proxy = kwargs.get('proxy')
|
144 |
+
super().__init__(topk, black_list)
|
145 |
+
|
146 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
147 |
+
def search(self, query: str, max_retry: int = 3) -> dict:
|
148 |
+
for attempt in range(max_retry):
|
149 |
+
try:
|
150 |
+
response = self._call_bing_api(query)
|
151 |
+
return self._parse_response(response)
|
152 |
+
except Exception as e:
|
153 |
+
logging.exception(str(e))
|
154 |
+
warnings.warn(
|
155 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
156 |
+
time.sleep(random.randint(2, 5))
|
157 |
+
raise Exception(
|
158 |
+
'Failed to get search results from Bing Search after retries.')
|
159 |
+
|
160 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
161 |
+
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
162 |
+
for attempt in range(max_retry):
|
163 |
+
try:
|
164 |
+
response = await self._async_call_bing_api(query)
|
165 |
+
return self._parse_response(response)
|
166 |
+
except Exception as e:
|
167 |
+
logging.exception(str(e))
|
168 |
+
warnings.warn(
|
169 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
170 |
+
await asyncio.sleep(random.randint(2, 5))
|
171 |
+
raise Exception(
|
172 |
+
'Failed to get search results from Bing Search after retries.')
|
173 |
+
|
174 |
+
def _call_bing_api(self, query: str) -> dict:
|
175 |
+
endpoint = 'https://api.bing.microsoft.com/v7.0/search'
|
176 |
+
params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
|
177 |
+
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
178 |
+
response = requests.get(
|
179 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
180 |
+
response.raise_for_status()
|
181 |
+
return response.json()
|
182 |
+
|
183 |
+
async def _async_call_bing_api(self, query: str) -> dict:
|
184 |
+
endpoint = 'https://api.bing.microsoft.com/v7.0/search'
|
185 |
+
params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
|
186 |
+
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
187 |
+
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
188 |
+
async with session.get(
|
189 |
+
endpoint,
|
190 |
+
headers=headers,
|
191 |
+
params=params,
|
192 |
+
proxy=self.proxy and
|
193 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
194 |
+
return await resp.json()
|
195 |
+
|
196 |
+
def _parse_response(self, response: dict) -> dict:
|
197 |
+
webpages = {
|
198 |
+
w['id']: w
|
199 |
+
for w in response.get('webPages', {}).get('value', [])
|
200 |
+
}
|
201 |
+
raw_results = []
|
202 |
+
|
203 |
+
for item in response.get('rankingResponse',
|
204 |
+
{}).get('mainline', {}).get('items', []):
|
205 |
+
if item['answerType'] == 'WebPages':
|
206 |
+
webpage = webpages.get(item['value']['id'])
|
207 |
+
if webpage:
|
208 |
+
raw_results.append(
|
209 |
+
(webpage['url'], webpage['snippet'], webpage['name']))
|
210 |
+
elif item['answerType'] == 'News' and item['value'][
|
211 |
+
'id'] == response.get('news', {}).get('id'):
|
212 |
+
for news in response.get('news', {}).get('value', []):
|
213 |
+
raw_results.append(
|
214 |
+
(news['url'], news['description'], news['name']))
|
215 |
+
|
216 |
+
return self._filter_results(raw_results)
|
217 |
+
|
218 |
+
|
219 |
+
class BraveSearch(BaseSearch):
|
220 |
+
"""
|
221 |
+
Wrapper around the Brave Search API.
|
222 |
+
|
223 |
+
To use, you should pass your Brave Search API key to the constructor.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
api_key (str): API KEY to use Brave Search API.
|
227 |
+
You can create a free API key at https://api.search.brave.com/app/keys.
|
228 |
+
search_type (str): Brave Search API supports ['web', 'news', 'images', 'videos'],
|
229 |
+
currently only supports 'news' and 'web'.
|
230 |
+
topk (int): The number of search results returned in response from API search results.
|
231 |
+
region (str): The country code string. Specifies the country where the search results come from.
|
232 |
+
language (str): The language code string. Specifies the preferred language for the search results.
|
233 |
+
extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
|
234 |
+
**kwargs: Any other parameters related to the Brave Search API. Find more details at
|
235 |
+
https://api.search.brave.com/app/documentation/web-search/get-started.
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self,
|
239 |
+
api_key: str,
|
240 |
+
region: str = 'ALL',
|
241 |
+
language: str = 'zh-hans',
|
242 |
+
extra_snippests: bool = True,
|
243 |
+
topk: int = 3,
|
244 |
+
black_list: List[str] = [
|
245 |
+
'enoN',
|
246 |
+
'youtube.com',
|
247 |
+
'bilibili.com',
|
248 |
+
'researchgate.net',
|
249 |
+
],
|
250 |
+
**kwargs):
|
251 |
+
self.api_key = api_key
|
252 |
+
self.market = region
|
253 |
+
self.proxy = kwargs.get('proxy')
|
254 |
+
self.language = language
|
255 |
+
self.extra_snippests = extra_snippests
|
256 |
+
self.search_type = kwargs.get('search_type', 'web')
|
257 |
+
self.kwargs = kwargs
|
258 |
+
super().__init__(topk, black_list)
|
259 |
+
|
260 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
261 |
+
def search(self, query: str, max_retry: int = 3) -> dict:
|
262 |
+
for attempt in range(max_retry):
|
263 |
+
try:
|
264 |
+
response = self._call_brave_api(query)
|
265 |
+
return self._parse_response(response)
|
266 |
+
except Exception as e:
|
267 |
+
logging.exception(str(e))
|
268 |
+
warnings.warn(
|
269 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
270 |
+
time.sleep(random.randint(2, 5))
|
271 |
+
raise Exception(
|
272 |
+
'Failed to get search results from Brave Search after retries.')
|
273 |
+
|
274 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
275 |
+
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
276 |
+
for attempt in range(max_retry):
|
277 |
+
try:
|
278 |
+
response = await self._async_call_brave_api(query)
|
279 |
+
return self._parse_response(response)
|
280 |
+
except Exception as e:
|
281 |
+
logging.exception(str(e))
|
282 |
+
warnings.warn(
|
283 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
284 |
+
await asyncio.sleep(random.randint(2, 5))
|
285 |
+
raise Exception(
|
286 |
+
'Failed to get search results from Brave Search after retries.')
|
287 |
+
|
288 |
+
def _call_brave_api(self, query: str) -> dict:
|
289 |
+
endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
|
290 |
+
params = {
|
291 |
+
'q': query,
|
292 |
+
'country': self.market,
|
293 |
+
'search_lang': self.language,
|
294 |
+
'extra_snippets': self.extra_snippests,
|
295 |
+
'count': self.topk,
|
296 |
+
**{
|
297 |
+
key: value
|
298 |
+
for key, value in self.kwargs.items() if value is not None
|
299 |
+
},
|
300 |
+
}
|
301 |
+
headers = {
|
302 |
+
'X-Subscription-Token': self.api_key or '',
|
303 |
+
'Accept': 'application/json'
|
304 |
+
}
|
305 |
+
response = requests.get(
|
306 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
307 |
+
response.raise_for_status()
|
308 |
+
return response.json()
|
309 |
+
|
310 |
+
async def _async_call_brave_api(self, query: str) -> dict:
|
311 |
+
endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
|
312 |
+
params = {
|
313 |
+
'q': query,
|
314 |
+
'country': self.market,
|
315 |
+
'search_lang': self.language,
|
316 |
+
'extra_snippets': self.extra_snippests,
|
317 |
+
'count': self.topk,
|
318 |
+
**{
|
319 |
+
key: value
|
320 |
+
for key, value in self.kwargs.items() if value is not None
|
321 |
+
},
|
322 |
+
}
|
323 |
+
headers = {
|
324 |
+
'X-Subscription-Token': self.api_key or '',
|
325 |
+
'Accept': 'application/json'
|
326 |
+
}
|
327 |
+
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
328 |
+
async with session.get(
|
329 |
+
endpoint,
|
330 |
+
headers=headers,
|
331 |
+
params=params,
|
332 |
+
proxy=self.proxy and
|
333 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
334 |
+
return await resp.json()
|
335 |
+
|
336 |
+
def _parse_response(self, response: dict) -> dict:
|
337 |
+
if self.search_type == 'web':
|
338 |
+
filtered_result = response.get('web', {}).get('results', [])
|
339 |
+
else:
|
340 |
+
filtered_result = response.get('results', {})
|
341 |
+
raw_results = []
|
342 |
+
|
343 |
+
for item in filtered_result:
|
344 |
+
raw_results.append((
|
345 |
+
item.get('url', ''),
|
346 |
+
' '.join(
|
347 |
+
filter(None, [
|
348 |
+
item.get('description'),
|
349 |
+
*item.get('extra_snippets', [])
|
350 |
+
])),
|
351 |
+
item.get('title', ''),
|
352 |
+
))
|
353 |
+
return self._filter_results(raw_results)
|
354 |
+
|
355 |
+
|
356 |
+
class GoogleSearch(BaseSearch):
|
357 |
+
"""
|
358 |
+
Wrapper around the Serper.dev Google Search API.
|
359 |
+
|
360 |
+
To use, you should pass your serper API key to the constructor.
|
361 |
+
|
362 |
+
Args:
|
363 |
+
api_key (str): API KEY to use serper google search API.
|
364 |
+
You can create a free API key at https://serper.dev.
|
365 |
+
search_type (str): Serper API supports ['search', 'images', 'news',
|
366 |
+
'places'] types of search, currently we only support 'search' and 'news'.
|
367 |
+
topk (int): The number of search results returned in response from api search results.
|
368 |
+
**kwargs: Any other parameters related to the Serper API. Find more details at
|
369 |
+
https://serper.dev/playground
|
370 |
+
"""
|
371 |
+
|
372 |
+
result_key_for_type = {
|
373 |
+
'news': 'news',
|
374 |
+
'places': 'places',
|
375 |
+
'images': 'images',
|
376 |
+
'search': 'organic',
|
377 |
+
}
|
378 |
+
|
379 |
+
def __init__(self,
|
380 |
+
api_key: str,
|
381 |
+
topk: int = 3,
|
382 |
+
black_list: List[str] = [
|
383 |
+
'enoN',
|
384 |
+
'youtube.com',
|
385 |
+
'bilibili.com',
|
386 |
+
'researchgate.net',
|
387 |
+
],
|
388 |
+
**kwargs):
|
389 |
+
self.api_key = api_key
|
390 |
+
self.proxy = kwargs.get('proxy')
|
391 |
+
self.search_type = kwargs.get('search_type', 'search')
|
392 |
+
self.kwargs = kwargs
|
393 |
+
super().__init__(topk, black_list)
|
394 |
+
|
395 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
396 |
+
def search(self, query: str, max_retry: int = 3) -> dict:
|
397 |
+
for attempt in range(max_retry):
|
398 |
+
try:
|
399 |
+
response = self._call_serper_api(query)
|
400 |
+
return self._parse_response(response)
|
401 |
+
except Exception as e:
|
402 |
+
logging.exception(str(e))
|
403 |
+
warnings.warn(
|
404 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
405 |
+
time.sleep(random.randint(2, 5))
|
406 |
+
raise Exception(
|
407 |
+
'Failed to get search results from Google Serper Search after retries.'
|
408 |
+
)
|
409 |
+
|
410 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
411 |
+
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
412 |
+
for attempt in range(max_retry):
|
413 |
+
try:
|
414 |
+
response = await self._async_call_serper_api(query)
|
415 |
+
return self._parse_response(response)
|
416 |
+
except Exception as e:
|
417 |
+
logging.exception(str(e))
|
418 |
+
warnings.warn(
|
419 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
420 |
+
await asyncio.sleep(random.randint(2, 5))
|
421 |
+
raise Exception(
|
422 |
+
'Failed to get search results from Google Serper Search after retries.'
|
423 |
+
)
|
424 |
+
|
425 |
+
def _call_serper_api(self, query: str) -> dict:
|
426 |
+
endpoint = f'https://google.serper.dev/{self.search_type}'
|
427 |
+
params = {
|
428 |
+
'q': query,
|
429 |
+
'num': self.topk,
|
430 |
+
**{
|
431 |
+
key: value
|
432 |
+
for key, value in self.kwargs.items() if value is not None
|
433 |
+
},
|
434 |
+
}
|
435 |
+
headers = {
|
436 |
+
'X-API-KEY': self.api_key or '',
|
437 |
+
'Content-Type': 'application/json'
|
438 |
+
}
|
439 |
+
response = requests.get(
|
440 |
+
endpoint, headers=headers, params=params, proxies=self.proxy)
|
441 |
+
response.raise_for_status()
|
442 |
+
return response.json()
|
443 |
+
|
444 |
+
async def _async_call_serper_api(self, query: str) -> dict:
|
445 |
+
endpoint = f'https://google.serper.dev/{self.search_type}'
|
446 |
+
params = {
|
447 |
+
'q': query,
|
448 |
+
'num': self.topk,
|
449 |
+
**{
|
450 |
+
key: value
|
451 |
+
for key, value in self.kwargs.items() if value is not None
|
452 |
+
},
|
453 |
+
}
|
454 |
+
headers = {
|
455 |
+
'X-API-KEY': self.api_key or '',
|
456 |
+
'Content-Type': 'application/json'
|
457 |
+
}
|
458 |
+
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
459 |
+
async with session.get(
|
460 |
+
endpoint,
|
461 |
+
headers=headers,
|
462 |
+
params=params,
|
463 |
+
proxy=self.proxy and
|
464 |
+
(self.proxy.get('http') or self.proxy.get('https'))) as resp:
|
465 |
+
return await resp.json()
|
466 |
+
|
467 |
+
def _parse_response(self, response: dict) -> dict:
|
468 |
+
raw_results = []
|
469 |
+
|
470 |
+
if response.get('answerBox'):
|
471 |
+
answer_box = response.get('answerBox', {})
|
472 |
+
if answer_box.get('answer'):
|
473 |
+
raw_results.append(('', answer_box.get('answer'), ''))
|
474 |
+
elif answer_box.get('snippet'):
|
475 |
+
raw_results.append(
|
476 |
+
('', answer_box.get('snippet').replace('\n', ' '), ''))
|
477 |
+
elif answer_box.get('snippetHighlighted'):
|
478 |
+
raw_results.append(
|
479 |
+
('', answer_box.get('snippetHighlighted'), ''))
|
480 |
+
|
481 |
+
if response.get('knowledgeGraph'):
|
482 |
+
kg = response.get('knowledgeGraph', {})
|
483 |
+
description = kg.get('description', '')
|
484 |
+
attributes = '. '.join(
|
485 |
+
f'{attribute}: {value}'
|
486 |
+
for attribute, value in kg.get('attributes', {}).items())
|
487 |
+
raw_results.append(
|
488 |
+
(kg.get('descriptionLink', ''),
|
489 |
+
f'{description}. {attributes}' if attributes else description,
|
490 |
+
f"{kg.get('title', '')}: {kg.get('type', '')}."))
|
491 |
+
|
492 |
+
for result in response[self.result_key_for_type[
|
493 |
+
self.search_type]][:self.topk]:
|
494 |
+
description = result.get('snippet', '')
|
495 |
+
attributes = '. '.join(
|
496 |
+
f'{attribute}: {value}'
|
497 |
+
for attribute, value in result.get('attributes', {}).items())
|
498 |
+
raw_results.append(
|
499 |
+
(result.get('link', ''),
|
500 |
+
f'{description}. {attributes}' if attributes else description,
|
501 |
+
result.get('title', '')))
|
502 |
+
|
503 |
+
return self._filter_results(raw_results)
|
504 |
+
|
505 |
+
|
506 |
+
class TencentSearch(BaseSearch):
|
507 |
+
"""Wrapper around the tencentclound Search API.
|
508 |
+
|
509 |
+
To use, you should pass your secret_id and secret_key to the constructor.
|
510 |
+
|
511 |
+
Args:
|
512 |
+
secret_id (str): Your Tencent Cloud secret ID for accessing the API.
|
513 |
+
For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488.
|
514 |
+
secret_key (str): Your Tencent Cloud secret key for accessing the API.
|
515 |
+
api_key (str, optional): Additional API key, if required.
|
516 |
+
action (str): The action for this interface, use `SearchCommon`.
|
517 |
+
version (str): The API version, use `2020-12-29`.
|
518 |
+
service (str): The service name, use `tms`.
|
519 |
+
host (str): The API host, use `tms.tencentcloudapi.com`.
|
520 |
+
topk (int): The maximum number of search results to return.
|
521 |
+
tsn (int): Time filter for search results. Valid values:
|
522 |
+
1 (within 1 day), 2 (within 1 week), 3 (within 1 month),
|
523 |
+
4 (within 1 year), 5 (within 6 months), 6 (within 3 years).
|
524 |
+
insite (str): Specify a site to search within (supports only a single site).
|
525 |
+
If not specified, the entire web is searched. Example: `zhihu.com`.
|
526 |
+
category (str): Vertical category for filtering results. Optional values include:
|
527 |
+
`baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope).
|
528 |
+
vrid (str): Result card type(s). Different `vrid` values represent different types of result cards.
|
529 |
+
Supports multiple values separated by commas. Example: `30010255`.
|
530 |
+
"""
|
531 |
+
|
532 |
+
def __init__(self,
|
533 |
+
secret_id: str = 'Your SecretId',
|
534 |
+
secret_key: str = 'Your SecretKey',
|
535 |
+
api_key: str = '',
|
536 |
+
action: str = 'SearchCommon',
|
537 |
+
version: str = '2020-12-29',
|
538 |
+
service: str = 'tms',
|
539 |
+
host: str = 'tms.tencentcloudapi.com',
|
540 |
+
topk: int = 3,
|
541 |
+
tsn: int = None,
|
542 |
+
insite: str = None,
|
543 |
+
category: str = None,
|
544 |
+
vrid: str = None,
|
545 |
+
black_list: List[str] = [
|
546 |
+
'enoN',
|
547 |
+
'youtube.com',
|
548 |
+
'bilibili.com',
|
549 |
+
'researchgate.net',
|
550 |
+
]):
|
551 |
+
self.secret_id = secret_id
|
552 |
+
self.secret_key = secret_key
|
553 |
+
self.api_key = api_key
|
554 |
+
self.action = action
|
555 |
+
self.version = version
|
556 |
+
self.service = service
|
557 |
+
self.host = host
|
558 |
+
self.tsn = tsn
|
559 |
+
self.insite = insite
|
560 |
+
self.category = category
|
561 |
+
self.vrid = vrid
|
562 |
+
super().__init__(topk, black_list=black_list)
|
563 |
+
|
564 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
565 |
+
def search(self, query: str, max_retry: int = 3) -> dict:
|
566 |
+
for attempt in range(max_retry):
|
567 |
+
try:
|
568 |
+
response = self._call_tencent_api(query)
|
569 |
+
return self._parse_response(response)
|
570 |
+
except Exception as e:
|
571 |
+
logging.exception(str(e))
|
572 |
+
warnings.warn(
|
573 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
574 |
+
time.sleep(random.randint(2, 5))
|
575 |
+
raise Exception(
|
576 |
+
'Failed to get search results from Bing Search after retries.')
|
577 |
+
|
578 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
579 |
+
async def asearch(self, query: str, max_retry: int = 3) -> dict:
|
580 |
+
for attempt in range(max_retry):
|
581 |
+
try:
|
582 |
+
response = await self._async_call_tencent_api(query)
|
583 |
+
return self._parse_response(response)
|
584 |
+
except Exception as e:
|
585 |
+
logging.exception(str(e))
|
586 |
+
warnings.warn(
|
587 |
+
f'Retry {attempt + 1}/{max_retry} due to error: {e}')
|
588 |
+
await asyncio.sleep(random.randint(2, 5))
|
589 |
+
raise Exception(
|
590 |
+
'Failed to get search results from Bing Search after retries.')
|
591 |
+
|
592 |
+
def _get_headers_and_payload(self, query: str) -> tuple:
|
593 |
+
|
594 |
+
def sign(key, msg):
|
595 |
+
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
|
596 |
+
|
597 |
+
params = dict(Query=query)
|
598 |
+
# if self.topk:
|
599 |
+
# params['Cnt'] = self.topk
|
600 |
+
if self.tsn:
|
601 |
+
params['Tsn'] = self.tsn
|
602 |
+
if self.insite:
|
603 |
+
params['Insite'] = self.insite
|
604 |
+
if self.category:
|
605 |
+
params['Category'] = self.category
|
606 |
+
if self.vrid:
|
607 |
+
params['Vrid'] = self.vrid
|
608 |
+
payload = json.dumps(params)
|
609 |
+
algorithm = 'TC3-HMAC-SHA256'
|
610 |
+
timestamp = int(time.time())
|
611 |
+
date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d')
|
612 |
+
|
613 |
+
# ************* 步骤 1:拼接规范请求串 *************
|
614 |
+
http_request_method = 'POST'
|
615 |
+
canonical_uri = '/'
|
616 |
+
canonical_querystring = ''
|
617 |
+
ct = 'application/json; charset=utf-8'
|
618 |
+
canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
|
619 |
+
signed_headers = 'content-type;host;x-tc-action'
|
620 |
+
hashed_request_payload = hashlib.sha256(
|
621 |
+
payload.encode('utf-8')).hexdigest()
|
622 |
+
canonical_request = (
|
623 |
+
http_request_method + '\n' + canonical_uri + '\n' +
|
624 |
+
canonical_querystring + '\n' + canonical_headers + '\n' +
|
625 |
+
signed_headers + '\n' + hashed_request_payload)
|
626 |
+
|
627 |
+
# ************* 步骤 2:拼接待签名字符串 *************
|
628 |
+
credential_scope = date + '/' + self.service + '/' + 'tc3_request'
|
629 |
+
hashed_canonical_request = hashlib.sha256(
|
630 |
+
canonical_request.encode('utf-8')).hexdigest()
|
631 |
+
string_to_sign = (
|
632 |
+
algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
|
633 |
+
'\n' + hashed_canonical_request)
|
634 |
+
|
635 |
+
# ************* 步骤 3:计算签名 *************
|
636 |
+
secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
|
637 |
+
secret_service = sign(secret_date, self.service)
|
638 |
+
secret_signing = sign(secret_service, 'tc3_request')
|
639 |
+
signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
|
640 |
+
hashlib.sha256).hexdigest()
|
641 |
+
|
642 |
+
# ************* 步骤 4:拼接 Authorization *************
|
643 |
+
authorization = (
|
644 |
+
algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
|
645 |
+
credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
|
646 |
+
', ' + 'Signature=' + signature)
|
647 |
+
|
648 |
+
# ************* 步骤 5:构造并发起请求 *************
|
649 |
+
headers = {
|
650 |
+
'Authorization': authorization,
|
651 |
+
'Content-Type': 'application/json; charset=utf-8',
|
652 |
+
'Host': self.host,
|
653 |
+
'X-TC-Action': self.action,
|
654 |
+
'X-TC-Timestamp': str(timestamp),
|
655 |
+
'X-TC-Version': self.version
|
656 |
+
}
|
657 |
+
# if self.region:
|
658 |
+
# headers["X-TC-Region"] = self.region
|
659 |
+
if self.api_key:
|
660 |
+
headers['X-TC-Token'] = self.api_key
|
661 |
+
return headers, payload
|
662 |
+
|
663 |
+
def _call_tencent_api(self, query: str) -> dict:
|
664 |
+
headers, payload = self._get_headers_and_payload(query)
|
665 |
+
req = HTTPSConnection(self.host)
|
666 |
+
req.request('POST', '/', headers=headers, body=payload.encode('utf-8'))
|
667 |
+
resp = req.getresponse()
|
668 |
+
try:
|
669 |
+
resp = json.loads(resp.read().decode('utf-8'))
|
670 |
+
except Exception as e:
|
671 |
+
logging.warning(str(e))
|
672 |
+
import ast
|
673 |
+
resp = ast.literal_eval(resp)
|
674 |
+
return resp.get('Response', dict())
|
675 |
+
|
676 |
+
async def _async_call_tencent_api(self, query: str):
|
677 |
+
headers, payload = self._get_headers_and_payload(query)
|
678 |
+
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
679 |
+
async with session.post(
|
680 |
+
'https://' + self.host.lstrip('/'),
|
681 |
+
headers=headers,
|
682 |
+
data=payload) as resp:
|
683 |
+
return (await resp.json()).get('Response', {})
|
684 |
+
|
685 |
+
def _parse_response(self, response: dict) -> dict:
|
686 |
+
raw_results = []
|
687 |
+
for item in response.get('Pages', []):
|
688 |
+
display = json.loads(item['Display'])
|
689 |
+
if not display['url']:
|
690 |
+
continue
|
691 |
+
raw_results.append((display['url'], display['content']
|
692 |
+
or display['abstract_info'], display['title']))
|
693 |
+
return self._filter_results(raw_results)
|
694 |
+
|
695 |
+
|
696 |
+
class ContentFetcher:
|
697 |
+
|
698 |
+
def __init__(self, timeout: int = 5):
|
699 |
+
self.timeout = timeout
|
700 |
+
|
701 |
+
@cached(cache=TTLCache(maxsize=100, ttl=600))
|
702 |
+
def fetch(self, url: str) -> Tuple[bool, str]:
|
703 |
+
try:
|
704 |
+
response = requests.get(url, timeout=self.timeout)
|
705 |
+
response.raise_for_status()
|
706 |
+
html = response.content
|
707 |
+
except requests.RequestException as e:
|
708 |
+
return False, str(e)
|
709 |
+
|
710 |
+
text = BeautifulSoup(html, 'html.parser').get_text()
|
711 |
+
cleaned_text = re.sub(r'\n+', '\n', text)
|
712 |
+
return True, cleaned_text
|
713 |
+
|
714 |
+
@acached(cache=TTLCache(maxsize=100, ttl=600))
|
715 |
+
async def afetch(self, url: str) -> Tuple[bool, str]:
|
716 |
+
try:
|
717 |
+
async with aiohttp.ClientSession(
|
718 |
+
raise_for_status=True,
|
719 |
+
timeout=aiohttp.ClientTimeout(self.timeout)) as session:
|
720 |
+
async with session.get(url) as resp:
|
721 |
+
html = await resp.text(errors='ignore')
|
722 |
+
text = BeautifulSoup(html, 'html.parser').get_text()
|
723 |
+
cleaned_text = re.sub(r'\n+', '\n', text)
|
724 |
+
return True, cleaned_text
|
725 |
+
except Exception as e:
|
726 |
+
return False, str(e)
|
727 |
+
|
728 |
+
|
729 |
+
class WebBrowser(BaseAction):
|
730 |
+
"""Wrapper around the Web Browser Tool.
|
731 |
+
"""
|
732 |
+
|
733 |
+
def __init__(self,
|
734 |
+
searcher_type: str = 'DuckDuckGoSearch',
|
735 |
+
timeout: int = 5,
|
736 |
+
black_list: Optional[List[str]] = [
|
737 |
+
'enoN',
|
738 |
+
'youtube.com',
|
739 |
+
'bilibili.com',
|
740 |
+
'researchgate.net',
|
741 |
+
],
|
742 |
+
topk: int = 20,
|
743 |
+
description: Optional[dict] = None,
|
744 |
+
parser: Type[BaseParser] = JsonParser,
|
745 |
+
**kwargs):
|
746 |
+
self.searcher = eval(searcher_type)(
|
747 |
+
black_list=black_list, topk=topk, **kwargs)
|
748 |
+
self.fetcher = ContentFetcher(timeout=timeout)
|
749 |
+
self.search_results = None
|
750 |
+
super().__init__(description, parser)
|
751 |
+
|
752 |
+
@tool_api
|
753 |
+
def search(self, query: Union[str, List[str]]) -> dict:
|
754 |
+
"""BING search API
|
755 |
+
Args:
|
756 |
+
query (List[str]): list of search query strings
|
757 |
+
"""
|
758 |
+
queries = query if isinstance(query, list) else [query]
|
759 |
+
search_results = {}
|
760 |
+
|
761 |
+
with ThreadPoolExecutor() as executor:
|
762 |
+
future_to_query = {
|
763 |
+
executor.submit(self.searcher.search, q): q
|
764 |
+
for q in queries
|
765 |
+
}
|
766 |
+
|
767 |
+
for future in as_completed(future_to_query):
|
768 |
+
query = future_to_query[future]
|
769 |
+
try:
|
770 |
+
results = future.result()
|
771 |
+
except Exception as exc:
|
772 |
+
warnings.warn(f'{query} generated an exception: {exc}')
|
773 |
+
else:
|
774 |
+
for result in results.values():
|
775 |
+
if result['url'] not in search_results:
|
776 |
+
search_results[result['url']] = result
|
777 |
+
else:
|
778 |
+
search_results[
|
779 |
+
result['url']]['summ'] += f"\n{result['summ']}"
|
780 |
+
|
781 |
+
self.search_results = {
|
782 |
+
idx: result
|
783 |
+
for idx, result in enumerate(search_results.values())
|
784 |
+
}
|
785 |
+
return self.search_results
|
786 |
+
|
787 |
+
@tool_api
|
788 |
+
def select(self, select_ids: List[int]) -> dict:
|
789 |
+
"""get the detailed content on the selected pages.
|
790 |
+
|
791 |
+
Args:
|
792 |
+
select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
|
793 |
+
"""
|
794 |
+
if not self.search_results:
|
795 |
+
raise ValueError('No search results to select from.')
|
796 |
+
|
797 |
+
new_search_results = {}
|
798 |
+
with ThreadPoolExecutor() as executor:
|
799 |
+
future_to_id = {
|
800 |
+
executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
|
801 |
+
for select_id in select_ids if select_id in self.search_results
|
802 |
+
}
|
803 |
+
for future in as_completed(future_to_id):
|
804 |
+
select_id = future_to_id[future]
|
805 |
+
try:
|
806 |
+
web_success, web_content = future.result()
|
807 |
+
except Exception as exc:
|
808 |
+
warnings.warn(f'{select_id} generated an exception: {exc}')
|
809 |
+
else:
|
810 |
+
if web_success:
|
811 |
+
self.search_results[select_id][
|
812 |
+
'content'] = web_content[:8192]
|
813 |
+
new_search_results[select_id] = self.search_results[
|
814 |
+
select_id].copy()
|
815 |
+
new_search_results[select_id].pop('summ')
|
816 |
+
|
817 |
+
return new_search_results
|
818 |
+
|
819 |
+
@tool_api
|
820 |
+
def open_url(self, url: str) -> dict:
|
821 |
+
print(f'Start Browsing: {url}')
|
822 |
+
web_success, web_content = self.fetcher.fetch(url)
|
823 |
+
if web_success:
|
824 |
+
return {'type': 'text', 'content': web_content}
|
825 |
+
else:
|
826 |
+
return {'error': web_content}
|
827 |
+
|
828 |
+
|
829 |
+
class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
|
830 |
+
"""Wrapper around the Web Browser Tool.
|
831 |
+
"""
|
832 |
+
|
833 |
+
@tool_api
|
834 |
+
async def search(self, query: Union[str, List[str]]) -> dict:
|
835 |
+
"""BING search API
|
836 |
+
|
837 |
+
Args:
|
838 |
+
query (List[str]): list of search query strings
|
839 |
+
"""
|
840 |
+
queries = query if isinstance(query, list) else [query]
|
841 |
+
search_results = {}
|
842 |
+
|
843 |
+
tasks = []
|
844 |
+
for q in queries:
|
845 |
+
task = asyncio.create_task(self.searcher.asearch(q))
|
846 |
+
task.query = q
|
847 |
+
tasks.append(task)
|
848 |
+
async for future in async_as_completed(tasks):
|
849 |
+
query = future.query
|
850 |
+
try:
|
851 |
+
results = await future
|
852 |
+
except Exception as exc:
|
853 |
+
warnings.warn(f'{query} generated an exception: {exc}')
|
854 |
+
else:
|
855 |
+
for result in results.values():
|
856 |
+
if result['url'] not in search_results:
|
857 |
+
search_results[result['url']] = result
|
858 |
+
else:
|
859 |
+
search_results[
|
860 |
+
result['url']]['summ'] += f"\n{result['summ']}"
|
861 |
+
|
862 |
+
self.search_results = {
|
863 |
+
idx: result
|
864 |
+
for idx, result in enumerate(search_results.values())
|
865 |
+
}
|
866 |
+
return self.search_results
|
867 |
+
|
868 |
+
@tool_api
|
869 |
+
async def select(self, select_ids: List[int]) -> dict:
|
870 |
+
"""get the detailed content on the selected pages.
|
871 |
+
|
872 |
+
Args:
|
873 |
+
select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
|
874 |
+
"""
|
875 |
+
if not self.search_results:
|
876 |
+
raise ValueError('No search results to select from.')
|
877 |
+
|
878 |
+
new_search_results = {}
|
879 |
+
tasks = []
|
880 |
+
for select_id in select_ids:
|
881 |
+
if select_id in self.search_results:
|
882 |
+
task = asyncio.create_task(
|
883 |
+
self.fetcher.afetch(self.search_results[select_id]['url']))
|
884 |
+
task.select_id = select_id
|
885 |
+
tasks.append(task)
|
886 |
+
async for future in async_as_completed(tasks):
|
887 |
+
select_id = future.select_id
|
888 |
+
try:
|
889 |
+
web_success, web_content = await future
|
890 |
+
except Exception as exc:
|
891 |
+
warnings.warn(f'{select_id} generated an exception: {exc}')
|
892 |
+
else:
|
893 |
+
if web_success:
|
894 |
+
self.search_results[select_id][
|
895 |
+
'content'] = web_content[:8192]
|
896 |
+
new_search_results[select_id] = self.search_results[
|
897 |
+
select_id].copy()
|
898 |
+
new_search_results[select_id].pop('summ')
|
899 |
+
return new_search_results
|
900 |
+
|
901 |
+
@tool_api
|
902 |
+
async def open_url(self, url: str) -> dict:
|
903 |
+
print(f'Start Browsing: {url}')
|
904 |
+
web_success, web_content = await self.fetcher.afetch(url)
|
905 |
+
if web_success:
|
906 |
+
return {'type': 'text', 'content': web_content}
|
907 |
+
else:
|
908 |
+
return {'error': web_content}
|
lagent/agents/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
|
2 |
+
from .react import AsyncReAct, ReAct
|
3 |
+
from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
|
7 |
+
'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
|
8 |
+
'AsyncReAct', 'Sequential', 'AsyncSequential'
|
9 |
+
]
|
lagent/agents/agent.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import warnings
|
3 |
+
from collections import OrderedDict, UserDict, UserList, abc
|
4 |
+
from functools import wraps
|
5 |
+
from itertools import chain, repeat
|
6 |
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
|
7 |
+
|
8 |
+
from lagent.agents.aggregator import DefaultAggregator
|
9 |
+
from lagent.hooks import Hook, RemovableHandle
|
10 |
+
from lagent.llms import BaseLLM
|
11 |
+
from lagent.memory import Memory, MemoryManager
|
12 |
+
from lagent.prompts.parsers import StrParser
|
13 |
+
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
+
from lagent.schema import AgentMessage
|
15 |
+
from lagent.utils import create_object
|
16 |
+
|
17 |
+
|
18 |
+
class Agent:
|
19 |
+
"""Agent is the basic unit of the system. It is responsible for
|
20 |
+
communicating with the LLM, managing the memory, and handling the
|
21 |
+
message aggregation and parsing. It can also be extended with hooks
|
22 |
+
|
23 |
+
Args:
|
24 |
+
llm (Union[BaseLLM, Dict]): The language model used by the agent.
|
25 |
+
template (Union[PromptTemplate, str]): The template used to format the
|
26 |
+
messages.
|
27 |
+
memory (Dict): The memory used by the agent.
|
28 |
+
output_format (Dict): The output format used by the agent.
|
29 |
+
aggregator (Dict): The aggregator used by the agent.
|
30 |
+
name (Optional[str]): The name of the agent.
|
31 |
+
description (Optional[str]): The description of the agent.
|
32 |
+
hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
AgentMessage: The response message.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
llm: Union[BaseLLM, Dict] = None,
|
41 |
+
template: Union[PromptTemplate, str, dict, List[dict]] = None,
|
42 |
+
memory: Dict = dict(type=Memory),
|
43 |
+
output_format: Optional[Dict] = None,
|
44 |
+
aggregator: Dict = dict(type=DefaultAggregator),
|
45 |
+
name: Optional[str] = None,
|
46 |
+
description: Optional[str] = None,
|
47 |
+
hooks: Optional[Union[List[Dict], Dict]] = None,
|
48 |
+
):
|
49 |
+
self.name = name or self.__class__.__name__
|
50 |
+
self.llm: BaseLLM = create_object(llm)
|
51 |
+
self.memory: MemoryManager = MemoryManager(memory) if memory else None
|
52 |
+
self.output_format: StrParser = create_object(output_format)
|
53 |
+
self.template = template
|
54 |
+
self.description = description
|
55 |
+
self.aggregator: DefaultAggregator = create_object(aggregator)
|
56 |
+
self._hooks: Dict[int, Hook] = OrderedDict()
|
57 |
+
if hooks:
|
58 |
+
for hook in hooks:
|
59 |
+
hook = create_object(hook)
|
60 |
+
self.register_hook(hook)
|
61 |
+
|
62 |
+
def update_memory(self, message, session_id=0):
|
63 |
+
if self.memory:
|
64 |
+
self.memory.add(message, session_id=session_id)
|
65 |
+
|
66 |
+
def __call__(
|
67 |
+
self,
|
68 |
+
*message: Union[str, AgentMessage, List[AgentMessage]],
|
69 |
+
session_id=0,
|
70 |
+
**kwargs,
|
71 |
+
) -> AgentMessage:
|
72 |
+
# message.receiver = self.name
|
73 |
+
message = [
|
74 |
+
AgentMessage(sender='user', content=m)
|
75 |
+
if isinstance(m, str) else copy.deepcopy(m) for m in message
|
76 |
+
]
|
77 |
+
for hook in self._hooks.values():
|
78 |
+
result = hook.before_agent(self, message, session_id)
|
79 |
+
if result:
|
80 |
+
message = result
|
81 |
+
self.update_memory(message, session_id=session_id)
|
82 |
+
response_message = self.forward(
|
83 |
+
*message, session_id=session_id, **kwargs)
|
84 |
+
if not isinstance(response_message, AgentMessage):
|
85 |
+
response_message = AgentMessage(
|
86 |
+
sender=self.name,
|
87 |
+
content=response_message,
|
88 |
+
)
|
89 |
+
self.update_memory(response_message, session_id=session_id)
|
90 |
+
response_message = copy.deepcopy(response_message)
|
91 |
+
for hook in self._hooks.values():
|
92 |
+
result = hook.after_agent(self, response_message, session_id)
|
93 |
+
if result:
|
94 |
+
response_message = result
|
95 |
+
return response_message
|
96 |
+
|
97 |
+
def forward(self,
|
98 |
+
*message: AgentMessage,
|
99 |
+
session_id=0,
|
100 |
+
**kwargs) -> Union[AgentMessage, str]:
|
101 |
+
formatted_messages = self.aggregator.aggregate(
|
102 |
+
self.memory.get(session_id),
|
103 |
+
self.name,
|
104 |
+
self.output_format,
|
105 |
+
self.template,
|
106 |
+
)
|
107 |
+
llm_response = self.llm.chat(formatted_messages, **kwargs)
|
108 |
+
if self.output_format:
|
109 |
+
formatted_messages = self.output_format.parse_response(
|
110 |
+
llm_response)
|
111 |
+
return AgentMessage(
|
112 |
+
sender=self.name,
|
113 |
+
content=llm_response,
|
114 |
+
formatted=formatted_messages,
|
115 |
+
)
|
116 |
+
return llm_response
|
117 |
+
|
118 |
+
def __setattr__(self, __name: str, __value: Any) -> None:
|
119 |
+
if isinstance(__value, Agent):
|
120 |
+
_agents = getattr(self, '_agents', OrderedDict())
|
121 |
+
_agents[__name] = __value
|
122 |
+
super().__setattr__('_agents', _agents)
|
123 |
+
super().__setattr__(__name, __value)
|
124 |
+
|
125 |
+
def state_dict(self, session_id=0):
|
126 |
+
state_dict, stack = {}, [('', self)]
|
127 |
+
while stack:
|
128 |
+
prefix, node = stack.pop()
|
129 |
+
key = prefix + 'memory'
|
130 |
+
if node.memory is not None:
|
131 |
+
if session_id not in node.memory.memory_map:
|
132 |
+
warnings.warn(f'No session id {session_id} in {key}')
|
133 |
+
memory = node.memory.get(session_id)
|
134 |
+
state_dict[key] = memory and memory.save() or []
|
135 |
+
if hasattr(node, '_agents'):
|
136 |
+
for name, value in reversed(node._agents.items()):
|
137 |
+
stack.append((prefix + name + '.', value))
|
138 |
+
return state_dict
|
139 |
+
|
140 |
+
def load_state_dict(self, state_dict: Dict, session_id=0):
|
141 |
+
_state_dict = self.state_dict()
|
142 |
+
missing_keys = set(_state_dict) - set(state_dict)
|
143 |
+
if missing_keys:
|
144 |
+
raise KeyError(f'Missing keys: {missing_keys}')
|
145 |
+
extra_keys = set(state_dict) - set(_state_dict)
|
146 |
+
if extra_keys:
|
147 |
+
warnings.warn(f'Mismatch keys which are not used: {extra_keys}')
|
148 |
+
for key in _state_dict:
|
149 |
+
obj = self
|
150 |
+
for attr in key.split('.')[:-1]:
|
151 |
+
if isinstance(obj, AgentList):
|
152 |
+
assert attr.isdigit()
|
153 |
+
obj = obj[int(attr)]
|
154 |
+
elif isinstance(obj, AgentDict):
|
155 |
+
obj = obj[attr]
|
156 |
+
else:
|
157 |
+
obj = getattr(obj, attr)
|
158 |
+
if obj.memory is not None:
|
159 |
+
if session_id not in obj.memory.memory_map:
|
160 |
+
obj.memory.create_instance(session_id)
|
161 |
+
obj.memory.memory_map[session_id].load(state_dict[key] or [])
|
162 |
+
|
163 |
+
def register_hook(self, hook: Callable):
|
164 |
+
handle = RemovableHandle(self._hooks)
|
165 |
+
self._hooks[handle.id] = hook
|
166 |
+
return handle
|
167 |
+
|
168 |
+
def reset(self,
|
169 |
+
session_id=0,
|
170 |
+
keypath: Optional[str] = None,
|
171 |
+
recursive: bool = False):
|
172 |
+
assert not (keypath and
|
173 |
+
recursive), 'keypath and recursive can\'t be used together'
|
174 |
+
if keypath:
|
175 |
+
keys, agent = keypath.split('.'), self
|
176 |
+
for key in keys:
|
177 |
+
agents = getattr(agent, '_agents', {})
|
178 |
+
if key not in agents:
|
179 |
+
raise KeyError(f'No sub-agent named {key} in {agent}')
|
180 |
+
agent = agents[key]
|
181 |
+
agent.reset(session_id, recursive=False)
|
182 |
+
else:
|
183 |
+
if self.memory:
|
184 |
+
self.memory.reset(session_id=session_id)
|
185 |
+
if recursive:
|
186 |
+
for agent in getattr(self, '_agents', {}).values():
|
187 |
+
agent.reset(session_id, recursive=True)
|
188 |
+
|
189 |
+
def __repr__(self):
|
190 |
+
|
191 |
+
def _rcsv_repr(agent, n_indent=1):
|
192 |
+
res = agent.__class__.__name__ + (f"(name='{agent.name}')"
|
193 |
+
if agent.name else '')
|
194 |
+
modules = [
|
195 |
+
f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
|
196 |
+
for name, agent in getattr(agent, '_agents', {}).items()
|
197 |
+
]
|
198 |
+
if modules:
|
199 |
+
res += '(\n' + '\n'.join(
|
200 |
+
modules) + f'\n{(n_indent - 1) * " "})'
|
201 |
+
elif not res.endswith(')'):
|
202 |
+
res += '()'
|
203 |
+
return res
|
204 |
+
|
205 |
+
return _rcsv_repr(self)
|
206 |
+
|
207 |
+
|
208 |
+
class AsyncAgent(Agent):
|
209 |
+
|
210 |
+
async def __call__(self,
|
211 |
+
*message: AgentMessage | List[AgentMessage],
|
212 |
+
session_id=0,
|
213 |
+
**kwargs) -> AgentMessage:
|
214 |
+
message = [
|
215 |
+
AgentMessage(sender='user', content=m)
|
216 |
+
if isinstance(m, str) else copy.deepcopy(m) for m in message
|
217 |
+
]
|
218 |
+
for hook in self._hooks.values():
|
219 |
+
result = hook.before_agent(self, message, session_id)
|
220 |
+
if result:
|
221 |
+
message = result
|
222 |
+
self.update_memory(message, session_id=session_id)
|
223 |
+
response_message = await self.forward(
|
224 |
+
*message, session_id=session_id, **kwargs)
|
225 |
+
if not isinstance(response_message, AgentMessage):
|
226 |
+
response_message = AgentMessage(
|
227 |
+
sender=self.name,
|
228 |
+
content=response_message,
|
229 |
+
)
|
230 |
+
self.update_memory(response_message, session_id=session_id)
|
231 |
+
response_message = copy.deepcopy(response_message)
|
232 |
+
for hook in self._hooks.values():
|
233 |
+
result = hook.after_agent(self, response_message, session_id)
|
234 |
+
if result:
|
235 |
+
response_message = result
|
236 |
+
return response_message
|
237 |
+
|
238 |
+
async def forward(self,
|
239 |
+
*message: AgentMessage,
|
240 |
+
session_id=0,
|
241 |
+
**kwargs) -> Union[AgentMessage, str]:
|
242 |
+
formatted_messages = self.aggregator.aggregate(
|
243 |
+
self.memory.get(session_id),
|
244 |
+
self.name,
|
245 |
+
self.output_format,
|
246 |
+
self.template,
|
247 |
+
)
|
248 |
+
llm_response = await self.llm.chat(formatted_messages, session_id,
|
249 |
+
**kwargs)
|
250 |
+
if self.output_format:
|
251 |
+
formatted_messages = self.output_format.parse_response(
|
252 |
+
llm_response)
|
253 |
+
return AgentMessage(
|
254 |
+
sender=self.name,
|
255 |
+
content=llm_response,
|
256 |
+
formatted=formatted_messages,
|
257 |
+
)
|
258 |
+
return llm_response
|
259 |
+
|
260 |
+
|
261 |
+
class Sequential(Agent):
|
262 |
+
"""Sequential is an agent container that forwards messages to each agent
|
263 |
+
in the order they are added."""
|
264 |
+
|
265 |
+
def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
|
266 |
+
super().__init__(**kwargs)
|
267 |
+
self._agents = OrderedDict()
|
268 |
+
if not agents:
|
269 |
+
raise ValueError('At least one agent should be provided')
|
270 |
+
if isinstance(agents[0],
|
271 |
+
Iterable) and not isinstance(agents[0], Agent):
|
272 |
+
if not agents[0]:
|
273 |
+
raise ValueError('At least one agent should be provided')
|
274 |
+
agents = agents[0]
|
275 |
+
for key, agent in enumerate(agents):
|
276 |
+
if isinstance(agents, Mapping):
|
277 |
+
key, agent = agent, agents[agent]
|
278 |
+
elif isinstance(agent, tuple):
|
279 |
+
key, agent = agent
|
280 |
+
self.add_agent(key, agent)
|
281 |
+
|
282 |
+
def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
|
283 |
+
assert isinstance(
|
284 |
+
agent, (Agent, AsyncAgent
|
285 |
+
)), f'{type(agent)} is not an Agent or AsyncAgent subclass'
|
286 |
+
self._agents[str(name)] = agent
|
287 |
+
|
288 |
+
def forward(self,
|
289 |
+
*message: AgentMessage,
|
290 |
+
session_id=0,
|
291 |
+
exit_at: Optional[int] = None,
|
292 |
+
**kwargs) -> AgentMessage:
|
293 |
+
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
294 |
+
if exit_at is None:
|
295 |
+
exit_at = len(self) - 1
|
296 |
+
iterator = chain.from_iterable(repeat(self._agents.values()))
|
297 |
+
for _ in range(exit_at + 1):
|
298 |
+
agent = next(iterator)
|
299 |
+
if isinstance(message, AgentMessage):
|
300 |
+
message = (message, )
|
301 |
+
message = agent(*message, session_id=session_id, **kwargs)
|
302 |
+
return message
|
303 |
+
|
304 |
+
def __getitem__(self, key):
|
305 |
+
if isinstance(key, int) and key < 0:
|
306 |
+
assert key >= -len(self), 'index out of range'
|
307 |
+
key = len(self) + key
|
308 |
+
return self._agents[str(key)]
|
309 |
+
|
310 |
+
def __len__(self):
|
311 |
+
return len(self._agents)
|
312 |
+
|
313 |
+
|
314 |
+
class AsyncSequential(Sequential, AsyncAgent):
|
315 |
+
|
316 |
+
async def forward(self,
|
317 |
+
*message: AgentMessage,
|
318 |
+
session_id=0,
|
319 |
+
exit_at: Optional[int] = None,
|
320 |
+
**kwargs) -> AgentMessage:
|
321 |
+
assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
|
322 |
+
if exit_at is None:
|
323 |
+
exit_at = len(self) - 1
|
324 |
+
iterator = chain.from_iterable(repeat(self._agents.values()))
|
325 |
+
for _ in range(exit_at + 1):
|
326 |
+
agent = next(iterator)
|
327 |
+
if isinstance(message, AgentMessage):
|
328 |
+
message = (message, )
|
329 |
+
message = await agent(*message, session_id=session_id, **kwargs)
|
330 |
+
return message
|
331 |
+
|
332 |
+
|
333 |
+
class AgentContainerMixin:
|
334 |
+
|
335 |
+
def __init_subclass__(cls):
|
336 |
+
super().__init_subclass__()
|
337 |
+
|
338 |
+
def wrap_api(func):
|
339 |
+
|
340 |
+
@wraps(func)
|
341 |
+
def wrapped_func(self, *args, **kwargs):
|
342 |
+
data = self.data.copy() if hasattr(self, 'data') else None
|
343 |
+
|
344 |
+
def _backup(d):
|
345 |
+
if d is None:
|
346 |
+
self.data.clear()
|
347 |
+
else:
|
348 |
+
self.data = d
|
349 |
+
|
350 |
+
ret = func(self, *args, **kwargs)
|
351 |
+
agents = OrderedDict()
|
352 |
+
for k, item in (self.data.items() if isinstance(
|
353 |
+
self.data, abc.Mapping) else enumerate(self.data)):
|
354 |
+
if isinstance(self.data,
|
355 |
+
abc.Mapping) and not isinstance(k, str):
|
356 |
+
_backup(data)
|
357 |
+
raise KeyError(
|
358 |
+
f'agent name should be a string, got {type(k)}')
|
359 |
+
if isinstance(k, str) and '.' in k:
|
360 |
+
_backup(data)
|
361 |
+
raise KeyError(
|
362 |
+
f'agent name can\'t contain ".", got {k}')
|
363 |
+
if not isinstance(item, (Agent, AsyncAgent)):
|
364 |
+
_backup(data)
|
365 |
+
raise TypeError(
|
366 |
+
f'{type(item)} is not an Agent or AsyncAgent subclass'
|
367 |
+
)
|
368 |
+
agents[str(k)] = item
|
369 |
+
self._agents = agents
|
370 |
+
return ret
|
371 |
+
|
372 |
+
return wrapped_func
|
373 |
+
|
374 |
+
for method in [
|
375 |
+
'append', 'sort', 'reverse', 'pop', 'clear', 'update',
|
376 |
+
'insert', 'extend', 'remove', '__init__', '__setitem__',
|
377 |
+
'__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
|
378 |
+
'__imul__', '__rmul__'
|
379 |
+
]:
|
380 |
+
if hasattr(cls, method):
|
381 |
+
setattr(cls, method, wrap_api(getattr(cls, method)))
|
382 |
+
|
383 |
+
|
384 |
+
class AgentList(Agent, UserList, AgentContainerMixin):
|
385 |
+
|
386 |
+
def __init__(self,
|
387 |
+
agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
|
388 |
+
Agent.__init__(self, memory=None)
|
389 |
+
UserList.__init__(self, agents)
|
390 |
+
self.name = None
|
391 |
+
|
392 |
+
|
393 |
+
class AgentDict(Agent, UserDict, AgentContainerMixin):
|
394 |
+
|
395 |
+
def __init__(self,
|
396 |
+
agents: Optional[Mapping[str, Union[Agent,
|
397 |
+
AsyncAgent]]] = None):
|
398 |
+
Agent.__init__(self, memory=None)
|
399 |
+
UserDict.__init__(self, agents)
|
400 |
+
self.name = None
|
lagent/agents/aggregator/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .default_aggregator import DefaultAggregator
|
2 |
+
from .tool_aggregator import InternLMToolAggregator
|
3 |
+
|
4 |
+
__all__ = ['DefaultAggregator', 'InternLMToolAggregator']
|
lagent/agents/aggregator/default_aggregator.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
|
3 |
+
from lagent.memory import Memory
|
4 |
+
from lagent.prompts import StrParser
|
5 |
+
|
6 |
+
|
7 |
+
class DefaultAggregator:
|
8 |
+
|
9 |
+
def aggregate(self,
|
10 |
+
messages: Memory,
|
11 |
+
name: str,
|
12 |
+
parser: StrParser = None,
|
13 |
+
system_instruction: str = None) -> List[Dict[str, str]]:
|
14 |
+
_message = []
|
15 |
+
messages = messages.get_memory()
|
16 |
+
if system_instruction:
|
17 |
+
_message.extend(
|
18 |
+
self.aggregate_system_intruction(system_instruction))
|
19 |
+
for message in messages:
|
20 |
+
if message.sender == name:
|
21 |
+
_message.append(
|
22 |
+
dict(role='assistant', content=str(message.content)))
|
23 |
+
else:
|
24 |
+
user_message = message.content
|
25 |
+
if len(_message) > 0 and _message[-1]['role'] == 'user':
|
26 |
+
_message[-1]['content'] += user_message
|
27 |
+
else:
|
28 |
+
_message.append(dict(role='user', content=user_message))
|
29 |
+
return _message
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def aggregate_system_intruction(system_intruction) -> List[dict]:
|
33 |
+
if isinstance(system_intruction, str):
|
34 |
+
system_intruction = dict(role='system', content=system_intruction)
|
35 |
+
if isinstance(system_intruction, dict):
|
36 |
+
system_intruction = [system_intruction]
|
37 |
+
if isinstance(system_intruction, list):
|
38 |
+
for msg in system_intruction:
|
39 |
+
if not isinstance(msg, dict):
|
40 |
+
raise TypeError(f'Unsupported message type: {type(msg)}')
|
41 |
+
if not ('role' in msg and 'content' in msg):
|
42 |
+
raise KeyError(
|
43 |
+
f"Missing required key 'role' or 'content': {msg}")
|
44 |
+
return system_intruction
|
lagent/agents/aggregator/tool_aggregator.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Union
|
2 |
+
|
3 |
+
from lagent.agents.aggregator.default_aggregator import DefaultAggregator
|
4 |
+
from lagent.memory.base_memory import Memory
|
5 |
+
from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode
|
6 |
+
|
7 |
+
|
8 |
+
class InternLMToolAggregator(DefaultAggregator):
|
9 |
+
|
10 |
+
def __init__(self,
|
11 |
+
environment_role='environment',
|
12 |
+
environment_begin='',
|
13 |
+
environment_end='',
|
14 |
+
user_names: Optional[List[str]] = None,
|
15 |
+
few_shot: Optional[List[List[dict]]] = None):
|
16 |
+
self.environment_role = environment_role
|
17 |
+
self.environment_begin = environment_begin
|
18 |
+
self.environment_end = environment_end
|
19 |
+
self.user_names = user_names or ['user']
|
20 |
+
self.few_shot = few_shot or []
|
21 |
+
|
22 |
+
def aggregate(self,
|
23 |
+
messages: Memory,
|
24 |
+
name: str,
|
25 |
+
parser: Union[ToolParser, MixedToolParser],
|
26 |
+
system_instruction: str = None) -> List[Dict[str, str]]:
|
27 |
+
_message = []
|
28 |
+
messages = messages.get_memory()
|
29 |
+
if system_instruction:
|
30 |
+
_message.extend(
|
31 |
+
self.aggregate_system_intruction(system_instruction))
|
32 |
+
tool_instruction = parser.format_instruction()
|
33 |
+
if tool_instruction:
|
34 |
+
if isinstance(tool_instruction, str):
|
35 |
+
tool_instruction = dict(
|
36 |
+
role='system', content=tool_instruction)
|
37 |
+
if parser.tool_type:
|
38 |
+
tool_instruction['name'] = parser.tool_type
|
39 |
+
if isinstance(tool_instruction, dict):
|
40 |
+
tool_instruction = [tool_instruction]
|
41 |
+
_message.extend(tool_instruction)
|
42 |
+
|
43 |
+
for shot in self.few_shot:
|
44 |
+
i = 0
|
45 |
+
while i < len(shot):
|
46 |
+
msg = shot[i]
|
47 |
+
if msg['role'] in ['assistant', 'user', 'system']:
|
48 |
+
_message.append(msg)
|
49 |
+
elif msg['role'] == self.environment_role:
|
50 |
+
if not msg['content'].startswith(self.environment_begin):
|
51 |
+
msg['content'] = self.environment_begin + msg['content']
|
52 |
+
if not msg['content'].endswith(self.environment_end):
|
53 |
+
msg['content'] += self.environment_end
|
54 |
+
_message.append(msg)
|
55 |
+
elif msg['role'] in ['thought', 'language']:
|
56 |
+
if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool':
|
57 |
+
_message.append(
|
58 |
+
dict(
|
59 |
+
role='assistant',
|
60 |
+
content=parser.format_response(
|
61 |
+
dict(
|
62 |
+
tool_type=shot[i + 1]['name'],
|
63 |
+
thought=msg['content'],
|
64 |
+
action=shot[i + 1]['content'],
|
65 |
+
status=None))))
|
66 |
+
i += 1
|
67 |
+
else:
|
68 |
+
_message.append(
|
69 |
+
dict(
|
70 |
+
role='assistant',
|
71 |
+
content=parser.format_response(
|
72 |
+
dict(
|
73 |
+
tool_type=None,
|
74 |
+
thought=msg['content'],
|
75 |
+
action=None,
|
76 |
+
status=None))))
|
77 |
+
else:
|
78 |
+
raise KeyError(f'Unkown role: {msg["role"]}')
|
79 |
+
i += 1
|
80 |
+
|
81 |
+
tool_type = None
|
82 |
+
for message in messages:
|
83 |
+
if message.sender == name:
|
84 |
+
if isinstance(message.formatted, dict):
|
85 |
+
parsed = message.formatted
|
86 |
+
if parsed['status'] == ToolStatusCode.PARSING_ERROR:
|
87 |
+
continue
|
88 |
+
_message.append(
|
89 |
+
dict(
|
90 |
+
role='assistant',
|
91 |
+
content=parser.format_response(parsed)))
|
92 |
+
tool_type = parsed['tool_type']
|
93 |
+
else:
|
94 |
+
_message.append(
|
95 |
+
dict(role='assistant', content=str(message.content)))
|
96 |
+
elif message.sender in self.user_names:
|
97 |
+
_message.append(dict(role='user', content=message.content))
|
98 |
+
else:
|
99 |
+
msg = dict(
|
100 |
+
role=self.environment_role,
|
101 |
+
content=self.environment_begin + str(message.content) +
|
102 |
+
self.environment_end)
|
103 |
+
if tool_type:
|
104 |
+
msg['name'] = tool_type
|
105 |
+
_message.append(msg)
|
106 |
+
return _message
|
lagent/agents/react.py
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Callable, Dict, List, Union
|
3 |
+
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
|
6 |
+
from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
|
7 |
+
from lagent.agents.agent import Agent, AsyncAgent
|
8 |
+
from lagent.agents.aggregator import DefaultAggregator
|
9 |
+
from lagent.hooks import ActionPreprocessor
|
10 |
+
from lagent.llms import BaseLLM
|
11 |
+
from lagent.memory import Memory
|
12 |
+
from lagent.prompts.parsers.json_parser import JSONParser
|
13 |
+
from lagent.prompts.prompt_template import PromptTemplate
|
14 |
+
from lagent.schema import AgentMessage
|
15 |
+
from lagent.utils import create_object
|
16 |
+
|
17 |
+
select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
|
18 |
+
{action_info}
|
19 |
+
{output_format}
|
20 |
+
开始!"""
|
21 |
+
|
22 |
+
output_format_template = """如果使用工具请遵循以下格式回复:
|
23 |
+
{function_format}
|
24 |
+
|
25 |
+
如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
|
26 |
+
{finish_format}"""
|
27 |
+
|
28 |
+
|
29 |
+
class ReAct(Agent):
|
30 |
+
|
31 |
+
def __init__(self,
|
32 |
+
llm: Union[BaseLLM, Dict],
|
33 |
+
actions: Union[BaseAction, List[BaseAction]],
|
34 |
+
template: Union[PromptTemplate, str] = None,
|
35 |
+
memory: Dict = dict(type=Memory),
|
36 |
+
output_format: Dict = dict(type=JSONParser),
|
37 |
+
aggregator: Dict = dict(type=DefaultAggregator),
|
38 |
+
hooks: List = [dict(type=ActionPreprocessor)],
|
39 |
+
finish_condition: Callable[[AgentMessage], bool] = lambda m:
|
40 |
+
'conclusion' in m.content or 'conclusion' in m.formatted,
|
41 |
+
max_turn: int = 5,
|
42 |
+
**kwargs):
|
43 |
+
self.max_turn = max_turn
|
44 |
+
self.finish_condition = finish_condition
|
45 |
+
actions = dict(
|
46 |
+
type=ActionExecutor,
|
47 |
+
actions=actions,
|
48 |
+
hooks=hooks,
|
49 |
+
)
|
50 |
+
self.actions: ActionExecutor = create_object(actions)
|
51 |
+
select_agent = dict(
|
52 |
+
type=Agent,
|
53 |
+
llm=llm,
|
54 |
+
template=template.format(
|
55 |
+
action_info=json.dumps(self.actions.description()),
|
56 |
+
output_format=output_format.format_instruction()),
|
57 |
+
output_format=output_format,
|
58 |
+
memory=memory,
|
59 |
+
aggregator=aggregator,
|
60 |
+
hooks=hooks,
|
61 |
+
)
|
62 |
+
self.select_agent = create_object(select_agent)
|
63 |
+
super().__init__(**kwargs)
|
64 |
+
|
65 |
+
def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
|
66 |
+
for _ in range(self.max_turn):
|
67 |
+
message = self.select_agent(message)
|
68 |
+
if self.finish_condition(message):
|
69 |
+
return message
|
70 |
+
message = self.actions(message)
|
71 |
+
return message
|
72 |
+
|
73 |
+
|
74 |
+
class AsyncReAct(AsyncAgent):
|
75 |
+
|
76 |
+
def __init__(self,
|
77 |
+
llm: Union[BaseLLM, Dict],
|
78 |
+
actions: Union[BaseAction, List[BaseAction]],
|
79 |
+
template: Union[PromptTemplate, str] = None,
|
80 |
+
memory: Dict = dict(type=Memory),
|
81 |
+
output_format: Dict = dict(type=JSONParser),
|
82 |
+
aggregator: Dict = dict(type=DefaultAggregator),
|
83 |
+
hooks: List = [dict(type=ActionPreprocessor)],
|
84 |
+
finish_condition: Callable[[AgentMessage], bool] = lambda m:
|
85 |
+
'conclusion' in m.content or 'conclusion' in m.formatted,
|
86 |
+
max_turn: int = 5,
|
87 |
+
**kwargs):
|
88 |
+
self.max_turn = max_turn
|
89 |
+
self.finish_condition = finish_condition
|
90 |
+
actions = dict(
|
91 |
+
type=AsyncActionExecutor,
|
92 |
+
actions=actions,
|
93 |
+
hooks=hooks,
|
94 |
+
)
|
95 |
+
self.actions: AsyncActionExecutor = create_object(actions)
|
96 |
+
select_agent = dict(
|
97 |
+
type=AsyncAgent,
|
98 |
+
llm=llm,
|
99 |
+
template=template.format(
|
100 |
+
action_info=json.dumps(self.actions.description()),
|
101 |
+
output_format=output_format.format_instruction()),
|
102 |
+
output_format=output_format,
|
103 |
+
memory=memory,
|
104 |
+
aggregator=aggregator,
|
105 |
+
hooks=hooks,
|
106 |
+
)
|
107 |
+
self.select_agent = create_object(select_agent)
|
108 |
+
super().__init__(**kwargs)
|
109 |
+
|
110 |
+
async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
|
111 |
+
for _ in range(self.max_turn):
|
112 |
+
message = await self.select_agent(message)
|
113 |
+
if self.finish_condition(message):
|
114 |
+
return message
|
115 |
+
message = await self.actions(message)
|
116 |
+
return message
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
from lagent.llms import GPTAPI
|
121 |
+
|
122 |
+
class ActionCall(BaseModel):
|
123 |
+
name: str = Field(description='调用的函数名称')
|
124 |
+
parameters: Dict = Field(description='调用函数的参数')
|
125 |
+
|
126 |
+
class ActionFormat(BaseModel):
|
127 |
+
thought_process: str = Field(
|
128 |
+
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
|
129 |
+
action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名��和参数。')
|
130 |
+
|
131 |
+
class FinishFormat(BaseModel):
|
132 |
+
thought_process: str = Field(
|
133 |
+
description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
|
134 |
+
conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
|
135 |
+
|
136 |
+
prompt_template = PromptTemplate(select_action_template)
|
137 |
+
output_format = JSONParser(
|
138 |
+
output_format_template,
|
139 |
+
function_format=ActionFormat,
|
140 |
+
finish_format=FinishFormat)
|
141 |
+
|
142 |
+
llm = dict(
|
143 |
+
type=GPTAPI,
|
144 |
+
model_type='gpt-4o-2024-05-13',
|
145 |
+
key=None,
|
146 |
+
max_new_tokens=4096,
|
147 |
+
proxies=dict(),
|
148 |
+
retry=1000)
|
149 |
+
|
150 |
+
agent = ReAct(
|
151 |
+
llm=llm,
|
152 |
+
template=prompt_template,
|
153 |
+
output_format=output_format,
|
154 |
+
aggregator=dict(type='DefaultAggregator'),
|
155 |
+
actions=[dict(type='PythonInterpreter')],
|
156 |
+
)
|
157 |
+
response = agent(
|
158 |
+
AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
|
159 |
+
print(response)
|
160 |
+
response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
|
161 |
+
print(response)
|
lagent/agents/stream.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import warnings
|
3 |
+
from copy import deepcopy
|
4 |
+
from typing import Callable, Dict, List, Union
|
5 |
+
|
6 |
+
from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive
|
7 |
+
from lagent.agents.agent import Agent, AsyncAgent
|
8 |
+
from lagent.agents.aggregator import InternLMToolAggregator
|
9 |
+
from lagent.hooks import InternLMActionProcessor
|
10 |
+
from lagent.llms import BaseLLM
|
11 |
+
from lagent.memory import Memory
|
12 |
+
from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode
|
13 |
+
from lagent.schema import AgentMessage
|
14 |
+
from lagent.utils import create_object
|
15 |
+
|
16 |
+
API_PREFIX = (
|
17 |
+
"This is the subfunction for tool '{tool_name}', you can use this tool. "
|
18 |
+
'The description of this function is: \n{description}')
|
19 |
+
|
20 |
+
META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
|
21 |
+
|
22 |
+
INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
|
23 |
+
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
|
24 |
+
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
|
25 |
+
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
|
26 |
+
'文本处理和分析(比如文本解析和自然语言处理),'
|
27 |
+
'机器学习和数据科学(用于展示模型训练和数据可视化),'
|
28 |
+
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
|
29 |
+
|
30 |
+
PLUGIN_CN = ('你可以使用如下工具:'
|
31 |
+
'\n{prompt}\n'
|
32 |
+
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
|
33 |
+
'同时注意你可以使用的工具,不要随意捏造!')
|
34 |
+
|
35 |
+
|
36 |
+
def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
|
37 |
+
plugin_descriptions = []
|
38 |
+
for action in actions if isinstance(actions, list) else [actions]:
|
39 |
+
action = create_object(action)
|
40 |
+
action_desc = deepcopy(action.description)
|
41 |
+
if action.is_toolkit:
|
42 |
+
for api in action_desc['api_list']:
|
43 |
+
api['name'] = f"{action.name}.{api['name']}"
|
44 |
+
api['description'] = api_desc_template.format(
|
45 |
+
tool_name=action.name, description=api['description'])
|
46 |
+
api['parameters'] = [
|
47 |
+
param for param in api['parameters']
|
48 |
+
if param['name'] in api['required']
|
49 |
+
]
|
50 |
+
plugin_descriptions.append(api)
|
51 |
+
else:
|
52 |
+
action_desc['description'] = api_desc_template.format(
|
53 |
+
tool_name=action.name, description=action_desc['description'])
|
54 |
+
action_desc['parameters'] = [
|
55 |
+
param for param in action_desc['parameters']
|
56 |
+
if param['name'] in action_desc['required']
|
57 |
+
]
|
58 |
+
plugin_descriptions.append(action_desc)
|
59 |
+
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
|
60 |
+
|
61 |
+
|
62 |
+
class AgentForInternLM(Agent):
|
63 |
+
|
64 |
+
_INTERNAL_AGENT_CLS = Agent
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
llm: Union[BaseLLM, Dict],
|
69 |
+
plugins: Union[dict, List[dict]] = None,
|
70 |
+
interpreter: dict = None,
|
71 |
+
template: Union[str, dict, List[dict]] = None,
|
72 |
+
memory: Dict = dict(type=Memory),
|
73 |
+
output_format: Dict = dict(
|
74 |
+
type=MixedToolParser,
|
75 |
+
template=META_CN,
|
76 |
+
parsers=[
|
77 |
+
dict(type=PluginParser, template=PLUGIN_CN),
|
78 |
+
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
79 |
+
]),
|
80 |
+
aggregator: Dict = dict(type=InternLMToolAggregator),
|
81 |
+
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
82 |
+
finish_condition: Callable[
|
83 |
+
[AgentMessage],
|
84 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
85 |
+
max_turn: int = 4,
|
86 |
+
**kwargs,
|
87 |
+
):
|
88 |
+
agent = dict(
|
89 |
+
type=self._INTERNAL_AGENT_CLS,
|
90 |
+
llm=llm,
|
91 |
+
template=template,
|
92 |
+
output_format=output_format,
|
93 |
+
memory=memory,
|
94 |
+
aggregator=aggregator,
|
95 |
+
hooks=kwargs.pop('hooks', None),
|
96 |
+
)
|
97 |
+
self.agent = create_object(agent)
|
98 |
+
self.plugin_executor = plugins and ActionExecutor(
|
99 |
+
plugins, hooks=action_hooks)
|
100 |
+
self.interpreter_executor = interpreter and ActionExecutor(
|
101 |
+
interpreter, hooks=action_hooks)
|
102 |
+
if not (self.plugin_executor or self.interpreter_executor):
|
103 |
+
warnings.warn(
|
104 |
+
'Neither plugin nor interpreter executor is initialized. '
|
105 |
+
'An exception will be thrown when the agent call a tool.')
|
106 |
+
self.finish_condition = finish_condition
|
107 |
+
self.max_turn = max_turn
|
108 |
+
super().__init__(**kwargs)
|
109 |
+
|
110 |
+
def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
111 |
+
if isinstance(message, str):
|
112 |
+
message = AgentMessage(sender='user', content=message)
|
113 |
+
for _ in range(self.max_turn):
|
114 |
+
message = self.agent(message, session_id=session_id, **kwargs)
|
115 |
+
assert isinstance(message.formatted, dict)
|
116 |
+
if self.finish_condition(message):
|
117 |
+
return message
|
118 |
+
if message.formatted['tool_type']:
|
119 |
+
tool_type = message.formatted["tool_type"]
|
120 |
+
executor = getattr(self, f'{tool_type}_executor', None)
|
121 |
+
if not executor:
|
122 |
+
raise RuntimeError(f'No available {tool_type} executor')
|
123 |
+
message = executor(message, session_id=session_id)
|
124 |
+
return message
|
125 |
+
|
126 |
+
def get_steps(self, session_id=0):
|
127 |
+
steps, tool_type = [], None
|
128 |
+
for msg in self.agent.memory.get_memory(session_id):
|
129 |
+
if msg.sender == self.agent.name:
|
130 |
+
steps.append(
|
131 |
+
dict(role='thought', content=msg.formatted['thought']))
|
132 |
+
if msg.formatted['tool_type']:
|
133 |
+
tool_type = msg.formatted['tool_type']
|
134 |
+
steps.append(
|
135 |
+
dict(
|
136 |
+
role='tool',
|
137 |
+
content=msg.formatted['action'],
|
138 |
+
name=tool_type))
|
139 |
+
elif msg.sender != 'user':
|
140 |
+
feedback = dict(role='environment', content=msg.content)
|
141 |
+
if tool_type:
|
142 |
+
feedback['name'] = tool_type
|
143 |
+
steps.append(feedback)
|
144 |
+
return steps
|
145 |
+
|
146 |
+
|
147 |
+
class MathCoder(AgentForInternLM):
|
148 |
+
|
149 |
+
def __init__(
|
150 |
+
self,
|
151 |
+
llm: Union[BaseLLM, Dict],
|
152 |
+
interpreter: dict = dict(
|
153 |
+
type=IPythonInteractive, timeout=20, max_out_len=8192),
|
154 |
+
template: Union[str, dict, List[dict]] = None,
|
155 |
+
memory: Dict = dict(type=Memory),
|
156 |
+
output_format: Dict = dict(
|
157 |
+
type=InterpreterParser,
|
158 |
+
template=
|
159 |
+
('Integrate step-by-step reasoning and Python code to solve math problems '
|
160 |
+
'using the following guidelines:\n'
|
161 |
+
'- Analyze the question and write jupyter code to solve the problem;\n'
|
162 |
+
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
163 |
+
'units. \n')),
|
164 |
+
aggregator: Dict = dict(type=InternLMToolAggregator),
|
165 |
+
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
166 |
+
finish_condition: Callable[
|
167 |
+
[AgentMessage],
|
168 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
169 |
+
max_turn: int = 6,
|
170 |
+
**kwargs,
|
171 |
+
):
|
172 |
+
kwargs.pop('plugins', None)
|
173 |
+
super().__init__(
|
174 |
+
llm=llm,
|
175 |
+
interpreter=interpreter,
|
176 |
+
template=template,
|
177 |
+
memory=memory,
|
178 |
+
output_format=output_format,
|
179 |
+
aggregator=aggregator,
|
180 |
+
action_hooks=action_hooks,
|
181 |
+
finish_condition=finish_condition,
|
182 |
+
max_turn=max_turn,
|
183 |
+
**kwargs)
|
184 |
+
|
185 |
+
|
186 |
+
class AsyncAgentForInternLM(AsyncAgent):
|
187 |
+
|
188 |
+
_INTERNAL_AGENT_CLS = AsyncAgent
|
189 |
+
|
190 |
+
def __init__(
|
191 |
+
self,
|
192 |
+
llm: Union[BaseLLM, Dict],
|
193 |
+
plugins: Union[dict, List[dict]] = None,
|
194 |
+
interpreter: dict = None,
|
195 |
+
template: Union[str, dict, List[dict]] = None,
|
196 |
+
memory: Dict = dict(type=Memory),
|
197 |
+
output_format: Dict = dict(
|
198 |
+
type=MixedToolParser,
|
199 |
+
template=META_CN,
|
200 |
+
parsers=[
|
201 |
+
dict(type=PluginParser, template=PLUGIN_CN),
|
202 |
+
dict(type=InterpreterParser, template=INTERPRETER_CN),
|
203 |
+
]),
|
204 |
+
aggregator: Dict = dict(type=InternLMToolAggregator),
|
205 |
+
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
206 |
+
finish_condition: Callable[
|
207 |
+
[AgentMessage],
|
208 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
209 |
+
max_turn: int = 4,
|
210 |
+
**kwargs,
|
211 |
+
):
|
212 |
+
agent = dict(
|
213 |
+
type=self._INTERNAL_AGENT_CLS,
|
214 |
+
llm=llm,
|
215 |
+
template=template,
|
216 |
+
output_format=output_format,
|
217 |
+
memory=memory,
|
218 |
+
aggregator=aggregator,
|
219 |
+
hooks=kwargs.pop('hooks', None),
|
220 |
+
)
|
221 |
+
self.agent = create_object(agent)
|
222 |
+
self.plugin_executor = plugins and AsyncActionExecutor(
|
223 |
+
plugins, hooks=action_hooks)
|
224 |
+
self.interpreter_executor = interpreter and AsyncActionExecutor(
|
225 |
+
interpreter, hooks=action_hooks)
|
226 |
+
if not (self.plugin_executor or self.interpreter_executor):
|
227 |
+
warnings.warn(
|
228 |
+
'Neither plugin nor interpreter executor is initialized. '
|
229 |
+
'An exception will be thrown when the agent call a tool.')
|
230 |
+
self.finish_condition = finish_condition
|
231 |
+
self.max_turn = max_turn
|
232 |
+
super().__init__(**kwargs)
|
233 |
+
|
234 |
+
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
235 |
+
if isinstance(message, str):
|
236 |
+
message = AgentMessage(sender='user', content=message)
|
237 |
+
for _ in range(self.max_turn):
|
238 |
+
message = await self.agent(
|
239 |
+
message, session_id=session_id, **kwargs)
|
240 |
+
assert isinstance(message.formatted, dict)
|
241 |
+
if self.finish_condition(message):
|
242 |
+
return message
|
243 |
+
if message.formatted['tool_type']:
|
244 |
+
tool_type = message.formatted["tool_type"]
|
245 |
+
executor = getattr(self, f'{tool_type}_executor', None)
|
246 |
+
if not executor:
|
247 |
+
raise RuntimeError(f'No available {tool_type} executor')
|
248 |
+
message = await executor(message, session_id=session_id)
|
249 |
+
return message
|
250 |
+
|
251 |
+
def get_steps(self, session_id=0):
|
252 |
+
steps, tool_type = [], None
|
253 |
+
for msg in self.agent.memory.get_memory(session_id):
|
254 |
+
if msg.sender == self.agent.name:
|
255 |
+
steps.append(
|
256 |
+
dict(role='thought', content=msg.formatted['thought']))
|
257 |
+
if msg.formatted['tool_type']:
|
258 |
+
tool_type = msg.formatted['tool_type']
|
259 |
+
steps.append(
|
260 |
+
dict(
|
261 |
+
role='tool',
|
262 |
+
content=msg.formatted['action'],
|
263 |
+
name=tool_type))
|
264 |
+
elif msg.sender != 'user':
|
265 |
+
feedback = dict(role='environment', content=msg.content)
|
266 |
+
if tool_type:
|
267 |
+
feedback['name'] = tool_type
|
268 |
+
steps.append(feedback)
|
269 |
+
return steps
|
270 |
+
|
271 |
+
|
272 |
+
class AsyncMathCoder(AsyncAgentForInternLM):
|
273 |
+
|
274 |
+
def __init__(
|
275 |
+
self,
|
276 |
+
llm: Union[BaseLLM, Dict],
|
277 |
+
interpreter: dict = dict(type=AsyncIPythonInterpreter),
|
278 |
+
template: Union[str, dict, List[dict]] = None,
|
279 |
+
memory: Dict = dict(type=Memory),
|
280 |
+
output_format: Dict = dict(
|
281 |
+
type=InterpreterParser,
|
282 |
+
template=
|
283 |
+
('Integrate step-by-step reasoning and Python code to solve math problems '
|
284 |
+
'using the following guidelines:\n'
|
285 |
+
'- Analyze the question and write jupyter code to solve the problem;\n'
|
286 |
+
r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
|
287 |
+
'units. \n')),
|
288 |
+
aggregator: Dict = dict(type=InternLMToolAggregator),
|
289 |
+
action_hooks: List = [dict(type=InternLMActionProcessor)],
|
290 |
+
finish_condition: Callable[
|
291 |
+
[AgentMessage],
|
292 |
+
bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
|
293 |
+
max_turn: int = 6,
|
294 |
+
**kwargs,
|
295 |
+
):
|
296 |
+
kwargs.pop('plugins', None)
|
297 |
+
super().__init__(
|
298 |
+
llm=llm,
|
299 |
+
interpreter=interpreter,
|
300 |
+
template=template,
|
301 |
+
memory=memory,
|
302 |
+
output_format=output_format,
|
303 |
+
aggregator=aggregator,
|
304 |
+
action_hooks=action_hooks,
|
305 |
+
finish_condition=finish_condition,
|
306 |
+
max_turn=max_turn,
|
307 |
+
**kwargs)
|
308 |
+
|
309 |
+
async def forward(self, message: AgentMessage, session_id=0, **kwargs):
|
310 |
+
try:
|
311 |
+
return await super().forward(message, session_id, **kwargs)
|
312 |
+
finally:
|
313 |
+
interpreter = next(
|
314 |
+
iter(self.interpreter_executor.actions.values()))
|
315 |
+
if interpreter.name == 'AsyncIPythonInterpreter':
|
316 |
+
await interpreter.close_session(session_id)
|
lagent/distributed/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .http_serve import AgentAPIServer, AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer
|
2 |
+
from .ray_serve import AgentRayActor, AsyncAgentRayActor
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'AsyncAgentRayActor', 'AgentRayActor', 'HTTPAgentServer',
|
6 |
+
'HTTPAgentClient', 'AsyncHTTPAgentServer', 'AsyncHTTPAgentClient',
|
7 |
+
'AgentAPIServer'
|
8 |
+
]
|
lagent/distributed/http_serve/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .api_server import AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer
|
2 |
+
from .app import AgentAPIServer
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'HTTPAgentServer', 'HTTPAgentClient', 'AsyncHTTPAgentClient',
|
6 |
+
'AsyncHTTPAgentServer', 'AgentAPIServer'
|
7 |
+
]
|
lagent/distributed/http_serve/api_server.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
import time
|
6 |
+
|
7 |
+
import aiohttp
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from lagent.schema import AgentMessage
|
11 |
+
|
12 |
+
|
13 |
+
class HTTPAgentClient:
|
14 |
+
|
15 |
+
def __init__(self, host='127.0.0.1', port=8090, timeout=None):
|
16 |
+
self.host = host
|
17 |
+
self.port = port
|
18 |
+
self.timeout = timeout
|
19 |
+
|
20 |
+
@property
|
21 |
+
def is_alive(self):
|
22 |
+
try:
|
23 |
+
resp = requests.get(
|
24 |
+
f'http://{self.host}:{self.port}/health_check',
|
25 |
+
timeout=self.timeout)
|
26 |
+
return resp.status_code == 200
|
27 |
+
except:
|
28 |
+
return False
|
29 |
+
|
30 |
+
def __call__(self, *message, session_id: int = 0, **kwargs):
|
31 |
+
response = requests.post(
|
32 |
+
f'http://{self.host}:{self.port}/chat_completion',
|
33 |
+
json={
|
34 |
+
'message': [
|
35 |
+
m if isinstance(m, str) else m.model_dump()
|
36 |
+
for m in message
|
37 |
+
],
|
38 |
+
'session_id': session_id,
|
39 |
+
**kwargs,
|
40 |
+
},
|
41 |
+
headers={'Content-Type': 'application/json'},
|
42 |
+
timeout=self.timeout)
|
43 |
+
resp = response.json()
|
44 |
+
if response.status_code != 200:
|
45 |
+
return resp
|
46 |
+
return AgentMessage.model_validate(resp)
|
47 |
+
|
48 |
+
def state_dict(self, session_id: int = 0):
|
49 |
+
resp = requests.get(
|
50 |
+
f'http://{self.host}:{self.port}/memory/{session_id}',
|
51 |
+
timeout=self.timeout)
|
52 |
+
return resp.json()
|
53 |
+
|
54 |
+
|
55 |
+
class HTTPAgentServer(HTTPAgentClient):
|
56 |
+
|
57 |
+
def __init__(self, gpu_id, config, host='127.0.0.1', port=8090):
|
58 |
+
super().__init__(host, port)
|
59 |
+
self.gpu_id = gpu_id
|
60 |
+
self.config = config
|
61 |
+
self.start_server()
|
62 |
+
|
63 |
+
def start_server(self):
|
64 |
+
# set CUDA_VISIBLE_DEVICES in subprocess
|
65 |
+
env = os.environ.copy()
|
66 |
+
env['CUDA_VISIBLE_DEVICES'] = self.gpu_id
|
67 |
+
cmds = [
|
68 |
+
sys.executable, 'lagent/distributed/http_serve/app.py', '--host',
|
69 |
+
self.host, '--port',
|
70 |
+
str(self.port), '--config',
|
71 |
+
json.dumps(self.config)
|
72 |
+
]
|
73 |
+
self.process = subprocess.Popen(
|
74 |
+
cmds,
|
75 |
+
env=env,
|
76 |
+
stdout=subprocess.PIPE,
|
77 |
+
stderr=subprocess.STDOUT,
|
78 |
+
text=True)
|
79 |
+
|
80 |
+
while True:
|
81 |
+
output = self.process.stdout.readline()
|
82 |
+
if not output: # 如果读到 EOF,跳出循环
|
83 |
+
break
|
84 |
+
sys.stdout.write(output) # 打印到标准输出
|
85 |
+
sys.stdout.flush()
|
86 |
+
if 'Uvicorn running on' in output: # 根据实际输出调整
|
87 |
+
break
|
88 |
+
time.sleep(0.1)
|
89 |
+
|
90 |
+
def shutdown(self):
|
91 |
+
self.process.terminate()
|
92 |
+
self.process.wait()
|
93 |
+
|
94 |
+
|
95 |
+
class AsyncHTTPAgentMixin:
|
96 |
+
|
97 |
+
async def __call__(self, *message, session_id: int = 0, **kwargs):
|
98 |
+
async with aiohttp.ClientSession(
|
99 |
+
timeout=aiohttp.ClientTimeout(self.timeout)) as session:
|
100 |
+
async with session.post(
|
101 |
+
f'http://{self.host}:{self.port}/chat_completion',
|
102 |
+
json={
|
103 |
+
'message': [
|
104 |
+
m if isinstance(m, str) else m.model_dump()
|
105 |
+
for m in message
|
106 |
+
],
|
107 |
+
'session_id': session_id,
|
108 |
+
**kwargs,
|
109 |
+
},
|
110 |
+
headers={'Content-Type': 'application/json'},
|
111 |
+
) as response:
|
112 |
+
resp = await response.json()
|
113 |
+
if response.status != 200:
|
114 |
+
return resp
|
115 |
+
return AgentMessage.model_validate(resp)
|
116 |
+
|
117 |
+
|
118 |
+
class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient):
|
119 |
+
pass
|
120 |
+
|
121 |
+
|
122 |
+
class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer):
|
123 |
+
pass
|
lagent/distributed/http_serve/app.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
|
6 |
+
import uvicorn
|
7 |
+
from fastapi import FastAPI, HTTPException
|
8 |
+
from fastapi.middleware.cors import CORSMiddleware
|
9 |
+
from fastapi.requests import Request
|
10 |
+
|
11 |
+
from lagent.schema import AgentMessage
|
12 |
+
from lagent.utils import load_class_from_string
|
13 |
+
|
14 |
+
|
15 |
+
class AgentAPIServer:
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
config: dict,
|
19 |
+
host: str = '127.0.0.1',
|
20 |
+
port: int = 8090):
|
21 |
+
self.app = FastAPI(docs_url='/')
|
22 |
+
self.app.add_middleware(
|
23 |
+
CORSMiddleware,
|
24 |
+
allow_origins=['*'],
|
25 |
+
allow_credentials=True,
|
26 |
+
allow_methods=['*'],
|
27 |
+
allow_headers=['*'],
|
28 |
+
)
|
29 |
+
cls_name = config.pop('type')
|
30 |
+
python_path = config.pop('python_path', None)
|
31 |
+
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
|
32 |
+
cls_name, str) else cls_name
|
33 |
+
self.agent = cls_name(**config)
|
34 |
+
self.setup_routes()
|
35 |
+
self.run(host, port)
|
36 |
+
|
37 |
+
def setup_routes(self):
|
38 |
+
|
39 |
+
def heartbeat():
|
40 |
+
return {'status': 'success', 'timestamp': time.time()}
|
41 |
+
|
42 |
+
async def process_message(request: Request):
|
43 |
+
try:
|
44 |
+
body = await request.json()
|
45 |
+
message = [
|
46 |
+
m if isinstance(m, str) else AgentMessage.model_validate(m)
|
47 |
+
for m in body.pop('message')
|
48 |
+
]
|
49 |
+
result = await self.agent(*message, **body)
|
50 |
+
return result
|
51 |
+
except Exception as e:
|
52 |
+
logging.error(f'Error processing message: {str(e)}')
|
53 |
+
raise HTTPException(
|
54 |
+
status_code=500, detail='Internal Server Error')
|
55 |
+
|
56 |
+
def get_memory(session_id: int = 0):
|
57 |
+
try:
|
58 |
+
result = self.agent.state_dict(session_id)
|
59 |
+
return result
|
60 |
+
except KeyError:
|
61 |
+
raise HTTPException(
|
62 |
+
status_code=404, detail="Session ID not found")
|
63 |
+
except Exception as e:
|
64 |
+
logging.error(f'Error processing message: {str(e)}')
|
65 |
+
raise HTTPException(
|
66 |
+
status_code=500, detail='Internal Server Error')
|
67 |
+
|
68 |
+
self.app.add_api_route('/health_check', heartbeat, methods=['GET'])
|
69 |
+
self.app.add_api_route(
|
70 |
+
'/chat_completion', process_message, methods=['POST'])
|
71 |
+
self.app.add_api_route(
|
72 |
+
'/memory/{session_id}', get_memory, methods=['GET'])
|
73 |
+
|
74 |
+
def run(self, host='127.0.0.1', port=8090):
|
75 |
+
logging.info(f'Starting server at {host}:{port}')
|
76 |
+
uvicorn.run(self.app, host=host, port=port)
|
77 |
+
|
78 |
+
|
79 |
+
def parse_args():
|
80 |
+
parser = argparse.ArgumentParser(description='Async Agent API Server')
|
81 |
+
parser.add_argument('--host', type=str, default='127.0.0.1')
|
82 |
+
parser.add_argument('--port', type=int, default=8090)
|
83 |
+
parser.add_argument(
|
84 |
+
'--config',
|
85 |
+
type=json.loads,
|
86 |
+
required=True,
|
87 |
+
help='JSON configuration for the agent')
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
return args
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == '__main__':
|
94 |
+
logging.basicConfig(level=logging.INFO)
|
95 |
+
args = parse_args()
|
96 |
+
AgentAPIServer(args.config, host=args.host, port=args.port)
|
lagent/distributed/ray_serve/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .ray_warpper import AgentRayActor, AsyncAgentRayActor
|
2 |
+
|
3 |
+
__all__ = ['AsyncAgentRayActor', 'AgentRayActor']
|
lagent/distributed/ray_serve/ray_warpper.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
import sys
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import ray
|
6 |
+
|
7 |
+
from lagent.schema import AgentMessage
|
8 |
+
from lagent.utils import load_class_from_string
|
9 |
+
|
10 |
+
|
11 |
+
class AsyncAgentRayActor:
|
12 |
+
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
config: Dict,
|
16 |
+
num_gpus: int,
|
17 |
+
):
|
18 |
+
cls_name = config.pop('type')
|
19 |
+
python_path = config.pop('python_path', None)
|
20 |
+
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
|
21 |
+
cls_name, str) else cls_name
|
22 |
+
AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
|
23 |
+
self.agent_actor = AsyncAgentActor.remote(**config)
|
24 |
+
|
25 |
+
async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
|
26 |
+
response = await self.agent_actor.__call__.remote(
|
27 |
+
*message, session_id=session_id, **kwargs)
|
28 |
+
return response
|
29 |
+
|
30 |
+
|
31 |
+
class AgentRayActor:
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
config: Dict,
|
36 |
+
num_gpus: int,
|
37 |
+
):
|
38 |
+
cls_name = config.pop('type')
|
39 |
+
python_path = config.pop('python_path', None)
|
40 |
+
cls_name = load_class_from_string(cls_name, python_path) if isinstance(
|
41 |
+
cls_name, str) else cls_name
|
42 |
+
AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
|
43 |
+
self.agent_actor = AgentActor.remote(**config)
|
44 |
+
|
45 |
+
def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
|
46 |
+
response = self.agent_actor.__call__.remote(
|
47 |
+
*message, session_id=session_id, **kwargs)
|
48 |
+
return ray.get(response)
|
lagent/hooks/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .action_preprocessor import ActionPreprocessor, InternLMActionProcessor
|
2 |
+
from .hook import Hook, RemovableHandle
|
3 |
+
from .logger import MessageLogger
|
4 |
+
|
5 |
+
__all__ = [
|
6 |
+
'Hook', 'RemovableHandle', 'ActionPreprocessor', 'InternLMActionProcessor',
|
7 |
+
'MessageLogger'
|
8 |
+
]
|
lagent/hooks/action_preprocessor.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall
|
4 |
+
from .hook import Hook
|
5 |
+
|
6 |
+
|
7 |
+
class ActionPreprocessor(Hook):
|
8 |
+
"""The ActionPreprocessor is a hook that preprocesses the action message
|
9 |
+
and postprocesses the action return message.
|
10 |
+
|
11 |
+
"""
|
12 |
+
|
13 |
+
def before_action(self, executor, message, session_id):
|
14 |
+
assert isinstance(message.formatted, FunctionCall) or (
|
15 |
+
isinstance(message.formatted, dict) and 'name' in message.content
|
16 |
+
and 'parameters' in message.formatted) or (
|
17 |
+
'action' in message.formatted
|
18 |
+
and 'parameters' in message.formatted['action']
|
19 |
+
and 'name' in message.formatted['action'])
|
20 |
+
if isinstance(message.formatted, dict):
|
21 |
+
name = message.formatted.get('name',
|
22 |
+
message.formatted['action']['name'])
|
23 |
+
parameters = message.formatted.get(
|
24 |
+
'parameters', message.formatted['action']['parameters'])
|
25 |
+
else:
|
26 |
+
name = message.formatted.name
|
27 |
+
parameters = message.formatted.parameters
|
28 |
+
message.content = dict(name=name, parameters=parameters)
|
29 |
+
return message
|
30 |
+
|
31 |
+
def after_action(self, executor, message, session_id):
|
32 |
+
action_return = message.content
|
33 |
+
if isinstance(action_return, ActionReturn):
|
34 |
+
if action_return.state == ActionStatusCode.SUCCESS:
|
35 |
+
response = action_return.format_result()
|
36 |
+
else:
|
37 |
+
response = action_return.errmsg
|
38 |
+
else:
|
39 |
+
response = action_return
|
40 |
+
message.content = response
|
41 |
+
return message
|
42 |
+
|
43 |
+
|
44 |
+
class InternLMActionProcessor(ActionPreprocessor):
|
45 |
+
|
46 |
+
def __init__(self, code_parameter: str = 'command'):
|
47 |
+
self.code_parameter = code_parameter
|
48 |
+
|
49 |
+
def before_action(self, executor, message, session_id):
|
50 |
+
message = deepcopy(message)
|
51 |
+
assert isinstance(message.formatted, dict) and set(
|
52 |
+
message.formatted).issuperset(
|
53 |
+
{'tool_type', 'thought', 'action', 'status'})
|
54 |
+
if isinstance(message.formatted['action'], str):
|
55 |
+
# encapsulate code interpreter arguments
|
56 |
+
action_name = next(iter(executor.actions))
|
57 |
+
parameters = {self.code_parameter: message.formatted['action']}
|
58 |
+
if action_name in ['AsyncIPythonInterpreter']:
|
59 |
+
parameters['session_id'] = session_id
|
60 |
+
message.formatted['action'] = dict(
|
61 |
+
name=action_name, parameters=parameters)
|
62 |
+
return super().before_action(executor, message, session_id)
|
lagent/hooks/hook.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import count
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
from lagent.schema import AgentMessage
|
5 |
+
|
6 |
+
|
7 |
+
class Hook:
|
8 |
+
|
9 |
+
def before_agent(
|
10 |
+
self,
|
11 |
+
agent,
|
12 |
+
message: Tuple[AgentMessage],
|
13 |
+
session_id: int,
|
14 |
+
):
|
15 |
+
pass
|
16 |
+
|
17 |
+
def after_agent(
|
18 |
+
self,
|
19 |
+
agent,
|
20 |
+
message: AgentMessage,
|
21 |
+
session_id: int,
|
22 |
+
):
|
23 |
+
pass
|
24 |
+
|
25 |
+
def before_action(
|
26 |
+
self,
|
27 |
+
executor,
|
28 |
+
message: AgentMessage,
|
29 |
+
session_id: int,
|
30 |
+
):
|
31 |
+
pass
|
32 |
+
|
33 |
+
def after_action(
|
34 |
+
self,
|
35 |
+
executor,
|
36 |
+
message: AgentMessage,
|
37 |
+
session_id: int,
|
38 |
+
):
|
39 |
+
pass
|
40 |
+
|
41 |
+
|
42 |
+
class RemovableHandle:
|
43 |
+
_id_iter = count(0)
|
44 |
+
|
45 |
+
def __init__(self, hooks_dict):
|
46 |
+
self.hooks_dict = hooks_dict
|
47 |
+
self.id = next(self._id_iter)
|
48 |
+
|
49 |
+
def remove(self):
|
50 |
+
del self.hooks_dict[self.id]
|
lagent/hooks/logger.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from termcolor import COLORS, colored
|
5 |
+
|
6 |
+
from lagent.utils import get_logger
|
7 |
+
from .hook import Hook
|
8 |
+
|
9 |
+
|
10 |
+
class MessageLogger(Hook):
|
11 |
+
|
12 |
+
def __init__(self, name: str = 'lagent'):
|
13 |
+
self.logger = get_logger(
|
14 |
+
name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s')
|
15 |
+
self.sender2color = {}
|
16 |
+
|
17 |
+
def before_agent(self, agent, messages, session_id):
|
18 |
+
for message in messages:
|
19 |
+
self._process_message(message, session_id)
|
20 |
+
|
21 |
+
def after_agent(self, agent, message, session_id):
|
22 |
+
self._process_message(message, session_id)
|
23 |
+
|
24 |
+
def before_action(self, executor, message, session_id):
|
25 |
+
self._process_message(message, session_id)
|
26 |
+
|
27 |
+
def after_action(self, executor, message, session_id):
|
28 |
+
self._process_message(message, session_id)
|
29 |
+
|
30 |
+
def _process_message(self, message, session_id):
|
31 |
+
sender = message.sender
|
32 |
+
color = self.sender2color.setdefault(sender,
|
33 |
+
random.choice(list(COLORS)))
|
34 |
+
self.logger.info(
|
35 |
+
colored(
|
36 |
+
f'session id: {session_id}, message sender: {sender}\n'
|
37 |
+
f'{message.content}', color))
|
lagent/llms/__init__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_api import AsyncBaseAPILLM, BaseAPILLM
|
2 |
+
from .base_llm import AsyncBaseLLM, BaseLLM
|
3 |
+
from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
|
4 |
+
from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline,
|
5 |
+
AsyncLMDeployServer, LMDeployClient,
|
6 |
+
LMDeployPipeline, LMDeployServer)
|
7 |
+
from .meta_template import INTERNLM2_META
|
8 |
+
from .openai import GPTAPI, AsyncGPTAPI
|
9 |
+
from .sensenova import SensenovaAPI
|
10 |
+
from .vllm_wrapper import AsyncVllmModel, VllmModel
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'AsyncBaseLLM',
|
14 |
+
'BaseLLM',
|
15 |
+
'AsyncBaseAPILLM',
|
16 |
+
'BaseAPILLM',
|
17 |
+
'AsyncGPTAPI',
|
18 |
+
'GPTAPI',
|
19 |
+
'LMDeployClient',
|
20 |
+
'AsyncLMDeployClient',
|
21 |
+
'LMDeployPipeline',
|
22 |
+
'AsyncLMDeployPipeline',
|
23 |
+
'LMDeployServer',
|
24 |
+
'AsyncLMDeployServer',
|
25 |
+
'HFTransformer',
|
26 |
+
'HFTransformerCasualLM',
|
27 |
+
'INTERNLM2_META',
|
28 |
+
'HFTransformerChat',
|
29 |
+
'VllmModel',
|
30 |
+
'AsyncVllmModel',
|
31 |
+
'SensenovaAPI',
|
32 |
+
]
|
lagent/llms/base_api.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
|
5 |
+
|
6 |
+
|
7 |
+
class APITemplateParser:
|
8 |
+
"""Intermidate prompt template parser, specifically for API models.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
meta_template (Dict): The meta template for the model.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, meta_template: Optional[Dict] = None):
|
15 |
+
self.meta_template = meta_template
|
16 |
+
# Check meta template
|
17 |
+
if meta_template:
|
18 |
+
assert isinstance(meta_template, list)
|
19 |
+
self.roles: Dict[str, dict] = dict() # maps role name to config
|
20 |
+
for item in meta_template:
|
21 |
+
assert isinstance(item, dict)
|
22 |
+
assert item['role'] not in self.roles, \
|
23 |
+
'role in meta prompt must be unique!'
|
24 |
+
self.roles[item['role']] = item.copy()
|
25 |
+
|
26 |
+
def __call__(self, dialog: List[Union[str, List]]):
|
27 |
+
"""Parse the intermidate prompt template, and wrap it with meta
|
28 |
+
template if applicable. When the meta template is set and the input is
|
29 |
+
a list, the return value will be a list containing the full
|
30 |
+
conversation history. Each item looks like:
|
31 |
+
|
32 |
+
.. code-block:: python
|
33 |
+
|
34 |
+
{'role': 'user', 'content': '...'}).
|
35 |
+
|
36 |
+
Args:
|
37 |
+
dialog (List[str or list]): An intermidate prompt
|
38 |
+
template (potentially before being wrapped by meta template).
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
List[str or list]: The finalized prompt or a conversation.
|
42 |
+
"""
|
43 |
+
assert isinstance(dialog, (str, list))
|
44 |
+
if isinstance(dialog, str):
|
45 |
+
return dialog
|
46 |
+
if self.meta_template:
|
47 |
+
|
48 |
+
prompt = list()
|
49 |
+
# Whether to keep generating the prompt
|
50 |
+
generate = True
|
51 |
+
for i, item in enumerate(dialog):
|
52 |
+
if not generate:
|
53 |
+
break
|
54 |
+
if isinstance(item, str):
|
55 |
+
if item.strip():
|
56 |
+
# TODO: logger
|
57 |
+
warnings.warn('Non-empty string in prompt template '
|
58 |
+
'will be ignored in API models.')
|
59 |
+
else:
|
60 |
+
api_prompts = self._prompt2api(item)
|
61 |
+
prompt.append(api_prompts)
|
62 |
+
|
63 |
+
# merge the consecutive prompts assigned to the same role
|
64 |
+
new_prompt = list([prompt[0]])
|
65 |
+
last_role = prompt[0]['role']
|
66 |
+
for item in prompt[1:]:
|
67 |
+
if item['role'] == last_role:
|
68 |
+
new_prompt[-1]['content'] += '\n' + item['content']
|
69 |
+
else:
|
70 |
+
last_role = item['role']
|
71 |
+
new_prompt.append(item)
|
72 |
+
prompt = new_prompt
|
73 |
+
|
74 |
+
else:
|
75 |
+
# in case the model does not have any meta template
|
76 |
+
prompt = ''
|
77 |
+
last_sep = ''
|
78 |
+
for item in dialog:
|
79 |
+
if isinstance(item, str):
|
80 |
+
if item:
|
81 |
+
prompt += last_sep + item
|
82 |
+
elif item.get('content', ''):
|
83 |
+
prompt += last_sep + item.get('content', '')
|
84 |
+
last_sep = '\n'
|
85 |
+
return prompt
|
86 |
+
|
87 |
+
def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
|
88 |
+
"""Convert the prompts to a API-style prompts, given an updated
|
89 |
+
role_dict.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
prompts (Union[List, str]): The prompts to be converted.
|
93 |
+
role_dict (Dict[str, Dict]): The updated role dict.
|
94 |
+
for_gen (bool): If True, the prompts will be converted for
|
95 |
+
generation tasks. The conversion stops before the first
|
96 |
+
role whose "generate" is set to True.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
Tuple[str, bool]: The converted string, and whether the follow-up
|
100 |
+
conversion should be proceeded.
|
101 |
+
"""
|
102 |
+
if isinstance(prompts, str):
|
103 |
+
return prompts
|
104 |
+
elif isinstance(prompts, dict):
|
105 |
+
api_role = self._role2api_role(prompts)
|
106 |
+
return api_role
|
107 |
+
|
108 |
+
res = []
|
109 |
+
for prompt in prompts:
|
110 |
+
if isinstance(prompt, str):
|
111 |
+
raise TypeError('Mixing str without explicit role is not '
|
112 |
+
'allowed in API models!')
|
113 |
+
else:
|
114 |
+
api_role = self._role2api_role(prompt)
|
115 |
+
res.append(api_role)
|
116 |
+
return res
|
117 |
+
|
118 |
+
def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
|
119 |
+
merged_prompt = self.roles[role_prompt['role']]
|
120 |
+
if merged_prompt.get('fallback_role'):
|
121 |
+
merged_prompt = self.roles[self.roles[
|
122 |
+
merged_prompt['fallback_role']]]
|
123 |
+
res = role_prompt.copy()
|
124 |
+
res['role'] = merged_prompt['api_role']
|
125 |
+
res['content'] = merged_prompt.get('begin', '')
|
126 |
+
res['content'] += role_prompt.get('content', '')
|
127 |
+
res['content'] += merged_prompt.get('end', '')
|
128 |
+
return res
|
129 |
+
|
130 |
+
|
131 |
+
class BaseAPILLM(BaseLLM):
|
132 |
+
"""Base class for API model wrapper.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
model_type (str): The type of model.
|
136 |
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
137 |
+
meta_template (Dict, optional): The model's meta prompt
|
138 |
+
template if needed, in case the requirement of injecting or
|
139 |
+
wrapping of any meta instructions.
|
140 |
+
"""
|
141 |
+
|
142 |
+
is_api: bool = True
|
143 |
+
|
144 |
+
def __init__(self,
|
145 |
+
model_type: str,
|
146 |
+
retry: int = 2,
|
147 |
+
template_parser: 'APITemplateParser' = APITemplateParser,
|
148 |
+
meta_template: Optional[Dict] = None,
|
149 |
+
*,
|
150 |
+
max_new_tokens: int = 512,
|
151 |
+
top_p: float = 0.8,
|
152 |
+
top_k: int = 40,
|
153 |
+
temperature: float = 0.8,
|
154 |
+
repetition_penalty: float = 0.0,
|
155 |
+
stop_words: Union[List[str], str] = None):
|
156 |
+
self.model_type = model_type
|
157 |
+
self.meta_template = meta_template
|
158 |
+
self.retry = retry
|
159 |
+
if template_parser:
|
160 |
+
self.template_parser = template_parser(meta_template)
|
161 |
+
|
162 |
+
if isinstance(stop_words, str):
|
163 |
+
stop_words = [stop_words]
|
164 |
+
self.gen_params = dict(
|
165 |
+
max_new_tokens=max_new_tokens,
|
166 |
+
top_p=top_p,
|
167 |
+
top_k=top_k,
|
168 |
+
temperature=temperature,
|
169 |
+
repetition_penalty=repetition_penalty,
|
170 |
+
stop_words=stop_words,
|
171 |
+
skip_special_tokens=False)
|
172 |
+
|
173 |
+
|
174 |
+
class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM):
|
175 |
+
pass
|
lagent/llms/base_llm.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import copy
|
2 |
+
from typing import Dict, List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
|
5 |
+
class LMTemplateParser:
|
6 |
+
"""Intermidate prompt template parser, specifically for language models.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
meta_template (list of dict, optional): The meta template for the
|
10 |
+
model.
|
11 |
+
"""
|
12 |
+
|
13 |
+
def __init__(self, meta_template: Optional[List[Dict]] = None):
|
14 |
+
self.meta_template = meta_template
|
15 |
+
if meta_template:
|
16 |
+
assert isinstance(meta_template, list)
|
17 |
+
self.roles: Dict[str, dict] = dict() # maps role name to config
|
18 |
+
for item in meta_template:
|
19 |
+
assert isinstance(item, dict)
|
20 |
+
assert item['role'] not in self.roles, \
|
21 |
+
'role in meta prompt must be unique!'
|
22 |
+
self.roles[item['role']] = item.copy()
|
23 |
+
|
24 |
+
def __call__(self, dialog) -> str:
|
25 |
+
"""Parse a prompt template, and wrap it with meta template if
|
26 |
+
applicable.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
dialog (List[str or PromptList]): A prompt
|
30 |
+
template (potentially before being wrapped by meta template).
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: The final string.
|
34 |
+
"""
|
35 |
+
assert isinstance(dialog, (str, list))
|
36 |
+
if isinstance(dialog, str):
|
37 |
+
return dialog
|
38 |
+
if self.meta_template:
|
39 |
+
|
40 |
+
prompt = ''
|
41 |
+
for index, item in enumerate(dialog):
|
42 |
+
if isinstance(item, str):
|
43 |
+
prompt += item
|
44 |
+
else:
|
45 |
+
new_str = self._prompt2str(item, index == len(dialog) - 1)
|
46 |
+
prompt += new_str
|
47 |
+
else:
|
48 |
+
# in case the model does not have any meta template
|
49 |
+
prompt = ''
|
50 |
+
last_sep = ''
|
51 |
+
for item in dialog:
|
52 |
+
if isinstance(item, str):
|
53 |
+
if item:
|
54 |
+
prompt += last_sep + item
|
55 |
+
elif item.get('content', ''):
|
56 |
+
prompt += last_sep + item.get('prompt', '')
|
57 |
+
last_sep = '\n'
|
58 |
+
return prompt
|
59 |
+
|
60 |
+
def _format_begin(self, role_cfg, message):
|
61 |
+
name = message.get('name', None)
|
62 |
+
if name is not None:
|
63 |
+
begin = role_cfg['begin'].get('with_name', '')
|
64 |
+
if name in role_cfg['begin'].get('name', {}):
|
65 |
+
begin = begin.format(name=role_cfg['begin']['name'][name])
|
66 |
+
else:
|
67 |
+
begin = begin.format(name=name)
|
68 |
+
else:
|
69 |
+
if isinstance(role_cfg.get('begin', ''), str):
|
70 |
+
begin = role_cfg.get('begin', '')
|
71 |
+
elif isinstance(role_cfg['begin'], dict):
|
72 |
+
begin = role_cfg['begin'].get('without_name', '')
|
73 |
+
return begin
|
74 |
+
|
75 |
+
def _prompt2str(self,
|
76 |
+
prompt: Union[str, Dict],
|
77 |
+
last: bool = False) -> Tuple[str, bool]:
|
78 |
+
if isinstance(prompt, str):
|
79 |
+
return prompt
|
80 |
+
merged_prompt = self.roles.get(prompt['role'])
|
81 |
+
|
82 |
+
if merged_prompt.get('fallback_role'):
|
83 |
+
merged_prompt = self.roles.get(merged_prompt['fallback_role'])
|
84 |
+
begin = self._format_begin(merged_prompt, prompt)
|
85 |
+
res = begin
|
86 |
+
if last and merged_prompt.get('generate', False):
|
87 |
+
res += prompt.get('content', '')
|
88 |
+
return res
|
89 |
+
res += prompt.get('content', '') + merged_prompt.get('end', '')
|
90 |
+
if last and merged_prompt['role'] != 'assistant':
|
91 |
+
res += self._format_begin(self.roles['assistant'], {})
|
92 |
+
return res
|
93 |
+
return res
|
94 |
+
|
95 |
+
|
96 |
+
class BaseLLM:
|
97 |
+
"""Base class for model wrapper.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
path (str): The path to the model.
|
101 |
+
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
|
102 |
+
to 512.
|
103 |
+
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
104 |
+
Defaults to False.
|
105 |
+
meta_template (list of dict, optional): The model's meta prompt
|
106 |
+
template if needed, in case the requirement of injecting or
|
107 |
+
wrapping of any meta instructions.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self,
|
111 |
+
path: str,
|
112 |
+
tokenizer_only: bool = False,
|
113 |
+
template_parser: 'LMTemplateParser' = LMTemplateParser,
|
114 |
+
meta_template: Optional[List[Dict]] = None,
|
115 |
+
*,
|
116 |
+
max_new_tokens: int = 512,
|
117 |
+
top_p: float = 0.8,
|
118 |
+
top_k: float = 40,
|
119 |
+
temperature: float = 0.8,
|
120 |
+
repetition_penalty: float = 1.0,
|
121 |
+
stop_words: Union[List[str], str] = None):
|
122 |
+
self.path = path
|
123 |
+
self.tokenizer_only = tokenizer_only
|
124 |
+
# meta template
|
125 |
+
self.template_parser = template_parser(meta_template)
|
126 |
+
self.eos_token_id = None
|
127 |
+
if meta_template and 'eos_token_id' in meta_template:
|
128 |
+
self.eos_token_id = meta_template['eos_token_id']
|
129 |
+
|
130 |
+
if isinstance(stop_words, str):
|
131 |
+
stop_words = [stop_words]
|
132 |
+
self.gen_params = dict(
|
133 |
+
max_new_tokens=max_new_tokens,
|
134 |
+
top_p=top_p,
|
135 |
+
top_k=top_k,
|
136 |
+
temperature=temperature,
|
137 |
+
repetition_penalty=repetition_penalty,
|
138 |
+
stop_words=stop_words)
|
139 |
+
|
140 |
+
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
|
141 |
+
"""Generate results given a str (or list of) inputs.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
inputs (Union[str, List[str]]):
|
145 |
+
gen_params (dict): The input params for generation.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Union[str, List[str]]: A (list of) generated strings.
|
149 |
+
|
150 |
+
eg.
|
151 |
+
batched = True
|
152 |
+
if isinstance(inputs, str):
|
153 |
+
inputs = [inputs]
|
154 |
+
batched = False
|
155 |
+
response = ['']
|
156 |
+
if batched:
|
157 |
+
return response
|
158 |
+
return response[0]
|
159 |
+
"""
|
160 |
+
raise NotImplementedError
|
161 |
+
|
162 |
+
def stream_generate(self, inputs: str, **gen_params) -> List[str]:
|
163 |
+
"""Generate results as streaming given a str inputs.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
inputs (str):
|
167 |
+
gen_params (dict): The input params for generation.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
str: A generated string.
|
171 |
+
"""
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
def chat(self,
|
175 |
+
inputs: Union[List[dict], List[List[dict]]],
|
176 |
+
session_ids: Union[int, List[int]] = None,
|
177 |
+
**gen_params):
|
178 |
+
"""Generate completion from a list of templates.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
inputs (Union[List[dict], List[List[dict]]]):
|
182 |
+
gen_params (dict): The input params for generation.
|
183 |
+
Returns:
|
184 |
+
"""
|
185 |
+
if isinstance(inputs[0], list):
|
186 |
+
_inputs = list()
|
187 |
+
for msg in inputs:
|
188 |
+
_inputs.append(self.template_parser(msg))
|
189 |
+
else:
|
190 |
+
_inputs = self.template_parser(inputs)
|
191 |
+
return self.generate(_inputs, **gen_params)
|
192 |
+
|
193 |
+
def stream_chat(self, inputs: List[dict], **gen_params):
|
194 |
+
"""Generate results as streaming given a list of templates.
|
195 |
+
|
196 |
+
Args:
|
197 |
+
inputs (Union[List[dict]):
|
198 |
+
gen_params (dict): The input params for generation.
|
199 |
+
Returns:
|
200 |
+
"""
|
201 |
+
raise NotImplementedError
|
202 |
+
|
203 |
+
def tokenize(self, prompts: Union[str, List[str], List[dict],
|
204 |
+
List[List[dict]]]):
|
205 |
+
"""Tokenize the input prompts.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
prompts(str | List[str]): user's prompt, or a batch prompts
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
|
212 |
+
ids, ids' length and requested output length
|
213 |
+
"""
|
214 |
+
raise NotImplementedError
|
215 |
+
|
216 |
+
def update_gen_params(self, **kwargs):
|
217 |
+
gen_params = copy(self.gen_params)
|
218 |
+
gen_params.update(kwargs)
|
219 |
+
return gen_params
|
220 |
+
|
221 |
+
|
222 |
+
class AsyncLLMMixin:
|
223 |
+
|
224 |
+
async def generate(self,
|
225 |
+
inputs: Union[str, List[str]],
|
226 |
+
session_ids: Union[int, List[int]] = None,
|
227 |
+
**gen_params) -> str:
|
228 |
+
"""Generate results given a str (or list of) inputs.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
inputs (Union[str, List[str]]):
|
232 |
+
gen_params (dict): The input params for generation.
|
233 |
+
|
234 |
+
Returns:
|
235 |
+
Union[str, List[str]]: A (list of) generated strings.
|
236 |
+
|
237 |
+
eg.
|
238 |
+
batched = True
|
239 |
+
if isinstance(inputs, str):
|
240 |
+
inputs = [inputs]
|
241 |
+
batched = False
|
242 |
+
response = ['']
|
243 |
+
if batched:
|
244 |
+
return response
|
245 |
+
return response[0]
|
246 |
+
"""
|
247 |
+
raise NotImplementedError
|
248 |
+
|
249 |
+
async def stream_generate(self, inputs: str, **gen_params) -> List[str]:
|
250 |
+
"""Generate results as streaming given a str inputs.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
inputs (str):
|
254 |
+
gen_params (dict): The input params for generation.
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
str: A generated string.
|
258 |
+
"""
|
259 |
+
raise NotImplementedError
|
260 |
+
|
261 |
+
async def chat(self,
|
262 |
+
inputs: Union[List[dict], List[List[dict]]],
|
263 |
+
session_ids: Union[int, List[int]] = None,
|
264 |
+
**gen_params):
|
265 |
+
"""Generate completion from a list of templates.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
inputs (Union[List[dict], List[List[dict]]]):
|
269 |
+
gen_params (dict): The input params for generation.
|
270 |
+
Returns:
|
271 |
+
"""
|
272 |
+
if isinstance(inputs[0], list):
|
273 |
+
_inputs = list()
|
274 |
+
for msg in inputs:
|
275 |
+
_inputs.append(self.template_parser(msg))
|
276 |
+
else:
|
277 |
+
_inputs = self.template_parser(inputs)
|
278 |
+
return await self.generate(_inputs, session_ids, **gen_params)
|
279 |
+
|
280 |
+
async def stream_chat(self, inputs: List[dict], **gen_params):
|
281 |
+
"""Generate results as streaming given a list of templates.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
inputs (Union[List[dict]):
|
285 |
+
gen_params (dict): The input params for generation.
|
286 |
+
Returns:
|
287 |
+
"""
|
288 |
+
raise NotImplementedError
|
289 |
+
|
290 |
+
async def tokenize(self, prompts: Union[str, List[str], List[dict],
|
291 |
+
List[List[dict]]]):
|
292 |
+
"""Tokenize the input prompts.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
prompts(str | List[str]): user's prompt, or a batch prompts
|
296 |
+
|
297 |
+
Returns:
|
298 |
+
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
|
299 |
+
ids, ids' length and requested output length
|
300 |
+
"""
|
301 |
+
raise NotImplementedError
|
302 |
+
|
303 |
+
|
304 |
+
class AsyncBaseLLM(AsyncLLMMixin, BaseLLM):
|
305 |
+
pass
|
lagent/llms/huggingface.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
|
5 |
+
from lagent.schema import ModelStatusCode
|
6 |
+
from .base_api import APITemplateParser
|
7 |
+
from .base_llm import BaseLLM
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class HFTransformer(BaseLLM):
|
13 |
+
"""Model wrapper around HuggingFace general models.
|
14 |
+
|
15 |
+
Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
|
16 |
+
chat/web_demo.py)
|
17 |
+
|
18 |
+
Args:
|
19 |
+
path (str): The name or path to HuggingFace's model.
|
20 |
+
tokenizer_path (str): The path to the tokenizer. Defaults to None.
|
21 |
+
tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
|
22 |
+
Defaults to {}.
|
23 |
+
tokenizer_only (bool): If True, only the tokenizer will be initialized.
|
24 |
+
Defaults to False.
|
25 |
+
model_kwargs (dict): Keyword arguments for the model, used in loader.
|
26 |
+
Defaults to dict(device_map='auto').
|
27 |
+
meta_template (Dict, optional): The model's meta prompt
|
28 |
+
template if needed, in case the requirement of injecting or
|
29 |
+
wrapping of any meta instructions.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self,
|
33 |
+
path: str,
|
34 |
+
tokenizer_path: Optional[str] = None,
|
35 |
+
tokenizer_kwargs: dict = dict(),
|
36 |
+
tokenizer_only: bool = False,
|
37 |
+
model_kwargs: dict = dict(device_map='auto'),
|
38 |
+
meta_template: Optional[Dict] = None,
|
39 |
+
stop_words_id: Union[List[int], int] = None,
|
40 |
+
**kwargs):
|
41 |
+
super().__init__(
|
42 |
+
path=path,
|
43 |
+
tokenizer_only=tokenizer_only,
|
44 |
+
meta_template=meta_template,
|
45 |
+
**kwargs)
|
46 |
+
if isinstance(stop_words_id, int):
|
47 |
+
stop_words_id = [stop_words_id]
|
48 |
+
self.gen_params.update(stop_words_id=stop_words_id)
|
49 |
+
if self.gen_params['stop_words'] is not None and \
|
50 |
+
self.gen_params['stop_words_id'] is not None:
|
51 |
+
logger.warning('Both stop_words and stop_words_id are specified,'
|
52 |
+
'only stop_words_id will be used.')
|
53 |
+
|
54 |
+
self._load_tokenizer(
|
55 |
+
path=path,
|
56 |
+
tokenizer_path=tokenizer_path,
|
57 |
+
tokenizer_kwargs=tokenizer_kwargs)
|
58 |
+
if not tokenizer_only:
|
59 |
+
self._load_model(path=path, model_kwargs=model_kwargs)
|
60 |
+
|
61 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501
|
62 |
+
self.logits_processor = LogitsProcessorList()
|
63 |
+
self.stopping_criteria = StoppingCriteriaList()
|
64 |
+
self.prefix_allowed_tokens_fn = None
|
65 |
+
|
66 |
+
stop_words_id = []
|
67 |
+
if self.gen_params.get('stop_words_id'):
|
68 |
+
stop_words_id = self.gen_params.get('stop_words_id')
|
69 |
+
elif self.gen_params.get('stop_words'):
|
70 |
+
for sw in self.gen_params.get('stop_words'):
|
71 |
+
stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
|
72 |
+
self.additional_eos_token_id = stop_words_id
|
73 |
+
|
74 |
+
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
|
75 |
+
tokenizer_kwargs: dict):
|
76 |
+
from transformers import AutoTokenizer
|
77 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
78 |
+
tokenizer_path if tokenizer_path else path,
|
79 |
+
trust_remote_code=True,
|
80 |
+
**tokenizer_kwargs)
|
81 |
+
|
82 |
+
if self.tokenizer.pad_token_id is None:
|
83 |
+
if self.tokenizer.eos_token is not None:
|
84 |
+
logger.warning(
|
85 |
+
f'Using eos_token_id {self.tokenizer.eos_token} '
|
86 |
+
'as pad_token_id.')
|
87 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
88 |
+
else:
|
89 |
+
from transformers.generation import GenerationConfig
|
90 |
+
self.gcfg = GenerationConfig.from_pretrained(path)
|
91 |
+
|
92 |
+
if self.gcfg.pad_token_id is not None:
|
93 |
+
logger.warning(
|
94 |
+
f'Using pad_token_id {self.gcfg.pad_token_id} '
|
95 |
+
'as pad_token_id.')
|
96 |
+
self.tokenizer.pad_token_id = self.gcfg.pad_token_id
|
97 |
+
else:
|
98 |
+
raise ValueError(
|
99 |
+
'pad_token_id is not set for this tokenizer. Try to '
|
100 |
+
'set pad_token_id via passing '
|
101 |
+
'`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
|
102 |
+
|
103 |
+
def _load_model(self, path: str, model_kwargs: dict):
|
104 |
+
import torch
|
105 |
+
from transformers import AutoModel
|
106 |
+
model_kwargs.setdefault('torch_dtype', torch.float16)
|
107 |
+
self.model = AutoModel.from_pretrained(
|
108 |
+
path, trust_remote_code=True, **model_kwargs)
|
109 |
+
self.model.eval()
|
110 |
+
|
111 |
+
def tokenize(self, inputs: str):
|
112 |
+
assert isinstance(inputs, str)
|
113 |
+
inputs = self.tokenizer(
|
114 |
+
inputs, return_tensors='pt', return_length=True)
|
115 |
+
return inputs['input_ids'].tolist()
|
116 |
+
|
117 |
+
def generate(
|
118 |
+
self,
|
119 |
+
inputs: Union[str, List[str]],
|
120 |
+
do_sample: bool = True,
|
121 |
+
**kwargs,
|
122 |
+
):
|
123 |
+
"""Return the chat completions in non-stream mode.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
127 |
+
do_sample (bool): do sampling if enabled
|
128 |
+
Returns:
|
129 |
+
(a list of/batched) text/chat completion
|
130 |
+
"""
|
131 |
+
for status, chunk, _ in self.stream_generate(inputs, do_sample,
|
132 |
+
**kwargs):
|
133 |
+
response = chunk
|
134 |
+
return response
|
135 |
+
|
136 |
+
def stream_generate(
|
137 |
+
self,
|
138 |
+
inputs: List[str],
|
139 |
+
do_sample: bool = True,
|
140 |
+
**kwargs,
|
141 |
+
):
|
142 |
+
"""Return the chat completions in stream mode.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
146 |
+
do_sample (bool): do sampling if enabled
|
147 |
+
Returns:
|
148 |
+
tuple(Status, str, int): status, text/chat completion,
|
149 |
+
generated token number
|
150 |
+
"""
|
151 |
+
import torch
|
152 |
+
from torch import nn
|
153 |
+
with torch.no_grad():
|
154 |
+
batched = True
|
155 |
+
if isinstance(inputs, str):
|
156 |
+
inputs = [inputs]
|
157 |
+
batched = False
|
158 |
+
inputs = self.tokenizer(
|
159 |
+
inputs, padding=True, return_tensors='pt', return_length=True)
|
160 |
+
input_length = inputs['length']
|
161 |
+
for k, v in inputs.items():
|
162 |
+
inputs[k] = v.cuda()
|
163 |
+
input_ids = inputs['input_ids']
|
164 |
+
attention_mask = inputs['attention_mask']
|
165 |
+
batch_size = input_ids.shape[0]
|
166 |
+
input_ids_seq_length = input_ids.shape[-1]
|
167 |
+
generation_config = self.model.generation_config
|
168 |
+
generation_config = copy.deepcopy(generation_config)
|
169 |
+
new_gen_params = self.update_gen_params(**kwargs)
|
170 |
+
generation_config.update(**new_gen_params)
|
171 |
+
generation_config.update(**kwargs)
|
172 |
+
model_kwargs = generation_config.to_dict()
|
173 |
+
model_kwargs['attention_mask'] = attention_mask
|
174 |
+
_, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
|
175 |
+
generation_config.bos_token_id,
|
176 |
+
generation_config.eos_token_id,
|
177 |
+
)
|
178 |
+
if eos_token_id is None:
|
179 |
+
if self.gcfg.eos_token_id is not None:
|
180 |
+
eos_token_id = self.gcfg.eos_token_id
|
181 |
+
else:
|
182 |
+
eos_token_id = []
|
183 |
+
if isinstance(eos_token_id, int):
|
184 |
+
eos_token_id = [eos_token_id]
|
185 |
+
if self.additional_eos_token_id is not None:
|
186 |
+
eos_token_id.extend(self.additional_eos_token_id)
|
187 |
+
eos_token_id_tensor = torch.tensor(eos_token_id).to(
|
188 |
+
input_ids.device) if eos_token_id is not None else None
|
189 |
+
generation_config.max_length = (
|
190 |
+
generation_config.max_new_tokens + input_ids_seq_length)
|
191 |
+
# Set generation parameters if not already defined
|
192 |
+
logits_processor = self.logits_processor
|
193 |
+
stopping_criteria = self.stopping_criteria
|
194 |
+
|
195 |
+
logits_processor = self.model._get_logits_processor(
|
196 |
+
generation_config=generation_config,
|
197 |
+
input_ids_seq_length=input_ids_seq_length,
|
198 |
+
encoder_input_ids=input_ids,
|
199 |
+
prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
|
200 |
+
logits_processor=logits_processor,
|
201 |
+
)
|
202 |
+
|
203 |
+
stopping_criteria = self.model._get_stopping_criteria(
|
204 |
+
generation_config=generation_config,
|
205 |
+
stopping_criteria=stopping_criteria)
|
206 |
+
logits_warper = self.model._get_logits_warper(generation_config)
|
207 |
+
|
208 |
+
unfinished_sequences = input_ids.new(batch_size).fill_(1)
|
209 |
+
scores = None
|
210 |
+
while True:
|
211 |
+
model_inputs = self.model.prepare_inputs_for_generation(
|
212 |
+
input_ids, **model_kwargs)
|
213 |
+
# forward pass to get next token
|
214 |
+
outputs = self.model(
|
215 |
+
**model_inputs,
|
216 |
+
return_dict=True,
|
217 |
+
output_attentions=False,
|
218 |
+
output_hidden_states=False,
|
219 |
+
)
|
220 |
+
|
221 |
+
next_token_logits = outputs.logits[:, -1, :]
|
222 |
+
|
223 |
+
# pre-process distribution
|
224 |
+
next_token_scores = logits_processor(input_ids,
|
225 |
+
next_token_logits)
|
226 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
227 |
+
|
228 |
+
# sample
|
229 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
230 |
+
if do_sample:
|
231 |
+
next_tokens = torch.multinomial(
|
232 |
+
probs, num_samples=1).squeeze(1)
|
233 |
+
else:
|
234 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
235 |
+
|
236 |
+
# update generated ids, model inputs,
|
237 |
+
# and length for next step
|
238 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]],
|
239 |
+
dim=-1)
|
240 |
+
model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
|
241 |
+
outputs,
|
242 |
+
model_kwargs,
|
243 |
+
is_encoder_decoder=False)
|
244 |
+
unfinished_sequences = unfinished_sequences.mul(
|
245 |
+
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
|
246 |
+
eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
|
247 |
+
output_token_ids = input_ids.cpu().tolist()
|
248 |
+
for i in range(len(output_token_ids)):
|
249 |
+
output_token_ids[i] = output_token_ids[i][:][
|
250 |
+
input_length[i]:]
|
251 |
+
# Find the first occurrence of
|
252 |
+
# an EOS token in the sequence
|
253 |
+
first_eos_idx = next(
|
254 |
+
(idx
|
255 |
+
for idx, token_id in enumerate(output_token_ids[i])
|
256 |
+
if token_id in eos_token_id), None)
|
257 |
+
# If an EOS token is found, only the previous
|
258 |
+
# part of it is retained
|
259 |
+
if first_eos_idx is not None:
|
260 |
+
output_token_ids[i] = output_token_ids[
|
261 |
+
i][:first_eos_idx]
|
262 |
+
|
263 |
+
response = self.tokenizer.batch_decode(output_token_ids)
|
264 |
+
# print(response)
|
265 |
+
if not batched:
|
266 |
+
response = response[0]
|
267 |
+
yield ModelStatusCode.STREAM_ING, response, None
|
268 |
+
# stop when each sentence is finished,
|
269 |
+
# or if we exceed the maximum length
|
270 |
+
if (unfinished_sequences.max() == 0
|
271 |
+
or stopping_criteria(input_ids, scores)):
|
272 |
+
break
|
273 |
+
yield ModelStatusCode.END, response, None
|
274 |
+
|
275 |
+
def stream_chat(
|
276 |
+
self,
|
277 |
+
inputs: List[dict],
|
278 |
+
do_sample: bool = True,
|
279 |
+
**kwargs,
|
280 |
+
):
|
281 |
+
"""Return the chat completions in stream mode.
|
282 |
+
|
283 |
+
Args:
|
284 |
+
inputs (List[dict]): input messages to be completed.
|
285 |
+
do_sample (bool): do sampling if enabled
|
286 |
+
Returns:
|
287 |
+
the text/chat completion
|
288 |
+
"""
|
289 |
+
prompt = self.template_parser(inputs)
|
290 |
+
yield from self.stream_generate(prompt, do_sample, **kwargs)
|
291 |
+
|
292 |
+
|
293 |
+
class HFTransformerCasualLM(HFTransformer):
|
294 |
+
|
295 |
+
def _load_model(self, path: str, model_kwargs: dict):
|
296 |
+
import torch
|
297 |
+
from transformers import AutoModelForCausalLM
|
298 |
+
model_kwargs.setdefault('torch_dtype', torch.float16)
|
299 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
300 |
+
path, trust_remote_code=True, **model_kwargs)
|
301 |
+
self.model.eval()
|
302 |
+
|
303 |
+
|
304 |
+
class HFTransformerChat(HFTransformerCasualLM):
|
305 |
+
|
306 |
+
def __init__(self, template_parser=APITemplateParser, **kwargs):
|
307 |
+
super().__init__(template_parser=template_parser, **kwargs)
|
308 |
+
|
309 |
+
def chat(self,
|
310 |
+
inputs: Union[List[dict], List[List[dict]]],
|
311 |
+
do_sample: bool = True,
|
312 |
+
**kwargs):
|
313 |
+
"""Return the chat completions in stream mode.
|
314 |
+
|
315 |
+
Args:
|
316 |
+
inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
|
317 |
+
do_sample (bool): do sampling if enabled
|
318 |
+
Returns:
|
319 |
+
the text/chat completion
|
320 |
+
"""
|
321 |
+
# handle batch inference with vanilla for loop
|
322 |
+
if isinstance(inputs[0], list):
|
323 |
+
resps = []
|
324 |
+
for input in inputs:
|
325 |
+
resps.append(self.chat(input, do_sample, **kwargs))
|
326 |
+
return resps
|
327 |
+
prompt = self.template_parser(inputs)
|
328 |
+
query = prompt[-1]['content']
|
329 |
+
history = prompt[:-1]
|
330 |
+
try:
|
331 |
+
response, history = self.model.chat(
|
332 |
+
self.tokenizer, query, history=history)
|
333 |
+
except Exception as e:
|
334 |
+
# handle over-length input error
|
335 |
+
logger.warning(str(e))
|
336 |
+
response = ''
|
337 |
+
return response
|
lagent/llms/lmdeploy_wrapper.py
ADDED
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import copy
|
3 |
+
import logging
|
4 |
+
from dataclasses import asdict
|
5 |
+
from typing import List, Optional, Union
|
6 |
+
|
7 |
+
import aiohttp
|
8 |
+
|
9 |
+
from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
|
10 |
+
from lagent.schema import ModelStatusCode
|
11 |
+
from lagent.utils.util import filter_suffix
|
12 |
+
|
13 |
+
|
14 |
+
class TritonClient(BaseLLM):
|
15 |
+
"""TritonClient is a wrapper of TritonClient for LLM.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
tritonserver_addr (str): the address in format "ip:port" of
|
19 |
+
triton inference server
|
20 |
+
model_name (str): the name of the model
|
21 |
+
session_len (int): the context size
|
22 |
+
max_tokens (int): the expected generated token numbers
|
23 |
+
"""
|
24 |
+
|
25 |
+
def __init__(self,
|
26 |
+
tritonserver_addr: str,
|
27 |
+
model_name: str,
|
28 |
+
session_len: int = 32768,
|
29 |
+
log_level: str = 'WARNING',
|
30 |
+
**kwargs):
|
31 |
+
super().__init__(path=None, **kwargs)
|
32 |
+
try:
|
33 |
+
from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
|
34 |
+
except Exception as e:
|
35 |
+
logging.error(f'{e}')
|
36 |
+
raise RuntimeError('DO NOT use turbomind.chatbot since it has '
|
37 |
+
'been removed by lmdeploy since v0.5.2')
|
38 |
+
self.state_map = {
|
39 |
+
StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
|
40 |
+
StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
|
41 |
+
StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED,
|
42 |
+
StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING,
|
43 |
+
StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
|
44 |
+
ModelStatusCode.SESSION_OUT_OF_LIMIT,
|
45 |
+
StatusCode.TRITON_SESSION_INVALID_ARG:
|
46 |
+
ModelStatusCode.SESSION_INVALID_ARG,
|
47 |
+
StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY
|
48 |
+
}
|
49 |
+
self.chatbot = Chatbot(
|
50 |
+
tritonserver_addr=tritonserver_addr,
|
51 |
+
model_name=model_name,
|
52 |
+
session_len=session_len,
|
53 |
+
log_level=log_level,
|
54 |
+
**kwargs)
|
55 |
+
|
56 |
+
def generate(self,
|
57 |
+
inputs: Union[str, List[str]],
|
58 |
+
session_id: int = 2967,
|
59 |
+
request_id: str = '',
|
60 |
+
sequence_start: bool = True,
|
61 |
+
sequence_end: bool = True,
|
62 |
+
skip_special_tokens: bool = False,
|
63 |
+
**kwargs):
|
64 |
+
"""Start a new round conversation of a session. Return the chat
|
65 |
+
completions in non-stream mode.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
inputs (str, List[str]): user's prompt(s) in this round
|
69 |
+
session_id (int): the identical id of a session
|
70 |
+
request_id (str): the identical id of this round conversation
|
71 |
+
sequence_start (bool): start flag of a session
|
72 |
+
sequence_end (bool): end flag of a session
|
73 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
74 |
+
in the decoding. Default to be False.
|
75 |
+
Returns:
|
76 |
+
(a list of/batched) text/chat completion
|
77 |
+
"""
|
78 |
+
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
79 |
+
if isinstance(inputs, str):
|
80 |
+
inputs = [inputs]
|
81 |
+
prompt = inputs
|
82 |
+
|
83 |
+
assert isinstance(session_id, int), \
|
84 |
+
f'INT session id is required, but got {type(session_id)}'
|
85 |
+
|
86 |
+
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
87 |
+
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
88 |
+
|
89 |
+
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
90 |
+
logger.info(f'session {session_id}, request_id {request_id}, '
|
91 |
+
f'max_out_len {max_new_tokens}')
|
92 |
+
|
93 |
+
if self.chatbot._session is None:
|
94 |
+
sequence_start = True
|
95 |
+
self.chatbot._session = Session(session_id=session_id)
|
96 |
+
elif self.chatbot._session.status == 0:
|
97 |
+
logger.error(f'session {session_id} has been ended. Please set '
|
98 |
+
f'`sequence_start` be True if you want to restart it')
|
99 |
+
return ''
|
100 |
+
|
101 |
+
self.chatbot._session.status = 1
|
102 |
+
self.chatbot._session.request_id = request_id
|
103 |
+
self.chatbot._session.response = ''
|
104 |
+
|
105 |
+
status, res, _ = None, '', 0
|
106 |
+
for status, res, _ in self.chatbot._stream_infer(
|
107 |
+
self.chatbot._session,
|
108 |
+
prompt,
|
109 |
+
max_new_tokens,
|
110 |
+
sequence_start,
|
111 |
+
sequence_end,
|
112 |
+
skip_special_tokens=skip_special_tokens):
|
113 |
+
status = self.state_map.get(status)
|
114 |
+
if status < ModelStatusCode.END:
|
115 |
+
return ''
|
116 |
+
elif status == ModelStatusCode.END:
|
117 |
+
self.chatbot._session.histories = (
|
118 |
+
self.chatbot._session.histories +
|
119 |
+
self.chatbot._session.prompt +
|
120 |
+
self.chatbot._session.response)
|
121 |
+
# remove stop_words
|
122 |
+
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
123 |
+
return res
|
124 |
+
|
125 |
+
def stream_chat(self,
|
126 |
+
inputs: List[dict],
|
127 |
+
session_id: int = 2967,
|
128 |
+
request_id: str = '',
|
129 |
+
sequence_start: bool = True,
|
130 |
+
sequence_end: bool = True,
|
131 |
+
skip_special_tokens: bool = False,
|
132 |
+
**kwargs):
|
133 |
+
"""Start a new round conversation of a session. Return the chat
|
134 |
+
completions in stream mode.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
session_id (int): the identical id of a session
|
138 |
+
inputs (List[dict]): user's inputs in this round conversation
|
139 |
+
request_id (str): the identical id of this round conversation
|
140 |
+
sequence_start (bool): start flag of a session
|
141 |
+
sequence_end (bool): end flag of a session
|
142 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
143 |
+
in the decoding. Default to be False.
|
144 |
+
Returns:
|
145 |
+
tuple(Status, str, int): status, text/chat completion,
|
146 |
+
generated token number
|
147 |
+
"""
|
148 |
+
from lmdeploy.serve.turbomind.chatbot import Session, get_logger
|
149 |
+
assert isinstance(session_id, int), \
|
150 |
+
f'INT session id is required, but got {type(session_id)}'
|
151 |
+
|
152 |
+
self.chatbot.cfg = self._update_gen_params(**kwargs)
|
153 |
+
max_new_tokens = self.chatbot.cfg.max_new_tokens
|
154 |
+
|
155 |
+
logger = get_logger('service.ft', log_level=self.chatbot.log_level)
|
156 |
+
logger.info(f'session {session_id}, request_id {request_id}, '
|
157 |
+
f'max_out_len {max_new_tokens}')
|
158 |
+
|
159 |
+
if self.chatbot._session is None:
|
160 |
+
sequence_start = True
|
161 |
+
self.chatbot._session = Session(session_id=session_id)
|
162 |
+
elif self.chatbot._session.status == 0:
|
163 |
+
logger.error(f'session {session_id} has been ended. Please set '
|
164 |
+
f'`sequence_start` be True if you want to restart it')
|
165 |
+
return ModelStatusCode.SESSION_CLOSED, '', 0
|
166 |
+
|
167 |
+
self.chatbot._session.status = 1
|
168 |
+
self.chatbot._session.request_id = request_id
|
169 |
+
self.chatbot._session.response = ''
|
170 |
+
|
171 |
+
prompt = self.template_parser(inputs)
|
172 |
+
status, res, _ = None, '', 0
|
173 |
+
for status, res, _ in self.chatbot._stream_infer(
|
174 |
+
self.chatbot._session,
|
175 |
+
prompt,
|
176 |
+
max_new_tokens,
|
177 |
+
sequence_start,
|
178 |
+
sequence_end,
|
179 |
+
skip_special_tokens=skip_special_tokens):
|
180 |
+
status = self.state_map.get(status)
|
181 |
+
# The stop symbol also appears in the output of the last STREAM_ING state.
|
182 |
+
res = filter_suffix(res, self.gen_params.get('stop_words'))
|
183 |
+
if status < ModelStatusCode.END:
|
184 |
+
return status, res, _
|
185 |
+
elif status == ModelStatusCode.END: # remove stop_words
|
186 |
+
self.chatbot._session.histories = (
|
187 |
+
self.chatbot._session.histories +
|
188 |
+
self.chatbot._session.prompt +
|
189 |
+
self.chatbot._session.response)
|
190 |
+
yield status, res, _
|
191 |
+
break
|
192 |
+
else:
|
193 |
+
yield status, res, _
|
194 |
+
|
195 |
+
def _update_gen_params(self, **kwargs):
|
196 |
+
import mmengine
|
197 |
+
new_gen_params = self.update_gen_params(**kwargs)
|
198 |
+
self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
|
199 |
+
stop_words = self.chatbot._stop_words(
|
200 |
+
self.gen_params.get('stop_words'))
|
201 |
+
cfg = mmengine.Config(
|
202 |
+
dict(
|
203 |
+
session_len=self.chatbot.model.session_len,
|
204 |
+
stop_words=stop_words,
|
205 |
+
bad_words=self.chatbot.cfg.bad_words,
|
206 |
+
**new_gen_params))
|
207 |
+
return cfg
|
208 |
+
|
209 |
+
|
210 |
+
class LMDeployPipeline(BaseLLM):
|
211 |
+
"""
|
212 |
+
|
213 |
+
Args:
|
214 |
+
path (str): The path to the model.
|
215 |
+
It could be one of the following options:
|
216 |
+
- i) A local directory path of a turbomind model which is
|
217 |
+
converted by `lmdeploy convert` command or download
|
218 |
+
from ii) and iii).
|
219 |
+
- ii) The model_id of a lmdeploy-quantized model hosted
|
220 |
+
inside a model repo on huggingface.co, such as
|
221 |
+
"InternLM/internlm-chat-20b-4bit",
|
222 |
+
"lmdeploy/llama2-chat-70b-4bit", etc.
|
223 |
+
- iii) The model_id of a model hosted inside a model repo
|
224 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
225 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
226 |
+
and so on.
|
227 |
+
model_name (str): needed when model_path is a pytorch model on
|
228 |
+
huggingface.co, such as "internlm-chat-7b",
|
229 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
230 |
+
tp (int): tensor parallel
|
231 |
+
pipeline_cfg (dict): config of pipeline
|
232 |
+
"""
|
233 |
+
|
234 |
+
def __init__(self,
|
235 |
+
path: str,
|
236 |
+
model_name: Optional[str] = None,
|
237 |
+
tp: int = 1,
|
238 |
+
pipeline_cfg=dict(),
|
239 |
+
**kwargs):
|
240 |
+
import lmdeploy
|
241 |
+
from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info
|
242 |
+
|
243 |
+
self.str_version = lmdeploy.__version__
|
244 |
+
self.version = version_info
|
245 |
+
self.do_sample = kwargs.pop('do_sample', None)
|
246 |
+
if self.do_sample is not None and self.version < (0, 6, 0):
|
247 |
+
raise RuntimeError(
|
248 |
+
'`do_sample` parameter is not supported by lmdeploy until '
|
249 |
+
f'v0.6.0, but currently using lmdeloy {self.str_version}')
|
250 |
+
super().__init__(path=path, **kwargs)
|
251 |
+
backend_config = copy.deepcopy(pipeline_cfg)
|
252 |
+
backend_config.update(tp=tp)
|
253 |
+
backend_config = {
|
254 |
+
k: v
|
255 |
+
for k, v in backend_config.items()
|
256 |
+
if hasattr(TurbomindEngineConfig, k)
|
257 |
+
}
|
258 |
+
backend_config = TurbomindEngineConfig(**backend_config)
|
259 |
+
chat_template_config = ChatTemplateConfig(
|
260 |
+
model_name=model_name) if model_name else None
|
261 |
+
self.model = pipeline(
|
262 |
+
model_path=self.path,
|
263 |
+
backend_config=backend_config,
|
264 |
+
chat_template_config=chat_template_config,
|
265 |
+
log_level='WARNING')
|
266 |
+
|
267 |
+
def generate(self,
|
268 |
+
inputs: Union[str, List[str]],
|
269 |
+
do_preprocess: bool = None,
|
270 |
+
skip_special_tokens: bool = False,
|
271 |
+
return_dict: bool = False,
|
272 |
+
**kwargs):
|
273 |
+
"""Return the chat completions in non-stream mode.
|
274 |
+
|
275 |
+
Args:
|
276 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
277 |
+
do_preprocess (bool): whether pre-process the messages. Default to
|
278 |
+
True, which means chat_template will be applied.
|
279 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
280 |
+
in the decoding. Default to be False.
|
281 |
+
Returns:
|
282 |
+
(a list of/batched) text/chat completion
|
283 |
+
"""
|
284 |
+
from lmdeploy.messages import GenerationConfig
|
285 |
+
batched = True
|
286 |
+
if isinstance(inputs, str):
|
287 |
+
inputs = [inputs]
|
288 |
+
batched = False
|
289 |
+
prompt = inputs
|
290 |
+
do_sample = kwargs.pop('do_sample', None)
|
291 |
+
gen_params = self.update_gen_params(**kwargs)
|
292 |
+
|
293 |
+
if do_sample is None:
|
294 |
+
do_sample = self.do_sample
|
295 |
+
if do_sample is not None and self.version < (0, 6, 0):
|
296 |
+
raise RuntimeError(
|
297 |
+
'`do_sample` parameter is not supported by lmdeploy until '
|
298 |
+
f'v0.6.0, but currently using lmdeloy {self.str_version}')
|
299 |
+
if self.version >= (0, 6, 0):
|
300 |
+
if do_sample is None:
|
301 |
+
do_sample = gen_params['top_k'] > 1 or gen_params[
|
302 |
+
'temperature'] > 0
|
303 |
+
gen_params.update(do_sample=do_sample)
|
304 |
+
|
305 |
+
gen_config = GenerationConfig(
|
306 |
+
skip_special_tokens=skip_special_tokens, **gen_params)
|
307 |
+
response = self.model.batch_infer(
|
308 |
+
prompt, gen_config=gen_config, do_preprocess=do_preprocess)
|
309 |
+
texts = [resp.text for resp in response]
|
310 |
+
# remove stop_words
|
311 |
+
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
|
312 |
+
for resp, text in zip(response, texts):
|
313 |
+
resp.text = text
|
314 |
+
if batched:
|
315 |
+
return [asdict(resp)
|
316 |
+
for resp in response] if return_dict else texts
|
317 |
+
return asdict(response[0]) if return_dict else texts[0]
|
318 |
+
|
319 |
+
|
320 |
+
class LMDeployServer(BaseLLM):
|
321 |
+
"""
|
322 |
+
|
323 |
+
Args:
|
324 |
+
path (str): The path to the model.
|
325 |
+
It could be one of the following options:
|
326 |
+
- i) A local directory path of a turbomind model which is
|
327 |
+
converted by `lmdeploy convert` command or download from
|
328 |
+
ii) and iii).
|
329 |
+
- ii) The model_id of a lmdeploy-quantized model hosted
|
330 |
+
inside a model repo on huggingface.co, such as
|
331 |
+
"InternLM/internlm-chat-20b-4bit",
|
332 |
+
"lmdeploy/llama2-chat-70b-4bit", etc.
|
333 |
+
- iii) The model_id of a model hosted inside a model repo
|
334 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
335 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
336 |
+
and so on.
|
337 |
+
model_name (str): needed when model_path is a pytorch model on
|
338 |
+
huggingface.co, such as "internlm-chat-7b",
|
339 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
340 |
+
server_name (str): host ip for serving
|
341 |
+
server_port (int): server port
|
342 |
+
tp (int): tensor parallel
|
343 |
+
log_level (str): set log level whose value among
|
344 |
+
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
|
345 |
+
"""
|
346 |
+
|
347 |
+
def __init__(self,
|
348 |
+
path: str,
|
349 |
+
model_name: Optional[str] = None,
|
350 |
+
server_name: str = '0.0.0.0',
|
351 |
+
server_port: int = 23333,
|
352 |
+
tp: int = 1,
|
353 |
+
log_level: str = 'WARNING',
|
354 |
+
serve_cfg=dict(),
|
355 |
+
**kwargs):
|
356 |
+
super().__init__(path=path, **kwargs)
|
357 |
+
self.model_name = model_name
|
358 |
+
# TODO get_logger issue in multi processing
|
359 |
+
import lmdeploy
|
360 |
+
self.client = lmdeploy.serve(
|
361 |
+
model_path=self.path,
|
362 |
+
model_name=model_name,
|
363 |
+
server_name=server_name,
|
364 |
+
server_port=server_port,
|
365 |
+
tp=tp,
|
366 |
+
log_level=log_level,
|
367 |
+
**serve_cfg)
|
368 |
+
|
369 |
+
def generate(self,
|
370 |
+
inputs: Union[str, List[str]],
|
371 |
+
session_id: int = 2967,
|
372 |
+
sequence_start: bool = True,
|
373 |
+
sequence_end: bool = True,
|
374 |
+
ignore_eos: bool = False,
|
375 |
+
skip_special_tokens: Optional[bool] = False,
|
376 |
+
timeout: int = 30,
|
377 |
+
**kwargs) -> List[str]:
|
378 |
+
"""Start a new round conversation of a session. Return the chat
|
379 |
+
completions in non-stream mode.
|
380 |
+
|
381 |
+
Args:
|
382 |
+
inputs (str, List[str]): user's prompt(s) in this round
|
383 |
+
session_id (int): the identical id of a session
|
384 |
+
sequence_start (bool): start flag of a session
|
385 |
+
sequence_end (bool): end flag of a session
|
386 |
+
ignore_eos (bool): indicator for ignoring eos
|
387 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
388 |
+
in the decoding. Default to be False.
|
389 |
+
timeout (int): max time to wait for response
|
390 |
+
Returns:
|
391 |
+
(a list of/batched) text/chat completion
|
392 |
+
"""
|
393 |
+
|
394 |
+
batched = True
|
395 |
+
if isinstance(inputs, str):
|
396 |
+
inputs = [inputs]
|
397 |
+
batched = False
|
398 |
+
|
399 |
+
gen_params = self.update_gen_params(**kwargs)
|
400 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
401 |
+
gen_params.update(max_tokens=max_new_tokens)
|
402 |
+
|
403 |
+
resp = [''] * len(inputs)
|
404 |
+
for text in self.client.completions_v1(
|
405 |
+
self.model_name,
|
406 |
+
inputs,
|
407 |
+
session_id=session_id,
|
408 |
+
sequence_start=sequence_start,
|
409 |
+
sequence_end=sequence_end,
|
410 |
+
stream=False,
|
411 |
+
ignore_eos=ignore_eos,
|
412 |
+
skip_special_tokens=skip_special_tokens,
|
413 |
+
timeout=timeout,
|
414 |
+
**gen_params):
|
415 |
+
resp = [
|
416 |
+
resp[i] + item['text']
|
417 |
+
for i, item in enumerate(text['choices'])
|
418 |
+
]
|
419 |
+
# remove stop_words
|
420 |
+
resp = filter_suffix(resp, self.gen_params.get('stop_words'))
|
421 |
+
if not batched:
|
422 |
+
return resp[0]
|
423 |
+
return resp
|
424 |
+
|
425 |
+
def stream_chat(self,
|
426 |
+
inputs: List[dict],
|
427 |
+
session_id=0,
|
428 |
+
sequence_start: bool = True,
|
429 |
+
sequence_end: bool = True,
|
430 |
+
stream: bool = True,
|
431 |
+
ignore_eos: bool = False,
|
432 |
+
skip_special_tokens: Optional[bool] = False,
|
433 |
+
timeout: int = 30,
|
434 |
+
**kwargs):
|
435 |
+
"""Start a new round conversation of a session. Return the chat
|
436 |
+
completions in stream mode.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
session_id (int): the identical id of a session
|
440 |
+
inputs (List[dict]): user's inputs in this round conversation
|
441 |
+
sequence_start (bool): start flag of a session
|
442 |
+
sequence_end (bool): end flag of a session
|
443 |
+
stream (bool): return in a streaming format if enabled
|
444 |
+
ignore_eos (bool): indicator for ignoring eos
|
445 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
446 |
+
in the decoding. Default to be False.
|
447 |
+
timeout (int): max time to wait for response
|
448 |
+
Returns:
|
449 |
+
tuple(Status, str, int): status, text/chat completion,
|
450 |
+
generated token number
|
451 |
+
"""
|
452 |
+
gen_params = self.update_gen_params(**kwargs)
|
453 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
454 |
+
gen_params.update(max_tokens=max_new_tokens)
|
455 |
+
prompt = self.template_parser(inputs)
|
456 |
+
|
457 |
+
resp = ''
|
458 |
+
finished = False
|
459 |
+
stop_words = self.gen_params.get('stop_words')
|
460 |
+
for text in self.client.completions_v1(
|
461 |
+
self.model_name,
|
462 |
+
prompt,
|
463 |
+
session_id=session_id,
|
464 |
+
sequence_start=sequence_start,
|
465 |
+
sequence_end=sequence_end,
|
466 |
+
stream=stream,
|
467 |
+
ignore_eos=ignore_eos,
|
468 |
+
skip_special_tokens=skip_special_tokens,
|
469 |
+
timeout=timeout,
|
470 |
+
**gen_params):
|
471 |
+
resp += text['choices'][0]['text']
|
472 |
+
if not resp:
|
473 |
+
continue
|
474 |
+
# remove stop_words
|
475 |
+
for sw in stop_words:
|
476 |
+
if sw in resp:
|
477 |
+
resp = filter_suffix(resp, stop_words)
|
478 |
+
finished = True
|
479 |
+
break
|
480 |
+
yield ModelStatusCode.STREAM_ING, resp, None
|
481 |
+
if finished:
|
482 |
+
break
|
483 |
+
yield ModelStatusCode.END, resp, None
|
484 |
+
|
485 |
+
|
486 |
+
class LMDeployClient(LMDeployServer):
|
487 |
+
"""
|
488 |
+
|
489 |
+
Args:
|
490 |
+
url (str): communicating address 'http://<ip>:<port>' of
|
491 |
+
api_server
|
492 |
+
model_name (str): needed when model_path is a pytorch model on
|
493 |
+
huggingface.co, such as "internlm-chat-7b",
|
494 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
495 |
+
"""
|
496 |
+
|
497 |
+
def __init__(self, url: str, model_name: str, **kwargs):
|
498 |
+
BaseLLM.__init__(self, path=url, **kwargs)
|
499 |
+
from lmdeploy.serve.openai.api_client import APIClient
|
500 |
+
self.client = APIClient(url)
|
501 |
+
self.model_name = model_name
|
502 |
+
|
503 |
+
|
504 |
+
class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline):
|
505 |
+
"""
|
506 |
+
|
507 |
+
Args:
|
508 |
+
path (str): The path to the model.
|
509 |
+
It could be one of the following options:
|
510 |
+
- i) A local directory path of a turbomind model which is
|
511 |
+
converted by `lmdeploy convert` command or download
|
512 |
+
from ii) and iii).
|
513 |
+
- ii) The model_id of a lmdeploy-quantized model hosted
|
514 |
+
inside a model repo on huggingface.co, such as
|
515 |
+
"InternLM/internlm-chat-20b-4bit",
|
516 |
+
"lmdeploy/llama2-chat-70b-4bit", etc.
|
517 |
+
- iii) The model_id of a model hosted inside a model repo
|
518 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
519 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
520 |
+
and so on.
|
521 |
+
model_name (str): needed when model_path is a pytorch model on
|
522 |
+
huggingface.co, such as "internlm-chat-7b",
|
523 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
524 |
+
tp (int): tensor parallel
|
525 |
+
pipeline_cfg (dict): config of pipeline
|
526 |
+
"""
|
527 |
+
|
528 |
+
async def generate(self,
|
529 |
+
inputs: Union[str, List[str]],
|
530 |
+
session_ids: Union[int, List[int]] = None,
|
531 |
+
do_preprocess: bool = None,
|
532 |
+
skip_special_tokens: bool = False,
|
533 |
+
return_dict: bool = False,
|
534 |
+
**kwargs):
|
535 |
+
"""Return the chat completions in non-stream mode.
|
536 |
+
|
537 |
+
Args:
|
538 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
539 |
+
do_preprocess (bool): whether pre-process the messages. Default to
|
540 |
+
True, which means chat_template will be applied.
|
541 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
542 |
+
in the decoding. Default to be False.
|
543 |
+
Returns:
|
544 |
+
(a list of/batched) text/chat completion
|
545 |
+
"""
|
546 |
+
from lmdeploy.messages import GenerationConfig, Response
|
547 |
+
|
548 |
+
batched = True
|
549 |
+
if isinstance(inputs, str):
|
550 |
+
inputs = [inputs]
|
551 |
+
batched = False
|
552 |
+
if session_ids is None:
|
553 |
+
session_ids = list(range(len(inputs)))
|
554 |
+
elif isinstance(session_ids, (int, str)):
|
555 |
+
session_ids = [session_ids]
|
556 |
+
assert len(inputs) == len(session_ids)
|
557 |
+
|
558 |
+
prompt = inputs
|
559 |
+
gen_params = self.update_gen_params(**kwargs)
|
560 |
+
gen_config = GenerationConfig(
|
561 |
+
skip_special_tokens=skip_special_tokens, **gen_params)
|
562 |
+
|
563 |
+
async def _inner_generate(uid, text):
|
564 |
+
resp = Response('', 0, 0, uid)
|
565 |
+
async for out in self.model.generate(
|
566 |
+
text,
|
567 |
+
uid,
|
568 |
+
gen_config,
|
569 |
+
stream_response=True,
|
570 |
+
sequence_start=True,
|
571 |
+
sequence_end=True,
|
572 |
+
do_preprocess=do_preprocess,
|
573 |
+
**kwargs):
|
574 |
+
resp.text += out.response
|
575 |
+
resp.generate_token_len = out.generate_token_len
|
576 |
+
resp.input_token_len = out.input_token_len
|
577 |
+
resp.finish_reason = out.finish_reason
|
578 |
+
if out.token_ids:
|
579 |
+
resp.token_ids.extend(out.token_ids)
|
580 |
+
if out.logprobs:
|
581 |
+
if resp.logprobs is None:
|
582 |
+
resp.logprobs = []
|
583 |
+
resp.logprobs.extend(out.logprobs)
|
584 |
+
return resp
|
585 |
+
|
586 |
+
response = await asyncio.gather(*[
|
587 |
+
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
|
588 |
+
])
|
589 |
+
texts = [resp.text for resp in response]
|
590 |
+
# remove stop_words
|
591 |
+
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
|
592 |
+
for resp, text in zip(response, texts):
|
593 |
+
resp.text = text
|
594 |
+
if batched:
|
595 |
+
return [asdict(resp)
|
596 |
+
for resp in response] if return_dict else texts
|
597 |
+
return asdict(response[0]) if return_dict else texts[0]
|
598 |
+
|
599 |
+
|
600 |
+
class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer):
|
601 |
+
"""
|
602 |
+
|
603 |
+
Args:
|
604 |
+
path (str): The path to the model.
|
605 |
+
It could be one of the following options:
|
606 |
+
- i) A local directory path of a turbomind model which is
|
607 |
+
converted by `lmdeploy convert` command or download from
|
608 |
+
ii) and iii).
|
609 |
+
- ii) The model_id of a lmdeploy-quantized model hosted
|
610 |
+
inside a model repo on huggingface.co, such as
|
611 |
+
"InternLM/internlm-chat-20b-4bit",
|
612 |
+
"lmdeploy/llama2-chat-70b-4bit", etc.
|
613 |
+
- iii) The model_id of a model hosted inside a model repo
|
614 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
615 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
616 |
+
and so on.
|
617 |
+
model_name (str): needed when model_path is a pytorch model on
|
618 |
+
huggingface.co, such as "internlm-chat-7b",
|
619 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
620 |
+
server_name (str): host ip for serving
|
621 |
+
server_port (int): server port
|
622 |
+
tp (int): tensor parallel
|
623 |
+
log_level (str): set log level whose value among
|
624 |
+
[CRITICAL, ERROR, WARNING, INFO, DEBUG]
|
625 |
+
"""
|
626 |
+
|
627 |
+
async def generate(
|
628 |
+
self,
|
629 |
+
inputs: Union[str, List[str]],
|
630 |
+
session_ids: Union[int, List[int]] = None,
|
631 |
+
sequence_start: bool = True,
|
632 |
+
sequence_end: bool = True,
|
633 |
+
ignore_eos: bool = False,
|
634 |
+
skip_special_tokens: Optional[bool] = False,
|
635 |
+
timeout: int = 30,
|
636 |
+
**kwargs,
|
637 |
+
):
|
638 |
+
"""Start a new round conversation of a session. Return the chat
|
639 |
+
completions in non-stream mode.
|
640 |
+
|
641 |
+
Args:
|
642 |
+
inputs (str, List[str]): user's prompt(s) in this round
|
643 |
+
session_ids (int, List[int]): session id(s)
|
644 |
+
sequence_start (bool): start flag of a session
|
645 |
+
sequence_end (bool): end flag of a session
|
646 |
+
ignore_eos (bool): indicator for ignoring eos
|
647 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
648 |
+
in the decoding. Default to be False.
|
649 |
+
timeout (int): max time to wait for response
|
650 |
+
Returns:
|
651 |
+
(a list of/batched) text/chat completion
|
652 |
+
"""
|
653 |
+
from lmdeploy.serve.openai.api_client import json_loads
|
654 |
+
|
655 |
+
batched = True
|
656 |
+
if isinstance(inputs, str):
|
657 |
+
inputs = [inputs]
|
658 |
+
batched = False
|
659 |
+
|
660 |
+
gen_params = self.update_gen_params(**kwargs)
|
661 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
662 |
+
gen_params.update(max_tokens=max_new_tokens)
|
663 |
+
|
664 |
+
responses = [''] * len(inputs)
|
665 |
+
pload = dict(
|
666 |
+
model=self.model_name,
|
667 |
+
prompt=inputs,
|
668 |
+
sequence_start=sequence_start,
|
669 |
+
sequence_end=sequence_end,
|
670 |
+
stream=False,
|
671 |
+
ignore_eos=ignore_eos,
|
672 |
+
skip_special_tokens=skip_special_tokens,
|
673 |
+
timeout=timeout,
|
674 |
+
**gen_params)
|
675 |
+
async with aiohttp.ClientSession(
|
676 |
+
timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
|
677 |
+
async with session.post(
|
678 |
+
self.client.completions_v1_url,
|
679 |
+
headers=self.client.headers,
|
680 |
+
json=pload) as resp:
|
681 |
+
async for chunk in resp.content:
|
682 |
+
if chunk:
|
683 |
+
decoded = chunk.decode('utf-8')
|
684 |
+
output = json_loads(decoded)
|
685 |
+
responses = [
|
686 |
+
response + item['text'] for response, item in zip(
|
687 |
+
responses, output['choices'])
|
688 |
+
]
|
689 |
+
# remove stop_words
|
690 |
+
responses = filter_suffix(responses, self.gen_params.get('stop_words'))
|
691 |
+
if not batched:
|
692 |
+
return responses[0]
|
693 |
+
return responses
|
694 |
+
|
695 |
+
async def stream_chat(
|
696 |
+
self,
|
697 |
+
inputs: List[dict],
|
698 |
+
session_id: int = None,
|
699 |
+
sequence_start: bool = True,
|
700 |
+
sequence_end: bool = True,
|
701 |
+
stream: bool = True,
|
702 |
+
ignore_eos: bool = False,
|
703 |
+
skip_special_tokens: Optional[bool] = False,
|
704 |
+
timeout: int = 30,
|
705 |
+
**kwargs,
|
706 |
+
):
|
707 |
+
"""Start a new round conversation of a session. Return the chat
|
708 |
+
completions in stream mode.
|
709 |
+
|
710 |
+
Args:
|
711 |
+
inputs (List[dict]): user's inputs in this round conversation
|
712 |
+
session_id (int): session id
|
713 |
+
sequence_start (bool): start flag of a session
|
714 |
+
sequence_end (bool): end flag of a session
|
715 |
+
stream (bool): return in a streaming format if enabled
|
716 |
+
ignore_eos (bool): indicator for ignoring eos
|
717 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
718 |
+
in the decoding. Default to be False.
|
719 |
+
timeout (int): max time to wait for response
|
720 |
+
Returns:
|
721 |
+
tuple(Status, str, int): status, text/chat completion,
|
722 |
+
generated token number
|
723 |
+
"""
|
724 |
+
from lmdeploy.serve.openai.api_client import json_loads
|
725 |
+
|
726 |
+
gen_params = self.update_gen_params(**kwargs)
|
727 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
728 |
+
gen_params.update(max_tokens=max_new_tokens)
|
729 |
+
prompt = self.template_parser(inputs)
|
730 |
+
|
731 |
+
response = ''
|
732 |
+
finished = False
|
733 |
+
stop_words = self.gen_params.get('stop_words')
|
734 |
+
|
735 |
+
pload = dict(
|
736 |
+
model=self.model_name,
|
737 |
+
prompt=prompt,
|
738 |
+
sequence_start=sequence_start,
|
739 |
+
sequence_end=sequence_end,
|
740 |
+
stream=stream,
|
741 |
+
ignore_eos=ignore_eos,
|
742 |
+
skip_special_tokens=skip_special_tokens,
|
743 |
+
timeout=timeout,
|
744 |
+
**gen_params)
|
745 |
+
async with aiohttp.ClientSession(
|
746 |
+
timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
|
747 |
+
async with session.post(
|
748 |
+
self.client.completions_v1_url,
|
749 |
+
headers=self.client.headers,
|
750 |
+
json=pload) as resp:
|
751 |
+
async for chunk in resp.content:
|
752 |
+
if chunk:
|
753 |
+
decoded = chunk.decode('utf-8')
|
754 |
+
if not decoded.strip() or decoded.rstrip(
|
755 |
+
) == 'data: [DONE]':
|
756 |
+
continue
|
757 |
+
if decoded[:6] == 'data: ':
|
758 |
+
decoded = decoded[6:]
|
759 |
+
output = json_loads(decoded)
|
760 |
+
response += output['choices'][0]['text']
|
761 |
+
if not response:
|
762 |
+
continue
|
763 |
+
# remove stop_words
|
764 |
+
for sw in stop_words:
|
765 |
+
if sw in response:
|
766 |
+
response = filter_suffix(response, stop_words)
|
767 |
+
finished = True
|
768 |
+
break
|
769 |
+
yield ModelStatusCode.STREAM_ING, response, None
|
770 |
+
if finished:
|
771 |
+
break
|
772 |
+
yield ModelStatusCode.END, response, None
|
773 |
+
|
774 |
+
|
775 |
+
class AsyncLMDeployClient(AsyncLMDeployServer):
|
776 |
+
"""
|
777 |
+
|
778 |
+
Args:
|
779 |
+
url (str): communicating address 'http://<ip>:<port>' of
|
780 |
+
api_server
|
781 |
+
model_name (str): needed when model_path is a pytorch model on
|
782 |
+
huggingface.co, such as "internlm-chat-7b",
|
783 |
+
"Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
|
784 |
+
"""
|
785 |
+
|
786 |
+
def __init__(self, url: str, model_name: str, **kwargs):
|
787 |
+
BaseLLM.__init__(self, path=url, **kwargs)
|
788 |
+
from lmdeploy.serve.openai.api_client import APIClient
|
789 |
+
self.client = APIClient(url)
|
790 |
+
self.model_name = model_name
|
lagent/llms/meta_template.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
INTERNLM2_META = [
|
2 |
+
dict(
|
3 |
+
role='system',
|
4 |
+
begin=dict(
|
5 |
+
with_name='<|im_start|>system name={name}\n',
|
6 |
+
without_name='<|im_start|>system\n',
|
7 |
+
name={
|
8 |
+
'interpreter': '<|interpreter|>',
|
9 |
+
'plugin': '<|plugin|>',
|
10 |
+
}),
|
11 |
+
end='<|im_end|>\n',
|
12 |
+
),
|
13 |
+
dict(
|
14 |
+
role='user',
|
15 |
+
begin=dict(
|
16 |
+
with_name='<|im_start|>user name={name}\n',
|
17 |
+
without_name='<|im_start|>user\n',
|
18 |
+
),
|
19 |
+
end='<|im_end|>\n'),
|
20 |
+
dict(
|
21 |
+
role='assistant',
|
22 |
+
begin=dict(
|
23 |
+
with_name='<|im_start|>assistant name={name}\n',
|
24 |
+
without_name='<|im_start|>assistant\n',
|
25 |
+
name={
|
26 |
+
'interpreter': '<|interpreter|>',
|
27 |
+
'plugin': '<|plugin|>',
|
28 |
+
}),
|
29 |
+
end='<|im_end|>\n'),
|
30 |
+
dict(
|
31 |
+
role='environment',
|
32 |
+
begin=dict(
|
33 |
+
with_name='<|im_start|>environment name={name}\n',
|
34 |
+
without_name='<|im_start|>environment\n',
|
35 |
+
name={
|
36 |
+
'interpreter': '<|interpreter|>',
|
37 |
+
'plugin': '<|plugin|>',
|
38 |
+
}),
|
39 |
+
end='<|im_end|>\n'),
|
40 |
+
]
|
lagent/llms/openai.py
ADDED
@@ -0,0 +1,924 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
import warnings
|
7 |
+
from concurrent.futures import ThreadPoolExecutor
|
8 |
+
from logging import getLogger
|
9 |
+
from threading import Lock
|
10 |
+
from typing import AsyncGenerator, Dict, List, Optional, Union
|
11 |
+
|
12 |
+
import aiohttp
|
13 |
+
import requests
|
14 |
+
|
15 |
+
from ..schema import ModelStatusCode
|
16 |
+
from ..utils import filter_suffix
|
17 |
+
from .base_api import AsyncBaseAPILLM, BaseAPILLM
|
18 |
+
|
19 |
+
warnings.simplefilter('default')
|
20 |
+
|
21 |
+
OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
|
22 |
+
|
23 |
+
|
24 |
+
class GPTAPI(BaseAPILLM):
|
25 |
+
"""Model wrapper around OpenAI's models.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
model_type (str): The name of OpenAI's model.
|
29 |
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
30 |
+
key (str or List[str]): OpenAI key(s). In particular, when it
|
31 |
+
is set to "ENV", the key will be fetched from the environment
|
32 |
+
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
|
33 |
+
list, the keys will be used in round-robin manner. Defaults to
|
34 |
+
'ENV'.
|
35 |
+
org (str or List[str], optional): OpenAI organization(s). If not
|
36 |
+
specified, OpenAI uses the default organization bound to each API
|
37 |
+
key. If specified, the orgs will be posted with each request in
|
38 |
+
round-robin manner. Defaults to None.
|
39 |
+
meta_template (Dict, optional): The model's meta prompt
|
40 |
+
template if needed, in case the requirement of injecting or
|
41 |
+
wrapping of any meta instructions.
|
42 |
+
api_base (str): The base url of OpenAI's API. Defaults to
|
43 |
+
'https://api.openai.com/v1/chat/completions'.
|
44 |
+
gen_params: Default generation configuration which could be overridden
|
45 |
+
on the fly of generation.
|
46 |
+
"""
|
47 |
+
|
48 |
+
is_api: bool = True
|
49 |
+
|
50 |
+
def __init__(self,
|
51 |
+
model_type: str = 'gpt-3.5-turbo',
|
52 |
+
retry: int = 2,
|
53 |
+
json_mode: bool = False,
|
54 |
+
key: Union[str, List[str]] = 'ENV',
|
55 |
+
org: Optional[Union[str, List[str]]] = None,
|
56 |
+
meta_template: Optional[Dict] = [
|
57 |
+
dict(role='system', api_role='system'),
|
58 |
+
dict(role='user', api_role='user'),
|
59 |
+
dict(role='assistant', api_role='assistant'),
|
60 |
+
dict(role='environment', api_role='system')
|
61 |
+
],
|
62 |
+
api_base: str = OPENAI_API_BASE,
|
63 |
+
proxies: Optional[Dict] = None,
|
64 |
+
**gen_params):
|
65 |
+
if 'top_k' in gen_params:
|
66 |
+
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
|
67 |
+
DeprecationWarning)
|
68 |
+
gen_params.pop('top_k')
|
69 |
+
super().__init__(
|
70 |
+
model_type=model_type,
|
71 |
+
meta_template=meta_template,
|
72 |
+
retry=retry,
|
73 |
+
**gen_params)
|
74 |
+
self.gen_params.pop('top_k')
|
75 |
+
self.logger = getLogger(__name__)
|
76 |
+
|
77 |
+
if isinstance(key, str):
|
78 |
+
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
79 |
+
else:
|
80 |
+
self.keys = key
|
81 |
+
|
82 |
+
# record invalid keys and skip them when requesting API
|
83 |
+
# - keys have insufficient_quota
|
84 |
+
self.invalid_keys = set()
|
85 |
+
|
86 |
+
self.key_ctr = 0
|
87 |
+
if isinstance(org, str):
|
88 |
+
self.orgs = [org]
|
89 |
+
else:
|
90 |
+
self.orgs = org
|
91 |
+
self.org_ctr = 0
|
92 |
+
self.url = api_base
|
93 |
+
self.model_type = model_type
|
94 |
+
self.proxies = proxies
|
95 |
+
self.json_mode = json_mode
|
96 |
+
|
97 |
+
def chat(
|
98 |
+
self,
|
99 |
+
inputs: Union[List[dict], List[List[dict]]],
|
100 |
+
**gen_params,
|
101 |
+
) -> Union[str, List[str]]:
|
102 |
+
"""Generate responses given the contexts.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
inputs (Union[List[dict], List[List[dict]]]): a list of messages
|
106 |
+
or list of lists of messages
|
107 |
+
gen_params: additional generation configuration
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Union[str, List[str]]: generated string(s)
|
111 |
+
"""
|
112 |
+
assert isinstance(inputs, list)
|
113 |
+
if 'max_tokens' in gen_params:
|
114 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
115 |
+
gen_params = {**self.gen_params, **gen_params}
|
116 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
117 |
+
tasks = [
|
118 |
+
executor.submit(self._chat,
|
119 |
+
self.template_parser._prompt2api(messages),
|
120 |
+
**gen_params)
|
121 |
+
for messages in (
|
122 |
+
[inputs] if isinstance(inputs[0], dict) else inputs)
|
123 |
+
]
|
124 |
+
ret = [task.result() for task in tasks]
|
125 |
+
return ret[0] if isinstance(inputs[0], dict) else ret
|
126 |
+
|
127 |
+
def stream_chat(
|
128 |
+
self,
|
129 |
+
inputs: List[dict],
|
130 |
+
**gen_params,
|
131 |
+
):
|
132 |
+
"""Generate responses given the contexts.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
inputs (List[dict]): a list of messages
|
136 |
+
gen_params: additional generation configuration
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
str: generated string
|
140 |
+
"""
|
141 |
+
assert isinstance(inputs, list)
|
142 |
+
if 'max_tokens' in gen_params:
|
143 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
144 |
+
gen_params = self.update_gen_params(**gen_params)
|
145 |
+
gen_params['stream'] = True
|
146 |
+
|
147 |
+
resp = ''
|
148 |
+
finished = False
|
149 |
+
stop_words = gen_params.get('stop_words')
|
150 |
+
if stop_words is None:
|
151 |
+
stop_words = []
|
152 |
+
# mapping to role that openai supports
|
153 |
+
messages = self.template_parser._prompt2api(inputs)
|
154 |
+
for text in self._stream_chat(messages, **gen_params):
|
155 |
+
if self.model_type.lower().startswith('qwen'):
|
156 |
+
resp = text
|
157 |
+
else:
|
158 |
+
resp += text
|
159 |
+
if not resp:
|
160 |
+
continue
|
161 |
+
# remove stop_words
|
162 |
+
for sw in stop_words:
|
163 |
+
if sw in resp:
|
164 |
+
resp = filter_suffix(resp, stop_words)
|
165 |
+
finished = True
|
166 |
+
break
|
167 |
+
yield ModelStatusCode.STREAM_ING, resp, None
|
168 |
+
if finished:
|
169 |
+
break
|
170 |
+
yield ModelStatusCode.END, resp, None
|
171 |
+
|
172 |
+
def _chat(self, messages: List[dict], **gen_params) -> str:
|
173 |
+
"""Generate completion from a list of templates.
|
174 |
+
|
175 |
+
Args:
|
176 |
+
messages (List[dict]): a list of prompt dictionaries
|
177 |
+
gen_params: additional generation configuration
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
str: The generated string.
|
181 |
+
"""
|
182 |
+
assert isinstance(messages, list)
|
183 |
+
|
184 |
+
header, data = self.generate_request_data(
|
185 |
+
model_type=self.model_type,
|
186 |
+
messages=messages,
|
187 |
+
gen_params=gen_params,
|
188 |
+
json_mode=self.json_mode)
|
189 |
+
|
190 |
+
max_num_retries, errmsg = 0, ''
|
191 |
+
while max_num_retries < self.retry:
|
192 |
+
with Lock():
|
193 |
+
if len(self.invalid_keys) == len(self.keys):
|
194 |
+
raise RuntimeError('All keys have insufficient quota.')
|
195 |
+
|
196 |
+
# find the next valid key
|
197 |
+
while True:
|
198 |
+
self.key_ctr += 1
|
199 |
+
if self.key_ctr == len(self.keys):
|
200 |
+
self.key_ctr = 0
|
201 |
+
|
202 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
203 |
+
break
|
204 |
+
|
205 |
+
key = self.keys[self.key_ctr]
|
206 |
+
header['Authorization'] = f'Bearer {key}'
|
207 |
+
|
208 |
+
if self.orgs:
|
209 |
+
with Lock():
|
210 |
+
self.org_ctr += 1
|
211 |
+
if self.org_ctr == len(self.orgs):
|
212 |
+
self.org_ctr = 0
|
213 |
+
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
214 |
+
|
215 |
+
response = dict()
|
216 |
+
try:
|
217 |
+
raw_response = requests.post(
|
218 |
+
self.url,
|
219 |
+
headers=header,
|
220 |
+
data=json.dumps(data),
|
221 |
+
proxies=self.proxies)
|
222 |
+
response = raw_response.json()
|
223 |
+
return response['choices'][0]['message']['content'].strip()
|
224 |
+
except requests.ConnectionError:
|
225 |
+
errmsg = 'Got connection error ' + str(traceback.format_exc())
|
226 |
+
self.logger.error(errmsg)
|
227 |
+
continue
|
228 |
+
except requests.JSONDecodeError:
|
229 |
+
errmsg = 'JsonDecode error, got ' + str(raw_response.content)
|
230 |
+
self.logger.error(errmsg)
|
231 |
+
continue
|
232 |
+
except KeyError:
|
233 |
+
if 'error' in response:
|
234 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
235 |
+
time.sleep(1)
|
236 |
+
continue
|
237 |
+
elif response['error']['code'] == 'insufficient_quota':
|
238 |
+
self.invalid_keys.add(key)
|
239 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
240 |
+
continue
|
241 |
+
|
242 |
+
errmsg = 'Find error message in response: ' + str(
|
243 |
+
response['error'])
|
244 |
+
self.logger.error(errmsg)
|
245 |
+
except Exception as error:
|
246 |
+
errmsg = str(error) + '\n' + str(traceback.format_exc())
|
247 |
+
self.logger.error(errmsg)
|
248 |
+
max_num_retries += 1
|
249 |
+
|
250 |
+
raise RuntimeError('Calling OpenAI failed after retrying for '
|
251 |
+
f'{max_num_retries} times. Check the logs for '
|
252 |
+
f'details. errmsg: {errmsg}')
|
253 |
+
|
254 |
+
def _stream_chat(self, messages: List[dict], **gen_params) -> str:
|
255 |
+
"""Generate completion from a list of templates.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
messages (List[dict]): a list of prompt dictionaries
|
259 |
+
gen_params: additional generation configuration
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
str: The generated string.
|
263 |
+
"""
|
264 |
+
|
265 |
+
def streaming(raw_response):
|
266 |
+
for chunk in raw_response.iter_lines(
|
267 |
+
chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
|
268 |
+
if chunk:
|
269 |
+
decoded = chunk.decode('utf-8')
|
270 |
+
if decoded.startswith('data: [DONE]'):
|
271 |
+
return
|
272 |
+
if decoded[:5] == 'data:':
|
273 |
+
decoded = decoded[5:]
|
274 |
+
if decoded[0] == ' ':
|
275 |
+
decoded = decoded[1:]
|
276 |
+
else:
|
277 |
+
print(decoded)
|
278 |
+
continue
|
279 |
+
try:
|
280 |
+
response = json.loads(decoded)
|
281 |
+
if 'code' in response and response['code'] == -20003:
|
282 |
+
# Context exceeds maximum length
|
283 |
+
yield ''
|
284 |
+
return
|
285 |
+
if self.model_type.lower().startswith('qwen'):
|
286 |
+
choice = response['output']['choices'][0]
|
287 |
+
yield choice['message']['content']
|
288 |
+
if choice['finish_reason'] == 'stop':
|
289 |
+
return
|
290 |
+
else:
|
291 |
+
choice = response['choices'][0]
|
292 |
+
if choice['finish_reason'] == 'stop':
|
293 |
+
return
|
294 |
+
yield choice['delta'].get('content', '')
|
295 |
+
except Exception as exc:
|
296 |
+
msg = f'response {decoded} lead to exception of {str(exc)}'
|
297 |
+
self.logger.error(msg)
|
298 |
+
raise Exception(msg) from exc
|
299 |
+
|
300 |
+
assert isinstance(messages, list)
|
301 |
+
|
302 |
+
header, data = self.generate_request_data(
|
303 |
+
model_type=self.model_type,
|
304 |
+
messages=messages,
|
305 |
+
gen_params=gen_params,
|
306 |
+
json_mode=self.json_mode)
|
307 |
+
|
308 |
+
max_num_retries, errmsg = 0, ''
|
309 |
+
while max_num_retries < self.retry:
|
310 |
+
if len(self.invalid_keys) == len(self.keys):
|
311 |
+
raise RuntimeError('All keys have insufficient quota.')
|
312 |
+
|
313 |
+
# find the next valid key
|
314 |
+
while True:
|
315 |
+
self.key_ctr += 1
|
316 |
+
if self.key_ctr == len(self.keys):
|
317 |
+
self.key_ctr = 0
|
318 |
+
|
319 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
320 |
+
break
|
321 |
+
|
322 |
+
key = self.keys[self.key_ctr]
|
323 |
+
header['Authorization'] = f'Bearer {key}'
|
324 |
+
|
325 |
+
if self.orgs:
|
326 |
+
self.org_ctr += 1
|
327 |
+
if self.org_ctr == len(self.orgs):
|
328 |
+
self.org_ctr = 0
|
329 |
+
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
330 |
+
|
331 |
+
response = dict()
|
332 |
+
try:
|
333 |
+
raw_response = requests.post(
|
334 |
+
self.url,
|
335 |
+
headers=header,
|
336 |
+
data=json.dumps(data),
|
337 |
+
proxies=self.proxies)
|
338 |
+
return streaming(raw_response)
|
339 |
+
except requests.ConnectionError:
|
340 |
+
errmsg = 'Got connection error ' + str(traceback.format_exc())
|
341 |
+
self.logger.error(errmsg)
|
342 |
+
continue
|
343 |
+
except requests.JSONDecodeError:
|
344 |
+
errmsg = 'JsonDecode error, got ' + str(raw_response.content)
|
345 |
+
self.logger.error(errmsg)
|
346 |
+
continue
|
347 |
+
except KeyError:
|
348 |
+
if 'error' in response:
|
349 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
350 |
+
time.sleep(1)
|
351 |
+
continue
|
352 |
+
elif response['error']['code'] == 'insufficient_quota':
|
353 |
+
self.invalid_keys.add(key)
|
354 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
355 |
+
continue
|
356 |
+
|
357 |
+
errmsg = 'Find error message in response: ' + str(
|
358 |
+
response['error'])
|
359 |
+
self.logger.error(errmsg)
|
360 |
+
except Exception as error:
|
361 |
+
errmsg = str(error) + '\n' + str(traceback.format_exc())
|
362 |
+
self.logger.error(errmsg)
|
363 |
+
max_num_retries += 1
|
364 |
+
|
365 |
+
raise RuntimeError('Calling OpenAI failed after retrying for '
|
366 |
+
f'{max_num_retries} times. Check the logs for '
|
367 |
+
f'details. errmsg: {errmsg}')
|
368 |
+
|
369 |
+
def generate_request_data(self,
|
370 |
+
model_type,
|
371 |
+
messages,
|
372 |
+
gen_params,
|
373 |
+
json_mode=False):
|
374 |
+
"""
|
375 |
+
Generates the request data for different model types.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
|
379 |
+
messages (list): The list of messages to be sent to the model.
|
380 |
+
gen_params (dict): The generation parameters.
|
381 |
+
json_mode (bool): Flag to determine if the response format should be JSON.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
tuple: A tuple containing the header and the request data.
|
385 |
+
"""
|
386 |
+
# Copy generation parameters to avoid modifying the original dictionary
|
387 |
+
gen_params = gen_params.copy()
|
388 |
+
|
389 |
+
# Hold out 100 tokens due to potential errors in token calculation
|
390 |
+
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
|
391 |
+
if max_tokens <= 0:
|
392 |
+
return '', ''
|
393 |
+
|
394 |
+
# Initialize the header
|
395 |
+
header = {
|
396 |
+
'content-type': 'application/json',
|
397 |
+
}
|
398 |
+
|
399 |
+
# Common parameters processing
|
400 |
+
gen_params['max_tokens'] = max_tokens
|
401 |
+
if 'stop_words' in gen_params:
|
402 |
+
gen_params['stop'] = gen_params.pop('stop_words')
|
403 |
+
if 'repetition_penalty' in gen_params:
|
404 |
+
gen_params['frequency_penalty'] = gen_params.pop(
|
405 |
+
'repetition_penalty')
|
406 |
+
|
407 |
+
# Model-specific processing
|
408 |
+
data = {}
|
409 |
+
if model_type.lower().startswith('gpt'):
|
410 |
+
if 'top_k' in gen_params:
|
411 |
+
warnings.warn(
|
412 |
+
'`top_k` parameter is deprecated in OpenAI APIs.',
|
413 |
+
DeprecationWarning)
|
414 |
+
gen_params.pop('top_k')
|
415 |
+
gen_params.pop('skip_special_tokens', None)
|
416 |
+
gen_params.pop('session_id', None)
|
417 |
+
data = {
|
418 |
+
'model': model_type,
|
419 |
+
'messages': messages,
|
420 |
+
'n': 1,
|
421 |
+
**gen_params
|
422 |
+
}
|
423 |
+
if json_mode:
|
424 |
+
data['response_format'] = {'type': 'json_object'}
|
425 |
+
elif model_type.lower().startswith('internlm'):
|
426 |
+
data = {
|
427 |
+
'model': model_type,
|
428 |
+
'messages': messages,
|
429 |
+
'n': 1,
|
430 |
+
**gen_params
|
431 |
+
}
|
432 |
+
if json_mode:
|
433 |
+
data['response_format'] = {'type': 'json_object'}
|
434 |
+
elif model_type.lower().startswith('qwen'):
|
435 |
+
header['X-DashScope-SSE'] = 'enable'
|
436 |
+
gen_params.pop('skip_special_tokens', None)
|
437 |
+
gen_params.pop('session_id', None)
|
438 |
+
if 'frequency_penalty' in gen_params:
|
439 |
+
gen_params['repetition_penalty'] = gen_params.pop(
|
440 |
+
'frequency_penalty')
|
441 |
+
gen_params['result_format'] = 'message'
|
442 |
+
data = {
|
443 |
+
'model': model_type,
|
444 |
+
'input': {
|
445 |
+
'messages': messages
|
446 |
+
},
|
447 |
+
'parameters': {
|
448 |
+
**gen_params
|
449 |
+
}
|
450 |
+
}
|
451 |
+
else:
|
452 |
+
raise NotImplementedError(
|
453 |
+
f'Model type {model_type} is not supported')
|
454 |
+
|
455 |
+
return header, data
|
456 |
+
|
457 |
+
def tokenize(self, prompt: str) -> list:
|
458 |
+
"""Tokenize the input prompt.
|
459 |
+
|
460 |
+
Args:
|
461 |
+
prompt (str): Input string.
|
462 |
+
|
463 |
+
Returns:
|
464 |
+
list: token ids
|
465 |
+
"""
|
466 |
+
import tiktoken
|
467 |
+
self.tiktoken = tiktoken
|
468 |
+
enc = self.tiktoken.encoding_for_model(self.model_type)
|
469 |
+
return enc.encode(prompt)
|
470 |
+
|
471 |
+
|
472 |
+
class AsyncGPTAPI(AsyncBaseAPILLM):
|
473 |
+
"""Model wrapper around OpenAI's models.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
model_type (str): The name of OpenAI's model.
|
477 |
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
478 |
+
key (str or List[str]): OpenAI key(s). In particular, when it
|
479 |
+
is set to "ENV", the key will be fetched from the environment
|
480 |
+
variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
|
481 |
+
list, the keys will be used in round-robin manner. Defaults to
|
482 |
+
'ENV'.
|
483 |
+
org (str or List[str], optional): OpenAI organization(s). If not
|
484 |
+
specified, OpenAI uses the default organization bound to each API
|
485 |
+
key. If specified, the orgs will be posted with each request in
|
486 |
+
round-robin manner. Defaults to None.
|
487 |
+
meta_template (Dict, optional): The model's meta prompt
|
488 |
+
template if needed, in case the requirement of injecting or
|
489 |
+
wrapping of any meta instructions.
|
490 |
+
api_base (str): The base url of OpenAI's API. Defaults to
|
491 |
+
'https://api.openai.com/v1/chat/completions'.
|
492 |
+
gen_params: Default generation configuration which could be overridden
|
493 |
+
on the fly of generation.
|
494 |
+
"""
|
495 |
+
|
496 |
+
is_api: bool = True
|
497 |
+
|
498 |
+
def __init__(self,
|
499 |
+
model_type: str = 'gpt-3.5-turbo',
|
500 |
+
retry: int = 2,
|
501 |
+
json_mode: bool = False,
|
502 |
+
key: Union[str, List[str]] = 'ENV',
|
503 |
+
org: Optional[Union[str, List[str]]] = None,
|
504 |
+
meta_template: Optional[Dict] = [
|
505 |
+
dict(role='system', api_role='system'),
|
506 |
+
dict(role='user', api_role='user'),
|
507 |
+
dict(role='assistant', api_role='assistant')
|
508 |
+
],
|
509 |
+
api_base: str = OPENAI_API_BASE,
|
510 |
+
proxies: Optional[Dict] = None,
|
511 |
+
**gen_params):
|
512 |
+
if 'top_k' in gen_params:
|
513 |
+
warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
|
514 |
+
DeprecationWarning)
|
515 |
+
gen_params.pop('top_k')
|
516 |
+
super().__init__(
|
517 |
+
model_type=model_type,
|
518 |
+
meta_template=meta_template,
|
519 |
+
retry=retry,
|
520 |
+
**gen_params)
|
521 |
+
self.gen_params.pop('top_k')
|
522 |
+
self.logger = getLogger(__name__)
|
523 |
+
|
524 |
+
if isinstance(key, str):
|
525 |
+
self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
|
526 |
+
else:
|
527 |
+
self.keys = key
|
528 |
+
|
529 |
+
# record invalid keys and skip them when requesting API
|
530 |
+
# - keys have insufficient_quota
|
531 |
+
self.invalid_keys = set()
|
532 |
+
|
533 |
+
self.key_ctr = 0
|
534 |
+
if isinstance(org, str):
|
535 |
+
self.orgs = [org]
|
536 |
+
else:
|
537 |
+
self.orgs = org
|
538 |
+
self.org_ctr = 0
|
539 |
+
self.url = api_base
|
540 |
+
self.model_type = model_type
|
541 |
+
self.proxies = proxies or {}
|
542 |
+
self.json_mode = json_mode
|
543 |
+
|
544 |
+
async def chat(
|
545 |
+
self,
|
546 |
+
inputs: Union[List[dict], List[List[dict]]],
|
547 |
+
session_ids: Union[int, List[int]] = None,
|
548 |
+
**gen_params,
|
549 |
+
) -> Union[str, List[str]]:
|
550 |
+
"""Generate responses given the contexts.
|
551 |
+
|
552 |
+
Args:
|
553 |
+
inputs (Union[List[dict], List[List[dict]]]): a list of messages
|
554 |
+
or list of lists of messages
|
555 |
+
gen_params: additional generation configuration
|
556 |
+
|
557 |
+
Returns:
|
558 |
+
Union[str, List[str]]: generated string(s)
|
559 |
+
"""
|
560 |
+
assert isinstance(inputs, list)
|
561 |
+
if 'max_tokens' in gen_params:
|
562 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
563 |
+
gen_params = {**self.gen_params, **gen_params}
|
564 |
+
tasks = [
|
565 |
+
self._chat(messages, **gen_params) for messages in (
|
566 |
+
[inputs] if isinstance(inputs[0], dict) else inputs)
|
567 |
+
]
|
568 |
+
ret = await asyncio.gather(*tasks)
|
569 |
+
return ret[0] if isinstance(inputs[0], dict) else ret
|
570 |
+
|
571 |
+
async def stream_chat(
|
572 |
+
self,
|
573 |
+
inputs: List[dict],
|
574 |
+
**gen_params,
|
575 |
+
):
|
576 |
+
"""Generate responses given the contexts.
|
577 |
+
|
578 |
+
Args:
|
579 |
+
inputs (List[dict]): a list of messages
|
580 |
+
gen_params: additional generation configuration
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
str: generated string
|
584 |
+
"""
|
585 |
+
assert isinstance(inputs, list)
|
586 |
+
if 'max_tokens' in gen_params:
|
587 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
588 |
+
gen_params = self.update_gen_params(**gen_params)
|
589 |
+
gen_params['stream'] = True
|
590 |
+
|
591 |
+
resp = ''
|
592 |
+
finished = False
|
593 |
+
stop_words = gen_params.get('stop_words')
|
594 |
+
if stop_words is None:
|
595 |
+
stop_words = []
|
596 |
+
# mapping to role that openai supports
|
597 |
+
messages = self.template_parser._prompt2api(inputs)
|
598 |
+
async for text in self._stream_chat(messages, **gen_params):
|
599 |
+
if self.model_type.lower().startswith('qwen'):
|
600 |
+
resp = text
|
601 |
+
else:
|
602 |
+
resp += text
|
603 |
+
if not resp:
|
604 |
+
continue
|
605 |
+
# remove stop_words
|
606 |
+
for sw in stop_words:
|
607 |
+
if sw in resp:
|
608 |
+
resp = filter_suffix(resp, stop_words)
|
609 |
+
finished = True
|
610 |
+
break
|
611 |
+
yield ModelStatusCode.STREAM_ING, resp, None
|
612 |
+
if finished:
|
613 |
+
break
|
614 |
+
yield ModelStatusCode.END, resp, None
|
615 |
+
|
616 |
+
async def _chat(self, messages: List[dict], **gen_params) -> str:
|
617 |
+
"""Generate completion from a list of templates.
|
618 |
+
|
619 |
+
Args:
|
620 |
+
messages (List[dict]): a list of prompt dictionaries
|
621 |
+
gen_params: additional generation configuration
|
622 |
+
|
623 |
+
Returns:
|
624 |
+
str: The generated string.
|
625 |
+
"""
|
626 |
+
assert isinstance(messages, list)
|
627 |
+
|
628 |
+
header, data = self.generate_request_data(
|
629 |
+
model_type=self.model_type,
|
630 |
+
messages=messages,
|
631 |
+
gen_params=gen_params,
|
632 |
+
json_mode=self.json_mode)
|
633 |
+
|
634 |
+
max_num_retries, errmsg = 0, ''
|
635 |
+
while max_num_retries < self.retry:
|
636 |
+
if len(self.invalid_keys) == len(self.keys):
|
637 |
+
raise RuntimeError('All keys have insufficient quota.')
|
638 |
+
|
639 |
+
# find the next valid key
|
640 |
+
while True:
|
641 |
+
self.key_ctr += 1
|
642 |
+
if self.key_ctr == len(self.keys):
|
643 |
+
self.key_ctr = 0
|
644 |
+
|
645 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
646 |
+
break
|
647 |
+
|
648 |
+
key = self.keys[self.key_ctr]
|
649 |
+
header['Authorization'] = f'Bearer {key}'
|
650 |
+
|
651 |
+
if self.orgs:
|
652 |
+
self.org_ctr += 1
|
653 |
+
if self.org_ctr == len(self.orgs):
|
654 |
+
self.org_ctr = 0
|
655 |
+
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
656 |
+
|
657 |
+
response = dict()
|
658 |
+
try:
|
659 |
+
async with aiohttp.ClientSession() as session:
|
660 |
+
async with session.post(
|
661 |
+
self.url,
|
662 |
+
headers=header,
|
663 |
+
json=data,
|
664 |
+
proxy=self.proxies.get(
|
665 |
+
'https', self.proxies.get('http'))) as resp:
|
666 |
+
response = await resp.json()
|
667 |
+
return response['choices'][0]['message'][
|
668 |
+
'content'].strip()
|
669 |
+
except aiohttp.ClientConnectionError:
|
670 |
+
errmsg = 'Got connection error ' + str(traceback.format_exc())
|
671 |
+
self.logger.error(errmsg)
|
672 |
+
continue
|
673 |
+
except aiohttp.ClientResponseError as e:
|
674 |
+
errmsg = 'Response error, got ' + str(e)
|
675 |
+
self.logger.error(errmsg)
|
676 |
+
continue
|
677 |
+
except json.JSONDecodeError:
|
678 |
+
errmsg = 'JsonDecode error, got ' + (await resp.text(
|
679 |
+
errors='replace'))
|
680 |
+
self.logger.error(errmsg)
|
681 |
+
continue
|
682 |
+
except KeyError:
|
683 |
+
if 'error' in response:
|
684 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
685 |
+
time.sleep(1)
|
686 |
+
continue
|
687 |
+
elif response['error']['code'] == 'insufficient_quota':
|
688 |
+
self.invalid_keys.add(key)
|
689 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
690 |
+
continue
|
691 |
+
|
692 |
+
errmsg = 'Find error message in response: ' + str(
|
693 |
+
response['error'])
|
694 |
+
self.logger.error(errmsg)
|
695 |
+
except Exception as error:
|
696 |
+
errmsg = str(error) + '\n' + str(traceback.format_exc())
|
697 |
+
self.logger.error(errmsg)
|
698 |
+
max_num_retries += 1
|
699 |
+
|
700 |
+
raise RuntimeError('Calling OpenAI failed after retrying for '
|
701 |
+
f'{max_num_retries} times. Check the logs for '
|
702 |
+
f'details. errmsg: {errmsg}')
|
703 |
+
|
704 |
+
async def _stream_chat(self, messages: List[dict],
|
705 |
+
**gen_params) -> AsyncGenerator[str, None]:
|
706 |
+
"""Generate completion from a list of templates.
|
707 |
+
|
708 |
+
Args:
|
709 |
+
messages (List[dict]): a list of prompt dictionaries
|
710 |
+
gen_params: additional generation configuration
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
str: The generated string.
|
714 |
+
"""
|
715 |
+
|
716 |
+
async def streaming(raw_response):
|
717 |
+
async for chunk in raw_response.content:
|
718 |
+
if chunk:
|
719 |
+
decoded = chunk.decode('utf-8')
|
720 |
+
if decoded.startswith('data: [DONE]'):
|
721 |
+
return
|
722 |
+
if decoded[:5] == 'data:':
|
723 |
+
decoded = decoded[5:]
|
724 |
+
if decoded[0] == ' ':
|
725 |
+
decoded = decoded[1:]
|
726 |
+
else:
|
727 |
+
print(decoded)
|
728 |
+
continue
|
729 |
+
try:
|
730 |
+
response = json.loads(decoded)
|
731 |
+
if 'code' in response and response['code'] == -20003:
|
732 |
+
# Context exceeds maximum length
|
733 |
+
yield ''
|
734 |
+
return
|
735 |
+
if self.model_type.lower().startswith('qwen'):
|
736 |
+
choice = response['output']['choices'][0]
|
737 |
+
yield choice['message']['content']
|
738 |
+
if choice['finish_reason'] == 'stop':
|
739 |
+
return
|
740 |
+
else:
|
741 |
+
choice = response['choices'][0]
|
742 |
+
if choice['finish_reason'] == 'stop':
|
743 |
+
return
|
744 |
+
yield choice['delta'].get('content', '')
|
745 |
+
except Exception as exc:
|
746 |
+
msg = f'response {decoded} lead to exception of {str(exc)}'
|
747 |
+
self.logger.error(msg)
|
748 |
+
raise Exception(msg) from exc
|
749 |
+
|
750 |
+
assert isinstance(messages, list)
|
751 |
+
|
752 |
+
header, data = self.generate_request_data(
|
753 |
+
model_type=self.model_type,
|
754 |
+
messages=messages,
|
755 |
+
gen_params=gen_params,
|
756 |
+
json_mode=self.json_mode)
|
757 |
+
|
758 |
+
max_num_retries, errmsg = 0, ''
|
759 |
+
while max_num_retries < self.retry:
|
760 |
+
if len(self.invalid_keys) == len(self.keys):
|
761 |
+
raise RuntimeError('All keys have insufficient quota.')
|
762 |
+
|
763 |
+
# find the next valid key
|
764 |
+
while True:
|
765 |
+
self.key_ctr += 1
|
766 |
+
if self.key_ctr == len(self.keys):
|
767 |
+
self.key_ctr = 0
|
768 |
+
|
769 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
770 |
+
break
|
771 |
+
|
772 |
+
key = self.keys[self.key_ctr]
|
773 |
+
header['Authorization'] = f'Bearer {key}'
|
774 |
+
|
775 |
+
if self.orgs:
|
776 |
+
self.org_ctr += 1
|
777 |
+
if self.org_ctr == len(self.orgs):
|
778 |
+
self.org_ctr = 0
|
779 |
+
header['OpenAI-Organization'] = self.orgs[self.org_ctr]
|
780 |
+
|
781 |
+
response = dict()
|
782 |
+
try:
|
783 |
+
async with aiohttp.ClientSession() as session:
|
784 |
+
async with session.post(
|
785 |
+
self.url,
|
786 |
+
headers=header,
|
787 |
+
json=data,
|
788 |
+
proxy=self.proxies.get(
|
789 |
+
'https',
|
790 |
+
self.proxies.get('http'))) as raw_response:
|
791 |
+
async for msg in streaming(raw_response):
|
792 |
+
yield msg
|
793 |
+
return
|
794 |
+
except aiohttp.ClientConnectionError:
|
795 |
+
errmsg = 'Got connection error ' + str(traceback.format_exc())
|
796 |
+
self.logger.error(errmsg)
|
797 |
+
continue
|
798 |
+
except aiohttp.ClientResponseError as e:
|
799 |
+
errmsg = 'Response error, got ' + str(e)
|
800 |
+
self.logger.error(errmsg)
|
801 |
+
continue
|
802 |
+
except KeyError:
|
803 |
+
if 'error' in response:
|
804 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
805 |
+
time.sleep(1)
|
806 |
+
continue
|
807 |
+
elif response['error']['code'] == 'insufficient_quota':
|
808 |
+
self.invalid_keys.add(key)
|
809 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
810 |
+
continue
|
811 |
+
|
812 |
+
errmsg = 'Find error message in response: ' + str(
|
813 |
+
response['error'])
|
814 |
+
self.logger.error(errmsg)
|
815 |
+
except Exception as error:
|
816 |
+
errmsg = str(error) + '\n' + str(traceback.format_exc())
|
817 |
+
self.logger.error(errmsg)
|
818 |
+
max_num_retries += 1
|
819 |
+
|
820 |
+
raise RuntimeError('Calling OpenAI failed after retrying for '
|
821 |
+
f'{max_num_retries} times. Check the logs for '
|
822 |
+
f'details. errmsg: {errmsg}')
|
823 |
+
|
824 |
+
def generate_request_data(self,
|
825 |
+
model_type,
|
826 |
+
messages,
|
827 |
+
gen_params,
|
828 |
+
json_mode=False):
|
829 |
+
"""
|
830 |
+
Generates the request data for different model types.
|
831 |
+
|
832 |
+
Args:
|
833 |
+
model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
|
834 |
+
messages (list): The list of messages to be sent to the model.
|
835 |
+
gen_params (dict): The generation parameters.
|
836 |
+
json_mode (bool): Flag to determine if the response format should be JSON.
|
837 |
+
|
838 |
+
Returns:
|
839 |
+
tuple: A tuple containing the header and the request data.
|
840 |
+
"""
|
841 |
+
# Copy generation parameters to avoid modifying the original dictionary
|
842 |
+
gen_params = gen_params.copy()
|
843 |
+
|
844 |
+
# Hold out 100 tokens due to potential errors in token calculation
|
845 |
+
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
|
846 |
+
if max_tokens <= 0:
|
847 |
+
return '', ''
|
848 |
+
|
849 |
+
# Initialize the header
|
850 |
+
header = {
|
851 |
+
'content-type': 'application/json',
|
852 |
+
}
|
853 |
+
|
854 |
+
# Common parameters processing
|
855 |
+
gen_params['max_tokens'] = max_tokens
|
856 |
+
if 'stop_words' in gen_params:
|
857 |
+
gen_params['stop'] = gen_params.pop('stop_words')
|
858 |
+
if 'repetition_penalty' in gen_params:
|
859 |
+
gen_params['frequency_penalty'] = gen_params.pop(
|
860 |
+
'repetition_penalty')
|
861 |
+
|
862 |
+
# Model-specific processing
|
863 |
+
data = {}
|
864 |
+
if model_type.lower().startswith('gpt'):
|
865 |
+
if 'top_k' in gen_params:
|
866 |
+
warnings.warn(
|
867 |
+
'`top_k` parameter is deprecated in OpenAI APIs.',
|
868 |
+
DeprecationWarning)
|
869 |
+
gen_params.pop('top_k')
|
870 |
+
gen_params.pop('skip_special_tokens', None)
|
871 |
+
gen_params.pop('session_id', None)
|
872 |
+
data = {
|
873 |
+
'model': model_type,
|
874 |
+
'messages': messages,
|
875 |
+
'n': 1,
|
876 |
+
**gen_params
|
877 |
+
}
|
878 |
+
if json_mode:
|
879 |
+
data['response_format'] = {'type': 'json_object'}
|
880 |
+
elif model_type.lower().startswith('internlm'):
|
881 |
+
data = {
|
882 |
+
'model': model_type,
|
883 |
+
'messages': messages,
|
884 |
+
'n': 1,
|
885 |
+
**gen_params
|
886 |
+
}
|
887 |
+
if json_mode:
|
888 |
+
data['response_format'] = {'type': 'json_object'}
|
889 |
+
elif model_type.lower().startswith('qwen'):
|
890 |
+
header['X-DashScope-SSE'] = 'enable'
|
891 |
+
gen_params.pop('skip_special_tokens', None)
|
892 |
+
gen_params.pop('session_id', None)
|
893 |
+
if 'frequency_penalty' in gen_params:
|
894 |
+
gen_params['repetition_penalty'] = gen_params.pop(
|
895 |
+
'frequency_penalty')
|
896 |
+
gen_params['result_format'] = 'message'
|
897 |
+
data = {
|
898 |
+
'model': model_type,
|
899 |
+
'input': {
|
900 |
+
'messages': messages
|
901 |
+
},
|
902 |
+
'parameters': {
|
903 |
+
**gen_params
|
904 |
+
}
|
905 |
+
}
|
906 |
+
else:
|
907 |
+
raise NotImplementedError(
|
908 |
+
f'Model type {model_type} is not supported')
|
909 |
+
|
910 |
+
return header, data
|
911 |
+
|
912 |
+
def tokenize(self, prompt: str) -> list:
|
913 |
+
"""Tokenize the input prompt.
|
914 |
+
|
915 |
+
Args:
|
916 |
+
prompt (str): Input string.
|
917 |
+
|
918 |
+
Returns:
|
919 |
+
list: token ids
|
920 |
+
"""
|
921 |
+
import tiktoken
|
922 |
+
self.tiktoken = tiktoken
|
923 |
+
enc = self.tiktoken.encoding_for_model(self.model_type)
|
924 |
+
return enc.encode(prompt)
|
lagent/llms/sensenova.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
from concurrent.futures import ThreadPoolExecutor
|
6 |
+
from logging import getLogger
|
7 |
+
from threading import Lock
|
8 |
+
from typing import Dict, Generator, List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
import requests
|
11 |
+
|
12 |
+
from lagent.schema import ModelStatusCode
|
13 |
+
from lagent.utils.util import filter_suffix
|
14 |
+
from .base_api import BaseAPILLM
|
15 |
+
|
16 |
+
warnings.simplefilter('default')
|
17 |
+
|
18 |
+
SENSENOVA_API_BASE = 'https://api.sensenova.cn/v1/llm/chat-completions'
|
19 |
+
|
20 |
+
sensechat_models = {'SenseChat-5': 131072, 'SenseChat-5-Cantonese': 32768}
|
21 |
+
|
22 |
+
|
23 |
+
class SensenovaAPI(BaseAPILLM):
|
24 |
+
"""Model wrapper around SenseTime's models.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
model_type (str): The name of SenseTime's model.
|
28 |
+
retry (int): Number of retires if the API call fails. Defaults to 2.
|
29 |
+
key (str or List[str]): SenseTime key(s). In particular, when it
|
30 |
+
is set to "ENV", the key will be fetched from the environment
|
31 |
+
variable $SENSENOVA_API_KEY. If it's a list, the keys will be
|
32 |
+
used in round-robin manner. Defaults to 'ENV'.
|
33 |
+
meta_template (Dict, optional): The model's meta prompt
|
34 |
+
template if needed, in case the requirement of injecting or
|
35 |
+
wrapping of any meta instructions.
|
36 |
+
sensenova_api_base (str): The base url of SenseTime's API. Defaults to
|
37 |
+
'https://api.sensenova.cn/v1/llm/chat-completions'.
|
38 |
+
gen_params: Default generation configuration which could be overridden
|
39 |
+
on the fly of generation.
|
40 |
+
"""
|
41 |
+
|
42 |
+
is_api: bool = True
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
model_type: str = 'SenseChat-5-Cantonese',
|
47 |
+
retry: int = 2,
|
48 |
+
json_mode: bool = False,
|
49 |
+
key: Union[str, List[str]] = 'ENV',
|
50 |
+
meta_template: Optional[Dict] = [
|
51 |
+
dict(role='system', api_role='system'),
|
52 |
+
dict(role='user', api_role='user'),
|
53 |
+
dict(role='assistant', api_role='assistant'),
|
54 |
+
dict(role='environment', api_role='system'),
|
55 |
+
],
|
56 |
+
sensenova_api_base: str = SENSENOVA_API_BASE,
|
57 |
+
proxies: Optional[Dict] = None,
|
58 |
+
**gen_params,
|
59 |
+
):
|
60 |
+
|
61 |
+
super().__init__(
|
62 |
+
model_type=model_type,
|
63 |
+
meta_template=meta_template,
|
64 |
+
retry=retry,
|
65 |
+
**gen_params,
|
66 |
+
)
|
67 |
+
self.logger = getLogger(__name__)
|
68 |
+
|
69 |
+
if isinstance(key, str):
|
70 |
+
# First, apply for SenseNova's ak and sk from SenseTime staff
|
71 |
+
# Then, generated SENSENOVA_API_KEY using lagent.utils.gen_key.auto_gen_jwt_token(ak, sk)
|
72 |
+
self.keys = [
|
73 |
+
os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key
|
74 |
+
]
|
75 |
+
else:
|
76 |
+
self.keys = key
|
77 |
+
|
78 |
+
# record invalid keys and skip them when requesting API
|
79 |
+
# - keys have insufficient_quota
|
80 |
+
self.invalid_keys = set()
|
81 |
+
|
82 |
+
self.key_ctr = 0
|
83 |
+
self.url = sensenova_api_base
|
84 |
+
self.model_type = model_type
|
85 |
+
self.proxies = proxies
|
86 |
+
self.json_mode = json_mode
|
87 |
+
|
88 |
+
def chat(
|
89 |
+
self,
|
90 |
+
inputs: Union[List[dict], List[List[dict]]],
|
91 |
+
**gen_params,
|
92 |
+
) -> Union[str, List[str]]:
|
93 |
+
"""Generate responses given the contexts.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
inputs (Union[List[dict], List[List[dict]]]): a list of messages
|
97 |
+
or list of lists of messages
|
98 |
+
gen_params: additional generation configuration
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Union[str, List[str]]: generated string(s)
|
102 |
+
"""
|
103 |
+
assert isinstance(inputs, list)
|
104 |
+
if 'max_tokens' in gen_params:
|
105 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
106 |
+
gen_params = {**self.gen_params, **gen_params}
|
107 |
+
with ThreadPoolExecutor(max_workers=20) as executor:
|
108 |
+
tasks = [
|
109 |
+
executor.submit(self._chat,
|
110 |
+
self.template_parser._prompt2api(messages),
|
111 |
+
**gen_params)
|
112 |
+
for messages in (
|
113 |
+
[inputs] if isinstance(inputs[0], dict) else inputs)
|
114 |
+
]
|
115 |
+
ret = [task.result() for task in tasks]
|
116 |
+
return ret[0] if isinstance(inputs[0], dict) else ret
|
117 |
+
|
118 |
+
def stream_chat(
|
119 |
+
self,
|
120 |
+
inputs: List[dict],
|
121 |
+
**gen_params,
|
122 |
+
) -> Generator[Tuple[ModelStatusCode, str, Optional[str]], None, None]:
|
123 |
+
"""Generate responses given the contexts.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
inputs (List[dict]): a list of messages
|
127 |
+
gen_params: additional generation configuration
|
128 |
+
|
129 |
+
Yields:
|
130 |
+
Tuple[ModelStatusCode, str, Optional[str]]: Status code, generated string, and optional metadata
|
131 |
+
"""
|
132 |
+
assert isinstance(inputs, list)
|
133 |
+
if 'max_tokens' in gen_params:
|
134 |
+
raise NotImplementedError('unsupported parameter: max_tokens')
|
135 |
+
gen_params = self.update_gen_params(**gen_params)
|
136 |
+
gen_params['stream'] = True
|
137 |
+
|
138 |
+
resp = ''
|
139 |
+
finished = False
|
140 |
+
stop_words = gen_params.get('stop_words') or []
|
141 |
+
messages = self.template_parser._prompt2api(inputs)
|
142 |
+
for text in self._stream_chat(messages, **gen_params):
|
143 |
+
# TODO 测试 resp = text 还是 resp += text
|
144 |
+
resp += text
|
145 |
+
if not resp:
|
146 |
+
continue
|
147 |
+
# remove stop_words
|
148 |
+
for sw in stop_words:
|
149 |
+
if sw in resp:
|
150 |
+
resp = filter_suffix(resp, stop_words)
|
151 |
+
finished = True
|
152 |
+
break
|
153 |
+
yield ModelStatusCode.STREAM_ING, resp, None
|
154 |
+
if finished:
|
155 |
+
break
|
156 |
+
yield ModelStatusCode.END, resp, None
|
157 |
+
|
158 |
+
def _chat(self, messages: List[dict], **gen_params) -> str:
|
159 |
+
"""Generate completion from a list of templates.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
messages (List[dict]): a list of prompt dictionaries
|
163 |
+
gen_params: additional generation configuration
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
str: The generated string.
|
167 |
+
"""
|
168 |
+
assert isinstance(messages, list)
|
169 |
+
|
170 |
+
header, data = self.generate_request_data(
|
171 |
+
model_type=self.model_type,
|
172 |
+
messages=messages,
|
173 |
+
gen_params=gen_params,
|
174 |
+
json_mode=self.json_mode,
|
175 |
+
)
|
176 |
+
|
177 |
+
max_num_retries = 0
|
178 |
+
while max_num_retries < self.retry:
|
179 |
+
self._wait()
|
180 |
+
|
181 |
+
with Lock():
|
182 |
+
if len(self.invalid_keys) == len(self.keys):
|
183 |
+
raise RuntimeError('All keys have insufficient quota.')
|
184 |
+
|
185 |
+
# find the next valid key
|
186 |
+
while True:
|
187 |
+
self.key_ctr += 1
|
188 |
+
if self.key_ctr == len(self.keys):
|
189 |
+
self.key_ctr = 0
|
190 |
+
|
191 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
192 |
+
break
|
193 |
+
|
194 |
+
key = self.keys[self.key_ctr]
|
195 |
+
header['Authorization'] = f'Bearer {key}'
|
196 |
+
|
197 |
+
response = dict()
|
198 |
+
try:
|
199 |
+
raw_response = requests.post(
|
200 |
+
self.url,
|
201 |
+
headers=header,
|
202 |
+
data=json.dumps(data),
|
203 |
+
proxies=self.proxies,
|
204 |
+
)
|
205 |
+
response = raw_response.json()
|
206 |
+
return response['choices'][0]['message']['content'].strip()
|
207 |
+
except requests.ConnectionError:
|
208 |
+
print('Got connection error, retrying...')
|
209 |
+
continue
|
210 |
+
except requests.JSONDecodeError:
|
211 |
+
print('JsonDecode error, got', str(raw_response.content))
|
212 |
+
continue
|
213 |
+
except KeyError:
|
214 |
+
if 'error' in response:
|
215 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
216 |
+
time.sleep(1)
|
217 |
+
continue
|
218 |
+
elif response['error']['code'] == 'insufficient_quota':
|
219 |
+
self.invalid_keys.add(key)
|
220 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
221 |
+
continue
|
222 |
+
|
223 |
+
print('Find error message in response: ',
|
224 |
+
str(response['error']))
|
225 |
+
except Exception as error:
|
226 |
+
print(str(error))
|
227 |
+
max_num_retries += 1
|
228 |
+
|
229 |
+
raise RuntimeError('Calling SenseTime failed after retrying for '
|
230 |
+
f'{max_num_retries} times. Check the logs for '
|
231 |
+
'details.')
|
232 |
+
|
233 |
+
def _stream_chat(self, messages: List[dict], **gen_params) -> str:
|
234 |
+
"""Generate completion from a list of templates.
|
235 |
+
|
236 |
+
Args:
|
237 |
+
messages (List[dict]): a list of prompt dictionaries
|
238 |
+
gen_params: additional generation configuration
|
239 |
+
|
240 |
+
Returns:
|
241 |
+
str: The generated string.
|
242 |
+
"""
|
243 |
+
|
244 |
+
def streaming(raw_response):
|
245 |
+
for chunk in raw_response.iter_lines():
|
246 |
+
if chunk:
|
247 |
+
try:
|
248 |
+
decoded_chunk = chunk.decode('utf-8')
|
249 |
+
# print(f"Decoded chunk: {decoded_chunk}")
|
250 |
+
|
251 |
+
if decoded_chunk == 'data:[DONE]':
|
252 |
+
# print("Stream ended")
|
253 |
+
break
|
254 |
+
|
255 |
+
if decoded_chunk.startswith('data:'):
|
256 |
+
json_str = decoded_chunk[5:]
|
257 |
+
chunk_data = json.loads(json_str)
|
258 |
+
|
259 |
+
if 'data' in chunk_data and 'choices' in chunk_data[
|
260 |
+
'data']:
|
261 |
+
choice = chunk_data['data']['choices'][0]
|
262 |
+
if 'delta' in choice:
|
263 |
+
content = choice['delta']
|
264 |
+
yield content
|
265 |
+
else:
|
266 |
+
print(f'Unexpected format: {decoded_chunk}')
|
267 |
+
|
268 |
+
except json.JSONDecodeError as e:
|
269 |
+
print(f'JSON parsing error: {e}')
|
270 |
+
except Exception as e:
|
271 |
+
print(
|
272 |
+
f'An error occurred while processing the chunk: {e}'
|
273 |
+
)
|
274 |
+
|
275 |
+
assert isinstance(messages, list)
|
276 |
+
|
277 |
+
header, data = self.generate_request_data(
|
278 |
+
model_type=self.model_type,
|
279 |
+
messages=messages,
|
280 |
+
gen_params=gen_params,
|
281 |
+
json_mode=self.json_mode,
|
282 |
+
)
|
283 |
+
|
284 |
+
max_num_retries = 0
|
285 |
+
while max_num_retries < self.retry:
|
286 |
+
if len(self.invalid_keys) == len(self.keys):
|
287 |
+
raise RuntimeError('All keys have insufficient quota.')
|
288 |
+
|
289 |
+
# find the next valid key
|
290 |
+
while True:
|
291 |
+
self.key_ctr += 1
|
292 |
+
if self.key_ctr == len(self.keys):
|
293 |
+
self.key_ctr = 0
|
294 |
+
|
295 |
+
if self.keys[self.key_ctr] not in self.invalid_keys:
|
296 |
+
break
|
297 |
+
|
298 |
+
key = self.keys[self.key_ctr]
|
299 |
+
header['Authorization'] = f'Bearer {key}'
|
300 |
+
|
301 |
+
response = dict()
|
302 |
+
try:
|
303 |
+
raw_response = requests.post(
|
304 |
+
self.url,
|
305 |
+
headers=header,
|
306 |
+
data=json.dumps(data),
|
307 |
+
proxies=self.proxies,
|
308 |
+
)
|
309 |
+
return streaming(raw_response)
|
310 |
+
except requests.ConnectionError:
|
311 |
+
print('Got connection error, retrying...')
|
312 |
+
continue
|
313 |
+
except requests.JSONDecodeError:
|
314 |
+
print('JsonDecode error, got', str(raw_response.content))
|
315 |
+
continue
|
316 |
+
except KeyError:
|
317 |
+
if 'error' in response:
|
318 |
+
if response['error']['code'] == 'rate_limit_exceeded':
|
319 |
+
time.sleep(1)
|
320 |
+
continue
|
321 |
+
elif response['error']['code'] == 'insufficient_quota':
|
322 |
+
self.invalid_keys.add(key)
|
323 |
+
self.logger.warn(f'insufficient_quota key: {key}')
|
324 |
+
continue
|
325 |
+
|
326 |
+
print('Find error message in response: ',
|
327 |
+
str(response['error']))
|
328 |
+
except Exception as error:
|
329 |
+
print(str(error))
|
330 |
+
max_num_retries += 1
|
331 |
+
|
332 |
+
raise RuntimeError('Calling SenseTime failed after retrying for '
|
333 |
+
f'{max_num_retries} times. Check the logs for '
|
334 |
+
'details.')
|
335 |
+
|
336 |
+
def generate_request_data(self,
|
337 |
+
model_type,
|
338 |
+
messages,
|
339 |
+
gen_params,
|
340 |
+
json_mode=False):
|
341 |
+
"""
|
342 |
+
Generates the request data for different model types.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
model_type (str): The type of the model (e.g., 'sense').
|
346 |
+
messages (list): The list of messages to be sent to the model.
|
347 |
+
gen_params (dict): The generation parameters.
|
348 |
+
json_mode (bool): Flag to determine if the response format should be JSON.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
tuple: A tuple containing the header and the request data.
|
352 |
+
"""
|
353 |
+
# Copy generation parameters to avoid modifying the original dictionary
|
354 |
+
gen_params = gen_params.copy()
|
355 |
+
|
356 |
+
# Hold out 100 tokens due to potential errors in token calculation
|
357 |
+
max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
|
358 |
+
if max_tokens <= 0:
|
359 |
+
return '', ''
|
360 |
+
|
361 |
+
# Initialize the header
|
362 |
+
header = {
|
363 |
+
'content-type': 'application/json',
|
364 |
+
}
|
365 |
+
|
366 |
+
# Common parameters processing
|
367 |
+
gen_params['max_tokens'] = max_tokens
|
368 |
+
if 'stop_words' in gen_params:
|
369 |
+
gen_params['stop'] = gen_params.pop('stop_words')
|
370 |
+
if 'repetition_penalty' in gen_params:
|
371 |
+
gen_params['frequency_penalty'] = gen_params.pop(
|
372 |
+
'repetition_penalty')
|
373 |
+
|
374 |
+
# Model-specific processing
|
375 |
+
data = {}
|
376 |
+
if model_type.lower().startswith('sense'):
|
377 |
+
gen_params.pop('skip_special_tokens', None)
|
378 |
+
gen_params.pop('session_id', None)
|
379 |
+
data = {
|
380 |
+
'model': model_type,
|
381 |
+
'messages': messages,
|
382 |
+
'n': 1,
|
383 |
+
**gen_params
|
384 |
+
}
|
385 |
+
if json_mode:
|
386 |
+
data['response_format'] = {'type': 'json_object'}
|
387 |
+
else:
|
388 |
+
raise NotImplementedError(
|
389 |
+
f'Model type {model_type} is not supported')
|
390 |
+
|
391 |
+
return header, data
|
392 |
+
|
393 |
+
def tokenize(self, prompt: str) -> list:
|
394 |
+
"""Tokenize the input prompt.
|
395 |
+
|
396 |
+
Args:
|
397 |
+
prompt (str): Input string.
|
398 |
+
|
399 |
+
Returns:
|
400 |
+
list: token ids
|
401 |
+
"""
|
402 |
+
import tiktoken
|
403 |
+
|
404 |
+
self.tiktoken = tiktoken
|
405 |
+
enc = self.tiktoken.encoding_for_model('gpt-4o')
|
406 |
+
return enc.encode(prompt)
|
lagent/llms/vllm_wrapper.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from typing import List, Union
|
3 |
+
|
4 |
+
from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM
|
5 |
+
from lagent.utils.util import filter_suffix
|
6 |
+
|
7 |
+
|
8 |
+
def asdict_completion(output):
|
9 |
+
return {
|
10 |
+
key: getattr(output, key)
|
11 |
+
for key in [
|
12 |
+
'text', 'token_ids', 'cumulative_logprob', 'logprobs',
|
13 |
+
'finish_reason', 'stop_reason'
|
14 |
+
]
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
class VllmModel(BaseLLM):
|
19 |
+
"""
|
20 |
+
A wrapper of vLLM model.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
path (str): The path to the model.
|
24 |
+
It could be one of the following options:
|
25 |
+
- i) A local directory path of a huggingface model.
|
26 |
+
- ii) The model_id of a model hosted inside a model repo
|
27 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
28 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
29 |
+
and so on.
|
30 |
+
tp (int): tensor parallel
|
31 |
+
vllm_cfg (dict): Other kwargs for vllm model initialization.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
|
35 |
+
|
36 |
+
super().__init__(path=path, **kwargs)
|
37 |
+
from vllm import LLM
|
38 |
+
self.model = LLM(
|
39 |
+
model=self.path,
|
40 |
+
trust_remote_code=True,
|
41 |
+
tensor_parallel_size=tp,
|
42 |
+
**vllm_cfg)
|
43 |
+
|
44 |
+
def generate(self,
|
45 |
+
inputs: Union[str, List[str]],
|
46 |
+
do_preprocess: bool = None,
|
47 |
+
skip_special_tokens: bool = False,
|
48 |
+
return_dict: bool = False,
|
49 |
+
**kwargs):
|
50 |
+
"""Return the chat completions in non-stream mode.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
54 |
+
do_preprocess (bool): whether pre-process the messages. Default to
|
55 |
+
True, which means chat_template will be applied.
|
56 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
57 |
+
in the decoding. Default to be False.
|
58 |
+
Returns:
|
59 |
+
(a list of/batched) text/chat completion
|
60 |
+
"""
|
61 |
+
from vllm import SamplingParams
|
62 |
+
|
63 |
+
batched = True
|
64 |
+
if isinstance(inputs, str):
|
65 |
+
inputs = [inputs]
|
66 |
+
batched = False
|
67 |
+
prompt = inputs
|
68 |
+
gen_params = self.update_gen_params(**kwargs)
|
69 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
70 |
+
stop_words = gen_params.pop('stop_words')
|
71 |
+
|
72 |
+
sampling_config = SamplingParams(
|
73 |
+
skip_special_tokens=skip_special_tokens,
|
74 |
+
max_tokens=max_new_tokens,
|
75 |
+
stop=stop_words,
|
76 |
+
**gen_params)
|
77 |
+
response = self.model.generate(prompt, sampling_params=sampling_config)
|
78 |
+
texts = [resp.outputs[0].text for resp in response]
|
79 |
+
# remove stop_words
|
80 |
+
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
|
81 |
+
for resp, text in zip(response, texts):
|
82 |
+
resp.outputs[0].text = text
|
83 |
+
if batched:
|
84 |
+
return [asdict_completion(resp.outputs[0])
|
85 |
+
for resp in response] if return_dict else texts
|
86 |
+
return asdict_completion(
|
87 |
+
response[0].outputs[0]) if return_dict else texts[0]
|
88 |
+
|
89 |
+
|
90 |
+
class AsyncVllmModel(AsyncBaseLLM):
|
91 |
+
"""
|
92 |
+
A asynchronous wrapper of vLLM model.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
path (str): The path to the model.
|
96 |
+
It could be one of the following options:
|
97 |
+
- i) A local directory path of a huggingface model.
|
98 |
+
- ii) The model_id of a model hosted inside a model repo
|
99 |
+
on huggingface.co, such as "internlm/internlm-chat-7b",
|
100 |
+
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
|
101 |
+
and so on.
|
102 |
+
tp (int): tensor parallel
|
103 |
+
vllm_cfg (dict): Other kwargs for vllm model initialization.
|
104 |
+
"""
|
105 |
+
|
106 |
+
def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
|
107 |
+
super().__init__(path=path, **kwargs)
|
108 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine
|
109 |
+
|
110 |
+
engine_args = AsyncEngineArgs(
|
111 |
+
model=self.path,
|
112 |
+
trust_remote_code=True,
|
113 |
+
tensor_parallel_size=tp,
|
114 |
+
**vllm_cfg)
|
115 |
+
self.model = AsyncLLMEngine.from_engine_args(engine_args)
|
116 |
+
|
117 |
+
async def generate(self,
|
118 |
+
inputs: Union[str, List[str]],
|
119 |
+
session_ids: Union[int, List[int]] = None,
|
120 |
+
do_preprocess: bool = None,
|
121 |
+
skip_special_tokens: bool = False,
|
122 |
+
return_dict: bool = False,
|
123 |
+
**kwargs):
|
124 |
+
"""Return the chat completions in non-stream mode.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
inputs (Union[str, List[str]]): input texts to be completed.
|
128 |
+
do_preprocess (bool): whether pre-process the messages. Default to
|
129 |
+
True, which means chat_template will be applied.
|
130 |
+
skip_special_tokens (bool): Whether or not to remove special tokens
|
131 |
+
in the decoding. Default to be False.
|
132 |
+
Returns:
|
133 |
+
(a list of/batched) text/chat completion
|
134 |
+
"""
|
135 |
+
from vllm import SamplingParams
|
136 |
+
|
137 |
+
batched = True
|
138 |
+
if isinstance(inputs, str):
|
139 |
+
inputs = [inputs]
|
140 |
+
batched = False
|
141 |
+
if session_ids is None:
|
142 |
+
session_ids = list(range(len(inputs)))
|
143 |
+
elif isinstance(session_ids, (int, str)):
|
144 |
+
session_ids = [session_ids]
|
145 |
+
assert len(inputs) == len(session_ids)
|
146 |
+
|
147 |
+
prompt = inputs
|
148 |
+
gen_params = self.update_gen_params(**kwargs)
|
149 |
+
max_new_tokens = gen_params.pop('max_new_tokens')
|
150 |
+
stop_words = gen_params.pop('stop_words')
|
151 |
+
|
152 |
+
sampling_config = SamplingParams(
|
153 |
+
skip_special_tokens=skip_special_tokens,
|
154 |
+
max_tokens=max_new_tokens,
|
155 |
+
stop=stop_words,
|
156 |
+
**gen_params)
|
157 |
+
|
158 |
+
async def _inner_generate(uid, text):
|
159 |
+
resp, generator = '', self.model.generate(
|
160 |
+
text, sampling_params=sampling_config, request_id=uid)
|
161 |
+
async for out in generator:
|
162 |
+
resp = out.outputs[0]
|
163 |
+
return resp
|
164 |
+
|
165 |
+
response = await asyncio.gather(*[
|
166 |
+
_inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
|
167 |
+
])
|
168 |
+
texts = [resp.text for resp in response]
|
169 |
+
# remove stop_words
|
170 |
+
texts = filter_suffix(texts, self.gen_params.get('stop_words'))
|
171 |
+
for resp, text in zip(response, texts):
|
172 |
+
resp.text = text
|
173 |
+
if batched:
|
174 |
+
return [asdict_completion(resp)
|
175 |
+
for resp in response] if return_dict else texts
|
176 |
+
return asdict_completion(response[0]) if return_dict else texts[0]
|
lagent/memory/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_memory import Memory
|
2 |
+
from .manager import MemoryManager
|
3 |
+
|
4 |
+
__all__ = ['Memory', 'MemoryManager']
|
lagent/memory/base_memory.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
from lagent.schema import AgentMessage
|
4 |
+
|
5 |
+
|
6 |
+
class Memory:
|
7 |
+
|
8 |
+
def __init__(self, recent_n=None) -> None:
|
9 |
+
self.memory: List[AgentMessage] = []
|
10 |
+
self.recent_n = recent_n
|
11 |
+
|
12 |
+
def get_memory(
|
13 |
+
self,
|
14 |
+
recent_n: Optional[int] = None,
|
15 |
+
filter_func: Optional[Callable[[int, dict], bool]] = None,
|
16 |
+
) -> list:
|
17 |
+
recent_n = recent_n or self.recent_n
|
18 |
+
if recent_n is not None:
|
19 |
+
memory = self.memory[-recent_n:]
|
20 |
+
else:
|
21 |
+
memory = self.memory
|
22 |
+
if filter_func is not None:
|
23 |
+
memory = [m for i, m in enumerate(memory) if filter_func(i, m)]
|
24 |
+
return memory
|
25 |
+
|
26 |
+
def add(self, memories: Union[List[Dict], Dict, None]) -> None:
|
27 |
+
for memory in memories if isinstance(memories,
|
28 |
+
(list, tuple)) else [memories]:
|
29 |
+
if isinstance(memory, str):
|
30 |
+
memory = AgentMessage(sender='user', content=memory)
|
31 |
+
if isinstance(memory, AgentMessage):
|
32 |
+
self.memory.append(memory)
|
33 |
+
|
34 |
+
def delete(self, index: Union[List, int]) -> None:
|
35 |
+
if isinstance(index, int):
|
36 |
+
del self.memory[index]
|
37 |
+
else:
|
38 |
+
for i in index:
|
39 |
+
del self.memory[i]
|
40 |
+
|
41 |
+
def load(
|
42 |
+
self,
|
43 |
+
memories: Union[str, Dict, List],
|
44 |
+
overwrite: bool = True,
|
45 |
+
) -> None:
|
46 |
+
if overwrite:
|
47 |
+
self.memory = []
|
48 |
+
if isinstance(memories, dict):
|
49 |
+
self.memory.append(AgentMessage(**memories))
|
50 |
+
elif isinstance(memories, list):
|
51 |
+
for m in memories:
|
52 |
+
self.memory.append(AgentMessage(**m))
|
53 |
+
else:
|
54 |
+
raise TypeError(f'{type(memories)} is not supported')
|
55 |
+
|
56 |
+
def save(self) -> List[dict]:
|
57 |
+
memory = []
|
58 |
+
for m in self.memory:
|
59 |
+
memory.append(m.model_dump())
|
60 |
+
return memory
|
lagent/memory/manager.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from ..utils import create_object
|
4 |
+
from .base_memory import Memory
|
5 |
+
|
6 |
+
|
7 |
+
class MemoryManager:
|
8 |
+
|
9 |
+
def __init__(self, cfg) -> None:
|
10 |
+
self.cfg = cfg
|
11 |
+
self.memory_map: Dict[str, Memory] = {}
|
12 |
+
|
13 |
+
def create_instance(self, session_id):
|
14 |
+
self.memory_map[session_id] = create_object(self.cfg)
|
15 |
+
|
16 |
+
def get_memory(self, session_id=0, **kwargs) -> list:
|
17 |
+
return self.memory_map[session_id].get_memory(**kwargs)
|
18 |
+
|
19 |
+
def add(self, memory, session_id=0, **kwargs) -> None:
|
20 |
+
if session_id not in self.memory_map:
|
21 |
+
self.create_instance(session_id)
|
22 |
+
self.memory_map[session_id].add(memory, **kwargs)
|
23 |
+
|
24 |
+
def get(self, session_id=0) -> Memory:
|
25 |
+
return self.memory_map.get(session_id, None)
|
26 |
+
|
27 |
+
def reset(self, session_id=0) -> None:
|
28 |
+
if session_id in self.memory_map:
|
29 |
+
del self.memory_map[session_id]
|
lagent/prompts/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .parsers import * # noqa
|
2 |
+
from .prompt_template import PromptTemplate
|
3 |
+
|
4 |
+
__all__ = ['PromptTemplate']
|