DESKTOP-P3A1PV5\inShine commited on
Commit
a5ab2ca
·
1 Parent(s): e3846de

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +196 -0
  2. examples/agent_api_web_demo.py +196 -0
  3. examples/multi_agents_api_web_demo.py +198 -0
  4. lagent/__init__.py +4 -0
  5. lagent/actions/__init__.py +26 -0
  6. lagent/actions/action_executor.py +198 -0
  7. lagent/actions/arxiv_search.py +79 -0
  8. lagent/actions/base_action.py +434 -0
  9. lagent/actions/bing_map.py +268 -0
  10. lagent/actions/builtin_actions.py +109 -0
  11. lagent/actions/google_scholar_search.py +438 -0
  12. lagent/actions/google_search.py +244 -0
  13. lagent/actions/ipython_interactive.py +273 -0
  14. lagent/actions/ipython_interpreter.py +584 -0
  15. lagent/actions/ipython_manager.py +220 -0
  16. lagent/actions/parser.py +146 -0
  17. lagent/actions/ppt.py +233 -0
  18. lagent/actions/python_interpreter.py +176 -0
  19. lagent/actions/weather_query.py +71 -0
  20. lagent/actions/web_browser.py +908 -0
  21. lagent/agents/__init__.py +9 -0
  22. lagent/agents/agent.py +400 -0
  23. lagent/agents/aggregator/__init__.py +4 -0
  24. lagent/agents/aggregator/default_aggregator.py +44 -0
  25. lagent/agents/aggregator/tool_aggregator.py +106 -0
  26. lagent/agents/react.py +161 -0
  27. lagent/agents/stream.py +316 -0
  28. lagent/distributed/__init__.py +8 -0
  29. lagent/distributed/http_serve/__init__.py +7 -0
  30. lagent/distributed/http_serve/api_server.py +123 -0
  31. lagent/distributed/http_serve/app.py +96 -0
  32. lagent/distributed/ray_serve/__init__.py +3 -0
  33. lagent/distributed/ray_serve/ray_warpper.py +48 -0
  34. lagent/hooks/__init__.py +8 -0
  35. lagent/hooks/action_preprocessor.py +62 -0
  36. lagent/hooks/hook.py +50 -0
  37. lagent/hooks/logger.py +37 -0
  38. lagent/llms/__init__.py +32 -0
  39. lagent/llms/base_api.py +175 -0
  40. lagent/llms/base_llm.py +305 -0
  41. lagent/llms/huggingface.py +337 -0
  42. lagent/llms/lmdeploy_wrapper.py +790 -0
  43. lagent/llms/meta_template.py +40 -0
  44. lagent/llms/openai.py +924 -0
  45. lagent/llms/sensenova.py +406 -0
  46. lagent/llms/vllm_wrapper.py +176 -0
  47. lagent/memory/__init__.py +4 -0
  48. lagent/memory/base_memory.py +60 -0
  49. lagent/memory/manager.py +29 -0
  50. lagent/prompts/__init__.py +4 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import List
4
+ import streamlit as st
5
+ from lagent.actions import ArxivSearch, WeatherQuery
6
+ from lagent.prompts.parsers import PluginParser
7
+ from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
8
+ from lagent.llms import GPTAPI
9
+
10
+ class SessionState:
11
+ """管理会话状态的类。"""
12
+
13
+ def init_state(self):
14
+ """初始化会话状态变量。"""
15
+ st.session_state['assistant'] = [] # 助手消息历史
16
+ st.session_state['user'] = [] # 用户消息历史
17
+ # 初始化插件列表
18
+ action_list = [
19
+ ArxivSearch(),
20
+ WeatherQuery()
21
+ ]
22
+ st.session_state['plugin_map'] = {action.name: action for action in action_list}
23
+ st.session_state['model_map'] = {} # 存储模型实例
24
+ st.session_state['model_selected'] = None # 当前选定模型
25
+ st.session_state['plugin_actions'] = set() # 当前激活插件
26
+ st.session_state['history'] = [] # 聊天历史
27
+ st.session_state['api_base'] = None # 初始化API base地址
28
+
29
+ def clear_state(self):
30
+ """清除当前会话状态。"""
31
+ st.session_state['assistant'] = []
32
+ st.session_state['user'] = []
33
+ st.session_state['model_selected'] = None
34
+
35
+
36
+ class StreamlitUI:
37
+ """管理 Streamlit 界面的类。"""
38
+
39
+ def __init__(self, session_state: SessionState):
40
+ self.session_state = session_state
41
+ self.plugin_action = [] # 当前选定的插件
42
+ # 初始化提示词
43
+ self.meta_prompt = META_CN
44
+ self.plugin_prompt = PLUGIN_CN
45
+ self.init_streamlit()
46
+
47
+ def init_streamlit(self):
48
+ """初始化 Streamlit 的 UI 设置。"""
49
+ st.set_page_config(
50
+ layout='wide',
51
+ page_title='lagent-web',
52
+ page_icon='./docs/imgs/lagent_icon.png'
53
+ )
54
+ st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
55
+
56
+ def setup_sidebar(self):
57
+ """设置侧边栏,选择模型和插件。"""
58
+ # 模型名称和 API Base 输入框
59
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
60
+
61
+ # ================================== 硅基流动的API ==================================
62
+ # 注意,如果采用硅基流动API,模型名称需要更改为:internlm/internlm2_5-7b-chat 或者 internlm/internlm2_5-20b-chat
63
+ # api_base = st.sidebar.text_input(
64
+ # 'API Base 地址:', value='https://api.siliconflow.cn/v1/chat/completions'
65
+ # )
66
+ # ================================== 浦语官方的API ==================================
67
+ api_base = st.sidebar.text_input(
68
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
69
+ )
70
+ # ==================================================================================
71
+ # 插件选择
72
+ plugin_name = st.sidebar.multiselect(
73
+ '插件选择',
74
+ options=list(st.session_state['plugin_map'].keys()),
75
+ default=[],
76
+ )
77
+
78
+ # 根据选择的插件生成插件操作列表
79
+ self.plugin_action = [st.session_state['plugin_map'][name] for name in plugin_name]
80
+
81
+ # 动态生成插件提示
82
+ if self.plugin_action:
83
+ self.plugin_prompt = get_plugin_prompt(self.plugin_action)
84
+
85
+ # 清空对话按钮
86
+ if st.sidebar.button('清空对话', key='clear'):
87
+ self.session_state.clear_state()
88
+
89
+ return model_name, api_base, self.plugin_action
90
+
91
+ def initialize_chatbot(self, model_name, api_base, plugin_action):
92
+ """初始化 GPTAPI 实例作为 chatbot。"""
93
+ token = os.getenv("INTERNLM_API_KEY")
94
+ if not token:
95
+ st.error("未检测到环境变量 `token`,请设置环境变量,例如 `export token='your_token_here'` 后重新运行 X﹏X")
96
+ st.stop() # 停止运行应用
97
+
98
+ # 创建完整的 meta_prompt,保留原始结构并动态插入侧边栏配置
99
+ meta_prompt = [
100
+ {"role": "system", "content": self.meta_prompt, "api_role": "system"},
101
+ {"role": "user", "content": "", "api_role": "user"},
102
+ {"role": "assistant", "content": self.plugin_prompt, "api_role": "assistant"},
103
+ {"role": "environment", "content": "", "api_role": "environment"}
104
+ ]
105
+
106
+ api_model = GPTAPI(
107
+ model_type=model_name,
108
+ api_base=api_base,
109
+ key=token, # 从环境变量中获取授权令牌
110
+ meta_template=meta_prompt,
111
+ max_new_tokens=512,
112
+ temperature=0.8,
113
+ top_p=0.9
114
+ )
115
+ return api_model
116
+
117
+ def render_user(self, prompt: str):
118
+ """渲染用户输入内容。"""
119
+ with st.chat_message('user'):
120
+ st.markdown(prompt)
121
+
122
+ def render_assistant(self, agent_return):
123
+ """渲染助手响应内容。"""
124
+ with st.chat_message('assistant'):
125
+ content = getattr(agent_return, "content", str(agent_return))
126
+ st.markdown(content if isinstance(content, str) else str(content))
127
+
128
+
129
+ def main():
130
+ """主函数,运行 Streamlit 应用。"""
131
+ if 'ui' not in st.session_state:
132
+ session_state = SessionState()
133
+ session_state.init_state()
134
+ st.session_state['ui'] = StreamlitUI(session_state)
135
+ else:
136
+ st.set_page_config(
137
+ layout='wide',
138
+ page_title='lagent-web',
139
+ page_icon='./docs/imgs/lagent_icon.png'
140
+ )
141
+ st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
142
+
143
+ # 设置侧边栏并获取模型和插件信息
144
+ model_name, api_base, plugin_action = st.session_state['ui'].setup_sidebar()
145
+ plugins = [dict(type=f"lagent.actions.{plugin.__class__.__name__}") for plugin in plugin_action]
146
+
147
+ if (
148
+ 'chatbot' not in st.session_state or
149
+ model_name != st.session_state['chatbot'].model_type or
150
+ 'last_plugin_action' not in st.session_state or
151
+ plugin_action != st.session_state['last_plugin_action'] or
152
+ api_base != st.session_state['api_base']
153
+ ):
154
+ # 更新 Chatbot
155
+ st.session_state['chatbot'] = st.session_state['ui'].initialize_chatbot(model_name, api_base, plugin_action)
156
+ st.session_state['last_plugin_action'] = plugin_action # 更新插件状态
157
+ st.session_state['api_base'] = api_base # 更新 API Base 地址
158
+
159
+ # 初始化 AgentForInternLM
160
+ st.session_state['agent'] = AgentForInternLM(
161
+ llm=st.session_state['chatbot'],
162
+ plugins=plugins,
163
+ output_format=dict(
164
+ type=PluginParser,
165
+ template=PLUGIN_CN,
166
+ prompt=get_plugin_prompt(plugin_action)
167
+ )
168
+ )
169
+ # 清空对话历史
170
+ st.session_state['session_history'] = []
171
+
172
+ if 'agent' not in st.session_state:
173
+ st.session_state['agent'] = None
174
+
175
+ agent = st.session_state['agent']
176
+ for prompt, agent_return in zip(st.session_state['user'], st.session_state['assistant']):
177
+ st.session_state['ui'].render_user(prompt)
178
+ st.session_state['ui'].render_assistant(agent_return)
179
+
180
+ # 处理用户输入
181
+ if user_input := st.chat_input(''):
182
+ st.session_state['ui'].render_user(user_input)
183
+
184
+ # 调用模型时确保侧边栏的系统提示词和插件提示词生效
185
+ res = agent(user_input, session_id=0)
186
+ st.session_state['ui'].render_assistant(res)
187
+
188
+ # 更新会话状态
189
+ st.session_state['user'].append(user_input)
190
+ st.session_state['assistant'].append(copy.deepcopy(res))
191
+
192
+ st.session_state['last_status'] = None
193
+
194
+
195
+ if __name__ == '__main__':
196
+ main()
examples/agent_api_web_demo.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ from typing import List
4
+ import streamlit as st
5
+ from lagent.actions import ArxivSearch, WeatherQuery
6
+ from lagent.prompts.parsers import PluginParser
7
+ from lagent.agents.stream import INTERPRETER_CN, META_CN, PLUGIN_CN, AgentForInternLM, get_plugin_prompt
8
+ from lagent.llms import GPTAPI
9
+
10
+ class SessionState:
11
+ """管理会话状态的类。"""
12
+
13
+ def init_state(self):
14
+ """初始化会话状态变量。"""
15
+ st.session_state['assistant'] = [] # 助手消息历史
16
+ st.session_state['user'] = [] # 用户消息历史
17
+ # 初始化插件列表
18
+ action_list = [
19
+ ArxivSearch(),
20
+ WeatherQuery()
21
+ ]
22
+ st.session_state['plugin_map'] = {action.name: action for action in action_list}
23
+ st.session_state['model_map'] = {} # 存储模型实例
24
+ st.session_state['model_selected'] = None # 当前选定模型
25
+ st.session_state['plugin_actions'] = set() # 当前激活插件
26
+ st.session_state['history'] = [] # 聊天历史
27
+ st.session_state['api_base'] = None # 初始化API base地址
28
+
29
+ def clear_state(self):
30
+ """清除当前会话状态。"""
31
+ st.session_state['assistant'] = []
32
+ st.session_state['user'] = []
33
+ st.session_state['model_selected'] = None
34
+
35
+
36
+ class StreamlitUI:
37
+ """管理 Streamlit 界面的类。"""
38
+
39
+ def __init__(self, session_state: SessionState):
40
+ self.session_state = session_state
41
+ self.plugin_action = [] # 当前选定的插件
42
+ # 初始化提示词
43
+ self.meta_prompt = META_CN
44
+ self.plugin_prompt = PLUGIN_CN
45
+ self.init_streamlit()
46
+
47
+ def init_streamlit(self):
48
+ """初始化 Streamlit 的 UI 设置。"""
49
+ st.set_page_config(
50
+ layout='wide',
51
+ page_title='lagent-web',
52
+ page_icon='./docs/imgs/lagent_icon.png'
53
+ )
54
+ st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
55
+
56
+ def setup_sidebar(self):
57
+ """设置侧边栏,选择模型和插件。"""
58
+ # 模型名称和 API Base 输入框
59
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
60
+
61
+ # ================================== 硅基流动的API ==================================
62
+ # 注意,如果采用硅基流动API,模型名称需要更改为:internlm/internlm2_5-7b-chat 或者 internlm/internlm2_5-20b-chat
63
+ # api_base = st.sidebar.text_input(
64
+ # 'API Base 地址:', value='https://api.siliconflow.cn/v1/chat/completions'
65
+ # )
66
+ # ================================== 浦语官方的API ==================================
67
+ api_base = st.sidebar.text_input(
68
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
69
+ )
70
+ # ==================================================================================
71
+ # 插件选择
72
+ plugin_name = st.sidebar.multiselect(
73
+ '插件选择',
74
+ options=list(st.session_state['plugin_map'].keys()),
75
+ default=[],
76
+ )
77
+
78
+ # 根据选择的插件生成插件操作列表
79
+ self.plugin_action = [st.session_state['plugin_map'][name] for name in plugin_name]
80
+
81
+ # 动态生成插件提示
82
+ if self.plugin_action:
83
+ self.plugin_prompt = get_plugin_prompt(self.plugin_action)
84
+
85
+ # 清空对话按钮
86
+ if st.sidebar.button('清空对话', key='clear'):
87
+ self.session_state.clear_state()
88
+
89
+ return model_name, api_base, self.plugin_action
90
+
91
+ def initialize_chatbot(self, model_name, api_base, plugin_action):
92
+ """初始化 GPTAPI 实例作为 chatbot。"""
93
+ token = os.getenv("INTERNLM_API_KEY")
94
+ if not token:
95
+ st.error("未检测到环境变量 `token`,请设置环境变量,例如 `export token='your_token_here'` 后重新运行 X﹏X")
96
+ st.stop() # 停止运行应用
97
+
98
+ # 创建完整的 meta_prompt,保留原始结构并动态插入侧边栏配置
99
+ meta_prompt = [
100
+ {"role": "system", "content": self.meta_prompt, "api_role": "system"},
101
+ {"role": "user", "content": "", "api_role": "user"},
102
+ {"role": "assistant", "content": self.plugin_prompt, "api_role": "assistant"},
103
+ {"role": "environment", "content": "", "api_role": "environment"}
104
+ ]
105
+
106
+ api_model = GPTAPI(
107
+ model_type=model_name,
108
+ api_base=api_base,
109
+ key=token, # 从环境变量中获取授权令牌
110
+ meta_template=meta_prompt,
111
+ max_new_tokens=512,
112
+ temperature=0.8,
113
+ top_p=0.9
114
+ )
115
+ return api_model
116
+
117
+ def render_user(self, prompt: str):
118
+ """渲染用户输入内容。"""
119
+ with st.chat_message('user'):
120
+ st.markdown(prompt)
121
+
122
+ def render_assistant(self, agent_return):
123
+ """渲染助手响应内容。"""
124
+ with st.chat_message('assistant'):
125
+ content = getattr(agent_return, "content", str(agent_return))
126
+ st.markdown(content if isinstance(content, str) else str(content))
127
+
128
+
129
+ def main():
130
+ """主函数,运行 Streamlit 应用。"""
131
+ if 'ui' not in st.session_state:
132
+ session_state = SessionState()
133
+ session_state.init_state()
134
+ st.session_state['ui'] = StreamlitUI(session_state)
135
+ else:
136
+ st.set_page_config(
137
+ layout='wide',
138
+ page_title='lagent-web',
139
+ page_icon='./docs/imgs/lagent_icon.png'
140
+ )
141
+ st.header(':robot_face: :blue[Lagent] Web Demo ', divider='rainbow')
142
+
143
+ # 设置侧边栏并获取模型和插件信息
144
+ model_name, api_base, plugin_action = st.session_state['ui'].setup_sidebar()
145
+ plugins = [dict(type=f"lagent.actions.{plugin.__class__.__name__}") for plugin in plugin_action]
146
+
147
+ if (
148
+ 'chatbot' not in st.session_state or
149
+ model_name != st.session_state['chatbot'].model_type or
150
+ 'last_plugin_action' not in st.session_state or
151
+ plugin_action != st.session_state['last_plugin_action'] or
152
+ api_base != st.session_state['api_base']
153
+ ):
154
+ # 更新 Chatbot
155
+ st.session_state['chatbot'] = st.session_state['ui'].initialize_chatbot(model_name, api_base, plugin_action)
156
+ st.session_state['last_plugin_action'] = plugin_action # 更新插件状态
157
+ st.session_state['api_base'] = api_base # 更新 API Base 地址
158
+
159
+ # 初始化 AgentForInternLM
160
+ st.session_state['agent'] = AgentForInternLM(
161
+ llm=st.session_state['chatbot'],
162
+ plugins=plugins,
163
+ output_format=dict(
164
+ type=PluginParser,
165
+ template=PLUGIN_CN,
166
+ prompt=get_plugin_prompt(plugin_action)
167
+ )
168
+ )
169
+ # 清空对话历史
170
+ st.session_state['session_history'] = []
171
+
172
+ if 'agent' not in st.session_state:
173
+ st.session_state['agent'] = None
174
+
175
+ agent = st.session_state['agent']
176
+ for prompt, agent_return in zip(st.session_state['user'], st.session_state['assistant']):
177
+ st.session_state['ui'].render_user(prompt)
178
+ st.session_state['ui'].render_assistant(agent_return)
179
+
180
+ # 处理用户输入
181
+ if user_input := st.chat_input(''):
182
+ st.session_state['ui'].render_user(user_input)
183
+
184
+ # 调用模型时确保侧边栏的系统提示词和插件提示词生效
185
+ res = agent(user_input, session_id=0)
186
+ st.session_state['ui'].render_assistant(res)
187
+
188
+ # 更新会话状态
189
+ st.session_state['user'].append(user_input)
190
+ st.session_state['assistant'].append(copy.deepcopy(res))
191
+
192
+ st.session_state['last_status'] = None
193
+
194
+
195
+ if __name__ == '__main__':
196
+ main()
examples/multi_agents_api_web_demo.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import json
4
+ import re
5
+ import requests
6
+ import streamlit as st
7
+
8
+ from lagent.agents import Agent
9
+ from lagent.prompts.parsers import PluginParser
10
+ from lagent.agents.stream import PLUGIN_CN, get_plugin_prompt
11
+ from lagent.schema import AgentMessage
12
+ from lagent.actions import ArxivSearch
13
+ from lagent.hooks import Hook
14
+ from lagent.llms import GPTAPI
15
+
16
+ YOUR_TOKEN_HERE = os.getenv("INTERNLM_API_KEY")
17
+ if not YOUR_TOKEN_HERE:
18
+ raise EnvironmentError("未找到环境变量 'token',请设置后再运行程序。")
19
+
20
+ # Hook类,用于对消息添加前缀
21
+ class PrefixedMessageHook(Hook):
22
+ def __init__(self, prefix, senders=None):
23
+ """
24
+ 初始化Hook
25
+ :param prefix: 消息前缀
26
+ :param senders: 指定发送者列表
27
+ """
28
+ self.prefix = prefix
29
+ self.senders = senders or []
30
+
31
+ def before_agent(self, agent, messages, session_id):
32
+ """
33
+ 在代理处理消息前修改消息内容
34
+ :param agent: 当前代理
35
+ :param messages: 消息列表
36
+ :param session_id: 会话ID
37
+ """
38
+ for message in messages:
39
+ if message.sender in self.senders:
40
+ message.content = self.prefix + message.content
41
+
42
+ class AsyncBlogger:
43
+ """博客生成类,整合写作者和批评者。"""
44
+
45
+ def __init__(self, model_type, api_base, writer_prompt, critic_prompt, critic_prefix='', max_turn=2):
46
+ """
47
+ 初始化博客生成器
48
+ :param model_type: 模型类型
49
+ :param api_base: API 基地址
50
+ :param writer_prompt: 写作者提示词
51
+ :param critic_prompt: 批评者提示词
52
+ :param critic_prefix: 批评消息前缀
53
+ :param max_turn: 最大轮次
54
+ """
55
+ self.model_type = model_type
56
+ self.api_base = api_base
57
+ self.llm = GPTAPI(
58
+ model_type=model_type,
59
+ api_base=api_base,
60
+ key=YOUR_TOKEN_HERE,
61
+ max_new_tokens=4096,
62
+ )
63
+ self.plugins = [dict(type='lagent.actions.ArxivSearch')]
64
+ self.writer = Agent(
65
+ self.llm,
66
+ writer_prompt,
67
+ name='写作者',
68
+ output_format=dict(
69
+ type=PluginParser,
70
+ template=PLUGIN_CN,
71
+ prompt=get_plugin_prompt(self.plugins)
72
+ )
73
+ )
74
+ self.critic = Agent(
75
+ self.llm,
76
+ critic_prompt,
77
+ name='批评者',
78
+ hooks=[PrefixedMessageHook(critic_prefix, ['写作者'])]
79
+ )
80
+ self.max_turn = max_turn
81
+
82
+ async def forward(self, message: AgentMessage, update_placeholder):
83
+ """
84
+ 执行多阶段博客生成流程
85
+ :param message: 初始消息
86
+ :param update_placeholder: Streamlit占位符
87
+ :return: 最终优化的博客内容
88
+ """
89
+ step1_placeholder = update_placeholder.container()
90
+ step2_placeholder = update_placeholder.container()
91
+ step3_placeholder = update_placeholder.container()
92
+
93
+ # 第一步:生成初始内容
94
+ step1_placeholder.markdown("**Step 1: 生成初始内容...**")
95
+ message = self.writer(message)
96
+ if message.content:
97
+ step1_placeholder.markdown(f"**生成的初始内容**:\n\n{message.content}")
98
+ else:
99
+ step1_placeholder.markdown("**生成的初始内容为空,请检查生成逻辑。**")
100
+
101
+ # 第二步:批评者提供反馈
102
+ step2_placeholder.markdown("**Step 2: 批评者正在提供反馈和文献推荐...**")
103
+ message = self.critic(message)
104
+ if message.content:
105
+ # 解析批评者反馈
106
+ suggestions = re.search(r"1\. 批评建议:\n(.*?)2\. 推荐的关键词:", message.content, re.S)
107
+ keywords = re.search(r"2\. 推荐的关键词:\n- (.*)", message.content)
108
+ feedback = suggestions.group(1).strip() if suggestions else "未提供批评建议"
109
+ keywords = keywords.group(1).strip() if keywords else "未提供关键词"
110
+
111
+ # Arxiv 文献查询
112
+ arxiv_search = ArxivSearch()
113
+ arxiv_results = arxiv_search.get_arxiv_article_information(keywords)
114
+
115
+ # 显示批评内容和文献推荐
116
+ message.content = f"**批评建议**:\n{feedback}\n\n**推荐的文献**:\n{arxiv_results}"
117
+ step2_placeholder.markdown(f"**批评和文献推荐**:\n\n{message.content}")
118
+ else:
119
+ step2_placeholder.markdown("**批评内容为空,请检查批评逻辑。**")
120
+
121
+ # 第三步:写作者根据反馈优化内容
122
+ step3_placeholder.markdown("**Step 3: 根据反馈改进内容...**")
123
+ improvement_prompt = AgentMessage(
124
+ sender="critic",
125
+ content=(
126
+ f"根据以下批评建议和推荐文献对内容进行改进:\n\n"
127
+ f"批评建议:\n{feedback}\n\n"
128
+ f"推荐文献:\n{arxiv_results}\n\n"
129
+ f"请优化初始内容,使其更加清晰、丰富,并符合专业水准。"
130
+ ),
131
+ )
132
+ message = self.writer(improvement_prompt)
133
+ if message.content:
134
+ step3_placeholder.markdown(f"**最终优化的博客内容**:\n\n{message.content}")
135
+ else:
136
+ step3_placeholder.markdown("**最终优化的博客内容为空,请检查生成逻辑。**")
137
+
138
+ return message
139
+
140
+ def setup_sidebar():
141
+ """设置侧边栏,选择模型。"""
142
+ model_name = st.sidebar.text_input('模型名称:', value='internlm2.5-latest')
143
+ api_base = st.sidebar.text_input(
144
+ 'API Base 地址:', value='https://internlm-chat.intern-ai.org.cn/puyu/api/v1/chat/completions'
145
+ )
146
+
147
+ return model_name, api_base
148
+
149
+ def main():
150
+ """
151
+ 主函数:构建Streamlit界面并处理用户交互
152
+ """
153
+ st.set_page_config(layout='wide', page_title='Lagent Web Demo', page_icon='🤖')
154
+ st.title("多代理博客优化助手")
155
+
156
+ model_type, api_base = setup_sidebar()
157
+ topic = st.text_input('输入一个话题:', 'Self-Supervised Learning')
158
+ generate_button = st.button('生成博客内容')
159
+
160
+ if (
161
+ 'blogger' not in st.session_state or
162
+ st.session_state['model_type'] != model_type or
163
+ st.session_state['api_base'] != api_base
164
+ ):
165
+ st.session_state['blogger'] = AsyncBlogger(
166
+ model_type=model_type,
167
+ api_base=api_base,
168
+ writer_prompt="你是一位优秀的AI内容写作者,请撰写一篇有吸引力且信息丰富的博客内容。",
169
+ critic_prompt="""
170
+ 作为一位严谨的批评者,请给出建设性的批评和改进建议,并基于相关主题使用已有的工具推荐一些参考文献,推荐的关键词应该是英语形式,简洁且切题。
171
+ 请按照以下格式提供反馈:
172
+ 1. 批评建议:
173
+ - (具体建议)
174
+ 2. 推荐的关键词:
175
+ - (关键词1, 关键词2, ...)
176
+ """,
177
+ critic_prefix="请批评以下内容,并提供改进建议:\n\n"
178
+ )
179
+ st.session_state['model_type'] = model_type
180
+ st.session_state['api_base'] = api_base
181
+
182
+ if generate_button:
183
+ update_placeholder = st.empty()
184
+
185
+ async def run_async_blogger():
186
+ message = AgentMessage(
187
+ sender='user',
188
+ content=f"请撰写一篇关于{topic}的博客文章,要求表达专业,生动有趣,并且易于理解。"
189
+ )
190
+ result = await st.session_state['blogger'].forward(message, update_placeholder)
191
+ return result
192
+
193
+ loop = asyncio.new_event_loop()
194
+ asyncio.set_event_loop(loop)
195
+ loop.run_until_complete(run_async_blogger())
196
+
197
+ if __name__ == '__main__':
198
+ main()
lagent/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .version import __version__, version_info
3
+
4
+ __all__ = ['__version__', 'version_info']
lagent/actions/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .action_executor import ActionExecutor, AsyncActionExecutor
2
+ from .arxiv_search import ArxivSearch, AsyncArxivSearch
3
+ from .base_action import BaseAction, tool_api
4
+ from .bing_map import AsyncBINGMap, BINGMap
5
+ from .builtin_actions import FinishAction, InvalidAction, NoAction
6
+ from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
7
+ from .google_search import AsyncGoogleSearch, GoogleSearch
8
+ from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
9
+ from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
10
+ from .ipython_manager import IPythonInteractiveManager
11
+ from .parser import BaseParser, JsonParser, TupleParser
12
+ from .ppt import PPT, AsyncPPT
13
+ from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
14
+ from .web_browser import AsyncWebBrowser, WebBrowser
15
+ from .weather_query import WeatherQuery
16
+
17
+ __all__ = [
18
+ 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
19
+ 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
20
+ 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
21
+ 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
22
+ 'IPythonInteractive', 'AsyncIPythonInteractive',
23
+ 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
24
+ 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
25
+ 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery'
26
+ ]
lagent/actions/action_executor.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from collections import OrderedDict
3
+ from typing import Callable, Dict, List, Union
4
+
5
+ from lagent.actions.base_action import BaseAction
6
+ from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
7
+ from lagent.hooks import Hook, RemovableHandle
8
+ from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
9
+ from lagent.utils import create_object
10
+
11
+
12
+ class ActionExecutor:
13
+ """The action executor class.
14
+
15
+ Args:
16
+ actions (Union[BaseAction, List[BaseAction]]): The action or actions.
17
+ invalid_action (BaseAction, optional): The invalid action. Defaults to
18
+ InvalidAction().
19
+ no_action (BaseAction, optional): The no action.
20
+ Defaults to NoAction().
21
+ finish_action (BaseAction, optional): The finish action. Defaults to
22
+ FinishAction().
23
+ finish_in_action (bool, optional): Whether the finish action is in the
24
+ action list. Defaults to False.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
30
+ invalid_action: BaseAction = dict(type=InvalidAction),
31
+ no_action: BaseAction = dict(type=NoAction),
32
+ finish_action: BaseAction = dict(type=FinishAction),
33
+ finish_in_action: bool = False,
34
+ hooks: List[Dict] = None,
35
+ ):
36
+
37
+ if not isinstance(actions, list):
38
+ actions = [actions]
39
+ finish_action = create_object(finish_action)
40
+ if finish_in_action:
41
+ actions.append(finish_action)
42
+ for i, action in enumerate(actions):
43
+ actions[i] = create_object(action)
44
+ self.actions = {action.name: action for action in actions}
45
+
46
+ self.invalid_action = create_object(invalid_action)
47
+ self.no_action = create_object(no_action)
48
+ self.finish_action = finish_action
49
+ self._hooks: Dict[int, Hook] = OrderedDict()
50
+ if hooks:
51
+ for hook in hooks:
52
+ hook = create_object(hook)
53
+ self.register_hook(hook)
54
+
55
+ def description(self) -> List[Dict]:
56
+ actions = []
57
+ for action_name, action in self.actions.items():
58
+ if action.is_toolkit:
59
+ for api in action.description['api_list']:
60
+ api_desc = api.copy()
61
+ api_desc['name'] = f"{action_name}.{api_desc['name']}"
62
+ actions.append(api_desc)
63
+ else:
64
+ action_desc = action.description.copy()
65
+ actions.append(action_desc)
66
+ return actions
67
+
68
+ def __contains__(self, name: str):
69
+ return name in self.actions
70
+
71
+ def keys(self):
72
+ return list(self.actions.keys())
73
+
74
+ def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
75
+ action = create_object(action)
76
+ self.actions[action.name] = action
77
+
78
+ def __delitem__(self, name: str):
79
+ del self.actions[name]
80
+
81
+ def forward(self, name, parameters, **kwargs) -> ActionReturn:
82
+ action_name, api_name = (
83
+ name.split('.') if '.' in name else (name, 'run'))
84
+ action_return: ActionReturn = ActionReturn()
85
+ if action_name not in self:
86
+ if name == self.no_action.name:
87
+ action_return = self.no_action(parameters)
88
+ elif name == self.finish_action.name:
89
+ action_return = self.finish_action(parameters)
90
+ else:
91
+ action_return = self.invalid_action(parameters)
92
+ else:
93
+ action_return = self.actions[action_name](parameters, api_name)
94
+ action_return.valid = ActionValidCode.OPEN
95
+ return action_return
96
+
97
+ def __call__(self,
98
+ message: AgentMessage,
99
+ session_id=0,
100
+ **kwargs) -> AgentMessage:
101
+ # message.receiver = self.name
102
+ for hook in self._hooks.values():
103
+ result = hook.before_action(self, message, session_id)
104
+ if result:
105
+ message = result
106
+
107
+ assert isinstance(message.content, FunctionCall) or (
108
+ isinstance(message.content, dict) and 'name' in message.content
109
+ and 'parameters' in message.content)
110
+ if isinstance(message.content, dict):
111
+ name = message.content.get('name')
112
+ parameters = message.content.get('parameters')
113
+ else:
114
+ name = message.content.name
115
+ parameters = message.content.parameters
116
+
117
+ response_message = self.forward(
118
+ name=name, parameters=parameters, **kwargs)
119
+ if not isinstance(response_message, AgentMessage):
120
+ response_message = AgentMessage(
121
+ sender=self.__class__.__name__,
122
+ content=response_message,
123
+ )
124
+
125
+ for hook in self._hooks.values():
126
+ result = hook.after_action(self, response_message, session_id)
127
+ if result:
128
+ response_message = result
129
+ return response_message
130
+
131
+ def register_hook(self, hook: Callable):
132
+ handle = RemovableHandle(self._hooks)
133
+ self._hooks[handle.id] = hook
134
+ return handle
135
+
136
+
137
+ class AsyncActionExecutor(ActionExecutor):
138
+
139
+ async def forward(self, name, parameters, **kwargs) -> ActionReturn:
140
+ action_name, api_name = (
141
+ name.split('.') if '.' in name else (name, 'run'))
142
+ action_return: ActionReturn = ActionReturn()
143
+ if action_name not in self:
144
+ if name == self.no_action.name:
145
+ action_return = self.no_action(parameters)
146
+ elif name == self.finish_action.name:
147
+ action_return = self.finish_action(parameters)
148
+ else:
149
+ action_return = self.invalid_action(parameters)
150
+ else:
151
+ action = self.actions[action_name]
152
+ if inspect.iscoroutinefunction(action.__call__):
153
+ action_return = await action(parameters, api_name)
154
+ else:
155
+ action_return = action(parameters, api_name)
156
+ action_return.valid = ActionValidCode.OPEN
157
+ return action_return
158
+
159
+ async def __call__(self,
160
+ message: AgentMessage,
161
+ session_id=0,
162
+ **kwargs) -> AgentMessage:
163
+ # message.receiver = self.name
164
+ for hook in self._hooks.values():
165
+ if inspect.iscoroutinefunction(hook.before_action):
166
+ result = await hook.before_action(self, message, session_id)
167
+ else:
168
+ result = hook.before_action(self, message, session_id)
169
+ if result:
170
+ message = result
171
+
172
+ assert isinstance(message.content, FunctionCall) or (
173
+ isinstance(message.content, dict) and 'name' in message.content
174
+ and 'parameters' in message.content)
175
+ if isinstance(message.content, dict):
176
+ name = message.content.get('name')
177
+ parameters = message.content.get('parameters')
178
+ else:
179
+ name = message.content.name
180
+ parameters = message.content.parameters
181
+
182
+ response_message = await self.forward(
183
+ name=name, parameters=parameters, **kwargs)
184
+ if not isinstance(response_message, AgentMessage):
185
+ response_message = AgentMessage(
186
+ sender=self.__class__.__name__,
187
+ content=response_message,
188
+ )
189
+
190
+ for hook in self._hooks.values():
191
+ if inspect.iscoroutinefunction(hook.after_action):
192
+ result = await hook.after_action(self, response_message,
193
+ session_id)
194
+ else:
195
+ result = hook.after_action(self, response_message, session_id)
196
+ if result:
197
+ response_message = result
198
+ return response_message
lagent/actions/arxiv_search.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Type
2
+
3
+ from asyncer import asyncify
4
+
5
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
+ from lagent.actions.parser import BaseParser, JsonParser
7
+ from lagent.schema import ActionReturn, ActionStatusCode
8
+
9
+
10
+ class ArxivSearch(BaseAction):
11
+ """Search information from Arxiv.org. \
12
+ Useful for when you need to answer questions about Physics, Mathematics, \
13
+ Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
14
+ Electrical Engineering, and Economics from scientific articles on arxiv.org.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ top_k_results: int = 3,
20
+ max_query_len: int = 300,
21
+ doc_content_chars_max: int = 1500,
22
+ description: Optional[dict] = None,
23
+ parser: Type[BaseParser] = JsonParser,
24
+ ):
25
+ super().__init__(description, parser)
26
+ self.top_k_results = top_k_results
27
+ self.max_query_len = max_query_len
28
+ self.doc_content_chars_max = doc_content_chars_max
29
+
30
+ @tool_api(explode_return=True)
31
+ def get_arxiv_article_information(self, query: str) -> dict:
32
+ """Run Arxiv search and get the article meta information.
33
+
34
+ Args:
35
+ query (:class:`str`): the content of search query
36
+
37
+ Returns:
38
+ :class:`dict`: article information
39
+ * content (str): a list of 3 arxiv search papers
40
+ """
41
+ import arxiv
42
+
43
+ try:
44
+ results = arxiv.Search( # type: ignore
45
+ query[: self.max_query_len], max_results=self.top_k_results
46
+ ).results()
47
+ except Exception as exc:
48
+ return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
49
+ docs = [
50
+ f'Published: {result.updated.date()}\nTitle: {result.title}\n'
51
+ f'Authors: {", ".join(a.name for a in result.authors)}\n'
52
+ f'Summary: {result.summary[:self.doc_content_chars_max]}'
53
+ for result in results
54
+ ]
55
+ if docs:
56
+ return {'content': '\n\n'.join(docs)}
57
+ return {'content': 'No good Arxiv Result was found'}
58
+
59
+
60
+ class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
61
+ """Search information from Arxiv.org. \
62
+ Useful for when you need to answer questions about Physics, Mathematics, \
63
+ Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
64
+ Electrical Engineering, and Economics from scientific articles on arxiv.org.
65
+ """
66
+
67
+ @tool_api(explode_return=True)
68
+ @asyncify
69
+ def get_arxiv_article_information(self, query: str) -> dict:
70
+ """Run Arxiv search and get the article meta information.
71
+
72
+ Args:
73
+ query (:class:`str`): the content of search query
74
+
75
+ Returns:
76
+ :class:`dict`: article information
77
+ * content (str): a list of 3 arxiv search papers
78
+ """
79
+ return super().get_arxiv_article_information(query)
lagent/actions/base_action.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import logging
3
+ import re
4
+ from abc import ABCMeta
5
+ from copy import deepcopy
6
+ from functools import wraps
7
+ from typing import Callable, Optional, Type, get_args, get_origin
8
+
9
+ try:
10
+ from typing import Annotated
11
+ except ImportError:
12
+ from typing_extensions import Annotated
13
+
14
+ from griffe import Docstring
15
+
16
+ try:
17
+ from griffe import DocstringSectionKind
18
+ except ImportError:
19
+ from griffe.enumerations import DocstringSectionKind
20
+
21
+ from ..schema import ActionReturn, ActionStatusCode
22
+ from .parser import BaseParser, JsonParser, ParseError
23
+
24
+ logging.getLogger('griffe').setLevel(logging.ERROR)
25
+
26
+
27
+ def tool_api(func: Optional[Callable] = None,
28
+ *,
29
+ explode_return: bool = False,
30
+ returns_named_value: bool = False,
31
+ **kwargs):
32
+ """Turn functions into tools. It will parse typehints as well as docstrings
33
+ to build the tool description and attach it to functions via an attribute
34
+ ``api_description``.
35
+
36
+ Examples:
37
+
38
+ .. code-block:: python
39
+
40
+ # typehints has higher priority than docstrings
41
+ from typing import Annotated
42
+
43
+ @tool_api
44
+ def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
45
+ '''Add operation
46
+
47
+ Args:
48
+ x (int): a
49
+ y (int): b
50
+ '''
51
+ return a + b
52
+
53
+ print(add.api_description)
54
+
55
+ Args:
56
+ func (Optional[Callable]): function to decorate. Defaults to ``None``.
57
+ explode_return (bool): whether to flatten the dictionary or tuple return
58
+ as the ``return_data`` field. When enabled, it is recommended to
59
+ annotate the member in docstrings. Defaults to ``False``.
60
+
61
+ .. code-block:: python
62
+
63
+ @tool_api(explode_return=True)
64
+ def foo(a, b):
65
+ '''A simple function
66
+
67
+ Args:
68
+ a (int): a
69
+ b (int): b
70
+
71
+ Returns:
72
+ dict: information of inputs
73
+ * x: value of a
74
+ * y: value of b
75
+ '''
76
+ return {'x': a, 'y': b}
77
+
78
+ print(foo.api_description)
79
+
80
+ returns_named_value (bool): whether to parse ``thing: Description`` in
81
+ returns sections as a name and description, rather than a type and
82
+ description. When true, type must be wrapped in parentheses:
83
+ ``(int): Description``. When false, parentheses are optional but
84
+ the items cannot be named: ``int: Description``. Defaults to ``False``.
85
+
86
+ Returns:
87
+ Callable: wrapped function or partial decorator
88
+
89
+ Important:
90
+ ``return_data`` field will be added to ``api_description`` only
91
+ when ``explode_return`` or ``returns_named_value`` is enabled.
92
+ """
93
+
94
+ def _detect_type(string):
95
+ field_type = 'STRING'
96
+ if 'list' in string:
97
+ field_type = 'Array'
98
+ elif 'str' not in string:
99
+ if 'float' in string:
100
+ field_type = 'FLOAT'
101
+ elif 'int' in string:
102
+ field_type = 'NUMBER'
103
+ elif 'bool' in string:
104
+ field_type = 'BOOLEAN'
105
+ return field_type
106
+
107
+ def _explode(desc):
108
+ kvs = []
109
+ desc = '\nArgs:\n' + '\n'.join([
110
+ ' ' + item.lstrip(' -+*#.')
111
+ for item in desc.split('\n')[1:] if item.strip()
112
+ ])
113
+ docs = Docstring(desc).parse('google')
114
+ if not docs:
115
+ return kvs
116
+ if docs[0].kind is DocstringSectionKind.parameters:
117
+ for d in docs[0].value:
118
+ d = d.as_dict()
119
+ if not d['annotation']:
120
+ d.pop('annotation')
121
+ else:
122
+ d['type'] = _detect_type(d.pop('annotation').lower())
123
+ kvs.append(d)
124
+ return kvs
125
+
126
+ def _parse_tool(function):
127
+ # remove rst syntax
128
+ docs = Docstring(
129
+ re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
130
+ 'google', returns_named_value=returns_named_value, **kwargs)
131
+ desc = dict(
132
+ name=function.__name__,
133
+ description=docs[0].value
134
+ if docs[0].kind is DocstringSectionKind.text else '',
135
+ parameters=[],
136
+ required=[],
137
+ )
138
+ args_doc, returns_doc = {}, []
139
+ for doc in docs:
140
+ if doc.kind is DocstringSectionKind.parameters:
141
+ for d in doc.value:
142
+ d = d.as_dict()
143
+ d['type'] = _detect_type(d.pop('annotation').lower())
144
+ args_doc[d['name']] = d
145
+ if doc.kind is DocstringSectionKind.returns:
146
+ for d in doc.value:
147
+ d = d.as_dict()
148
+ if not d['name']:
149
+ d.pop('name')
150
+ if not d['annotation']:
151
+ d.pop('annotation')
152
+ else:
153
+ d['type'] = _detect_type(d.pop('annotation').lower())
154
+ returns_doc.append(d)
155
+
156
+ sig = inspect.signature(function)
157
+ for name, param in sig.parameters.items():
158
+ if name == 'self':
159
+ continue
160
+ parameter = dict(
161
+ name=param.name,
162
+ type='STRING',
163
+ description=args_doc.get(param.name,
164
+ {}).get('description', ''))
165
+ annotation = param.annotation
166
+ if annotation is inspect.Signature.empty:
167
+ parameter['type'] = args_doc.get(param.name,
168
+ {}).get('type', 'STRING')
169
+ else:
170
+ if get_origin(annotation) is Annotated:
171
+ annotation, info = get_args(annotation)
172
+ if info:
173
+ parameter['description'] = info
174
+ while get_origin(annotation):
175
+ annotation = get_args(annotation)
176
+ parameter['type'] = _detect_type(str(annotation))
177
+ desc['parameters'].append(parameter)
178
+ if param.default is inspect.Signature.empty:
179
+ desc['required'].append(param.name)
180
+
181
+ return_data = []
182
+ if explode_return:
183
+ return_data = _explode(returns_doc[0]['description'])
184
+ elif returns_named_value:
185
+ return_data = returns_doc
186
+ if return_data:
187
+ desc['return_data'] = return_data
188
+ return desc
189
+
190
+ if callable(func):
191
+
192
+ if inspect.iscoroutinefunction(func):
193
+
194
+ @wraps(func)
195
+ async def wrapper(self, *args, **kwargs):
196
+ return await func(self, *args, **kwargs)
197
+
198
+ else:
199
+
200
+ @wraps(func)
201
+ def wrapper(self, *args, **kwargs):
202
+ return func(self, *args, **kwargs)
203
+
204
+ wrapper.api_description = _parse_tool(func)
205
+ return wrapper
206
+
207
+ def decorate(func):
208
+
209
+ if inspect.iscoroutinefunction(func):
210
+
211
+ @wraps(func)
212
+ async def wrapper(self, *args, **kwargs):
213
+ return await func(self, *args, **kwargs)
214
+
215
+ else:
216
+
217
+ @wraps(func)
218
+ def wrapper(self, *args, **kwargs):
219
+ return func(self, *args, **kwargs)
220
+
221
+ wrapper.api_description = _parse_tool(func)
222
+ return wrapper
223
+
224
+ return decorate
225
+
226
+
227
+ class ToolMeta(ABCMeta):
228
+ """Metaclass of tools."""
229
+
230
+ def __new__(mcs, name, base, attrs):
231
+ is_toolkit, tool_desc = True, dict(
232
+ name=name,
233
+ description=Docstring(attrs.get('__doc__',
234
+ '')).parse('google')[0].value)
235
+ for key, value in attrs.items():
236
+ if callable(value) and hasattr(value, 'api_description'):
237
+ api_desc = getattr(value, 'api_description')
238
+ if key == 'run':
239
+ tool_desc['parameters'] = api_desc['parameters']
240
+ tool_desc['required'] = api_desc['required']
241
+ if api_desc['description']:
242
+ tool_desc['description'] = api_desc['description']
243
+ if api_desc.get('return_data'):
244
+ tool_desc['return_data'] = api_desc['return_data']
245
+ is_toolkit = False
246
+ else:
247
+ tool_desc.setdefault('api_list', []).append(api_desc)
248
+ if not is_toolkit and 'api_list' in tool_desc:
249
+ raise KeyError('`run` and other tool APIs can not be implemented '
250
+ 'at the same time')
251
+ if is_toolkit and 'api_list' not in tool_desc:
252
+ is_toolkit = False
253
+ if callable(attrs.get('run')):
254
+ run_api = tool_api(attrs['run'])
255
+ api_desc = run_api.api_description
256
+ tool_desc['parameters'] = api_desc['parameters']
257
+ tool_desc['required'] = api_desc['required']
258
+ if api_desc['description']:
259
+ tool_desc['description'] = api_desc['description']
260
+ if api_desc.get('return_data'):
261
+ tool_desc['return_data'] = api_desc['return_data']
262
+ attrs['run'] = run_api
263
+ else:
264
+ tool_desc['parameters'], tool_desc['required'] = [], []
265
+ attrs['_is_toolkit'] = is_toolkit
266
+ attrs['__tool_description__'] = tool_desc
267
+ return super().__new__(mcs, name, base, attrs)
268
+
269
+
270
+ class BaseAction(metaclass=ToolMeta):
271
+ """Base class for all actions.
272
+
273
+ Args:
274
+ description (:class:`Optional[dict]`): The description of the action.
275
+ Defaults to ``None``.
276
+ parser (:class:`Type[BaseParser]`): The parser class to process the
277
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
278
+
279
+ Examples:
280
+
281
+ * simple tool
282
+
283
+ .. code-block:: python
284
+
285
+ class Bold(BaseAction):
286
+ '''Make text bold'''
287
+
288
+ def run(self, text: str):
289
+ '''
290
+ Args:
291
+ text (str): input text
292
+
293
+ Returns:
294
+ str: bold text
295
+ '''
296
+ return '**' + text + '**'
297
+
298
+ action = Bold()
299
+
300
+ * toolkit with multiple APIs
301
+
302
+ .. code-block:: python
303
+
304
+ class Calculator(BaseAction):
305
+ '''Calculator'''
306
+
307
+ @tool_api
308
+ def add(self, a, b):
309
+ '''Add operation
310
+
311
+ Args:
312
+ a (int): augend
313
+ b (int): addend
314
+
315
+ Returns:
316
+ int: sum
317
+ '''
318
+ return a + b
319
+
320
+ @tool_api
321
+ def sub(self, a, b):
322
+ '''Subtraction operation
323
+
324
+ Args:
325
+ a (int): minuend
326
+ b (int): subtrahend
327
+
328
+ Returns:
329
+ int: difference
330
+ '''
331
+ return a - b
332
+
333
+ action = Calculator()
334
+ """
335
+
336
+ def __init__(
337
+ self,
338
+ description: Optional[dict] = None,
339
+ parser: Type[BaseParser] = JsonParser,
340
+ ):
341
+ self._description = deepcopy(description or self.__tool_description__)
342
+ self._name = self._description['name']
343
+ self._parser = parser(self)
344
+
345
+ def __call__(self, inputs: str, name='run') -> ActionReturn:
346
+ fallback_args = {'inputs': inputs, 'name': name}
347
+ if not hasattr(self, name):
348
+ return ActionReturn(
349
+ fallback_args,
350
+ type=self.name,
351
+ errmsg=f'invalid API: {name}',
352
+ state=ActionStatusCode.API_ERROR)
353
+ try:
354
+ inputs = self._parser.parse_inputs(inputs, name)
355
+ except ParseError as exc:
356
+ return ActionReturn(
357
+ fallback_args,
358
+ type=self.name,
359
+ errmsg=exc.err_msg,
360
+ state=ActionStatusCode.ARGS_ERROR)
361
+ try:
362
+ outputs = getattr(self, name)(**inputs)
363
+ except Exception as exc:
364
+ return ActionReturn(
365
+ inputs,
366
+ type=self.name,
367
+ errmsg=str(exc),
368
+ state=ActionStatusCode.API_ERROR)
369
+ if isinstance(outputs, ActionReturn):
370
+ action_return = outputs
371
+ if not action_return.args:
372
+ action_return.args = inputs
373
+ if not action_return.type:
374
+ action_return.type = self.name
375
+ else:
376
+ result = self._parser.parse_outputs(outputs)
377
+ action_return = ActionReturn(inputs, type=self.name, result=result)
378
+ return action_return
379
+
380
+ @property
381
+ def name(self):
382
+ return self._name
383
+
384
+ @property
385
+ def is_toolkit(self):
386
+ return self._is_toolkit
387
+
388
+ @property
389
+ def description(self) -> dict:
390
+ """Description of the tool."""
391
+ return self._description
392
+
393
+ def __repr__(self):
394
+ return f'{self.description}'
395
+
396
+ __str__ = __repr__
397
+
398
+
399
+ class AsyncActionMixin:
400
+
401
+ async def __call__(self, inputs: str, name='run') -> ActionReturn:
402
+ fallback_args = {'inputs': inputs, 'name': name}
403
+ if not hasattr(self, name):
404
+ return ActionReturn(
405
+ fallback_args,
406
+ type=self.name,
407
+ errmsg=f'invalid API: {name}',
408
+ state=ActionStatusCode.API_ERROR)
409
+ try:
410
+ inputs = self._parser.parse_inputs(inputs, name)
411
+ except ParseError as exc:
412
+ return ActionReturn(
413
+ fallback_args,
414
+ type=self.name,
415
+ errmsg=exc.err_msg,
416
+ state=ActionStatusCode.ARGS_ERROR)
417
+ try:
418
+ outputs = await getattr(self, name)(**inputs)
419
+ except Exception as exc:
420
+ return ActionReturn(
421
+ inputs,
422
+ type=self.name,
423
+ errmsg=str(exc),
424
+ state=ActionStatusCode.API_ERROR)
425
+ if isinstance(outputs, ActionReturn):
426
+ action_return = outputs
427
+ if not action_return.args:
428
+ action_return.args = inputs
429
+ if not action_return.type:
430
+ action_return.type = self.name
431
+ else:
432
+ result = self._parser.parse_outputs(outputs)
433
+ action_return = ActionReturn(inputs, type=self.name, result=result)
434
+ return action_return
lagent/actions/bing_map.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import json
3
+ import os
4
+ from typing import Optional, Type
5
+
6
+ import aiohttp
7
+ import requests
8
+
9
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
+ from lagent.actions.parser import BaseParser, JsonParser
11
+
12
+
13
+ class BINGMap(BaseAction):
14
+ """BING Map plugin for looking up map information."""
15
+
16
+ def __init__(
17
+ self,
18
+ key: Optional[str] = None,
19
+ description: Optional[dict] = None,
20
+ parser: Type[BaseParser] = JsonParser,
21
+ ) -> None:
22
+ super().__init__(description, parser)
23
+ key = os.environ.get('BING_MAP_KEY', key)
24
+ if key is None:
25
+ raise ValueError(
26
+ 'Please set BING Map API key either in the environment '
27
+ 'as BING_MAP_KEY or pass it as `key` parameter.')
28
+ self.key = key
29
+ self.base_url = 'http://dev.virtualearth.net/REST/V1/'
30
+
31
+ @tool_api(explode_return=True)
32
+ def get_distance(self, start: str, end: str) -> dict:
33
+ """Get the distance between two locations in km.
34
+
35
+ Args:
36
+ start (:class:`str`): The start location
37
+ end (:class:`str`): The end location
38
+
39
+ Returns:
40
+ :class:`dict`: distance information
41
+ * distance (str): the distance in km.
42
+ """
43
+ # Request URL
44
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
45
+ # GET request
46
+ r = requests.get(url)
47
+ # TODO check request status?
48
+ data = json.loads(r.text)
49
+ # Extract route information
50
+ route = data['resourceSets'][0]['resources'][0]
51
+ # Extract distance in miles
52
+ distance = route['travelDistance']
53
+ return dict(distance=distance)
54
+
55
+ @tool_api(explode_return=True)
56
+ def get_route(self, start: str, end: str) -> dict:
57
+ """Get the route between two locations in km.
58
+
59
+ Args:
60
+ start (:class:`str`): The start location
61
+ end (:class:`str`): The end location
62
+
63
+ Returns:
64
+ :class:`dict`: route information
65
+ * route (list): the route, a list of actions.
66
+ """
67
+ # Request URL
68
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
69
+ # GET request
70
+ r = requests.get(url)
71
+ data = json.loads(r.text)
72
+ # Extract route information
73
+ route = data['resourceSets'][0]['resources'][0]
74
+ itinerary = route['routeLegs'][0]['itineraryItems']
75
+ # Extract route text information
76
+ route_text = []
77
+ for item in itinerary:
78
+ if 'instruction' in item:
79
+ route_text.append(item['instruction']['text'])
80
+ return dict(route=route_text)
81
+
82
+ @tool_api(explode_return=True)
83
+ def get_coordinates(self, location: str) -> dict:
84
+ """Get the coordinates of a location.
85
+
86
+ Args:
87
+ location (:class:`str`): the location need to get coordinates.
88
+
89
+ Returns:
90
+ :class:`dict`: coordinates information
91
+ * latitude (float): the latitude of the location.
92
+ * longitude (float): the longitude of the location.
93
+ """
94
+ url = self.base_url + 'Locations'
95
+ params = {'query': location, 'key': self.key}
96
+ response = requests.get(url, params=params)
97
+ json_data = response.json()
98
+ coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
99
+ 'coordinates']
100
+ return dict(latitude=coordinates[0], longitude=coordinates[1])
101
+
102
+ @tool_api(explode_return=True)
103
+ def search_nearby(self,
104
+ search_term: str,
105
+ places: str = 'unknown',
106
+ latitude: float = 0.0,
107
+ longitude: float = 0.0,
108
+ radius: int = 5000) -> dict:
109
+ """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
110
+
111
+ Args:
112
+ search_term (:class:`str`): the place name.
113
+ places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
114
+ latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
115
+ longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
116
+ radius (:class:`int`): radius in meters. Defaults to ``5000``.
117
+
118
+ Returns:
119
+ :class:`dict`: places information
120
+ * places (list): the list of places, each place is a dict with name and address, at most 5 places.
121
+ """
122
+ url = self.base_url + 'LocalSearch'
123
+ if places != 'unknown':
124
+ pos = self.get_coordinates(**{'location': places})
125
+ latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
126
+ # Build the request query string
127
+ params = {
128
+ 'query': search_term,
129
+ 'userLocation': f'{latitude},{longitude}',
130
+ 'radius': radius,
131
+ 'key': self.key
132
+ }
133
+ # Make the request
134
+ response = requests.get(url, params=params)
135
+ # Parse the response
136
+ response_data = json.loads(response.content)
137
+ # Get the results
138
+ results = response_data['resourceSets'][0]['resources']
139
+ addresses = []
140
+ for result in results:
141
+ name = result['name']
142
+ address = result['Address']['formattedAddress']
143
+ addresses.append(dict(name=name, address=address))
144
+ if len(addresses) == 5:
145
+ break
146
+ return dict(place=addresses)
147
+
148
+
149
+ class AsyncBINGMap(AsyncActionMixin, BINGMap):
150
+ """BING Map plugin for looking up map information."""
151
+
152
+ @tool_api(explode_return=True)
153
+ async def get_distance(self, start: str, end: str) -> dict:
154
+ """Get the distance between two locations in km.
155
+
156
+ Args:
157
+ start (:class:`str`): The start location
158
+ end (:class:`str`): The end location
159
+
160
+ Returns:
161
+ :class:`dict`: distance information
162
+ * distance (str): the distance in km.
163
+ """
164
+ # Request URL
165
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
166
+ # GET request
167
+ async with aiohttp.ClientSession() as session:
168
+ async with session.get(url) as resp:
169
+ # TODO check request status?
170
+ data = await resp.json()
171
+ # Extract route information
172
+ route = data['resourceSets'][0]['resources'][0]
173
+ # Extract distance in miles
174
+ distance = route['travelDistance']
175
+ return dict(distance=distance)
176
+
177
+ @tool_api(explode_return=True)
178
+ async def get_route(self, start: str, end: str) -> dict:
179
+ """Get the route between two locations in km.
180
+
181
+ Args:
182
+ start (:class:`str`): The start location
183
+ end (:class:`str`): The end location
184
+
185
+ Returns:
186
+ :class:`dict`: route information
187
+ * route (list): the route, a list of actions.
188
+ """
189
+ # Request URL
190
+ url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
191
+ # GET request
192
+ async with aiohttp.ClientSession() as session:
193
+ async with session.get(url) as resp:
194
+ data = await resp.json()
195
+ # Extract route information
196
+ route = data['resourceSets'][0]['resources'][0]
197
+ itinerary = route['routeLegs'][0]['itineraryItems']
198
+ # Extract route text information
199
+ route_text = []
200
+ for item in itinerary:
201
+ if 'instruction' in item:
202
+ route_text.append(item['instruction']['text'])
203
+ return dict(route=route_text)
204
+
205
+ @tool_api(explode_return=True)
206
+ async def get_coordinates(self, location: str) -> dict:
207
+ """Get the coordinates of a location.
208
+
209
+ Args:
210
+ location (:class:`str`): the location need to get coordinates.
211
+
212
+ Returns:
213
+ :class:`dict`: coordinates information
214
+ * latitude (float): the latitude of the location.
215
+ * longitude (float): the longitude of the location.
216
+ """
217
+ url = self.base_url + 'Locations'
218
+ params = {'query': location, 'key': self.key}
219
+ async with aiohttp.ClientSession() as session:
220
+ async with session.get(url, params=params) as resp:
221
+ data = await resp.json()
222
+ coordinates = data['resourceSets'][0]['resources'][0]['point'][
223
+ 'coordinates']
224
+ return dict(latitude=coordinates[0], longitude=coordinates[1])
225
+
226
+ @tool_api(explode_return=True)
227
+ async def search_nearby(self,
228
+ search_term: str,
229
+ places: str = 'unknown',
230
+ latitude: float = 0.0,
231
+ longitude: float = 0.0,
232
+ radius: int = 5000) -> dict:
233
+ """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
234
+
235
+ Args:
236
+ search_term (:class:`str`): the place name.
237
+ places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
238
+ latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
239
+ longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
240
+ radius (:class:`int`): radius in meters. Defaults to ``5000``.
241
+
242
+ Returns:
243
+ :class:`dict`: places information
244
+ * places (list): the list of places, each place is a dict with name and address, at most 5 places.
245
+ """
246
+ url = self.base_url + 'LocalSearch'
247
+ if places != 'unknown':
248
+ pos = self.get_coordinates(**{'location': places})
249
+ latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
250
+ # Build the request query string
251
+ params = {
252
+ 'query': search_term,
253
+ 'userLocation': f'{latitude},{longitude}',
254
+ 'radius': radius,
255
+ 'key': self.key
256
+ }
257
+ async with aiohttp.ClientSession() as session:
258
+ async with session.get(url, params=params) as resp:
259
+ data = await resp.json()
260
+ results = data['resourceSets'][0]['resources']
261
+ addresses = []
262
+ for result in results:
263
+ name = result['name']
264
+ address = result['Address']['formattedAddress']
265
+ addresses.append(dict(name=name, address=address))
266
+ if len(addresses) == 5:
267
+ break
268
+ return dict(place=addresses)
lagent/actions/builtin_actions.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from lagent.actions.base_action import BaseAction, tool_api
4
+ from lagent.actions.parser import BaseParser
5
+ from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
6
+
7
+
8
+ class InvalidAction(BaseAction):
9
+ """This is a invalid action class, which is used to return error message
10
+ when the action is invalid.
11
+
12
+ Args:
13
+ err_msg (str): The error message. Defaults to 'The action is invalid,
14
+ please check the action name'.
15
+
16
+ Returns:
17
+ ActionReturn: The action return.
18
+ """
19
+
20
+ def __init__(self,
21
+ err_msg:
22
+ str = 'The action is invalid, please check the action name.',
23
+ description: Optional[dict] = None,
24
+ parser=BaseParser) -> None:
25
+ super().__init__(description, parser)
26
+ self._err_msg = err_msg
27
+
28
+ @tool_api
29
+ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
30
+ """Return the error message.
31
+
32
+ Args:
33
+ err_msg (str, optional): The error message. If err_msg is not None,
34
+ it will be returned, otherwise the default error message will
35
+ be returned. Defaults to None.
36
+ """
37
+ action_return = ActionReturn(
38
+ url=None,
39
+ args=dict(text=err_msg),
40
+ errmsg=err_msg or self._err_msg,
41
+ type=self.name,
42
+ valid=ActionValidCode.INVALID,
43
+ state=ActionStatusCode.API_ERROR)
44
+ return action_return
45
+
46
+
47
+ class NoAction(BaseAction):
48
+ """This is a no action class, which is used to return error message when
49
+ the response does not follow the format.
50
+
51
+ Args:
52
+ err_msg (str): The error message. Defaults to
53
+ 'Please follow the format'.
54
+ """
55
+
56
+ def __init__(self,
57
+ err_msg: str = 'Please follow the format',
58
+ description: Optional[dict] = None,
59
+ parser=BaseParser):
60
+ super().__init__(description, parser)
61
+ self._err_msg = err_msg
62
+
63
+ @tool_api
64
+ def run(self, err_msg: Optional[str] = None) -> ActionReturn:
65
+ """Return the error message.
66
+
67
+ Args:
68
+ err_msg (str, optional): The error message. If err_msg is not None,
69
+ it will be returned, otherwise the default error message will
70
+ be returned. Defaults to None.
71
+
72
+ Returns:
73
+ ActionReturn: The action return.
74
+ """
75
+ action_return = ActionReturn(
76
+ url=None,
77
+ args=dict(text=err_msg),
78
+ type=self.name,
79
+ errmsg=err_msg or self._err_msg,
80
+ valid=ActionValidCode.INVALID,
81
+ state=ActionStatusCode.API_ERROR)
82
+ return action_return
83
+
84
+
85
+ class FinishAction(BaseAction):
86
+ """This is a finish action class, which is used to return the final
87
+ result."""
88
+
89
+ def __init__(self, description: Optional[dict] = None, parser=BaseParser):
90
+ super().__init__(description, parser)
91
+
92
+ @tool_api
93
+ def run(self, response: str) -> ActionReturn:
94
+ """Return the final result.
95
+
96
+ Args:
97
+ response (str): The final result.
98
+
99
+ Returns:
100
+ ActionReturn: The action return.
101
+ """
102
+ action_return = ActionReturn(
103
+ url=None,
104
+ args=dict(text=response),
105
+ result=[dict(type='text', content=response)],
106
+ type=self.name,
107
+ valid=ActionValidCode.FINISH,
108
+ state=ActionStatusCode.SUCCESS)
109
+ return action_return
lagent/actions/google_scholar_search.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import os
3
+ from typing import Optional, Type
4
+
5
+ from asyncer import asyncify
6
+
7
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
8
+ from lagent.schema import ActionReturn, ActionStatusCode
9
+ from .parser import BaseParser, JsonParser
10
+
11
+
12
+ class GoogleScholar(BaseAction):
13
+ """Plugin for google scholar search.
14
+
15
+ Args:
16
+ api_key (str): API KEY to use serper google search API,
17
+ You can create a free API key at https://serper.dev.
18
+ description (dict): The description of the action. Defaults to ``None``.
19
+ parser (Type[BaseParser]): The parser class to process the
20
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ api_key: Optional[str] = None,
26
+ description: Optional[dict] = None,
27
+ parser: Type[BaseParser] = JsonParser,
28
+ ):
29
+ super().__init__(description, parser)
30
+ api_key = os.environ.get('SERPER_API_KEY', api_key)
31
+ if api_key is None:
32
+ raise ValueError(
33
+ 'Please set Serper API key either in the environment '
34
+ 'as SERPER_API_KEY or pass it as `api_key` parameter.'
35
+ )
36
+ self.api_key = api_key
37
+
38
+ @tool_api(explode_return=True)
39
+ def search_google_scholar(
40
+ self,
41
+ query: str,
42
+ cites: Optional[str] = None,
43
+ as_ylo: Optional[int] = None,
44
+ as_yhi: Optional[int] = None,
45
+ scisbd: Optional[int] = None,
46
+ cluster: Optional[str] = None,
47
+ hl: Optional[str] = None,
48
+ lr: Optional[str] = None,
49
+ start: Optional[int] = None,
50
+ num: Optional[int] = None,
51
+ as_sdt: Optional[str] = None,
52
+ safe: Optional[str] = None,
53
+ filter: Optional[str] = None,
54
+ as_vis: Optional[str] = None,
55
+ ) -> dict:
56
+ """Search for scholarly articles based on a query according to the google scholar.
57
+
58
+ Args:
59
+ query (str): The query to search for.
60
+ cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
61
+ as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
62
+ as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
63
+ scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
64
+ cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
65
+ hl (Optional[str]): The language to use for the Google Scholar search.
66
+ lr (Optional[str]): One or multiple languages to limit the search to.
67
+ start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
68
+ num (Optional[int]): The maximum number of results to return, limited to 20.
69
+ as_sdt (Optional[str]): Can be used either as a search type or a filter.
70
+ safe (Optional[str]): The level of filtering for adult content.
71
+ filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
72
+ as_vis (Optional[str]): Defines whether to include citations or not.
73
+
74
+ Returns:
75
+ :class:`dict`: article information
76
+ - title: a list of the titles of the three selected papers
77
+ - cited_by: a list of the citation numbers of the three selected papers
78
+ - organic_id: a list of the organic results' ids of the three selected papers
79
+ - pub_info: publication information of selected papers
80
+ """
81
+ from serpapi import GoogleSearch
82
+
83
+ params = {
84
+ 'q': query,
85
+ 'engine': 'google_scholar',
86
+ 'api_key': self.api_key,
87
+ 'cites': cites,
88
+ 'as_ylo': as_ylo,
89
+ 'as_yhi': as_yhi,
90
+ 'scisbd': scisbd,
91
+ 'cluster': cluster,
92
+ 'hl': hl,
93
+ 'lr': lr,
94
+ 'start': start,
95
+ 'num': num,
96
+ 'as_sdt': as_sdt,
97
+ 'safe': safe,
98
+ 'filter': filter,
99
+ 'as_vis': as_vis,
100
+ }
101
+ search = GoogleSearch(params)
102
+ try:
103
+ r = search.get_dict()
104
+ results = r['organic_results']
105
+ title = []
106
+ snippets = []
107
+ cited_by = []
108
+ organic_id = []
109
+ pub_info = []
110
+ for item in results[:3]:
111
+ title.append(item['title'])
112
+ pub_info.append(item['publication_info']['summary'])
113
+ citation = item['inline_links'].get('cited_by', {'total': ''})
114
+ cited_by.append(citation['total'])
115
+ snippets.append(item['snippet'])
116
+ organic_id.append(item['result_id'])
117
+ return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
118
+ except Exception as e:
119
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
120
+
121
+ @tool_api(explode_return=True)
122
+ def get_author_information(
123
+ self,
124
+ author_id: str,
125
+ hl: Optional[str] = None,
126
+ view_op: Optional[str] = None,
127
+ sort: Optional[str] = None,
128
+ citation_id: Optional[str] = None,
129
+ start: Optional[int] = None,
130
+ num: Optional[int] = None,
131
+ no_cache: Optional[bool] = None,
132
+ async_req: Optional[bool] = None,
133
+ output: Optional[str] = None,
134
+ ) -> dict:
135
+ """Search for an author's information by author's id provided by get_author_id.
136
+
137
+ Args:
138
+ author_id (str): Required. The ID of an author.
139
+ hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
140
+ view_op (Optional[str]): Used for viewing specific parts of a page.
141
+ sort (Optional[str]): Used for sorting and refining articles.
142
+ citation_id (Optional[str]): Used for retrieving individual article citation.
143
+ start (Optional[int]): Defines the result offset. Default is 0.
144
+ num (Optional[int]): Defines the number of results to return. Default is 20.
145
+ no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
146
+ async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
147
+ output (Optional[str]): Defines the final output you want. Default is 'json'.
148
+
149
+ Returns:
150
+ :class:`dict`: author information
151
+ * name: author's name
152
+ * affliation: the affliation of the author
153
+ * articles: at most 3 articles by the author
154
+ * website: the author's homepage url
155
+ """
156
+ from serpapi import GoogleSearch
157
+
158
+ params = {
159
+ 'engine': 'google_scholar_author',
160
+ 'author_id': author_id,
161
+ 'api_key': self.api_key,
162
+ 'hl': hl,
163
+ 'view_op': view_op,
164
+ 'sort': sort,
165
+ 'citation_id': citation_id,
166
+ 'start': start,
167
+ 'num': num,
168
+ 'no_cache': no_cache,
169
+ 'async': async_req,
170
+ 'output': output,
171
+ }
172
+ try:
173
+ search = GoogleSearch(params)
174
+ results = search.get_dict()
175
+ author = results['author']
176
+ articles = results.get('articles', [])
177
+ return dict(
178
+ name=author['name'],
179
+ affiliations=author.get('affiliations', ''),
180
+ website=author.get('website', ''),
181
+ articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
182
+ )
183
+ except Exception as e:
184
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
185
+
186
+ @tool_api(explode_return=True)
187
+ def get_citation_format(
188
+ self,
189
+ q: str,
190
+ no_cache: Optional[bool] = None,
191
+ async_: Optional[bool] = None,
192
+ output: Optional[str] = 'json',
193
+ ) -> dict:
194
+ """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
195
+
196
+ Args:
197
+ q (str): ID of an individual Google Scholar organic search result.
198
+ no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
199
+ async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
200
+ output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
201
+
202
+ Returns:
203
+ :class:`dict`: citation format
204
+ * authors: the authors of the article
205
+ * citation: the citation format of the article
206
+ """
207
+ from serpapi import GoogleSearch
208
+
209
+ params = {
210
+ 'q': q,
211
+ 'engine': 'google_scholar_cite',
212
+ 'api_key': self.api_key,
213
+ 'no_cache': no_cache,
214
+ 'async': async_,
215
+ 'output': output,
216
+ }
217
+ try:
218
+ search = GoogleSearch(params)
219
+ results = search.get_dict()
220
+ citation = results['citations']
221
+ citation_info = citation[0]['snippet']
222
+ return citation_info
223
+ except Exception as e:
224
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
225
+
226
+ @tool_api(explode_return=True)
227
+ def get_author_id(
228
+ self,
229
+ mauthors: str,
230
+ hl: Optional[str] = 'en',
231
+ after_author: Optional[str] = None,
232
+ before_author: Optional[str] = None,
233
+ no_cache: Optional[bool] = False,
234
+ _async: Optional[bool] = False,
235
+ output: Optional[str] = 'json',
236
+ ) -> dict:
237
+ """The getAuthorId function is used to get the author's id by his or her name.
238
+
239
+ Args:
240
+ mauthors (str): Defines the author you want to search for.
241
+ hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
242
+ after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
243
+ before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
244
+ no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
245
+ _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
246
+ output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
247
+
248
+ Returns:
249
+ :class:`dict`: author id
250
+ * author_id: the author_id of the author
251
+ """
252
+ from serpapi import GoogleSearch
253
+
254
+ params = {
255
+ 'mauthors': mauthors,
256
+ 'engine': 'google_scholar_profiles',
257
+ 'api_key': self.api_key,
258
+ 'hl': hl,
259
+ 'after_author': after_author,
260
+ 'before_author': before_author,
261
+ 'no_cache': no_cache,
262
+ 'async': _async,
263
+ 'output': output,
264
+ }
265
+ try:
266
+ search = GoogleSearch(params)
267
+ results = search.get_dict()
268
+ profile = results['profiles']
269
+ author_info = dict(author_id=profile[0]['author_id'])
270
+ return author_info
271
+ except Exception as e:
272
+ return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
273
+
274
+
275
+ class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
276
+ """Plugin for google scholar search.
277
+
278
+ Args:
279
+ api_key (str): API KEY to use serper google search API,
280
+ You can create a free API key at https://serper.dev.
281
+ description (dict): The description of the action. Defaults to ``None``.
282
+ parser (Type[BaseParser]): The parser class to process the
283
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
284
+ """
285
+
286
+ @tool_api(explode_return=True)
287
+ @asyncify
288
+ def search_google_scholar(
289
+ self,
290
+ query: str,
291
+ cites: Optional[str] = None,
292
+ as_ylo: Optional[int] = None,
293
+ as_yhi: Optional[int] = None,
294
+ scisbd: Optional[int] = None,
295
+ cluster: Optional[str] = None,
296
+ hl: Optional[str] = None,
297
+ lr: Optional[str] = None,
298
+ start: Optional[int] = None,
299
+ num: Optional[int] = None,
300
+ as_sdt: Optional[str] = None,
301
+ safe: Optional[str] = None,
302
+ filter: Optional[str] = None,
303
+ as_vis: Optional[str] = None,
304
+ ) -> dict:
305
+ """Search for scholarly articles based on a query according to the google scholar.
306
+
307
+ Args:
308
+ query (str): The query to search for.
309
+ cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
310
+ as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
311
+ as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
312
+ scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
313
+ cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
314
+ hl (Optional[str]): The language to use for the Google Scholar search.
315
+ lr (Optional[str]): One or multiple languages to limit the search to.
316
+ start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
317
+ num (Optional[int]): The maximum number of results to return, limited to 20.
318
+ as_sdt (Optional[str]): Can be used either as a search type or a filter.
319
+ safe (Optional[str]): The level of filtering for adult content.
320
+ filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
321
+ as_vis (Optional[str]): Defines whether to include citations or not.
322
+
323
+ Returns:
324
+ :class:`dict`: article information
325
+ - title: a list of the titles of the three selected papers
326
+ - cited_by: a list of the citation numbers of the three selected papers
327
+ - organic_id: a list of the organic results' ids of the three selected papers
328
+ - pub_info: publication information of selected papers
329
+ """
330
+ return super().search_google_scholar(
331
+ query,
332
+ cites,
333
+ as_ylo,
334
+ as_yhi,
335
+ scisbd,
336
+ cluster,
337
+ hl,
338
+ lr,
339
+ start,
340
+ num,
341
+ as_sdt,
342
+ safe,
343
+ filter,
344
+ as_vis,
345
+ )
346
+
347
+ @tool_api(explode_return=True)
348
+ @asyncify
349
+ def get_author_information(
350
+ self,
351
+ author_id: str,
352
+ hl: Optional[str] = None,
353
+ view_op: Optional[str] = None,
354
+ sort: Optional[str] = None,
355
+ citation_id: Optional[str] = None,
356
+ start: Optional[int] = None,
357
+ num: Optional[int] = None,
358
+ no_cache: Optional[bool] = None,
359
+ async_req: Optional[bool] = None,
360
+ output: Optional[str] = None,
361
+ ) -> dict:
362
+ """Search for an author's information by author's id provided by get_author_id.
363
+
364
+ Args:
365
+ author_id (str): Required. The ID of an author.
366
+ hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
367
+ view_op (Optional[str]): Used for viewing specific parts of a page.
368
+ sort (Optional[str]): Used for sorting and refining articles.
369
+ citation_id (Optional[str]): Used for retrieving individual article citation.
370
+ start (Optional[int]): Defines the result offset. Default is 0.
371
+ num (Optional[int]): Defines the number of results to return. Default is 20.
372
+ no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
373
+ async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
374
+ output (Optional[str]): Defines the final output you want. Default is 'json'.
375
+
376
+ Returns:
377
+ :class:`dict`: author information
378
+ * name: author's name
379
+ * affliation: the affliation of the author
380
+ * articles: at most 3 articles by the author
381
+ * website: the author's homepage url
382
+ """
383
+ return super().get_author_information(
384
+ author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
385
+ )
386
+
387
+ @tool_api(explode_return=True)
388
+ @asyncify
389
+ def get_citation_format(
390
+ self,
391
+ q: str,
392
+ no_cache: Optional[bool] = None,
393
+ async_: Optional[bool] = None,
394
+ output: Optional[str] = 'json',
395
+ ) -> dict:
396
+ """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
397
+
398
+ Args:
399
+ q (str): ID of an individual Google Scholar organic search result.
400
+ no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
401
+ async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
402
+ output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
403
+
404
+ Returns:
405
+ :class:`dict`: citation format
406
+ * authors: the authors of the article
407
+ * citation: the citation format of the article
408
+ """
409
+ return super().get_citation_format(q, no_cache, async_, output)
410
+
411
+ @tool_api(explode_return=True)
412
+ @asyncify
413
+ def get_author_id(
414
+ self,
415
+ mauthors: str,
416
+ hl: Optional[str] = 'en',
417
+ after_author: Optional[str] = None,
418
+ before_author: Optional[str] = None,
419
+ no_cache: Optional[bool] = False,
420
+ _async: Optional[bool] = False,
421
+ output: Optional[str] = 'json',
422
+ ) -> dict:
423
+ """The getAuthorId function is used to get the author's id by his or her name.
424
+
425
+ Args:
426
+ mauthors (str): Defines the author you want to search for.
427
+ hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
428
+ after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
429
+ before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
430
+ no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
431
+ _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
432
+ output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
433
+
434
+ Returns:
435
+ :class:`dict`: author id
436
+ * author_id: the author_id of the author
437
+ """
438
+ return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
lagent/actions/google_search.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional, Tuple, Type, Union
3
+
4
+ import aiohttp
5
+ import requests
6
+
7
+ from lagent.schema import ActionReturn, ActionStatusCode
8
+ from .base_action import AsyncActionMixin, BaseAction, tool_api
9
+ from .parser import BaseParser, JsonParser
10
+
11
+
12
+ class GoogleSearch(BaseAction):
13
+ """Wrapper around the Serper.dev Google Search API.
14
+
15
+ To use, you should pass your serper API key to the constructor.
16
+
17
+ Code is modified from lang-chain GoogleSerperAPIWrapper
18
+ (https://github.com/langchain-ai/langchain/blob/ba5f
19
+ baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
20
+ langchain/utilities/google_serper.py)
21
+
22
+ Args:
23
+ api_key (str): API KEY to use serper google search API,
24
+ You can create a free API key at https://serper.dev.
25
+ timeout (int): Upper bound of waiting time for a serper request.
26
+ search_type (str): Serper API support ['search', 'images', 'news',
27
+ 'places'] types of search, currently we only support 'search'.
28
+ description (dict): The description of the action. Defaults to ``None``.
29
+ parser (Type[BaseParser]): The parser class to process the
30
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
31
+ """
32
+ result_key_for_type = {
33
+ 'news': 'news',
34
+ 'places': 'places',
35
+ 'images': 'images',
36
+ 'search': 'organic',
37
+ }
38
+
39
+ def __init__(
40
+ self,
41
+ api_key: Optional[str] = None,
42
+ timeout: int = 5,
43
+ search_type: str = 'search',
44
+ description: Optional[dict] = None,
45
+ parser: Type[BaseParser] = JsonParser,
46
+ ):
47
+ super().__init__(description, parser)
48
+ api_key = os.environ.get('SERPER_API_KEY', api_key)
49
+ if api_key is None:
50
+ raise ValueError(
51
+ 'Please set Serper API key either in the environment '
52
+ 'as SERPER_API_KEY or pass it as `api_key` parameter.')
53
+ self.api_key = api_key
54
+ self.timeout = timeout
55
+ self.search_type = search_type
56
+
57
+ @tool_api
58
+ def run(self, query: str, k: int = 10) -> ActionReturn:
59
+ """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
60
+
61
+ Args:
62
+ query (str): the search content
63
+ k (int): select first k results in the search results as response
64
+ """
65
+ tool_return = ActionReturn(type=self.name)
66
+ status_code, response = self._search(query, k=k)
67
+ # convert search results to ToolReturn format
68
+ if status_code == -1:
69
+ tool_return.errmsg = response
70
+ tool_return.state = ActionStatusCode.HTTP_ERROR
71
+ elif status_code == 200:
72
+ parsed_res = self._parse_results(response, k)
73
+ tool_return.result = [dict(type='text', content=str(parsed_res))]
74
+ tool_return.state = ActionStatusCode.SUCCESS
75
+ else:
76
+ tool_return.errmsg = str(status_code)
77
+ tool_return.state = ActionStatusCode.API_ERROR
78
+ return tool_return
79
+
80
+ def _parse_results(self, results: dict, k: int) -> Union[str, List[str]]:
81
+ """Parse the search results from Serper API.
82
+
83
+ Args:
84
+ results (dict): The search content from Serper API
85
+ in json format.
86
+
87
+ Returns:
88
+ List[str]: The parsed search results.
89
+ """
90
+
91
+ snippets = []
92
+
93
+ if results.get('answerBox'):
94
+ answer_box = results.get('answerBox', {})
95
+ if answer_box.get('answer'):
96
+ return [answer_box.get('answer')]
97
+ elif answer_box.get('snippet'):
98
+ return [answer_box.get('snippet').replace('\n', ' ')]
99
+ elif answer_box.get('snippetHighlighted'):
100
+ return answer_box.get('snippetHighlighted')
101
+
102
+ if results.get('knowledgeGraph'):
103
+ kg = results.get('knowledgeGraph', {})
104
+ title = kg.get('title')
105
+ entity_type = kg.get('type')
106
+ if entity_type:
107
+ snippets.append(f'{title}: {entity_type}.')
108
+ description = kg.get('description')
109
+ if description:
110
+ snippets.append(description)
111
+ for attribute, value in kg.get('attributes', {}).items():
112
+ snippets.append(f'{title} {attribute}: {value}.')
113
+
114
+ for result in results[self.result_key_for_type[
115
+ self.search_type]][:k]:
116
+ if 'snippet' in result:
117
+ snippets.append(result['snippet'])
118
+ for attribute, value in result.get('attributes', {}).items():
119
+ snippets.append(f'{attribute}: {value}.')
120
+
121
+ if len(snippets) == 0:
122
+ return ['No good Google Search Result was found']
123
+ return snippets
124
+
125
+ def _search(self,
126
+ search_term: str,
127
+ search_type: Optional[str] = None,
128
+ **kwargs) -> Tuple[int, Union[dict, str]]:
129
+ """HTTP requests to Serper API.
130
+
131
+ Args:
132
+ search_term (str): The search query.
133
+ search_type (str): search type supported by Serper API,
134
+ default to 'search'.
135
+
136
+ Returns:
137
+ tuple: the return value is a tuple contains:
138
+ - status_code (int): HTTP status code from Serper API.
139
+ - response (dict): response context with json format.
140
+ """
141
+ headers = {
142
+ 'X-API-KEY': self.api_key or '',
143
+ 'Content-Type': 'application/json',
144
+ }
145
+ params = {
146
+ 'q': search_term,
147
+ **{
148
+ key: value
149
+ for key, value in kwargs.items() if value is not None
150
+ },
151
+ }
152
+ try:
153
+ response = requests.post(
154
+ f'https://google.serper.dev/{search_type or self.search_type}',
155
+ headers=headers,
156
+ params=params,
157
+ timeout=self.timeout)
158
+ except Exception as e:
159
+ return -1, str(e)
160
+ return response.status_code, response.json()
161
+
162
+
163
+ class AsyncGoogleSearch(AsyncActionMixin, GoogleSearch):
164
+ """Wrapper around the Serper.dev Google Search API.
165
+
166
+ To use, you should pass your serper API key to the constructor.
167
+
168
+ Code is modified from lang-chain GoogleSerperAPIWrapper
169
+ (https://github.com/langchain-ai/langchain/blob/ba5f
170
+ baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
171
+ langchain/utilities/google_serper.py)
172
+
173
+ Args:
174
+ api_key (str): API KEY to use serper google search API,
175
+ You can create a free API key at https://serper.dev.
176
+ timeout (int): Upper bound of waiting time for a serper request.
177
+ search_type (str): Serper API support ['search', 'images', 'news',
178
+ 'places'] types of search, currently we only support 'search'.
179
+ description (dict): The description of the action. Defaults to ``None``.
180
+ parser (Type[BaseParser]): The parser class to process the
181
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
182
+ """
183
+
184
+ @tool_api
185
+ async def run(self, query: str, k: int = 10) -> ActionReturn:
186
+ """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
187
+
188
+ Args:
189
+ query (str): the search content
190
+ k (int): select first k results in the search results as response
191
+ """
192
+ tool_return = ActionReturn(type=self.name)
193
+ status_code, response = await self._search(query, k=k)
194
+ # convert search results to ToolReturn format
195
+ if status_code == -1:
196
+ tool_return.errmsg = response
197
+ tool_return.state = ActionStatusCode.HTTP_ERROR
198
+ elif status_code == 200:
199
+ parsed_res = self._parse_results(response)
200
+ tool_return.result = [dict(type='text', content=str(parsed_res))]
201
+ tool_return.state = ActionStatusCode.SUCCESS
202
+ else:
203
+ tool_return.errmsg = str(status_code)
204
+ tool_return.state = ActionStatusCode.API_ERROR
205
+ return tool_return
206
+
207
+ async def _search(self,
208
+ search_term: str,
209
+ search_type: Optional[str] = None,
210
+ **kwargs) -> Tuple[int, Union[dict, str]]:
211
+ """HTTP requests to Serper API.
212
+
213
+ Args:
214
+ search_term (str): The search query.
215
+ search_type (str): search type supported by Serper API,
216
+ default to 'search'.
217
+
218
+ Returns:
219
+ tuple: the return value is a tuple contains:
220
+ - status_code (int): HTTP status code from Serper API.
221
+ - response (dict): response context with json format.
222
+ """
223
+ headers = {
224
+ 'X-API-KEY': self.api_key or '',
225
+ 'Content-Type': 'application/json',
226
+ }
227
+ params = {
228
+ 'q': search_term,
229
+ **{
230
+ key: value
231
+ for key, value in kwargs.items() if value is not None
232
+ },
233
+ }
234
+ timeout = aiohttp.ClientTimeout(total=self.timeout)
235
+ async with aiohttp.ClientSession(timeout=timeout) as session:
236
+ try:
237
+ async with session.post(
238
+ f'https://google.serper.dev/{search_type or self.search_type}',
239
+ headers=headers,
240
+ params=params) as resp:
241
+ code, ret = resp.status, await resp.json()
242
+ except aiohttp.ClientError as e:
243
+ code, ret = -1, str(e)
244
+ return code, ret
lagent/actions/ipython_interactive.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import signal
3
+ from contextlib import contextmanager, redirect_stdout
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from io import StringIO
7
+ from typing import Optional, Type
8
+
9
+ from ..schema import ActionReturn, ActionStatusCode
10
+ from .base_action import AsyncActionMixin, BaseAction, tool_api
11
+ from .parser import BaseParser, JsonParser
12
+
13
+
14
+ class Status(str, Enum):
15
+ """Execution status."""
16
+ SUCCESS = 'success'
17
+ FAILURE = 'failure'
18
+
19
+
20
+ @dataclass
21
+ class ExecutionResult:
22
+ """Execution result."""
23
+ status: Status
24
+ value: Optional[str] = None
25
+ msg: Optional[str] = None
26
+
27
+
28
+ @contextmanager
29
+ def _raise_timeout(timeout):
30
+
31
+ def _handler(signum, frame):
32
+ raise TimeoutError()
33
+
34
+ signal.signal(signal.SIGALRM, _handler)
35
+ signal.alarm(timeout)
36
+
37
+ try:
38
+ yield
39
+ finally:
40
+ signal.alarm(0)
41
+
42
+
43
+ class IPythonInteractive(BaseAction):
44
+ """An interactive IPython shell for code execution.
45
+
46
+ Args:
47
+ timeout (int): Upper bound of waiting time for Python script execution.
48
+ Defaults to ``20``.
49
+ max_out_len (int): maximum output length. No truncation occurs if negative.
50
+ Defaults to ``2048``.
51
+ use_signals (bool): whether signals should be used for timing function out
52
+ or the multiprocessing. Set to ``False`` when not running in the main
53
+ thread, e.g. web applications. Defaults to ``True``
54
+ description (dict): The description of the action. Defaults to ``None``.
55
+ parser (Type[BaseParser]): The parser class to process the
56
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ timeout: int = 30,
62
+ max_out_len: int = 8192,
63
+ use_signals: bool = True,
64
+ description: Optional[dict] = None,
65
+ parser: Type[BaseParser] = JsonParser,
66
+ ):
67
+ super().__init__(description, parser)
68
+ self.timeout = timeout
69
+ self._executor = self.create_shell()
70
+ self._highlighting = re.compile(
71
+ r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
72
+ self._max_out_len = max_out_len if max_out_len >= 0 else None
73
+ self._use_signals = use_signals
74
+
75
+ def reset(self):
76
+ """Clear the context."""
77
+ self._executor.reset()
78
+
79
+ @tool_api
80
+ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
81
+ """Launch an IPython Interactive Shell to execute code.
82
+
83
+ Args:
84
+ command (:class:`str`): Python code snippet
85
+ timeout (:class:`Optional[int]`): timeout for execution.
86
+ This argument only works in the main thread. Defaults to ``None``.
87
+ """
88
+ from timeout_decorator import timeout as timer
89
+ tool_return = ActionReturn(args={'text': command}, type=self.name)
90
+ ret = (
91
+ timer(timeout or self.timeout)(self.exec)(command)
92
+ if self._use_signals else self.exec(command))
93
+ if ret.status is Status.SUCCESS:
94
+ tool_return.result = [{'type': 'text', 'content': ret.value}]
95
+ tool_return.state = ActionStatusCode.SUCCESS
96
+ else:
97
+ tool_return.errmsg = ret.msg
98
+ tool_return.state = ActionStatusCode.API_ERROR
99
+ return tool_return
100
+
101
+ def exec(self, code: str) -> ExecutionResult:
102
+ """Run Python scripts in IPython shell.
103
+
104
+ Args:
105
+ code (:class:`str`): code block
106
+
107
+ Returns:
108
+ :py:class:`ExecutionResult`: execution result
109
+ """
110
+ with StringIO() as io:
111
+ with redirect_stdout(io):
112
+ ret = self._executor.run_cell(self.extract_code(code))
113
+ result = ret.result
114
+ if result is not None:
115
+ return ExecutionResult(Status.SUCCESS,
116
+ str(result)[:self._max_out_len])
117
+ outs = io.getvalue().strip().split('\n')
118
+ if not outs:
119
+ return ExecutionResult(Status.SUCCESS, '')
120
+ for i, out in enumerate(outs):
121
+ if re.search('Error|Traceback', out, re.S):
122
+ if 'TimeoutError' in out:
123
+ return ExecutionResult(
124
+ Status.FAILURE,
125
+ msg=('The code interpreter encountered '
126
+ 'a timeout error.'))
127
+ err_idx = i
128
+ break
129
+ else:
130
+ return ExecutionResult(Status.SUCCESS,
131
+ outs[-1].strip()[:self._max_out_len])
132
+ return ExecutionResult(
133
+ Status.FAILURE,
134
+ msg=self._highlighting.sub(
135
+ '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
136
+ )
137
+
138
+ @staticmethod
139
+ def create_shell():
140
+ from IPython import InteractiveShell
141
+ from traitlets.config import Config
142
+
143
+ c = Config()
144
+ c.HistoryManager.enabled = False
145
+ c.HistoryManager.hist_file = ':memory:'
146
+ return InteractiveShell(
147
+ user_ns={'_raise_timeout': _raise_timeout}, config=c)
148
+
149
+ @staticmethod
150
+ def extract_code(text: str) -> str:
151
+ """Extract Python code from markup languages.
152
+
153
+ Args:
154
+ text (:class:`str`): Markdown-formatted text
155
+
156
+ Returns:
157
+ :class:`str`: Python code
158
+ """
159
+ import json5
160
+
161
+ # Match triple backtick blocks first
162
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
163
+ # Match single backtick blocks second
164
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
165
+ if triple_match:
166
+ text = triple_match.group(1)
167
+ elif single_match:
168
+ text = single_match.group(1)
169
+ else:
170
+ try:
171
+ text = json5.loads(text)['code']
172
+ except Exception:
173
+ pass
174
+ # If no code blocks found, return original text
175
+ return text
176
+
177
+ @staticmethod
178
+ def wrap_code_with_timeout(code: str, timeout: int) -> str:
179
+ if not code.strip():
180
+ return code
181
+ code = code.strip('\n').rstrip()
182
+ indent = len(code) - len(code.lstrip())
183
+ handle = ' ' * indent + f'with _raise_timeout({timeout}):\n'
184
+ block = '\n'.join([' ' + line for line in code.split('\n')])
185
+ wrapped_code = handle + block
186
+ last_line = code.split('\n')[-1]
187
+ is_expression = True
188
+ try:
189
+ compile(last_line.lstrip(), '<stdin>', 'eval')
190
+ except SyntaxError:
191
+ is_expression = False
192
+ if is_expression:
193
+ wrapped_code += '\n' * 5 + last_line
194
+ return wrapped_code
195
+
196
+
197
+ class AsyncIPythonInteractive(AsyncActionMixin, IPythonInteractive):
198
+ """An interactive IPython shell for code execution.
199
+
200
+ Args:
201
+ timeout (int): Upper bound of waiting time for Python script execution.
202
+ Defaults to ``20``.
203
+ max_out_len (int): maximum output length. No truncation occurs if negative.
204
+ Defaults to ``2048``.
205
+ use_signals (bool): whether signals should be used for timing function out
206
+ or the multiprocessing. Set to ``False`` when not running in the main
207
+ thread, e.g. web applications. Defaults to ``True``
208
+ description (dict): The description of the action. Defaults to ``None``.
209
+ parser (Type[BaseParser]): The parser class to process the
210
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
211
+ """
212
+
213
+ @tool_api
214
+ async def run(self,
215
+ command: str,
216
+ timeout: Optional[int] = None) -> ActionReturn:
217
+ """Launch an IPython Interactive Shell to execute code.
218
+
219
+ Args:
220
+ command (:class:`str`): Python code snippet
221
+ timeout (:class:`Optional[int]`): timeout for execution.
222
+ This argument only works in the main thread. Defaults to ``None``.
223
+ """
224
+ tool_return = ActionReturn(args={'text': command}, type=self.name)
225
+ ret = await self.exec(command, timeout)
226
+ if ret.status is Status.SUCCESS:
227
+ tool_return.result = [{'type': 'text', 'content': ret.value}]
228
+ tool_return.state = ActionStatusCode.SUCCESS
229
+ else:
230
+ tool_return.errmsg = ret.msg
231
+ tool_return.state = ActionStatusCode.API_ERROR
232
+ return tool_return
233
+
234
+ async def exec(self, code: str, timeout: int = None) -> ExecutionResult:
235
+ """Asynchronously run Python scripts in IPython shell.
236
+
237
+ Args:
238
+ code (:class:`str`): code block
239
+ timeout (:class:`int`): max waiting time for code execution
240
+
241
+ Returns:
242
+ :py:class:`ExecutionResult`: execution result
243
+ """
244
+ with StringIO() as io:
245
+ with redirect_stdout(io):
246
+ ret = await self._executor.run_cell_async(
247
+ # ret = await self.create_shell().run_cell_async(
248
+ self.wrap_code_with_timeout(
249
+ self.extract_code(code), timeout or self.timeout))
250
+ result = ret.result
251
+ if result is not None:
252
+ return ExecutionResult(Status.SUCCESS,
253
+ str(result)[:self._max_out_len])
254
+ outs = io.getvalue().strip().split('\n')
255
+ if not outs:
256
+ return ExecutionResult(Status.SUCCESS, '')
257
+ for i, out in enumerate(outs):
258
+ if re.search('Error|Traceback', out, re.S):
259
+ if 'TimeoutError' in out:
260
+ return ExecutionResult(
261
+ Status.FAILURE,
262
+ msg=('The code interpreter encountered a '
263
+ 'timeout error.'))
264
+ err_idx = i
265
+ break
266
+ else:
267
+ return ExecutionResult(Status.SUCCESS,
268
+ outs[-1].strip()[:self._max_out_len])
269
+ return ExecutionResult(
270
+ Status.FAILURE,
271
+ msg=self._highlighting.sub(
272
+ '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
273
+ )
lagent/actions/ipython_interpreter.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import asyncio
3
+ import base64
4
+ import io
5
+ import json
6
+ import logging
7
+ import os
8
+ import queue
9
+ import re
10
+ import signal
11
+ import sys
12
+ import tempfile
13
+ import traceback
14
+ import uuid
15
+ from typing import Optional, Tuple, Type
16
+
17
+ from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager
18
+ from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
19
+
20
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
21
+ from lagent.actions.parser import BaseParser, JsonParser
22
+ from lagent.schema import ActionReturn, ActionStatusCode
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ START_CODE = """
27
+ def input(*args, **kwargs):
28
+ raise NotImplementedError('Python input() function is disabled.')
29
+
30
+ get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
31
+ {}
32
+ """ # noqa
33
+
34
+
35
+ class TimeoutError(Exception):
36
+ pass
37
+
38
+
39
+ class KernelDeath(Exception):
40
+ pass
41
+
42
+
43
+ async def async_run_code(
44
+ km: AsyncKernelManager,
45
+ code,
46
+ *,
47
+ interrupt_after=30,
48
+ iopub_timeout=40,
49
+ wait_for_ready_timeout=60,
50
+ shutdown_kernel=True,
51
+ ):
52
+ assert iopub_timeout > interrupt_after
53
+ try:
54
+
55
+ async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
56
+ *,
57
+ timeout=None):
58
+ loop = asyncio.get_running_loop()
59
+ dead_fut = loop.create_future()
60
+
61
+ def restarting():
62
+ assert (
63
+ False
64
+ ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0"
65
+
66
+ def dead():
67
+ logger.info("Kernel has died, will NOT restart")
68
+ dead_fut.set_result(None)
69
+
70
+ msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout))
71
+ km.add_restart_callback(restarting, "restart")
72
+ km.add_restart_callback(dead, "dead")
73
+ try:
74
+ done, _ = await asyncio.wait(
75
+ [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
76
+ if dead_fut in done:
77
+ raise KernelDeath()
78
+ assert msg_task in done
79
+ return await msg_task
80
+ finally:
81
+ msg_task.cancel()
82
+ km.remove_restart_callback(restarting, "restart")
83
+ km.remove_restart_callback(dead, "dead")
84
+
85
+ async def send_interrupt():
86
+ await asyncio.sleep(interrupt_after)
87
+ logger.info("Sending interrupt to kernel")
88
+ await km.interrupt_kernel()
89
+
90
+ @retry(
91
+ retry=retry_if_result(lambda ret: ret[-1].strip() in [
92
+ 'KeyboardInterrupt',
93
+ f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
94
+ ] if isinstance(ret, tuple) else False),
95
+ stop=stop_after_attempt(3),
96
+ wait=wait_fixed(1),
97
+ retry_error_callback=lambda state: state.outcome.result())
98
+ async def run():
99
+ execute_result = None
100
+ error_traceback = None
101
+ stream_text_list = []
102
+ kc = km.client()
103
+ assert isinstance(kc, AsyncKernelClient)
104
+ kc.start_channels()
105
+ try:
106
+ await kc.wait_for_ready(timeout=wait_for_ready_timeout)
107
+ msg_id = kc.execute(code)
108
+ while True:
109
+ message = await get_iopub_msg_with_death_detection(
110
+ kc, timeout=iopub_timeout)
111
+ if logger.isEnabledFor(logging.DEBUG):
112
+ logger.debug(
113
+ json.dumps(message, indent=2, default=str))
114
+ assert message["parent_header"]["msg_id"] == msg_id
115
+ msg_type = message["msg_type"]
116
+ if msg_type == "status":
117
+ if message["content"]["execution_state"] == "idle":
118
+ break
119
+ elif msg_type == "stream":
120
+ stream_name = message["content"]["name"]
121
+ stream_text = message["content"]["text"]
122
+ stream_text_list.append(stream_text)
123
+ elif msg_type == "execute_result":
124
+ execute_result = message["content"]["data"]
125
+ elif msg_type == "error":
126
+ error_traceback_lines = message["content"]["traceback"]
127
+ error_traceback = "\n".join(error_traceback_lines)
128
+ elif msg_type == "execute_input":
129
+ pass
130
+ else:
131
+ assert False, f"Unknown message_type: {msg_type}"
132
+ finally:
133
+ kc.stop_channels()
134
+ return execute_result, error_traceback, "".join(stream_text_list)
135
+
136
+ if interrupt_after:
137
+ run_task = asyncio.create_task(run())
138
+ send_interrupt_task = asyncio.create_task(send_interrupt())
139
+ done, _ = await asyncio.wait([run_task, send_interrupt_task],
140
+ return_when=asyncio.FIRST_COMPLETED)
141
+ if run_task in done:
142
+ send_interrupt_task.cancel()
143
+ else:
144
+ assert send_interrupt_task in done
145
+ result = await run_task
146
+ else:
147
+ result = await run()
148
+ return result
149
+ finally:
150
+ if shutdown_kernel:
151
+ await km.shutdown_kernel()
152
+
153
+
154
+ class IPythonInterpreter(BaseAction):
155
+ """A IPython executor that can execute Python scripts in a jupyter manner.
156
+
157
+ Args:
158
+ timeout (int): Upper bound of waiting time for Python script execution.
159
+ Defaults to 20.
160
+ user_data_dir (str, optional): Specified the user data directory for files
161
+ loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
162
+ Defaults to `ENV`.
163
+ work_dir (str, optional): Specify which directory to save output images to.
164
+ Defaults to ``'./work_dir/tmp_dir'``.
165
+ description (dict): The description of the action. Defaults to ``None``.
166
+ parser (Type[BaseParser]): The parser class to process the
167
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
168
+ """
169
+
170
+ _KERNEL_CLIENTS = {}
171
+
172
+ def __init__(
173
+ self,
174
+ timeout: int = 20,
175
+ user_data_dir: str = 'ENV',
176
+ work_dir='./work_dir/tmp_dir',
177
+ description: Optional[dict] = None,
178
+ parser: Type[BaseParser] = JsonParser,
179
+ ):
180
+ super().__init__(description, parser)
181
+
182
+ self.timeout = timeout
183
+ if user_data_dir == 'ENV':
184
+ user_data_dir = os.environ.get('USER_DATA_DIR', '')
185
+
186
+ if user_data_dir:
187
+ user_data_dir = os.path.dirname(user_data_dir)
188
+ user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
189
+ self.user_data_dir = user_data_dir
190
+ self._initialized = False
191
+ self.work_dir = work_dir
192
+ if not os.path.exists(self.work_dir):
193
+ os.makedirs(self.work_dir, exist_ok=True)
194
+
195
+ @staticmethod
196
+ def start_kernel():
197
+ from jupyter_client import KernelManager
198
+
199
+ # start the kernel and manager
200
+ km = KernelManager()
201
+ km.start_kernel()
202
+ kc = km.client()
203
+ return km, kc
204
+
205
+ def initialize(self):
206
+ if self._initialized:
207
+ return
208
+ pid = os.getpid()
209
+ if pid not in self._KERNEL_CLIENTS:
210
+ self._KERNEL_CLIENTS[pid] = self.start_kernel()
211
+ self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
212
+ self._initialized = True
213
+ self._call(START_CODE.format(self.user_data_dir), None)
214
+
215
+ def reset(self):
216
+ if not self._initialized:
217
+ self.initialize()
218
+ else:
219
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + \
220
+ START_CODE.format(self.user_data_dir)
221
+ self._call(code, None)
222
+
223
+ def _call(self,
224
+ command: str,
225
+ timeout: Optional[int] = None) -> Tuple[str, bool]:
226
+ self.initialize()
227
+ command = extract_code(command)
228
+
229
+ # check previous remaining result
230
+ while True:
231
+ try:
232
+ msg = self.kernel_client.get_iopub_msg(timeout=5)
233
+ msg_type = msg['msg_type']
234
+ if msg_type == 'status':
235
+ if msg['content'].get('execution_state') == 'idle':
236
+ break
237
+ except queue.Empty:
238
+ # assume no result
239
+ break
240
+
241
+ self.kernel_client.execute(command)
242
+
243
+ def _inner_call():
244
+ result = ''
245
+ images = []
246
+ succeed = True
247
+ image_idx = 0
248
+
249
+ while True:
250
+ text = ''
251
+ image = ''
252
+ finished = False
253
+ msg_type = 'error'
254
+ try:
255
+ msg = self.kernel_client.get_iopub_msg(timeout=20)
256
+ msg_type = msg['msg_type']
257
+ if msg_type == 'status':
258
+ if msg['content'].get('execution_state') == 'idle':
259
+ finished = True
260
+ elif msg_type == 'execute_result':
261
+ text = msg['content']['data'].get('text/plain', '')
262
+ if 'image/png' in msg['content']['data']:
263
+ image_b64 = msg['content']['data']['image/png']
264
+ image_url = publish_image_to_local(
265
+ image_b64, self.work_dir)
266
+ image_idx += 1
267
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
268
+
269
+ elif msg_type == 'display_data':
270
+ if 'image/png' in msg['content']['data']:
271
+ image_b64 = msg['content']['data']['image/png']
272
+ image_url = publish_image_to_local(
273
+ image_b64, self.work_dir)
274
+ image_idx += 1
275
+ image = '![fig-%03d](%s)' % (image_idx, image_url)
276
+
277
+ else:
278
+ text = msg['content']['data'].get('text/plain', '')
279
+ elif msg_type == 'stream':
280
+ msg_type = msg['content']['name'] # stdout, stderr
281
+ text = msg['content']['text']
282
+ elif msg_type == 'error':
283
+ succeed = False
284
+ text = escape_ansi('\n'.join(
285
+ msg['content']['traceback']))
286
+ if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
287
+ text = f'Timeout. No response after {timeout} seconds.' # noqa
288
+ except queue.Empty:
289
+ # stop current task in case break next input.
290
+ self.kernel_manager.interrupt_kernel()
291
+ succeed = False
292
+ text = f'Timeout. No response after {timeout} seconds.'
293
+ finished = True
294
+ except Exception:
295
+ succeed = False
296
+ msg = ''.join(traceback.format_exception(*sys.exc_info()))
297
+ # text = 'The code interpreter encountered an unexpected error.' # noqa
298
+ text = msg
299
+ logging.warning(msg)
300
+ finished = True
301
+ if text:
302
+ # result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
303
+ result += f'{text}'
304
+
305
+ if image:
306
+ images.append(image_url)
307
+ if finished:
308
+ return succeed, dict(text=result, image=images)
309
+
310
+ try:
311
+ if timeout:
312
+
313
+ def handler(signum, frame):
314
+ raise TimeoutError()
315
+
316
+ signal.signal(signal.SIGALRM, handler)
317
+ signal.alarm(timeout)
318
+ succeed, result = _inner_call()
319
+ except TimeoutError:
320
+ succeed = False
321
+ text = 'The code interpreter encountered an unexpected error.'
322
+ result = f'\n\nerror:\n\n```\n{text}\n```'
323
+ finally:
324
+ if timeout:
325
+ signal.alarm(0)
326
+
327
+ # result = result.strip('\n')
328
+ return succeed, result
329
+
330
+ @tool_api
331
+ def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
332
+ r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
333
+
334
+ Args:
335
+ command (:class:`str`): Python code
336
+ timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
337
+ """
338
+ tool_return = ActionReturn(url=None, args=None, type=self.name)
339
+ tool_return.args = dict(text=command)
340
+ succeed, result = self._call(command, timeout)
341
+ if succeed:
342
+ text = result['text']
343
+ image = result.get('image', [])
344
+ resp = [dict(type='text', content=text)]
345
+ if image:
346
+ resp.extend([dict(type='image', content=im) for im in image])
347
+ tool_return.result = resp
348
+ # tool_return.result = dict(
349
+ # text=result['text'], image=result.get('image', [])[0])
350
+ tool_return.state = ActionStatusCode.SUCCESS
351
+ else:
352
+ tool_return.errmsg = result.get('text', '') if isinstance(
353
+ result, dict) else result
354
+ tool_return.state = ActionStatusCode.API_ERROR
355
+ return tool_return
356
+
357
+
358
+ class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
359
+ """A IPython executor that can execute Python scripts in a jupyter manner.
360
+
361
+ Args:
362
+ timeout (int): Upper bound of waiting time for Python script execution.
363
+ Defaults to 20.
364
+ user_data_dir (str, optional): Specified the user data directory for files
365
+ loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
366
+ Defaults to `ENV`.
367
+ work_dir (str, optional): Specify which directory to save output images to.
368
+ Defaults to ``'./work_dir/tmp_dir'``.
369
+ description (dict): The description of the action. Defaults to ``None``.
370
+ parser (Type[BaseParser]): The parser class to process the
371
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
372
+ """
373
+
374
+ _UNBOUND_KERNEL_CLIENTS = asyncio.Queue()
375
+
376
+ def __init__(
377
+ self,
378
+ timeout: int = 20,
379
+ user_data_dir: str = 'ENV',
380
+ work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'),
381
+ max_kernels: Optional[int] = None,
382
+ reuse_kernel: bool = True,
383
+ startup_rate: bool = 32,
384
+ connection_dir: str = tempfile.gettempdir(),
385
+ description: Optional[dict] = None,
386
+ parser: Type[BaseParser] = JsonParser,
387
+ ):
388
+ super().__init__(timeout, user_data_dir, work_dir, description, parser)
389
+ from traitlets.config import Config
390
+
391
+ c = Config()
392
+ c.KernelManager.transport = 'ipc'
393
+ self._amkm = AsyncMultiKernelManager(
394
+ config=c, connection_dir=connection_dir)
395
+ self._max_kernels = max_kernels
396
+ self._reuse_kernel = reuse_kernel
397
+ self._sem = asyncio.Semaphore(startup_rate)
398
+ self._lock = asyncio.Lock()
399
+
400
+ async def initialize(self, session_id: str):
401
+ session_id = str(session_id)
402
+ while True:
403
+ if session_id in self._KERNEL_CLIENTS:
404
+ return self._KERNEL_CLIENTS[session_id]
405
+ if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
406
+ self._KERNEL_CLIENTS[
407
+ session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
408
+ return self._KERNEL_CLIENTS[session_id]
409
+ async with self._sem:
410
+ if self._max_kernels is None or len(
411
+ self._KERNEL_CLIENTS
412
+ ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
413
+ kernel_id = None
414
+ try:
415
+ kernel_id = await self._amkm.start_kernel()
416
+ kernel = self._amkm.get_kernel(kernel_id)
417
+ client = kernel.client()
418
+ _, error_stacktrace, stream_text = await async_run_code(
419
+ kernel,
420
+ START_CODE.format(self.user_data_dir),
421
+ shutdown_kernel=False)
422
+ # check if the output of START_CODE meets expectations
423
+ if not (error_stacktrace is None
424
+ and stream_text == ''):
425
+ raise RuntimeError
426
+ except Exception as e:
427
+ print(f'Starting kernel error: {e}')
428
+ if kernel_id:
429
+ await self._amkm.shutdown_kernel(kernel_id)
430
+ self._amkm.remove_kernel(kernel_id)
431
+ await asyncio.sleep(1)
432
+ continue
433
+ if self._max_kernels is None:
434
+ self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
435
+ client)
436
+ return kernel_id, kernel, client
437
+ async with self._lock:
438
+ if len(self._KERNEL_CLIENTS
439
+ ) + self._UNBOUND_KERNEL_CLIENTS.qsize(
440
+ ) < self._max_kernels:
441
+ self._KERNEL_CLIENTS[session_id] = (kernel_id,
442
+ kernel, client)
443
+ return kernel_id, kernel, client
444
+ await self._amkm.shutdown_kernel(kernel_id)
445
+ self._amkm.remove_kernel(kernel_id)
446
+ await asyncio.sleep(1)
447
+
448
+ async def reset(self, session_id: str):
449
+ session_id = str(session_id)
450
+ if session_id not in self._KERNEL_CLIENTS:
451
+ return
452
+ _, kernel, _ = self._KERNEL_CLIENTS[session_id]
453
+ code = "get_ipython().run_line_magic('reset', '-f')\n" + \
454
+ START_CODE.format(self.user_data_dir)
455
+ await async_run_code(kernel, code, shutdown_kernel=False)
456
+
457
+ async def shutdown(self, session_id: str):
458
+ session_id = str(session_id)
459
+ if session_id in self._KERNEL_CLIENTS:
460
+ kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id)
461
+ await self._amkm.shutdown_kernel(kernel_id)
462
+ self._amkm.remove_kernel(kernel_id)
463
+ del self._KERNEL_CLIENTS[session_id]
464
+
465
+ async def close_session(self, session_id: str):
466
+ session_id = str(session_id)
467
+ if self._reuse_kernel:
468
+ if session_id in self._KERNEL_CLIENTS:
469
+ await self.reset(session_id)
470
+ await self._UNBOUND_KERNEL_CLIENTS.put(
471
+ self._KERNEL_CLIENTS.pop(session_id))
472
+ else:
473
+ await self.shutdown(session_id)
474
+
475
+ async def _call(self, command, timeout=None, session_id=None):
476
+ _, kernel, _ = await self.initialize(str(session_id))
477
+ result = await async_run_code(
478
+ kernel,
479
+ extract_code(command),
480
+ interrupt_after=timeout or self.timeout,
481
+ shutdown_kernel=False)
482
+ execute_result, error_stacktrace, stream_text = result
483
+ if error_stacktrace is not None:
484
+ ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
485
+ if ret.endswith('KeyboardInterrupt: '):
486
+ ret = 'The code interpreter encountered a timeout error.'
487
+ status, ret = False, ret.strip()
488
+ elif execute_result is not None:
489
+ status, ret = True, dict(text=execute_result.get('text/plain', ''))
490
+ else:
491
+ status, ret = True, dict(text=stream_text.strip())
492
+ return status, ret
493
+
494
+ @tool_api
495
+ async def run(self,
496
+ command: str,
497
+ timeout: Optional[int] = None,
498
+ session_id: Optional[str] = None) -> ActionReturn:
499
+ r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
500
+
501
+ Args:
502
+ command (:class:`str`): Python code
503
+ timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
504
+ """
505
+ tool_return = ActionReturn(url=None, args=None, type=self.name)
506
+ tool_return.args = dict(text=command)
507
+ succeed, result = await self._call(command, timeout, session_id)
508
+ if succeed:
509
+ text = result['text']
510
+ image = result.get('image', [])
511
+ resp = [dict(type='text', content=text)]
512
+ if image:
513
+ resp.extend([dict(type='image', content=im) for im in image])
514
+ tool_return.result = resp
515
+ # tool_return.result = dict(
516
+ # text=result['text'], image=result.get('image', [])[0])
517
+ tool_return.state = ActionStatusCode.SUCCESS
518
+ else:
519
+ tool_return.errmsg = result.get('text', '') if isinstance(
520
+ result, dict) else result
521
+ tool_return.state = ActionStatusCode.API_ERROR
522
+ return tool_return
523
+
524
+
525
+ def extract_code(text):
526
+ import json5
527
+
528
+ # Match triple backtick blocks first
529
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
530
+ # Match single backtick blocks second
531
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
532
+ if triple_match:
533
+ text = triple_match.group(1)
534
+ elif single_match:
535
+ text = single_match.group(1)
536
+ else:
537
+ try:
538
+ text = json5.loads(text)['code']
539
+ except Exception:
540
+ pass
541
+ # If no code blocks found, return original text
542
+ return text
543
+
544
+
545
+ def escape_ansi(line):
546
+ ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
547
+ return ansi_escape.sub('', line)
548
+
549
+
550
+ def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
551
+ import PIL.Image
552
+ image_file = str(uuid.uuid4()) + '.png'
553
+ local_image_file = os.path.join(work_dir, image_file)
554
+
555
+ png_bytes = base64.b64decode(image_base64)
556
+ assert isinstance(png_bytes, bytes)
557
+ bytes_io = io.BytesIO(png_bytes)
558
+ PIL.Image.open(bytes_io).save(local_image_file, 'png')
559
+
560
+ return local_image_file
561
+
562
+
563
+ # local test for code interpreter
564
+ def get_multiline_input(hint):
565
+ print(hint)
566
+ print('// Press ENTER to make a new line. Press CTRL-D to end input.')
567
+ lines = []
568
+ while True:
569
+ try:
570
+ line = input()
571
+ except EOFError: # CTRL-D
572
+ break
573
+ lines.append(line)
574
+ print('// Input received.')
575
+ if lines:
576
+ return '\n'.join(lines)
577
+ else:
578
+ return ''
579
+
580
+
581
+ if __name__ == '__main__':
582
+ code_interpreter = IPythonInterpreter()
583
+ while True:
584
+ print(code_interpreter(get_multiline_input('Enter python code:')))
lagent/actions/ipython_manager.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sys
3
+ from collections import defaultdict
4
+ from contextlib import nullcontext
5
+ from io import StringIO
6
+ from multiprocessing import Process, Queue
7
+ from typing import List, Optional, Type, Union
8
+
9
+ from filelock import FileLock
10
+ from timeout_decorator import timeout as tm
11
+
12
+ from ..schema import ActionReturn, ActionStatusCode
13
+ from .base_action import BaseAction
14
+ from .parser import BaseParser, JsonParser
15
+
16
+
17
+ class IPythonProcess(Process):
18
+
19
+ def __init__(self,
20
+ in_q: Queue,
21
+ out_q: Queue,
22
+ timeout: int = 20,
23
+ ci_lock: str = None,
24
+ daemon: bool = True):
25
+ super().__init__(daemon=daemon)
26
+ self.in_q = in_q
27
+ self.out_q = out_q
28
+ self.timeout = timeout
29
+ self.session_id2shell = defaultdict(self.create_shell)
30
+ self.ci_lock = FileLock(
31
+ ci_lock) if ci_lock else nullcontext() # avoid core corruption
32
+ self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
33
+
34
+ def run(self):
35
+ while True:
36
+ msg = self.in_q.get()
37
+ if msg == 'reset':
38
+ for session_id, shell in self.session_id2shell.items():
39
+ with self.ci_lock:
40
+ try:
41
+ shell.reset(new_session=False)
42
+ # shell.run_line_magic('reset', '-sf')
43
+ except Exception:
44
+ self.session_id2shell[
45
+ session_id] = self.create_shell()
46
+ self.out_q.put('ok')
47
+ elif isinstance(msg, tuple) and len(msg) == 3:
48
+ i, session_id, code = msg
49
+ res = self.exec(session_id, code)
50
+ self.out_q.put((i, session_id, res))
51
+
52
+ def exec(self, session_id, code):
53
+ try:
54
+ shell = self.session_id2shell[session_id]
55
+ with StringIO() as io:
56
+ old_stdout = sys.stdout
57
+ sys.stdout = io
58
+ if self.timeout is False or self.timeout < 0:
59
+ shell.run_cell(self.extract_code(code))
60
+ else:
61
+ tm(self.timeout)(shell.run_cell)(self.extract_code(code))
62
+ sys.stdout = old_stdout
63
+ output = self._highlighting.sub('', io.getvalue().strip())
64
+ output = re.sub(r'^Out\[\d+\]: ', '', output)
65
+ if 'Error' in output or 'Traceback' in output:
66
+ output = output.lstrip('-').strip()
67
+ if output.startswith('TimeoutError'):
68
+ output = 'The code interpreter encountered a timeout error.'
69
+ return {'status': 'FAILURE', 'msg': output, 'code': code}
70
+ return {'status': 'SUCCESS', 'value': output, 'code': code}
71
+ except Exception as e:
72
+ return {'status': 'FAILURE', 'msg': str(e), 'code': code}
73
+
74
+ @staticmethod
75
+ def create_shell(enable_history: bool = False, in_memory: bool = True):
76
+ from IPython import InteractiveShell
77
+ from traitlets.config import Config
78
+
79
+ c = Config()
80
+ c.HistoryManager.enabled = enable_history
81
+ if in_memory:
82
+ c.HistoryManager.hist_file = ':memory:'
83
+ shell = InteractiveShell(config=c)
84
+ return shell
85
+
86
+ @staticmethod
87
+ def extract_code(text: str) -> str:
88
+ """Extract Python code from markup languages.
89
+
90
+ Args:
91
+ text (:class:`str`): Markdown-formatted text
92
+
93
+ Returns:
94
+ :class:`str`: Python code
95
+ """
96
+ import json5
97
+
98
+ # Match triple backtick blocks first
99
+ triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
100
+ # Match single backtick blocks second
101
+ single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
102
+ if triple_match:
103
+ text = triple_match.group(1)
104
+ elif single_match:
105
+ text = single_match.group(1)
106
+ else:
107
+ try:
108
+ text = json5.loads(text)['code']
109
+ except Exception:
110
+ pass
111
+ # If no code blocks found, return original text
112
+ return text
113
+
114
+
115
+ class IPythonInteractiveManager(BaseAction):
116
+ """An interactive IPython shell manager for code execution"""
117
+
118
+ def __init__(
119
+ self,
120
+ max_workers: int = 50,
121
+ timeout: int = 20,
122
+ ci_lock: str = None,
123
+ description: Optional[dict] = None,
124
+ parser: Type[BaseParser] = JsonParser,
125
+ ):
126
+ super().__init__(description, parser)
127
+ self.max_workers = max_workers
128
+ self.timeout = timeout
129
+ self.ci_lock = ci_lock
130
+ self.id2queue = defaultdict(Queue)
131
+ self.id2process = {}
132
+ self.out_queue = Queue()
133
+
134
+ def __call__(self,
135
+ commands: Union[str, List[str]],
136
+ session_ids: Union[int, List[int]] = None):
137
+ if isinstance(commands, list):
138
+ batch_size = len(commands)
139
+ is_batch = True
140
+ else:
141
+ batch_size = 1
142
+ commands = [commands]
143
+ is_batch = False
144
+ if session_ids is None:
145
+ session_ids = range(batch_size)
146
+ elif isinstance(session_ids, int):
147
+ session_ids = [session_ids]
148
+ if len(session_ids) != batch_size or len(session_ids) != len(
149
+ set(session_ids)):
150
+ raise ValueError(
151
+ 'the size of `session_ids` must equal that of `commands`')
152
+ try:
153
+ exec_results = self.run_code_blocks([
154
+ (session_id, command)
155
+ for session_id, command in zip(session_ids, commands)
156
+ ])
157
+ except KeyboardInterrupt:
158
+ self.clear()
159
+ exit(1)
160
+ action_returns = []
161
+ for result, code in zip(exec_results, commands):
162
+ action_return = ActionReturn({'command': code}, type=self.name)
163
+ if result['status'] == 'SUCCESS':
164
+ action_return.result = [
165
+ dict(type='text', content=result['value'])
166
+ ]
167
+ action_return.state = ActionStatusCode.SUCCESS
168
+ else:
169
+ action_return.errmsg = result['msg']
170
+ action_return.state = ActionStatusCode.API_ERROR
171
+ action_returns.append(action_return)
172
+ if not is_batch:
173
+ return action_returns[0]
174
+ return action_returns
175
+
176
+ def process_code(self, index, session_id, code):
177
+ ipy_id = session_id % self.max_workers
178
+ input_queue = self.id2queue[ipy_id]
179
+ proc = self.id2process.setdefault(
180
+ ipy_id,
181
+ IPythonProcess(
182
+ input_queue,
183
+ self.out_queue,
184
+ self.timeout,
185
+ self.ci_lock,
186
+ daemon=True))
187
+ if not proc.is_alive():
188
+ proc.start()
189
+ input_queue.put((index, session_id, code))
190
+
191
+ def run_code_blocks(self, session_code_pairs):
192
+ size = len(session_code_pairs)
193
+ for index, (session_id, code) in enumerate(session_code_pairs):
194
+ self.process_code(index, session_id, code)
195
+ results = []
196
+ while len(results) < size:
197
+ msg = self.out_queue.get()
198
+ if isinstance(msg, tuple) and len(msg) == 3:
199
+ index, _, result = msg
200
+ results.append((index, result))
201
+ results.sort()
202
+ return [item[1] for item in results]
203
+
204
+ def clear(self):
205
+ self.id2queue.clear()
206
+ for proc in self.id2process.values():
207
+ proc.terminate()
208
+ self.id2process.clear()
209
+ while not self.out_queue.empty():
210
+ self.out_queue.get()
211
+
212
+ def reset(self):
213
+ cnt = 0
214
+ for q in self.id2queue.values():
215
+ q.put('reset')
216
+ cnt += 1
217
+ while cnt > 0:
218
+ msg = self.out_queue.get()
219
+ if msg == 'ok':
220
+ cnt -= 1
lagent/actions/parser.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from ast import literal_eval
4
+ from typing import Any, List, Union
5
+
6
+
7
+ class ParseError(Exception):
8
+ """Parsing exception class."""
9
+
10
+ def __init__(self, err_msg: str):
11
+ self.err_msg = err_msg
12
+
13
+
14
+ class BaseParser:
15
+ """Base parser to process inputs and outputs of actions.
16
+
17
+ Args:
18
+ action (:class:`BaseAction`): action to validate
19
+
20
+ Attributes:
21
+ PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
22
+ LLMs should follow when generating arguments for decided tools.
23
+ """
24
+
25
+ PARAMETER_DESCRIPTION: str = ''
26
+
27
+ def __init__(self, action):
28
+ self.action = action
29
+ self._api2param = {}
30
+ self._api2required = {}
31
+ # perform basic argument validation
32
+ if action.description:
33
+ for api in action.description.get('api_list',
34
+ [action.description]):
35
+ name = (f'{action.name}.{api["name"]}'
36
+ if self.action.is_toolkit else api['name'])
37
+ required_parameters = set(api['required'])
38
+ all_parameters = {j['name'] for j in api['parameters']}
39
+ if not required_parameters.issubset(all_parameters):
40
+ raise ValueError(
41
+ f'unknown parameters for function "{name}": '
42
+ f'{required_parameters - all_parameters}')
43
+ if self.PARAMETER_DESCRIPTION:
44
+ api['parameter_description'] = self.PARAMETER_DESCRIPTION
45
+ api_name = api['name'] if self.action.is_toolkit else 'run'
46
+ self._api2param[api_name] = api['parameters']
47
+ self._api2required[api_name] = api['required']
48
+
49
+ def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
50
+ """Parse inputs LLMs generate for the action.
51
+
52
+ Args:
53
+ inputs (:class:`str`): input string extracted from responses
54
+
55
+ Returns:
56
+ :class:`dict`: processed input
57
+ """
58
+ inputs = {self._api2param[name][0]['name']: inputs}
59
+ return inputs
60
+
61
+ def parse_outputs(self, outputs: Any) -> List[dict]:
62
+ """Parser outputs returned by the action.
63
+
64
+ Args:
65
+ outputs (:class:`Any`): raw output of the action
66
+
67
+ Returns:
68
+ :class:`List[dict]`: processed output of which each member is a
69
+ dictionary with two keys - 'type' and 'content'.
70
+ """
71
+ if isinstance(outputs, dict):
72
+ outputs = json.dumps(outputs, ensure_ascii=False)
73
+ elif not isinstance(outputs, str):
74
+ outputs = str(outputs)
75
+ return [{
76
+ 'type': 'text',
77
+ 'content': outputs.encode('gbk', 'ignore').decode('gbk')
78
+ }]
79
+
80
+
81
+ class JsonParser(BaseParser):
82
+ """Json parser to convert input string into a dictionary.
83
+
84
+ Args:
85
+ action (:class:`BaseAction`): action to validate
86
+ """
87
+
88
+ PARAMETER_DESCRIPTION = (
89
+ 'If you call this tool, you must pass arguments in '
90
+ 'the JSON format {key: value}, where the key is the parameter name.')
91
+
92
+ def parse_inputs(self,
93
+ inputs: Union[str, dict],
94
+ name: str = 'run') -> dict:
95
+ if not isinstance(inputs, dict):
96
+ try:
97
+ match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
98
+ re.S)
99
+ if match:
100
+ inputs = match.group(2).strip()
101
+ inputs = json.loads(inputs)
102
+ except json.JSONDecodeError as exc:
103
+ raise ParseError(f'invalid json format: {inputs}') from exc
104
+ input_keys = set(inputs)
105
+ all_keys = {param['name'] for param in self._api2param[name]}
106
+ if not input_keys.issubset(all_keys):
107
+ raise ParseError(f'unknown arguments: {input_keys - all_keys}')
108
+ required_keys = set(self._api2required[name])
109
+ if not input_keys.issuperset(required_keys):
110
+ raise ParseError(
111
+ f'missing required arguments: {required_keys - input_keys}')
112
+ return inputs
113
+
114
+
115
+ class TupleParser(BaseParser):
116
+ """Tuple parser to convert input string into a tuple.
117
+
118
+ Args:
119
+ action (:class:`BaseAction`): action to validate
120
+ """
121
+
122
+ PARAMETER_DESCRIPTION = (
123
+ 'If you call this tool, you must pass arguments in the tuple format '
124
+ 'like (arg1, arg2, arg3), and the arguments are ordered.')
125
+
126
+ def parse_inputs(self,
127
+ inputs: Union[str, tuple],
128
+ name: str = 'run') -> dict:
129
+ if not isinstance(inputs, tuple):
130
+ try:
131
+ inputs = literal_eval(inputs)
132
+ except Exception as exc:
133
+ raise ParseError(f'invalid tuple format: {inputs}') from exc
134
+ if len(inputs) < len(self._api2required[name]):
135
+ raise ParseError(
136
+ f'API takes {len(self._api2required[name])} required positional '
137
+ f'arguments but {len(inputs)} were given')
138
+ if len(inputs) > len(self._api2param[name]):
139
+ raise ParseError(
140
+ f'API takes {len(self._api2param[name])} positional arguments '
141
+ f'but {len(inputs)} were given')
142
+ inputs = {
143
+ self._api2param[name][i]['name']: item
144
+ for i, item in enumerate(inputs)
145
+ }
146
+ return inputs
lagent/actions/ppt.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Type
2
+
3
+ from asyncer import asyncify
4
+
5
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
+ from lagent.actions.parser import BaseParser, JsonParser
7
+
8
+ THEME_MAPPING = {
9
+ 'Default': {
10
+ 'template': None,
11
+ 'title': 'Title Slide',
12
+ 'single': 'Title and Content',
13
+ 'two': 'Two Content',
14
+ }
15
+ }
16
+
17
+
18
+ class PPT(BaseAction):
19
+ """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
20
+
21
+ def __init__(
22
+ self,
23
+ theme_mapping: Optional[Dict[str, dict]] = None,
24
+ description: Optional[dict] = None,
25
+ parser: Type[BaseParser] = JsonParser,
26
+ ):
27
+ super().__init__(description, parser)
28
+ self.theme_mapping = theme_mapping or THEME_MAPPING
29
+ self.pointer = None
30
+ self.location = None
31
+
32
+ @tool_api(explode_return=True)
33
+ def create_file(self, theme: str, abs_location: str) -> dict:
34
+ """Create a pptx file with specific themes.
35
+
36
+ Args:
37
+ theme (:class:`str`): the theme used. The value should be one of ['Default'].
38
+ abs_location (:class:`str`): the ppt file's absolute location
39
+
40
+ Returns:
41
+ :class:`dict`: operation status
42
+ * status: the result of the execution
43
+ """
44
+ from pptx import Presentation
45
+
46
+ self.location = abs_location
47
+ try:
48
+ self.pointer = Presentation(self.theme_mapping[theme]['template'])
49
+ self.pointer.slide_master.name = theme
50
+ # print('created')
51
+ except Exception as e:
52
+ print(e)
53
+ return dict(status='created a ppt file.')
54
+
55
+ @tool_api(explode_return=True)
56
+ def add_first_page(self, title: str, subtitle: str) -> dict:
57
+ """Add the first page of ppt.
58
+
59
+ Args:
60
+ title (:class:`str`): the title of ppt
61
+ subtitle (:class:`str`): the subtitle of ppt
62
+
63
+ Returns:
64
+ :class:`dict`: operation status
65
+ * status: the result of the execution
66
+ """
67
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['title']
68
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
69
+ slide = self.pointer.slides.add_slide(layout)
70
+ ph_title, ph_subtitle = slide.placeholders
71
+ ph_title.text = title
72
+ if subtitle:
73
+ ph_subtitle.text = subtitle
74
+ return dict(status='added page')
75
+
76
+ @tool_api(explode_return=True)
77
+ def add_text_page(self, title: str, bullet_items: str) -> dict:
78
+ """Add text page of ppt.
79
+
80
+ Args:
81
+ title (:class:`str`): the title of the page
82
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
83
+
84
+ Returns:
85
+ :class:`dict`: operation status
86
+ * status: the result of the execution
87
+ """ # noqa: E501
88
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['single']
89
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
90
+ slide = self.pointer.slides.add_slide(layout)
91
+ ph_title, ph_body = slide.placeholders
92
+ ph_title.text = title
93
+ ph = ph_body
94
+ tf = ph.text_frame
95
+ for i, item in enumerate(bullet_items.split('[SPAN]')):
96
+ if i == 0:
97
+ p = tf.paragraphs[0]
98
+ else:
99
+ p = tf.add_paragraph()
100
+ p.text = item.strip()
101
+ p.level = 0
102
+ return dict(status='added page')
103
+
104
+ @tool_api(explode_return=True)
105
+ def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
106
+ """Add a text page with one image. Image should be a path.
107
+
108
+ Args:
109
+ title (:class:`str`): the title of the page
110
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
111
+ image (:class:`str`): the path of the image
112
+
113
+ Returns:
114
+ :class:`dict`: operation status
115
+ * status: the result of the execution
116
+ """ # noqa: E501
117
+ from PIL import Image
118
+
119
+ layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
120
+ layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
121
+ slide = self.pointer.slides.add_slide(layout)
122
+ ph_title, ph_body1, ph_body2 = slide.placeholders
123
+ ph_title.text = title
124
+ ph = ph_body2
125
+ image = Image.open(image)
126
+ image_pil = image.to_pil()
127
+ left = ph.left
128
+ width = ph.width
129
+ height = int(width / image_pil.width * image_pil.height)
130
+ top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
131
+ slide.shapes.add_picture(image.to_path(), left, top, width, height)
132
+
133
+ ph = ph_body1
134
+ tf = ph.text_frame
135
+ for i, item in enumerate(bullet_items.split('[SPAN]')):
136
+ if i == 0:
137
+ p = tf.paragraphs[0]
138
+ else:
139
+ p = tf.add_paragraph()
140
+ p.text = item.strip()
141
+ p.level = 0
142
+
143
+ return dict(status='added page')
144
+
145
+ @tool_api(explode_return=True)
146
+ def submit_file(self) -> dict:
147
+ """When all steps done, YOU MUST use submit_file() to submit your work.
148
+
149
+ Returns:
150
+ :class:`dict`: operation status
151
+ * status: the result of the execution
152
+ """
153
+ # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
154
+ # self.pointer.save(file_path)
155
+ # retreival_url = upload_file(file_path)
156
+ self.pointer.save(self.location)
157
+ return dict(status=f'submitted. view ppt at {self.location}')
158
+
159
+
160
+ class AsyncPPT(AsyncActionMixin, PPT):
161
+ """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
162
+
163
+ @tool_api(explode_return=True)
164
+ @asyncify
165
+ def create_file(self, theme: str, abs_location: str) -> dict:
166
+ """Create a pptx file with specific themes.
167
+
168
+ Args:
169
+ theme (:class:`str`): the theme used. The value should be one of ['Default'].
170
+ abs_location (:class:`str`): the ppt file's absolute location
171
+
172
+ Returns:
173
+ :class:`dict`: operation status
174
+ * status: the result of the execution
175
+ """
176
+ return super().create_file(theme, abs_location)
177
+
178
+ @tool_api(explode_return=True)
179
+ @asyncify
180
+ def add_first_page(self, title: str, subtitle: str) -> dict:
181
+ """Add the first page of ppt.
182
+
183
+ Args:
184
+ title (:class:`str`): the title of ppt
185
+ subtitle (:class:`str`): the subtitle of ppt
186
+
187
+ Returns:
188
+ :class:`dict`: operation status
189
+ * status: the result of the execution
190
+ """
191
+ return super().add_first_page(title, subtitle)
192
+
193
+ @tool_api(explode_return=True)
194
+ @asyncify
195
+ def add_text_page(self, title: str, bullet_items: str) -> dict:
196
+ """Add text page of ppt.
197
+
198
+ Args:
199
+ title (:class:`str`): the title of the page
200
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
201
+
202
+ Returns:
203
+ :class:`dict`: operation status
204
+ * status: the result of the execution
205
+ """ # noqa: E501
206
+ return super().add_text_page(title, bullet_items)
207
+
208
+ @tool_api(explode_return=True)
209
+ @asyncify
210
+ def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
211
+ """Add a text page with one image. Image should be a path.
212
+
213
+ Args:
214
+ title (:class:`str`): the title of the page
215
+ bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
216
+ image (:class:`str`): the path of the image
217
+
218
+ Returns:
219
+ :class:`dict`: operation status
220
+ * status: the result of the execution
221
+ """ # noqa: E501
222
+ return super().add_text_image_page(title, bullet_items, image)
223
+
224
+ @tool_api(explode_return=True)
225
+ @asyncify
226
+ def submit_file(self) -> dict:
227
+ """When all steps done, YOU MUST use submit_file() to submit your work.
228
+
229
+ Returns:
230
+ :class:`dict`: operation status
231
+ * status: the result of the execution
232
+ """
233
+ return super().submit_file()
lagent/actions/python_interpreter.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+ import copy
3
+ import io
4
+ from contextlib import redirect_stdout
5
+ from typing import Any, Optional, Type
6
+
7
+ from asyncer import asyncify
8
+
9
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
+ from lagent.actions.parser import BaseParser, JsonParser
11
+ from lagent.schema import ActionReturn, ActionStatusCode
12
+
13
+
14
+ class GenericRuntime:
15
+ GLOBAL_DICT = {}
16
+ LOCAL_DICT = None
17
+ HEADERS = []
18
+
19
+ def __init__(self):
20
+ self._global_vars = copy.copy(self.GLOBAL_DICT)
21
+ self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
22
+
23
+ for c in self.HEADERS:
24
+ self.exec_code(c)
25
+
26
+ def exec_code(self, code_piece: str) -> None:
27
+ exec(code_piece, self._global_vars)
28
+
29
+ def eval_code(self, expr: str) -> Any:
30
+ return eval(expr, self._global_vars)
31
+
32
+
33
+ class PythonInterpreter(BaseAction):
34
+ """A Python executor that can execute Python scripts.
35
+
36
+ Args:
37
+ answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
38
+ answer_expr (str, Optional): the answer function name of the Python
39
+ script. Defaults to ``'solution()'``.
40
+ answer_from_stdout (boolean, Optional): whether the execution results is from
41
+ stdout. Defaults to ``False``.
42
+ timeout (int, Optional): Upper bound of waiting time for Python script execution.
43
+ Defaults to ``20``.
44
+ description (dict, Optional): The description of the action. Defaults to ``None``.
45
+ parser (Type[BaseParser]): The parser class to process the
46
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ answer_symbol: Optional[str] = None,
52
+ answer_expr: Optional[str] = 'solution()',
53
+ answer_from_stdout: bool = False,
54
+ timeout: int = 20,
55
+ description: Optional[dict] = None,
56
+ parser: Type[BaseParser] = JsonParser,
57
+ ) -> None:
58
+ super().__init__(description, parser)
59
+ self.answer_symbol = answer_symbol
60
+ self.answer_expr = answer_expr
61
+ self.answer_from_stdout = answer_from_stdout
62
+ self.timeout = timeout
63
+
64
+ @tool_api
65
+ def run(self, command: str) -> ActionReturn:
66
+ """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
67
+
68
+ ```python
69
+ # import 依赖包
70
+ import xxx
71
+ def solution():
72
+ # 初始化一些变量
73
+ variable_names_with_real_meaning = xxx
74
+ # 步骤一
75
+ mid_variable = func(variable_names_with_real_meaning)
76
+ # 步骤 x
77
+ mid_variable = func(mid_variable)
78
+ # 最后结果
79
+ final_answer = func(mid_variable)
80
+ return final_answer
81
+ ```
82
+
83
+ Args:
84
+ command (:class:`str`): Python code snippet
85
+ """
86
+ from func_timeout import FunctionTimedOut, func_set_timeout
87
+
88
+ self.runtime = GenericRuntime()
89
+ try:
90
+ tool_return = func_set_timeout(self.timeout)(self._call)(command)
91
+ except FunctionTimedOut as e:
92
+ tool_return = ActionReturn(type=self.name)
93
+ tool_return.errmsg = repr(e)
94
+ tool_return.state = ActionStatusCode.API_ERROR
95
+ return tool_return
96
+
97
+ def _call(self, command: str) -> ActionReturn:
98
+ tool_return = ActionReturn(type=self.name)
99
+ try:
100
+ if '```python' in command:
101
+ command = command.split('```python')[1].split('```')[0]
102
+ elif '```' in command:
103
+ command = command.split('```')[1].split('```')[0]
104
+ tool_return.args = dict(text='```python\n' + command + '\n```')
105
+ command = command.split('\n')
106
+
107
+ if self.answer_from_stdout:
108
+ program_io = io.StringIO()
109
+ with redirect_stdout(program_io):
110
+ self.runtime.exec_code('\n'.join(command))
111
+ program_io.seek(0)
112
+ res = program_io.readlines()[-1]
113
+ elif self.answer_symbol:
114
+ self.runtime.exec_code('\n'.join(command))
115
+ res = self.runtime._global_vars[self.answer_symbol]
116
+ elif self.answer_expr:
117
+ self.runtime.exec_code('\n'.join(command))
118
+ res = self.runtime.eval_code(self.answer_expr)
119
+ else:
120
+ self.runtime.exec_code('\n'.join(command[:-1]))
121
+ res = self.runtime.eval_code(command[-1])
122
+ except Exception as e:
123
+ tool_return.errmsg = repr(e)
124
+ tool_return.type = self.name
125
+ tool_return.state = ActionStatusCode.API_ERROR
126
+ return tool_return
127
+ try:
128
+ tool_return.result = [dict(type='text', content=str(res))]
129
+ tool_return.state = ActionStatusCode.SUCCESS
130
+ except Exception as e:
131
+ tool_return.errmsg = repr(e)
132
+ tool_return.type = self.name
133
+ tool_return.state = ActionStatusCode.API_ERROR
134
+ return tool_return
135
+
136
+
137
+ class AsyncPythonInterpreter(AsyncActionMixin, PythonInterpreter):
138
+ """A Python executor that can execute Python scripts.
139
+
140
+ Args:
141
+ answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
142
+ answer_expr (str, Optional): the answer function name of the Python
143
+ script. Defaults to ``'solution()'``.
144
+ answer_from_stdout (boolean, Optional): whether the execution results is from
145
+ stdout. Defaults to ``False``.
146
+ timeout (int, Optional): Upper bound of waiting time for Python script execution.
147
+ Defaults to ``20``.
148
+ description (dict, Optional): The description of the action. Defaults to ``None``.
149
+ parser (Type[BaseParser]): The parser class to process the
150
+ action's inputs and outputs. Defaults to :class:`JsonParser`.
151
+ """
152
+
153
+ @tool_api
154
+ @asyncify
155
+ def run(self, command: str) -> ActionReturn:
156
+ """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
157
+
158
+ ```python
159
+ # import 依赖包
160
+ import xxx
161
+ def solution():
162
+ # 初始化一些变量
163
+ variable_names_with_real_meaning = xxx
164
+ # 步骤一
165
+ mid_variable = func(variable_names_with_real_meaning)
166
+ # 步骤 x
167
+ mid_variable = func(mid_variable)
168
+ # 最后结果
169
+ final_answer = func(mid_variable)
170
+ return final_answer
171
+ ```
172
+
173
+ Args:
174
+ command (:class:`str`): Python code snippet
175
+ """
176
+ return super().run(command)
lagent/actions/weather_query.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from lagent.actions.base_action import BaseAction, tool_api
4
+ from lagent.schema import ActionReturn, ActionStatusCode
5
+
6
+ class WeatherQuery(BaseAction):
7
+ def __init__(self):
8
+ super().__init__()
9
+ self.api_key = os.getenv("weather_token")
10
+ print(self.api_key)
11
+ if not self.api_key:
12
+ raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
13
+
14
+ @tool_api
15
+ def run(self, location: str) -> dict:
16
+ """
17
+ 查询实时天气信息。
18
+
19
+ Args:
20
+ location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
21
+
22
+ Returns:
23
+ dict: 包含天气信息的字典
24
+ * location: 地点名称
25
+ * weather: 天气状况
26
+ * temperature: 当前温度
27
+ * wind_direction: 风向
28
+ * wind_speed: 风速(公里/小时)
29
+ * humidity: 相对湿度(%)
30
+ * report_time: 数据报告时间
31
+ """
32
+ try:
33
+ # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
34
+ if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
35
+ # 使用 GeoAPI 获取 LocationID
36
+ geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
37
+ geo_response = requests.get(geo_url)
38
+ geo_data = geo_response.json()
39
+
40
+ if geo_data.get("code") != "200" or not geo_data.get("location"):
41
+ raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
42
+
43
+ location = geo_data["location"][0]["id"]
44
+
45
+ # 构建天气查询的 API 请求 URL
46
+ weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
47
+ response = requests.get(weather_url)
48
+ data = response.json()
49
+
50
+ # 检查 API 响应码
51
+ if data.get("code") != "200":
52
+ raise Exception(f"Weather API 返回错误码:{data.get('code')}")
53
+
54
+ # 解析和组织天气信息
55
+ weather_info = {
56
+ "location": location,
57
+ "weather": data["now"]["text"],
58
+ "temperature": data["now"]["temp"] + "°C",
59
+ "wind_direction": data["now"]["windDir"],
60
+ "wind_speed": data["now"]["windSpeed"] + " km/h",
61
+ "humidity": data["now"]["humidity"] + "%",
62
+ "report_time": data["updateTime"]
63
+ }
64
+
65
+ return {"result": weather_info}
66
+
67
+ except Exception as exc:
68
+ return ActionReturn(
69
+ errmsg=f"WeatherQuery 异常:{exc}",
70
+ state=ActionStatusCode.HTTP_ERROR
71
+ )
lagent/actions/web_browser.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import hashlib
3
+ import hmac
4
+ import json
5
+ import logging
6
+ import random
7
+ import re
8
+ import time
9
+ import warnings
10
+ from concurrent.futures import ThreadPoolExecutor, as_completed
11
+ from datetime import datetime
12
+ from http.client import HTTPSConnection
13
+ from typing import List, Optional, Tuple, Type, Union
14
+
15
+ import aiohttp
16
+ import aiohttp.client_exceptions
17
+ import requests
18
+ from asyncache import cached as acached
19
+ from bs4 import BeautifulSoup
20
+ from cachetools import TTLCache, cached
21
+ from duckduckgo_search import DDGS, AsyncDDGS
22
+
23
+ from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
24
+ from lagent.actions.parser import BaseParser, JsonParser
25
+ from lagent.utils import async_as_completed
26
+
27
+
28
+ class BaseSearch:
29
+
30
+ def __init__(self, topk: int = 3, black_list: List[str] = None):
31
+ self.topk = topk
32
+ self.black_list = black_list
33
+
34
+ def _filter_results(self, results: List[tuple]) -> dict:
35
+ filtered_results = {}
36
+ count = 0
37
+ for url, snippet, title in results:
38
+ if all(domain not in url
39
+ for domain in self.black_list) and not url.endswith('.pdf'):
40
+ filtered_results[count] = {
41
+ 'url': url,
42
+ 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
43
+ 'title': title
44
+ }
45
+ count += 1
46
+ if count >= self.topk:
47
+ break
48
+ return filtered_results
49
+
50
+
51
+ class DuckDuckGoSearch(BaseSearch):
52
+
53
+ def __init__(self,
54
+ topk: int = 3,
55
+ black_list: List[str] = [
56
+ 'enoN',
57
+ 'youtube.com',
58
+ 'bilibili.com',
59
+ 'researchgate.net',
60
+ ],
61
+ **kwargs):
62
+ self.proxy = kwargs.get('proxy')
63
+ self.timeout = kwargs.get('timeout', 30)
64
+ super().__init__(topk, black_list)
65
+
66
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
67
+ def search(self, query: str, max_retry: int = 3) -> dict:
68
+ for attempt in range(max_retry):
69
+ try:
70
+ response = self._call_ddgs(
71
+ query, timeout=self.timeout, proxy=self.proxy)
72
+ return self._parse_response(response)
73
+ except Exception as e:
74
+ logging.exception(str(e))
75
+ warnings.warn(
76
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
77
+ time.sleep(random.randint(2, 5))
78
+ raise Exception(
79
+ 'Failed to get search results from DuckDuckGo after retries.')
80
+
81
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
82
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
83
+ for attempt in range(max_retry):
84
+ try:
85
+ ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
86
+ response = await ddgs.atext(query.strip("'"), max_results=10)
87
+ return self._parse_response(response)
88
+ except Exception as e:
89
+ if isinstance(e, asyncio.TimeoutError):
90
+ logging.exception('Request to DDGS timed out.')
91
+ logging.exception(str(e))
92
+ warnings.warn(
93
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
94
+ await asyncio.sleep(random.randint(2, 5))
95
+ raise Exception(
96
+ 'Failed to get search results from DuckDuckGo after retries.')
97
+
98
+ async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
99
+ ddgs = DDGS(**kwargs)
100
+ try:
101
+ response = await asyncio.wait_for(
102
+ asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
103
+ timeout=self.timeout)
104
+ return response
105
+ except asyncio.TimeoutError:
106
+ logging.exception('Request to DDGS timed out.')
107
+ raise
108
+
109
+ def _call_ddgs(self, query: str, **kwargs) -> dict:
110
+ loop = asyncio.new_event_loop()
111
+ asyncio.set_event_loop(loop)
112
+ try:
113
+ response = loop.run_until_complete(
114
+ self._async_call_ddgs(query, **kwargs))
115
+ return response
116
+ finally:
117
+ loop.close()
118
+
119
+ def _parse_response(self, response: dict) -> dict:
120
+ raw_results = []
121
+ for item in response:
122
+ raw_results.append(
123
+ (item['href'], item['description']
124
+ if 'description' in item else item['body'], item['title']))
125
+ return self._filter_results(raw_results)
126
+
127
+
128
+ class BingSearch(BaseSearch):
129
+
130
+ def __init__(self,
131
+ api_key: str,
132
+ region: str = 'zh-CN',
133
+ topk: int = 3,
134
+ black_list: List[str] = [
135
+ 'enoN',
136
+ 'youtube.com',
137
+ 'bilibili.com',
138
+ 'researchgate.net',
139
+ ],
140
+ **kwargs):
141
+ self.api_key = api_key
142
+ self.market = region
143
+ self.proxy = kwargs.get('proxy')
144
+ super().__init__(topk, black_list)
145
+
146
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
147
+ def search(self, query: str, max_retry: int = 3) -> dict:
148
+ for attempt in range(max_retry):
149
+ try:
150
+ response = self._call_bing_api(query)
151
+ return self._parse_response(response)
152
+ except Exception as e:
153
+ logging.exception(str(e))
154
+ warnings.warn(
155
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
156
+ time.sleep(random.randint(2, 5))
157
+ raise Exception(
158
+ 'Failed to get search results from Bing Search after retries.')
159
+
160
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
161
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
162
+ for attempt in range(max_retry):
163
+ try:
164
+ response = await self._async_call_bing_api(query)
165
+ return self._parse_response(response)
166
+ except Exception as e:
167
+ logging.exception(str(e))
168
+ warnings.warn(
169
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
170
+ await asyncio.sleep(random.randint(2, 5))
171
+ raise Exception(
172
+ 'Failed to get search results from Bing Search after retries.')
173
+
174
+ def _call_bing_api(self, query: str) -> dict:
175
+ endpoint = 'https://api.bing.microsoft.com/v7.0/search'
176
+ params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
177
+ headers = {'Ocp-Apim-Subscription-Key': self.api_key}
178
+ response = requests.get(
179
+ endpoint, headers=headers, params=params, proxies=self.proxy)
180
+ response.raise_for_status()
181
+ return response.json()
182
+
183
+ async def _async_call_bing_api(self, query: str) -> dict:
184
+ endpoint = 'https://api.bing.microsoft.com/v7.0/search'
185
+ params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
186
+ headers = {'Ocp-Apim-Subscription-Key': self.api_key}
187
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
188
+ async with session.get(
189
+ endpoint,
190
+ headers=headers,
191
+ params=params,
192
+ proxy=self.proxy and
193
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
194
+ return await resp.json()
195
+
196
+ def _parse_response(self, response: dict) -> dict:
197
+ webpages = {
198
+ w['id']: w
199
+ for w in response.get('webPages', {}).get('value', [])
200
+ }
201
+ raw_results = []
202
+
203
+ for item in response.get('rankingResponse',
204
+ {}).get('mainline', {}).get('items', []):
205
+ if item['answerType'] == 'WebPages':
206
+ webpage = webpages.get(item['value']['id'])
207
+ if webpage:
208
+ raw_results.append(
209
+ (webpage['url'], webpage['snippet'], webpage['name']))
210
+ elif item['answerType'] == 'News' and item['value'][
211
+ 'id'] == response.get('news', {}).get('id'):
212
+ for news in response.get('news', {}).get('value', []):
213
+ raw_results.append(
214
+ (news['url'], news['description'], news['name']))
215
+
216
+ return self._filter_results(raw_results)
217
+
218
+
219
+ class BraveSearch(BaseSearch):
220
+ """
221
+ Wrapper around the Brave Search API.
222
+
223
+ To use, you should pass your Brave Search API key to the constructor.
224
+
225
+ Args:
226
+ api_key (str): API KEY to use Brave Search API.
227
+ You can create a free API key at https://api.search.brave.com/app/keys.
228
+ search_type (str): Brave Search API supports ['web', 'news', 'images', 'videos'],
229
+ currently only supports 'news' and 'web'.
230
+ topk (int): The number of search results returned in response from API search results.
231
+ region (str): The country code string. Specifies the country where the search results come from.
232
+ language (str): The language code string. Specifies the preferred language for the search results.
233
+ extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
234
+ **kwargs: Any other parameters related to the Brave Search API. Find more details at
235
+ https://api.search.brave.com/app/documentation/web-search/get-started.
236
+ """
237
+
238
+ def __init__(self,
239
+ api_key: str,
240
+ region: str = 'ALL',
241
+ language: str = 'zh-hans',
242
+ extra_snippests: bool = True,
243
+ topk: int = 3,
244
+ black_list: List[str] = [
245
+ 'enoN',
246
+ 'youtube.com',
247
+ 'bilibili.com',
248
+ 'researchgate.net',
249
+ ],
250
+ **kwargs):
251
+ self.api_key = api_key
252
+ self.market = region
253
+ self.proxy = kwargs.get('proxy')
254
+ self.language = language
255
+ self.extra_snippests = extra_snippests
256
+ self.search_type = kwargs.get('search_type', 'web')
257
+ self.kwargs = kwargs
258
+ super().__init__(topk, black_list)
259
+
260
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
261
+ def search(self, query: str, max_retry: int = 3) -> dict:
262
+ for attempt in range(max_retry):
263
+ try:
264
+ response = self._call_brave_api(query)
265
+ return self._parse_response(response)
266
+ except Exception as e:
267
+ logging.exception(str(e))
268
+ warnings.warn(
269
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
270
+ time.sleep(random.randint(2, 5))
271
+ raise Exception(
272
+ 'Failed to get search results from Brave Search after retries.')
273
+
274
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
275
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
276
+ for attempt in range(max_retry):
277
+ try:
278
+ response = await self._async_call_brave_api(query)
279
+ return self._parse_response(response)
280
+ except Exception as e:
281
+ logging.exception(str(e))
282
+ warnings.warn(
283
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
284
+ await asyncio.sleep(random.randint(2, 5))
285
+ raise Exception(
286
+ 'Failed to get search results from Brave Search after retries.')
287
+
288
+ def _call_brave_api(self, query: str) -> dict:
289
+ endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
290
+ params = {
291
+ 'q': query,
292
+ 'country': self.market,
293
+ 'search_lang': self.language,
294
+ 'extra_snippets': self.extra_snippests,
295
+ 'count': self.topk,
296
+ **{
297
+ key: value
298
+ for key, value in self.kwargs.items() if value is not None
299
+ },
300
+ }
301
+ headers = {
302
+ 'X-Subscription-Token': self.api_key or '',
303
+ 'Accept': 'application/json'
304
+ }
305
+ response = requests.get(
306
+ endpoint, headers=headers, params=params, proxies=self.proxy)
307
+ response.raise_for_status()
308
+ return response.json()
309
+
310
+ async def _async_call_brave_api(self, query: str) -> dict:
311
+ endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
312
+ params = {
313
+ 'q': query,
314
+ 'country': self.market,
315
+ 'search_lang': self.language,
316
+ 'extra_snippets': self.extra_snippests,
317
+ 'count': self.topk,
318
+ **{
319
+ key: value
320
+ for key, value in self.kwargs.items() if value is not None
321
+ },
322
+ }
323
+ headers = {
324
+ 'X-Subscription-Token': self.api_key or '',
325
+ 'Accept': 'application/json'
326
+ }
327
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
328
+ async with session.get(
329
+ endpoint,
330
+ headers=headers,
331
+ params=params,
332
+ proxy=self.proxy and
333
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
334
+ return await resp.json()
335
+
336
+ def _parse_response(self, response: dict) -> dict:
337
+ if self.search_type == 'web':
338
+ filtered_result = response.get('web', {}).get('results', [])
339
+ else:
340
+ filtered_result = response.get('results', {})
341
+ raw_results = []
342
+
343
+ for item in filtered_result:
344
+ raw_results.append((
345
+ item.get('url', ''),
346
+ ' '.join(
347
+ filter(None, [
348
+ item.get('description'),
349
+ *item.get('extra_snippets', [])
350
+ ])),
351
+ item.get('title', ''),
352
+ ))
353
+ return self._filter_results(raw_results)
354
+
355
+
356
+ class GoogleSearch(BaseSearch):
357
+ """
358
+ Wrapper around the Serper.dev Google Search API.
359
+
360
+ To use, you should pass your serper API key to the constructor.
361
+
362
+ Args:
363
+ api_key (str): API KEY to use serper google search API.
364
+ You can create a free API key at https://serper.dev.
365
+ search_type (str): Serper API supports ['search', 'images', 'news',
366
+ 'places'] types of search, currently we only support 'search' and 'news'.
367
+ topk (int): The number of search results returned in response from api search results.
368
+ **kwargs: Any other parameters related to the Serper API. Find more details at
369
+ https://serper.dev/playground
370
+ """
371
+
372
+ result_key_for_type = {
373
+ 'news': 'news',
374
+ 'places': 'places',
375
+ 'images': 'images',
376
+ 'search': 'organic',
377
+ }
378
+
379
+ def __init__(self,
380
+ api_key: str,
381
+ topk: int = 3,
382
+ black_list: List[str] = [
383
+ 'enoN',
384
+ 'youtube.com',
385
+ 'bilibili.com',
386
+ 'researchgate.net',
387
+ ],
388
+ **kwargs):
389
+ self.api_key = api_key
390
+ self.proxy = kwargs.get('proxy')
391
+ self.search_type = kwargs.get('search_type', 'search')
392
+ self.kwargs = kwargs
393
+ super().__init__(topk, black_list)
394
+
395
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
396
+ def search(self, query: str, max_retry: int = 3) -> dict:
397
+ for attempt in range(max_retry):
398
+ try:
399
+ response = self._call_serper_api(query)
400
+ return self._parse_response(response)
401
+ except Exception as e:
402
+ logging.exception(str(e))
403
+ warnings.warn(
404
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
405
+ time.sleep(random.randint(2, 5))
406
+ raise Exception(
407
+ 'Failed to get search results from Google Serper Search after retries.'
408
+ )
409
+
410
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
411
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
412
+ for attempt in range(max_retry):
413
+ try:
414
+ response = await self._async_call_serper_api(query)
415
+ return self._parse_response(response)
416
+ except Exception as e:
417
+ logging.exception(str(e))
418
+ warnings.warn(
419
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
420
+ await asyncio.sleep(random.randint(2, 5))
421
+ raise Exception(
422
+ 'Failed to get search results from Google Serper Search after retries.'
423
+ )
424
+
425
+ def _call_serper_api(self, query: str) -> dict:
426
+ endpoint = f'https://google.serper.dev/{self.search_type}'
427
+ params = {
428
+ 'q': query,
429
+ 'num': self.topk,
430
+ **{
431
+ key: value
432
+ for key, value in self.kwargs.items() if value is not None
433
+ },
434
+ }
435
+ headers = {
436
+ 'X-API-KEY': self.api_key or '',
437
+ 'Content-Type': 'application/json'
438
+ }
439
+ response = requests.get(
440
+ endpoint, headers=headers, params=params, proxies=self.proxy)
441
+ response.raise_for_status()
442
+ return response.json()
443
+
444
+ async def _async_call_serper_api(self, query: str) -> dict:
445
+ endpoint = f'https://google.serper.dev/{self.search_type}'
446
+ params = {
447
+ 'q': query,
448
+ 'num': self.topk,
449
+ **{
450
+ key: value
451
+ for key, value in self.kwargs.items() if value is not None
452
+ },
453
+ }
454
+ headers = {
455
+ 'X-API-KEY': self.api_key or '',
456
+ 'Content-Type': 'application/json'
457
+ }
458
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
459
+ async with session.get(
460
+ endpoint,
461
+ headers=headers,
462
+ params=params,
463
+ proxy=self.proxy and
464
+ (self.proxy.get('http') or self.proxy.get('https'))) as resp:
465
+ return await resp.json()
466
+
467
+ def _parse_response(self, response: dict) -> dict:
468
+ raw_results = []
469
+
470
+ if response.get('answerBox'):
471
+ answer_box = response.get('answerBox', {})
472
+ if answer_box.get('answer'):
473
+ raw_results.append(('', answer_box.get('answer'), ''))
474
+ elif answer_box.get('snippet'):
475
+ raw_results.append(
476
+ ('', answer_box.get('snippet').replace('\n', ' '), ''))
477
+ elif answer_box.get('snippetHighlighted'):
478
+ raw_results.append(
479
+ ('', answer_box.get('snippetHighlighted'), ''))
480
+
481
+ if response.get('knowledgeGraph'):
482
+ kg = response.get('knowledgeGraph', {})
483
+ description = kg.get('description', '')
484
+ attributes = '. '.join(
485
+ f'{attribute}: {value}'
486
+ for attribute, value in kg.get('attributes', {}).items())
487
+ raw_results.append(
488
+ (kg.get('descriptionLink', ''),
489
+ f'{description}. {attributes}' if attributes else description,
490
+ f"{kg.get('title', '')}: {kg.get('type', '')}."))
491
+
492
+ for result in response[self.result_key_for_type[
493
+ self.search_type]][:self.topk]:
494
+ description = result.get('snippet', '')
495
+ attributes = '. '.join(
496
+ f'{attribute}: {value}'
497
+ for attribute, value in result.get('attributes', {}).items())
498
+ raw_results.append(
499
+ (result.get('link', ''),
500
+ f'{description}. {attributes}' if attributes else description,
501
+ result.get('title', '')))
502
+
503
+ return self._filter_results(raw_results)
504
+
505
+
506
+ class TencentSearch(BaseSearch):
507
+ """Wrapper around the tencentclound Search API.
508
+
509
+ To use, you should pass your secret_id and secret_key to the constructor.
510
+
511
+ Args:
512
+ secret_id (str): Your Tencent Cloud secret ID for accessing the API.
513
+ For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488.
514
+ secret_key (str): Your Tencent Cloud secret key for accessing the API.
515
+ api_key (str, optional): Additional API key, if required.
516
+ action (str): The action for this interface, use `SearchCommon`.
517
+ version (str): The API version, use `2020-12-29`.
518
+ service (str): The service name, use `tms`.
519
+ host (str): The API host, use `tms.tencentcloudapi.com`.
520
+ topk (int): The maximum number of search results to return.
521
+ tsn (int): Time filter for search results. Valid values:
522
+ 1 (within 1 day), 2 (within 1 week), 3 (within 1 month),
523
+ 4 (within 1 year), 5 (within 6 months), 6 (within 3 years).
524
+ insite (str): Specify a site to search within (supports only a single site).
525
+ If not specified, the entire web is searched. Example: `zhihu.com`.
526
+ category (str): Vertical category for filtering results. Optional values include:
527
+ `baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope).
528
+ vrid (str): Result card type(s). Different `vrid` values represent different types of result cards.
529
+ Supports multiple values separated by commas. Example: `30010255`.
530
+ """
531
+
532
+ def __init__(self,
533
+ secret_id: str = 'Your SecretId',
534
+ secret_key: str = 'Your SecretKey',
535
+ api_key: str = '',
536
+ action: str = 'SearchCommon',
537
+ version: str = '2020-12-29',
538
+ service: str = 'tms',
539
+ host: str = 'tms.tencentcloudapi.com',
540
+ topk: int = 3,
541
+ tsn: int = None,
542
+ insite: str = None,
543
+ category: str = None,
544
+ vrid: str = None,
545
+ black_list: List[str] = [
546
+ 'enoN',
547
+ 'youtube.com',
548
+ 'bilibili.com',
549
+ 'researchgate.net',
550
+ ]):
551
+ self.secret_id = secret_id
552
+ self.secret_key = secret_key
553
+ self.api_key = api_key
554
+ self.action = action
555
+ self.version = version
556
+ self.service = service
557
+ self.host = host
558
+ self.tsn = tsn
559
+ self.insite = insite
560
+ self.category = category
561
+ self.vrid = vrid
562
+ super().__init__(topk, black_list=black_list)
563
+
564
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
565
+ def search(self, query: str, max_retry: int = 3) -> dict:
566
+ for attempt in range(max_retry):
567
+ try:
568
+ response = self._call_tencent_api(query)
569
+ return self._parse_response(response)
570
+ except Exception as e:
571
+ logging.exception(str(e))
572
+ warnings.warn(
573
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
574
+ time.sleep(random.randint(2, 5))
575
+ raise Exception(
576
+ 'Failed to get search results from Bing Search after retries.')
577
+
578
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
579
+ async def asearch(self, query: str, max_retry: int = 3) -> dict:
580
+ for attempt in range(max_retry):
581
+ try:
582
+ response = await self._async_call_tencent_api(query)
583
+ return self._parse_response(response)
584
+ except Exception as e:
585
+ logging.exception(str(e))
586
+ warnings.warn(
587
+ f'Retry {attempt + 1}/{max_retry} due to error: {e}')
588
+ await asyncio.sleep(random.randint(2, 5))
589
+ raise Exception(
590
+ 'Failed to get search results from Bing Search after retries.')
591
+
592
+ def _get_headers_and_payload(self, query: str) -> tuple:
593
+
594
+ def sign(key, msg):
595
+ return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
596
+
597
+ params = dict(Query=query)
598
+ # if self.topk:
599
+ # params['Cnt'] = self.topk
600
+ if self.tsn:
601
+ params['Tsn'] = self.tsn
602
+ if self.insite:
603
+ params['Insite'] = self.insite
604
+ if self.category:
605
+ params['Category'] = self.category
606
+ if self.vrid:
607
+ params['Vrid'] = self.vrid
608
+ payload = json.dumps(params)
609
+ algorithm = 'TC3-HMAC-SHA256'
610
+ timestamp = int(time.time())
611
+ date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d')
612
+
613
+ # ************* 步骤 1:拼接规范请求串 *************
614
+ http_request_method = 'POST'
615
+ canonical_uri = '/'
616
+ canonical_querystring = ''
617
+ ct = 'application/json; charset=utf-8'
618
+ canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
619
+ signed_headers = 'content-type;host;x-tc-action'
620
+ hashed_request_payload = hashlib.sha256(
621
+ payload.encode('utf-8')).hexdigest()
622
+ canonical_request = (
623
+ http_request_method + '\n' + canonical_uri + '\n' +
624
+ canonical_querystring + '\n' + canonical_headers + '\n' +
625
+ signed_headers + '\n' + hashed_request_payload)
626
+
627
+ # ************* 步骤 2:拼接待签名字符串 *************
628
+ credential_scope = date + '/' + self.service + '/' + 'tc3_request'
629
+ hashed_canonical_request = hashlib.sha256(
630
+ canonical_request.encode('utf-8')).hexdigest()
631
+ string_to_sign = (
632
+ algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
633
+ '\n' + hashed_canonical_request)
634
+
635
+ # ************* 步骤 3:计算签名 *************
636
+ secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
637
+ secret_service = sign(secret_date, self.service)
638
+ secret_signing = sign(secret_service, 'tc3_request')
639
+ signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
640
+ hashlib.sha256).hexdigest()
641
+
642
+ # ************* 步骤 4:拼接 Authorization *************
643
+ authorization = (
644
+ algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
645
+ credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
646
+ ', ' + 'Signature=' + signature)
647
+
648
+ # ************* 步骤 5:构造并发起请求 *************
649
+ headers = {
650
+ 'Authorization': authorization,
651
+ 'Content-Type': 'application/json; charset=utf-8',
652
+ 'Host': self.host,
653
+ 'X-TC-Action': self.action,
654
+ 'X-TC-Timestamp': str(timestamp),
655
+ 'X-TC-Version': self.version
656
+ }
657
+ # if self.region:
658
+ # headers["X-TC-Region"] = self.region
659
+ if self.api_key:
660
+ headers['X-TC-Token'] = self.api_key
661
+ return headers, payload
662
+
663
+ def _call_tencent_api(self, query: str) -> dict:
664
+ headers, payload = self._get_headers_and_payload(query)
665
+ req = HTTPSConnection(self.host)
666
+ req.request('POST', '/', headers=headers, body=payload.encode('utf-8'))
667
+ resp = req.getresponse()
668
+ try:
669
+ resp = json.loads(resp.read().decode('utf-8'))
670
+ except Exception as e:
671
+ logging.warning(str(e))
672
+ import ast
673
+ resp = ast.literal_eval(resp)
674
+ return resp.get('Response', dict())
675
+
676
+ async def _async_call_tencent_api(self, query: str):
677
+ headers, payload = self._get_headers_and_payload(query)
678
+ async with aiohttp.ClientSession(raise_for_status=True) as session:
679
+ async with session.post(
680
+ 'https://' + self.host.lstrip('/'),
681
+ headers=headers,
682
+ data=payload) as resp:
683
+ return (await resp.json()).get('Response', {})
684
+
685
+ def _parse_response(self, response: dict) -> dict:
686
+ raw_results = []
687
+ for item in response.get('Pages', []):
688
+ display = json.loads(item['Display'])
689
+ if not display['url']:
690
+ continue
691
+ raw_results.append((display['url'], display['content']
692
+ or display['abstract_info'], display['title']))
693
+ return self._filter_results(raw_results)
694
+
695
+
696
+ class ContentFetcher:
697
+
698
+ def __init__(self, timeout: int = 5):
699
+ self.timeout = timeout
700
+
701
+ @cached(cache=TTLCache(maxsize=100, ttl=600))
702
+ def fetch(self, url: str) -> Tuple[bool, str]:
703
+ try:
704
+ response = requests.get(url, timeout=self.timeout)
705
+ response.raise_for_status()
706
+ html = response.content
707
+ except requests.RequestException as e:
708
+ return False, str(e)
709
+
710
+ text = BeautifulSoup(html, 'html.parser').get_text()
711
+ cleaned_text = re.sub(r'\n+', '\n', text)
712
+ return True, cleaned_text
713
+
714
+ @acached(cache=TTLCache(maxsize=100, ttl=600))
715
+ async def afetch(self, url: str) -> Tuple[bool, str]:
716
+ try:
717
+ async with aiohttp.ClientSession(
718
+ raise_for_status=True,
719
+ timeout=aiohttp.ClientTimeout(self.timeout)) as session:
720
+ async with session.get(url) as resp:
721
+ html = await resp.text(errors='ignore')
722
+ text = BeautifulSoup(html, 'html.parser').get_text()
723
+ cleaned_text = re.sub(r'\n+', '\n', text)
724
+ return True, cleaned_text
725
+ except Exception as e:
726
+ return False, str(e)
727
+
728
+
729
+ class WebBrowser(BaseAction):
730
+ """Wrapper around the Web Browser Tool.
731
+ """
732
+
733
+ def __init__(self,
734
+ searcher_type: str = 'DuckDuckGoSearch',
735
+ timeout: int = 5,
736
+ black_list: Optional[List[str]] = [
737
+ 'enoN',
738
+ 'youtube.com',
739
+ 'bilibili.com',
740
+ 'researchgate.net',
741
+ ],
742
+ topk: int = 20,
743
+ description: Optional[dict] = None,
744
+ parser: Type[BaseParser] = JsonParser,
745
+ **kwargs):
746
+ self.searcher = eval(searcher_type)(
747
+ black_list=black_list, topk=topk, **kwargs)
748
+ self.fetcher = ContentFetcher(timeout=timeout)
749
+ self.search_results = None
750
+ super().__init__(description, parser)
751
+
752
+ @tool_api
753
+ def search(self, query: Union[str, List[str]]) -> dict:
754
+ """BING search API
755
+ Args:
756
+ query (List[str]): list of search query strings
757
+ """
758
+ queries = query if isinstance(query, list) else [query]
759
+ search_results = {}
760
+
761
+ with ThreadPoolExecutor() as executor:
762
+ future_to_query = {
763
+ executor.submit(self.searcher.search, q): q
764
+ for q in queries
765
+ }
766
+
767
+ for future in as_completed(future_to_query):
768
+ query = future_to_query[future]
769
+ try:
770
+ results = future.result()
771
+ except Exception as exc:
772
+ warnings.warn(f'{query} generated an exception: {exc}')
773
+ else:
774
+ for result in results.values():
775
+ if result['url'] not in search_results:
776
+ search_results[result['url']] = result
777
+ else:
778
+ search_results[
779
+ result['url']]['summ'] += f"\n{result['summ']}"
780
+
781
+ self.search_results = {
782
+ idx: result
783
+ for idx, result in enumerate(search_results.values())
784
+ }
785
+ return self.search_results
786
+
787
+ @tool_api
788
+ def select(self, select_ids: List[int]) -> dict:
789
+ """get the detailed content on the selected pages.
790
+
791
+ Args:
792
+ select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
793
+ """
794
+ if not self.search_results:
795
+ raise ValueError('No search results to select from.')
796
+
797
+ new_search_results = {}
798
+ with ThreadPoolExecutor() as executor:
799
+ future_to_id = {
800
+ executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
801
+ for select_id in select_ids if select_id in self.search_results
802
+ }
803
+ for future in as_completed(future_to_id):
804
+ select_id = future_to_id[future]
805
+ try:
806
+ web_success, web_content = future.result()
807
+ except Exception as exc:
808
+ warnings.warn(f'{select_id} generated an exception: {exc}')
809
+ else:
810
+ if web_success:
811
+ self.search_results[select_id][
812
+ 'content'] = web_content[:8192]
813
+ new_search_results[select_id] = self.search_results[
814
+ select_id].copy()
815
+ new_search_results[select_id].pop('summ')
816
+
817
+ return new_search_results
818
+
819
+ @tool_api
820
+ def open_url(self, url: str) -> dict:
821
+ print(f'Start Browsing: {url}')
822
+ web_success, web_content = self.fetcher.fetch(url)
823
+ if web_success:
824
+ return {'type': 'text', 'content': web_content}
825
+ else:
826
+ return {'error': web_content}
827
+
828
+
829
+ class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
830
+ """Wrapper around the Web Browser Tool.
831
+ """
832
+
833
+ @tool_api
834
+ async def search(self, query: Union[str, List[str]]) -> dict:
835
+ """BING search API
836
+
837
+ Args:
838
+ query (List[str]): list of search query strings
839
+ """
840
+ queries = query if isinstance(query, list) else [query]
841
+ search_results = {}
842
+
843
+ tasks = []
844
+ for q in queries:
845
+ task = asyncio.create_task(self.searcher.asearch(q))
846
+ task.query = q
847
+ tasks.append(task)
848
+ async for future in async_as_completed(tasks):
849
+ query = future.query
850
+ try:
851
+ results = await future
852
+ except Exception as exc:
853
+ warnings.warn(f'{query} generated an exception: {exc}')
854
+ else:
855
+ for result in results.values():
856
+ if result['url'] not in search_results:
857
+ search_results[result['url']] = result
858
+ else:
859
+ search_results[
860
+ result['url']]['summ'] += f"\n{result['summ']}"
861
+
862
+ self.search_results = {
863
+ idx: result
864
+ for idx, result in enumerate(search_results.values())
865
+ }
866
+ return self.search_results
867
+
868
+ @tool_api
869
+ async def select(self, select_ids: List[int]) -> dict:
870
+ """get the detailed content on the selected pages.
871
+
872
+ Args:
873
+ select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
874
+ """
875
+ if not self.search_results:
876
+ raise ValueError('No search results to select from.')
877
+
878
+ new_search_results = {}
879
+ tasks = []
880
+ for select_id in select_ids:
881
+ if select_id in self.search_results:
882
+ task = asyncio.create_task(
883
+ self.fetcher.afetch(self.search_results[select_id]['url']))
884
+ task.select_id = select_id
885
+ tasks.append(task)
886
+ async for future in async_as_completed(tasks):
887
+ select_id = future.select_id
888
+ try:
889
+ web_success, web_content = await future
890
+ except Exception as exc:
891
+ warnings.warn(f'{select_id} generated an exception: {exc}')
892
+ else:
893
+ if web_success:
894
+ self.search_results[select_id][
895
+ 'content'] = web_content[:8192]
896
+ new_search_results[select_id] = self.search_results[
897
+ select_id].copy()
898
+ new_search_results[select_id].pop('summ')
899
+ return new_search_results
900
+
901
+ @tool_api
902
+ async def open_url(self, url: str) -> dict:
903
+ print(f'Start Browsing: {url}')
904
+ web_success, web_content = await self.fetcher.afetch(url)
905
+ if web_success:
906
+ return {'type': 'text', 'content': web_content}
907
+ else:
908
+ return {'error': web_content}
lagent/agents/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .agent import Agent, AgentDict, AgentList, AsyncAgent, AsyncSequential, Sequential
2
+ from .react import AsyncReAct, ReAct
3
+ from .stream import AgentForInternLM, AsyncAgentForInternLM, AsyncMathCoder, MathCoder
4
+
5
+ __all__ = [
6
+ 'Agent', 'AgentDict', 'AgentList', 'AsyncAgent', 'AgentForInternLM',
7
+ 'AsyncAgentForInternLM', 'MathCoder', 'AsyncMathCoder', 'ReAct',
8
+ 'AsyncReAct', 'Sequential', 'AsyncSequential'
9
+ ]
lagent/agents/agent.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import warnings
3
+ from collections import OrderedDict, UserDict, UserList, abc
4
+ from functools import wraps
5
+ from itertools import chain, repeat
6
+ from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
7
+
8
+ from lagent.agents.aggregator import DefaultAggregator
9
+ from lagent.hooks import Hook, RemovableHandle
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory, MemoryManager
12
+ from lagent.prompts.parsers import StrParser
13
+ from lagent.prompts.prompt_template import PromptTemplate
14
+ from lagent.schema import AgentMessage
15
+ from lagent.utils import create_object
16
+
17
+
18
+ class Agent:
19
+ """Agent is the basic unit of the system. It is responsible for
20
+ communicating with the LLM, managing the memory, and handling the
21
+ message aggregation and parsing. It can also be extended with hooks
22
+
23
+ Args:
24
+ llm (Union[BaseLLM, Dict]): The language model used by the agent.
25
+ template (Union[PromptTemplate, str]): The template used to format the
26
+ messages.
27
+ memory (Dict): The memory used by the agent.
28
+ output_format (Dict): The output format used by the agent.
29
+ aggregator (Dict): The aggregator used by the agent.
30
+ name (Optional[str]): The name of the agent.
31
+ description (Optional[str]): The description of the agent.
32
+ hooks (Optional[Union[List[Dict], Dict]]): The hooks used by the agent.
33
+
34
+ Returns:
35
+ AgentMessage: The response message.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ llm: Union[BaseLLM, Dict] = None,
41
+ template: Union[PromptTemplate, str, dict, List[dict]] = None,
42
+ memory: Dict = dict(type=Memory),
43
+ output_format: Optional[Dict] = None,
44
+ aggregator: Dict = dict(type=DefaultAggregator),
45
+ name: Optional[str] = None,
46
+ description: Optional[str] = None,
47
+ hooks: Optional[Union[List[Dict], Dict]] = None,
48
+ ):
49
+ self.name = name or self.__class__.__name__
50
+ self.llm: BaseLLM = create_object(llm)
51
+ self.memory: MemoryManager = MemoryManager(memory) if memory else None
52
+ self.output_format: StrParser = create_object(output_format)
53
+ self.template = template
54
+ self.description = description
55
+ self.aggregator: DefaultAggregator = create_object(aggregator)
56
+ self._hooks: Dict[int, Hook] = OrderedDict()
57
+ if hooks:
58
+ for hook in hooks:
59
+ hook = create_object(hook)
60
+ self.register_hook(hook)
61
+
62
+ def update_memory(self, message, session_id=0):
63
+ if self.memory:
64
+ self.memory.add(message, session_id=session_id)
65
+
66
+ def __call__(
67
+ self,
68
+ *message: Union[str, AgentMessage, List[AgentMessage]],
69
+ session_id=0,
70
+ **kwargs,
71
+ ) -> AgentMessage:
72
+ # message.receiver = self.name
73
+ message = [
74
+ AgentMessage(sender='user', content=m)
75
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
76
+ ]
77
+ for hook in self._hooks.values():
78
+ result = hook.before_agent(self, message, session_id)
79
+ if result:
80
+ message = result
81
+ self.update_memory(message, session_id=session_id)
82
+ response_message = self.forward(
83
+ *message, session_id=session_id, **kwargs)
84
+ if not isinstance(response_message, AgentMessage):
85
+ response_message = AgentMessage(
86
+ sender=self.name,
87
+ content=response_message,
88
+ )
89
+ self.update_memory(response_message, session_id=session_id)
90
+ response_message = copy.deepcopy(response_message)
91
+ for hook in self._hooks.values():
92
+ result = hook.after_agent(self, response_message, session_id)
93
+ if result:
94
+ response_message = result
95
+ return response_message
96
+
97
+ def forward(self,
98
+ *message: AgentMessage,
99
+ session_id=0,
100
+ **kwargs) -> Union[AgentMessage, str]:
101
+ formatted_messages = self.aggregator.aggregate(
102
+ self.memory.get(session_id),
103
+ self.name,
104
+ self.output_format,
105
+ self.template,
106
+ )
107
+ llm_response = self.llm.chat(formatted_messages, **kwargs)
108
+ if self.output_format:
109
+ formatted_messages = self.output_format.parse_response(
110
+ llm_response)
111
+ return AgentMessage(
112
+ sender=self.name,
113
+ content=llm_response,
114
+ formatted=formatted_messages,
115
+ )
116
+ return llm_response
117
+
118
+ def __setattr__(self, __name: str, __value: Any) -> None:
119
+ if isinstance(__value, Agent):
120
+ _agents = getattr(self, '_agents', OrderedDict())
121
+ _agents[__name] = __value
122
+ super().__setattr__('_agents', _agents)
123
+ super().__setattr__(__name, __value)
124
+
125
+ def state_dict(self, session_id=0):
126
+ state_dict, stack = {}, [('', self)]
127
+ while stack:
128
+ prefix, node = stack.pop()
129
+ key = prefix + 'memory'
130
+ if node.memory is not None:
131
+ if session_id not in node.memory.memory_map:
132
+ warnings.warn(f'No session id {session_id} in {key}')
133
+ memory = node.memory.get(session_id)
134
+ state_dict[key] = memory and memory.save() or []
135
+ if hasattr(node, '_agents'):
136
+ for name, value in reversed(node._agents.items()):
137
+ stack.append((prefix + name + '.', value))
138
+ return state_dict
139
+
140
+ def load_state_dict(self, state_dict: Dict, session_id=0):
141
+ _state_dict = self.state_dict()
142
+ missing_keys = set(_state_dict) - set(state_dict)
143
+ if missing_keys:
144
+ raise KeyError(f'Missing keys: {missing_keys}')
145
+ extra_keys = set(state_dict) - set(_state_dict)
146
+ if extra_keys:
147
+ warnings.warn(f'Mismatch keys which are not used: {extra_keys}')
148
+ for key in _state_dict:
149
+ obj = self
150
+ for attr in key.split('.')[:-1]:
151
+ if isinstance(obj, AgentList):
152
+ assert attr.isdigit()
153
+ obj = obj[int(attr)]
154
+ elif isinstance(obj, AgentDict):
155
+ obj = obj[attr]
156
+ else:
157
+ obj = getattr(obj, attr)
158
+ if obj.memory is not None:
159
+ if session_id not in obj.memory.memory_map:
160
+ obj.memory.create_instance(session_id)
161
+ obj.memory.memory_map[session_id].load(state_dict[key] or [])
162
+
163
+ def register_hook(self, hook: Callable):
164
+ handle = RemovableHandle(self._hooks)
165
+ self._hooks[handle.id] = hook
166
+ return handle
167
+
168
+ def reset(self,
169
+ session_id=0,
170
+ keypath: Optional[str] = None,
171
+ recursive: bool = False):
172
+ assert not (keypath and
173
+ recursive), 'keypath and recursive can\'t be used together'
174
+ if keypath:
175
+ keys, agent = keypath.split('.'), self
176
+ for key in keys:
177
+ agents = getattr(agent, '_agents', {})
178
+ if key not in agents:
179
+ raise KeyError(f'No sub-agent named {key} in {agent}')
180
+ agent = agents[key]
181
+ agent.reset(session_id, recursive=False)
182
+ else:
183
+ if self.memory:
184
+ self.memory.reset(session_id=session_id)
185
+ if recursive:
186
+ for agent in getattr(self, '_agents', {}).values():
187
+ agent.reset(session_id, recursive=True)
188
+
189
+ def __repr__(self):
190
+
191
+ def _rcsv_repr(agent, n_indent=1):
192
+ res = agent.__class__.__name__ + (f"(name='{agent.name}')"
193
+ if agent.name else '')
194
+ modules = [
195
+ f"{n_indent * ' '}({name}): {_rcsv_repr(agent, n_indent + 1)}"
196
+ for name, agent in getattr(agent, '_agents', {}).items()
197
+ ]
198
+ if modules:
199
+ res += '(\n' + '\n'.join(
200
+ modules) + f'\n{(n_indent - 1) * " "})'
201
+ elif not res.endswith(')'):
202
+ res += '()'
203
+ return res
204
+
205
+ return _rcsv_repr(self)
206
+
207
+
208
+ class AsyncAgent(Agent):
209
+
210
+ async def __call__(self,
211
+ *message: AgentMessage | List[AgentMessage],
212
+ session_id=0,
213
+ **kwargs) -> AgentMessage:
214
+ message = [
215
+ AgentMessage(sender='user', content=m)
216
+ if isinstance(m, str) else copy.deepcopy(m) for m in message
217
+ ]
218
+ for hook in self._hooks.values():
219
+ result = hook.before_agent(self, message, session_id)
220
+ if result:
221
+ message = result
222
+ self.update_memory(message, session_id=session_id)
223
+ response_message = await self.forward(
224
+ *message, session_id=session_id, **kwargs)
225
+ if not isinstance(response_message, AgentMessage):
226
+ response_message = AgentMessage(
227
+ sender=self.name,
228
+ content=response_message,
229
+ )
230
+ self.update_memory(response_message, session_id=session_id)
231
+ response_message = copy.deepcopy(response_message)
232
+ for hook in self._hooks.values():
233
+ result = hook.after_agent(self, response_message, session_id)
234
+ if result:
235
+ response_message = result
236
+ return response_message
237
+
238
+ async def forward(self,
239
+ *message: AgentMessage,
240
+ session_id=0,
241
+ **kwargs) -> Union[AgentMessage, str]:
242
+ formatted_messages = self.aggregator.aggregate(
243
+ self.memory.get(session_id),
244
+ self.name,
245
+ self.output_format,
246
+ self.template,
247
+ )
248
+ llm_response = await self.llm.chat(formatted_messages, session_id,
249
+ **kwargs)
250
+ if self.output_format:
251
+ formatted_messages = self.output_format.parse_response(
252
+ llm_response)
253
+ return AgentMessage(
254
+ sender=self.name,
255
+ content=llm_response,
256
+ formatted=formatted_messages,
257
+ )
258
+ return llm_response
259
+
260
+
261
+ class Sequential(Agent):
262
+ """Sequential is an agent container that forwards messages to each agent
263
+ in the order they are added."""
264
+
265
+ def __init__(self, *agents: Union[Agent, AsyncAgent, Iterable], **kwargs):
266
+ super().__init__(**kwargs)
267
+ self._agents = OrderedDict()
268
+ if not agents:
269
+ raise ValueError('At least one agent should be provided')
270
+ if isinstance(agents[0],
271
+ Iterable) and not isinstance(agents[0], Agent):
272
+ if not agents[0]:
273
+ raise ValueError('At least one agent should be provided')
274
+ agents = agents[0]
275
+ for key, agent in enumerate(agents):
276
+ if isinstance(agents, Mapping):
277
+ key, agent = agent, agents[agent]
278
+ elif isinstance(agent, tuple):
279
+ key, agent = agent
280
+ self.add_agent(key, agent)
281
+
282
+ def add_agent(self, name: str, agent: Union[Agent, AsyncAgent]):
283
+ assert isinstance(
284
+ agent, (Agent, AsyncAgent
285
+ )), f'{type(agent)} is not an Agent or AsyncAgent subclass'
286
+ self._agents[str(name)] = agent
287
+
288
+ def forward(self,
289
+ *message: AgentMessage,
290
+ session_id=0,
291
+ exit_at: Optional[int] = None,
292
+ **kwargs) -> AgentMessage:
293
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
294
+ if exit_at is None:
295
+ exit_at = len(self) - 1
296
+ iterator = chain.from_iterable(repeat(self._agents.values()))
297
+ for _ in range(exit_at + 1):
298
+ agent = next(iterator)
299
+ if isinstance(message, AgentMessage):
300
+ message = (message, )
301
+ message = agent(*message, session_id=session_id, **kwargs)
302
+ return message
303
+
304
+ def __getitem__(self, key):
305
+ if isinstance(key, int) and key < 0:
306
+ assert key >= -len(self), 'index out of range'
307
+ key = len(self) + key
308
+ return self._agents[str(key)]
309
+
310
+ def __len__(self):
311
+ return len(self._agents)
312
+
313
+
314
+ class AsyncSequential(Sequential, AsyncAgent):
315
+
316
+ async def forward(self,
317
+ *message: AgentMessage,
318
+ session_id=0,
319
+ exit_at: Optional[int] = None,
320
+ **kwargs) -> AgentMessage:
321
+ assert exit_at is None or exit_at >= 0, 'exit_at should be greater than or equal to 0'
322
+ if exit_at is None:
323
+ exit_at = len(self) - 1
324
+ iterator = chain.from_iterable(repeat(self._agents.values()))
325
+ for _ in range(exit_at + 1):
326
+ agent = next(iterator)
327
+ if isinstance(message, AgentMessage):
328
+ message = (message, )
329
+ message = await agent(*message, session_id=session_id, **kwargs)
330
+ return message
331
+
332
+
333
+ class AgentContainerMixin:
334
+
335
+ def __init_subclass__(cls):
336
+ super().__init_subclass__()
337
+
338
+ def wrap_api(func):
339
+
340
+ @wraps(func)
341
+ def wrapped_func(self, *args, **kwargs):
342
+ data = self.data.copy() if hasattr(self, 'data') else None
343
+
344
+ def _backup(d):
345
+ if d is None:
346
+ self.data.clear()
347
+ else:
348
+ self.data = d
349
+
350
+ ret = func(self, *args, **kwargs)
351
+ agents = OrderedDict()
352
+ for k, item in (self.data.items() if isinstance(
353
+ self.data, abc.Mapping) else enumerate(self.data)):
354
+ if isinstance(self.data,
355
+ abc.Mapping) and not isinstance(k, str):
356
+ _backup(data)
357
+ raise KeyError(
358
+ f'agent name should be a string, got {type(k)}')
359
+ if isinstance(k, str) and '.' in k:
360
+ _backup(data)
361
+ raise KeyError(
362
+ f'agent name can\'t contain ".", got {k}')
363
+ if not isinstance(item, (Agent, AsyncAgent)):
364
+ _backup(data)
365
+ raise TypeError(
366
+ f'{type(item)} is not an Agent or AsyncAgent subclass'
367
+ )
368
+ agents[str(k)] = item
369
+ self._agents = agents
370
+ return ret
371
+
372
+ return wrapped_func
373
+
374
+ for method in [
375
+ 'append', 'sort', 'reverse', 'pop', 'clear', 'update',
376
+ 'insert', 'extend', 'remove', '__init__', '__setitem__',
377
+ '__delitem__', '__add__', '__iadd__', '__radd__', '__mul__',
378
+ '__imul__', '__rmul__'
379
+ ]:
380
+ if hasattr(cls, method):
381
+ setattr(cls, method, wrap_api(getattr(cls, method)))
382
+
383
+
384
+ class AgentList(Agent, UserList, AgentContainerMixin):
385
+
386
+ def __init__(self,
387
+ agents: Optional[Iterable[Union[Agent, AsyncAgent]]] = None):
388
+ Agent.__init__(self, memory=None)
389
+ UserList.__init__(self, agents)
390
+ self.name = None
391
+
392
+
393
+ class AgentDict(Agent, UserDict, AgentContainerMixin):
394
+
395
+ def __init__(self,
396
+ agents: Optional[Mapping[str, Union[Agent,
397
+ AsyncAgent]]] = None):
398
+ Agent.__init__(self, memory=None)
399
+ UserDict.__init__(self, agents)
400
+ self.name = None
lagent/agents/aggregator/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .default_aggregator import DefaultAggregator
2
+ from .tool_aggregator import InternLMToolAggregator
3
+
4
+ __all__ = ['DefaultAggregator', 'InternLMToolAggregator']
lagent/agents/aggregator/default_aggregator.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+
3
+ from lagent.memory import Memory
4
+ from lagent.prompts import StrParser
5
+
6
+
7
+ class DefaultAggregator:
8
+
9
+ def aggregate(self,
10
+ messages: Memory,
11
+ name: str,
12
+ parser: StrParser = None,
13
+ system_instruction: str = None) -> List[Dict[str, str]]:
14
+ _message = []
15
+ messages = messages.get_memory()
16
+ if system_instruction:
17
+ _message.extend(
18
+ self.aggregate_system_intruction(system_instruction))
19
+ for message in messages:
20
+ if message.sender == name:
21
+ _message.append(
22
+ dict(role='assistant', content=str(message.content)))
23
+ else:
24
+ user_message = message.content
25
+ if len(_message) > 0 and _message[-1]['role'] == 'user':
26
+ _message[-1]['content'] += user_message
27
+ else:
28
+ _message.append(dict(role='user', content=user_message))
29
+ return _message
30
+
31
+ @staticmethod
32
+ def aggregate_system_intruction(system_intruction) -> List[dict]:
33
+ if isinstance(system_intruction, str):
34
+ system_intruction = dict(role='system', content=system_intruction)
35
+ if isinstance(system_intruction, dict):
36
+ system_intruction = [system_intruction]
37
+ if isinstance(system_intruction, list):
38
+ for msg in system_intruction:
39
+ if not isinstance(msg, dict):
40
+ raise TypeError(f'Unsupported message type: {type(msg)}')
41
+ if not ('role' in msg and 'content' in msg):
42
+ raise KeyError(
43
+ f"Missing required key 'role' or 'content': {msg}")
44
+ return system_intruction
lagent/agents/aggregator/tool_aggregator.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Union
2
+
3
+ from lagent.agents.aggregator.default_aggregator import DefaultAggregator
4
+ from lagent.memory.base_memory import Memory
5
+ from lagent.prompts.parsers.tool_parser import MixedToolParser, ToolParser, ToolStatusCode
6
+
7
+
8
+ class InternLMToolAggregator(DefaultAggregator):
9
+
10
+ def __init__(self,
11
+ environment_role='environment',
12
+ environment_begin='',
13
+ environment_end='',
14
+ user_names: Optional[List[str]] = None,
15
+ few_shot: Optional[List[List[dict]]] = None):
16
+ self.environment_role = environment_role
17
+ self.environment_begin = environment_begin
18
+ self.environment_end = environment_end
19
+ self.user_names = user_names or ['user']
20
+ self.few_shot = few_shot or []
21
+
22
+ def aggregate(self,
23
+ messages: Memory,
24
+ name: str,
25
+ parser: Union[ToolParser, MixedToolParser],
26
+ system_instruction: str = None) -> List[Dict[str, str]]:
27
+ _message = []
28
+ messages = messages.get_memory()
29
+ if system_instruction:
30
+ _message.extend(
31
+ self.aggregate_system_intruction(system_instruction))
32
+ tool_instruction = parser.format_instruction()
33
+ if tool_instruction:
34
+ if isinstance(tool_instruction, str):
35
+ tool_instruction = dict(
36
+ role='system', content=tool_instruction)
37
+ if parser.tool_type:
38
+ tool_instruction['name'] = parser.tool_type
39
+ if isinstance(tool_instruction, dict):
40
+ tool_instruction = [tool_instruction]
41
+ _message.extend(tool_instruction)
42
+
43
+ for shot in self.few_shot:
44
+ i = 0
45
+ while i < len(shot):
46
+ msg = shot[i]
47
+ if msg['role'] in ['assistant', 'user', 'system']:
48
+ _message.append(msg)
49
+ elif msg['role'] == self.environment_role:
50
+ if not msg['content'].startswith(self.environment_begin):
51
+ msg['content'] = self.environment_begin + msg['content']
52
+ if not msg['content'].endswith(self.environment_end):
53
+ msg['content'] += self.environment_end
54
+ _message.append(msg)
55
+ elif msg['role'] in ['thought', 'language']:
56
+ if i < len(shot) - 1 and shot[i + 1]['role'] == 'tool':
57
+ _message.append(
58
+ dict(
59
+ role='assistant',
60
+ content=parser.format_response(
61
+ dict(
62
+ tool_type=shot[i + 1]['name'],
63
+ thought=msg['content'],
64
+ action=shot[i + 1]['content'],
65
+ status=None))))
66
+ i += 1
67
+ else:
68
+ _message.append(
69
+ dict(
70
+ role='assistant',
71
+ content=parser.format_response(
72
+ dict(
73
+ tool_type=None,
74
+ thought=msg['content'],
75
+ action=None,
76
+ status=None))))
77
+ else:
78
+ raise KeyError(f'Unkown role: {msg["role"]}')
79
+ i += 1
80
+
81
+ tool_type = None
82
+ for message in messages:
83
+ if message.sender == name:
84
+ if isinstance(message.formatted, dict):
85
+ parsed = message.formatted
86
+ if parsed['status'] == ToolStatusCode.PARSING_ERROR:
87
+ continue
88
+ _message.append(
89
+ dict(
90
+ role='assistant',
91
+ content=parser.format_response(parsed)))
92
+ tool_type = parsed['tool_type']
93
+ else:
94
+ _message.append(
95
+ dict(role='assistant', content=str(message.content)))
96
+ elif message.sender in self.user_names:
97
+ _message.append(dict(role='user', content=message.content))
98
+ else:
99
+ msg = dict(
100
+ role=self.environment_role,
101
+ content=self.environment_begin + str(message.content) +
102
+ self.environment_end)
103
+ if tool_type:
104
+ msg['name'] = tool_type
105
+ _message.append(msg)
106
+ return _message
lagent/agents/react.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import Callable, Dict, List, Union
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from lagent.actions import ActionExecutor, AsyncActionExecutor, BaseAction
7
+ from lagent.agents.agent import Agent, AsyncAgent
8
+ from lagent.agents.aggregator import DefaultAggregator
9
+ from lagent.hooks import ActionPreprocessor
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory
12
+ from lagent.prompts.parsers.json_parser import JSONParser
13
+ from lagent.prompts.prompt_template import PromptTemplate
14
+ from lagent.schema import AgentMessage
15
+ from lagent.utils import create_object
16
+
17
+ select_action_template = """你是一个可以调用外部工具的助手,可以使用的工具包括:
18
+ {action_info}
19
+ {output_format}
20
+ 开始!"""
21
+
22
+ output_format_template = """如果使用工具请遵循以下格式回复:
23
+ {function_format}
24
+
25
+ 如果你已经知道了答案,或者你不需要工具,请遵循以下格式回复
26
+ {finish_format}"""
27
+
28
+
29
+ class ReAct(Agent):
30
+
31
+ def __init__(self,
32
+ llm: Union[BaseLLM, Dict],
33
+ actions: Union[BaseAction, List[BaseAction]],
34
+ template: Union[PromptTemplate, str] = None,
35
+ memory: Dict = dict(type=Memory),
36
+ output_format: Dict = dict(type=JSONParser),
37
+ aggregator: Dict = dict(type=DefaultAggregator),
38
+ hooks: List = [dict(type=ActionPreprocessor)],
39
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
40
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
41
+ max_turn: int = 5,
42
+ **kwargs):
43
+ self.max_turn = max_turn
44
+ self.finish_condition = finish_condition
45
+ actions = dict(
46
+ type=ActionExecutor,
47
+ actions=actions,
48
+ hooks=hooks,
49
+ )
50
+ self.actions: ActionExecutor = create_object(actions)
51
+ select_agent = dict(
52
+ type=Agent,
53
+ llm=llm,
54
+ template=template.format(
55
+ action_info=json.dumps(self.actions.description()),
56
+ output_format=output_format.format_instruction()),
57
+ output_format=output_format,
58
+ memory=memory,
59
+ aggregator=aggregator,
60
+ hooks=hooks,
61
+ )
62
+ self.select_agent = create_object(select_agent)
63
+ super().__init__(**kwargs)
64
+
65
+ def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
66
+ for _ in range(self.max_turn):
67
+ message = self.select_agent(message)
68
+ if self.finish_condition(message):
69
+ return message
70
+ message = self.actions(message)
71
+ return message
72
+
73
+
74
+ class AsyncReAct(AsyncAgent):
75
+
76
+ def __init__(self,
77
+ llm: Union[BaseLLM, Dict],
78
+ actions: Union[BaseAction, List[BaseAction]],
79
+ template: Union[PromptTemplate, str] = None,
80
+ memory: Dict = dict(type=Memory),
81
+ output_format: Dict = dict(type=JSONParser),
82
+ aggregator: Dict = dict(type=DefaultAggregator),
83
+ hooks: List = [dict(type=ActionPreprocessor)],
84
+ finish_condition: Callable[[AgentMessage], bool] = lambda m:
85
+ 'conclusion' in m.content or 'conclusion' in m.formatted,
86
+ max_turn: int = 5,
87
+ **kwargs):
88
+ self.max_turn = max_turn
89
+ self.finish_condition = finish_condition
90
+ actions = dict(
91
+ type=AsyncActionExecutor,
92
+ actions=actions,
93
+ hooks=hooks,
94
+ )
95
+ self.actions: AsyncActionExecutor = create_object(actions)
96
+ select_agent = dict(
97
+ type=AsyncAgent,
98
+ llm=llm,
99
+ template=template.format(
100
+ action_info=json.dumps(self.actions.description()),
101
+ output_format=output_format.format_instruction()),
102
+ output_format=output_format,
103
+ memory=memory,
104
+ aggregator=aggregator,
105
+ hooks=hooks,
106
+ )
107
+ self.select_agent = create_object(select_agent)
108
+ super().__init__(**kwargs)
109
+
110
+ async def forward(self, message: AgentMessage, **kwargs) -> AgentMessage:
111
+ for _ in range(self.max_turn):
112
+ message = await self.select_agent(message)
113
+ if self.finish_condition(message):
114
+ return message
115
+ message = await self.actions(message)
116
+ return message
117
+
118
+
119
+ if __name__ == '__main__':
120
+ from lagent.llms import GPTAPI
121
+
122
+ class ActionCall(BaseModel):
123
+ name: str = Field(description='调用的函数名称')
124
+ parameters: Dict = Field(description='调用函数的参数')
125
+
126
+ class ActionFormat(BaseModel):
127
+ thought_process: str = Field(
128
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
129
+ action: ActionCall = Field(description='当前步骤需要执行的操作,包括函数名��和参数。')
130
+
131
+ class FinishFormat(BaseModel):
132
+ thought_process: str = Field(
133
+ description='描述当前所处的状态和已知信息。这有助于明确目前所掌握的信息和接下来的搜索方向。')
134
+ conclusion: str = Field(description='总结当前的搜索结果,回答问题。')
135
+
136
+ prompt_template = PromptTemplate(select_action_template)
137
+ output_format = JSONParser(
138
+ output_format_template,
139
+ function_format=ActionFormat,
140
+ finish_format=FinishFormat)
141
+
142
+ llm = dict(
143
+ type=GPTAPI,
144
+ model_type='gpt-4o-2024-05-13',
145
+ key=None,
146
+ max_new_tokens=4096,
147
+ proxies=dict(),
148
+ retry=1000)
149
+
150
+ agent = ReAct(
151
+ llm=llm,
152
+ template=prompt_template,
153
+ output_format=output_format,
154
+ aggregator=dict(type='DefaultAggregator'),
155
+ actions=[dict(type='PythonInterpreter')],
156
+ )
157
+ response = agent(
158
+ AgentMessage(sender='user', content='用 Python 计算一下 3 ** 5'))
159
+ print(response)
160
+ response = agent(AgentMessage(sender='user', content=' 2 ** 5 呢'))
161
+ print(response)
lagent/agents/stream.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import warnings
3
+ from copy import deepcopy
4
+ from typing import Callable, Dict, List, Union
5
+
6
+ from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive
7
+ from lagent.agents.agent import Agent, AsyncAgent
8
+ from lagent.agents.aggregator import InternLMToolAggregator
9
+ from lagent.hooks import InternLMActionProcessor
10
+ from lagent.llms import BaseLLM
11
+ from lagent.memory import Memory
12
+ from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode
13
+ from lagent.schema import AgentMessage
14
+ from lagent.utils import create_object
15
+
16
+ API_PREFIX = (
17
+ "This is the subfunction for tool '{tool_name}', you can use this tool. "
18
+ 'The description of this function is: \n{description}')
19
+
20
+ META_CN = ('当开启工具以及代码时,根据需求选择合适的工具进行调用')
21
+
22
+ INTERPRETER_CN = ('你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。'
23
+ '当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。'
24
+ '这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),'
25
+ '复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),'
26
+ '文本处理和分析(比如文本解析和自然语言处理),'
27
+ '机器学习和数据科学(用于展示模型训练和数据可视化),'
28
+ '以及文件操作和数据导入(处理CSV、JSON等格式的文件)。')
29
+
30
+ PLUGIN_CN = ('你可以使用如下工具:'
31
+ '\n{prompt}\n'
32
+ '如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! '
33
+ '同时注意你可以使用的工具,不要随意捏造!')
34
+
35
+
36
+ def get_plugin_prompt(actions, api_desc_template=API_PREFIX):
37
+ plugin_descriptions = []
38
+ for action in actions if isinstance(actions, list) else [actions]:
39
+ action = create_object(action)
40
+ action_desc = deepcopy(action.description)
41
+ if action.is_toolkit:
42
+ for api in action_desc['api_list']:
43
+ api['name'] = f"{action.name}.{api['name']}"
44
+ api['description'] = api_desc_template.format(
45
+ tool_name=action.name, description=api['description'])
46
+ api['parameters'] = [
47
+ param for param in api['parameters']
48
+ if param['name'] in api['required']
49
+ ]
50
+ plugin_descriptions.append(api)
51
+ else:
52
+ action_desc['description'] = api_desc_template.format(
53
+ tool_name=action.name, description=action_desc['description'])
54
+ action_desc['parameters'] = [
55
+ param for param in action_desc['parameters']
56
+ if param['name'] in action_desc['required']
57
+ ]
58
+ plugin_descriptions.append(action_desc)
59
+ return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4)
60
+
61
+
62
+ class AgentForInternLM(Agent):
63
+
64
+ _INTERNAL_AGENT_CLS = Agent
65
+
66
+ def __init__(
67
+ self,
68
+ llm: Union[BaseLLM, Dict],
69
+ plugins: Union[dict, List[dict]] = None,
70
+ interpreter: dict = None,
71
+ template: Union[str, dict, List[dict]] = None,
72
+ memory: Dict = dict(type=Memory),
73
+ output_format: Dict = dict(
74
+ type=MixedToolParser,
75
+ template=META_CN,
76
+ parsers=[
77
+ dict(type=PluginParser, template=PLUGIN_CN),
78
+ dict(type=InterpreterParser, template=INTERPRETER_CN),
79
+ ]),
80
+ aggregator: Dict = dict(type=InternLMToolAggregator),
81
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
82
+ finish_condition: Callable[
83
+ [AgentMessage],
84
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
85
+ max_turn: int = 4,
86
+ **kwargs,
87
+ ):
88
+ agent = dict(
89
+ type=self._INTERNAL_AGENT_CLS,
90
+ llm=llm,
91
+ template=template,
92
+ output_format=output_format,
93
+ memory=memory,
94
+ aggregator=aggregator,
95
+ hooks=kwargs.pop('hooks', None),
96
+ )
97
+ self.agent = create_object(agent)
98
+ self.plugin_executor = plugins and ActionExecutor(
99
+ plugins, hooks=action_hooks)
100
+ self.interpreter_executor = interpreter and ActionExecutor(
101
+ interpreter, hooks=action_hooks)
102
+ if not (self.plugin_executor or self.interpreter_executor):
103
+ warnings.warn(
104
+ 'Neither plugin nor interpreter executor is initialized. '
105
+ 'An exception will be thrown when the agent call a tool.')
106
+ self.finish_condition = finish_condition
107
+ self.max_turn = max_turn
108
+ super().__init__(**kwargs)
109
+
110
+ def forward(self, message: AgentMessage, session_id=0, **kwargs):
111
+ if isinstance(message, str):
112
+ message = AgentMessage(sender='user', content=message)
113
+ for _ in range(self.max_turn):
114
+ message = self.agent(message, session_id=session_id, **kwargs)
115
+ assert isinstance(message.formatted, dict)
116
+ if self.finish_condition(message):
117
+ return message
118
+ if message.formatted['tool_type']:
119
+ tool_type = message.formatted["tool_type"]
120
+ executor = getattr(self, f'{tool_type}_executor', None)
121
+ if not executor:
122
+ raise RuntimeError(f'No available {tool_type} executor')
123
+ message = executor(message, session_id=session_id)
124
+ return message
125
+
126
+ def get_steps(self, session_id=0):
127
+ steps, tool_type = [], None
128
+ for msg in self.agent.memory.get_memory(session_id):
129
+ if msg.sender == self.agent.name:
130
+ steps.append(
131
+ dict(role='thought', content=msg.formatted['thought']))
132
+ if msg.formatted['tool_type']:
133
+ tool_type = msg.formatted['tool_type']
134
+ steps.append(
135
+ dict(
136
+ role='tool',
137
+ content=msg.formatted['action'],
138
+ name=tool_type))
139
+ elif msg.sender != 'user':
140
+ feedback = dict(role='environment', content=msg.content)
141
+ if tool_type:
142
+ feedback['name'] = tool_type
143
+ steps.append(feedback)
144
+ return steps
145
+
146
+
147
+ class MathCoder(AgentForInternLM):
148
+
149
+ def __init__(
150
+ self,
151
+ llm: Union[BaseLLM, Dict],
152
+ interpreter: dict = dict(
153
+ type=IPythonInteractive, timeout=20, max_out_len=8192),
154
+ template: Union[str, dict, List[dict]] = None,
155
+ memory: Dict = dict(type=Memory),
156
+ output_format: Dict = dict(
157
+ type=InterpreterParser,
158
+ template=
159
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
160
+ 'using the following guidelines:\n'
161
+ '- Analyze the question and write jupyter code to solve the problem;\n'
162
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
163
+ 'units. \n')),
164
+ aggregator: Dict = dict(type=InternLMToolAggregator),
165
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
166
+ finish_condition: Callable[
167
+ [AgentMessage],
168
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
169
+ max_turn: int = 6,
170
+ **kwargs,
171
+ ):
172
+ kwargs.pop('plugins', None)
173
+ super().__init__(
174
+ llm=llm,
175
+ interpreter=interpreter,
176
+ template=template,
177
+ memory=memory,
178
+ output_format=output_format,
179
+ aggregator=aggregator,
180
+ action_hooks=action_hooks,
181
+ finish_condition=finish_condition,
182
+ max_turn=max_turn,
183
+ **kwargs)
184
+
185
+
186
+ class AsyncAgentForInternLM(AsyncAgent):
187
+
188
+ _INTERNAL_AGENT_CLS = AsyncAgent
189
+
190
+ def __init__(
191
+ self,
192
+ llm: Union[BaseLLM, Dict],
193
+ plugins: Union[dict, List[dict]] = None,
194
+ interpreter: dict = None,
195
+ template: Union[str, dict, List[dict]] = None,
196
+ memory: Dict = dict(type=Memory),
197
+ output_format: Dict = dict(
198
+ type=MixedToolParser,
199
+ template=META_CN,
200
+ parsers=[
201
+ dict(type=PluginParser, template=PLUGIN_CN),
202
+ dict(type=InterpreterParser, template=INTERPRETER_CN),
203
+ ]),
204
+ aggregator: Dict = dict(type=InternLMToolAggregator),
205
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
206
+ finish_condition: Callable[
207
+ [AgentMessage],
208
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
209
+ max_turn: int = 4,
210
+ **kwargs,
211
+ ):
212
+ agent = dict(
213
+ type=self._INTERNAL_AGENT_CLS,
214
+ llm=llm,
215
+ template=template,
216
+ output_format=output_format,
217
+ memory=memory,
218
+ aggregator=aggregator,
219
+ hooks=kwargs.pop('hooks', None),
220
+ )
221
+ self.agent = create_object(agent)
222
+ self.plugin_executor = plugins and AsyncActionExecutor(
223
+ plugins, hooks=action_hooks)
224
+ self.interpreter_executor = interpreter and AsyncActionExecutor(
225
+ interpreter, hooks=action_hooks)
226
+ if not (self.plugin_executor or self.interpreter_executor):
227
+ warnings.warn(
228
+ 'Neither plugin nor interpreter executor is initialized. '
229
+ 'An exception will be thrown when the agent call a tool.')
230
+ self.finish_condition = finish_condition
231
+ self.max_turn = max_turn
232
+ super().__init__(**kwargs)
233
+
234
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
235
+ if isinstance(message, str):
236
+ message = AgentMessage(sender='user', content=message)
237
+ for _ in range(self.max_turn):
238
+ message = await self.agent(
239
+ message, session_id=session_id, **kwargs)
240
+ assert isinstance(message.formatted, dict)
241
+ if self.finish_condition(message):
242
+ return message
243
+ if message.formatted['tool_type']:
244
+ tool_type = message.formatted["tool_type"]
245
+ executor = getattr(self, f'{tool_type}_executor', None)
246
+ if not executor:
247
+ raise RuntimeError(f'No available {tool_type} executor')
248
+ message = await executor(message, session_id=session_id)
249
+ return message
250
+
251
+ def get_steps(self, session_id=0):
252
+ steps, tool_type = [], None
253
+ for msg in self.agent.memory.get_memory(session_id):
254
+ if msg.sender == self.agent.name:
255
+ steps.append(
256
+ dict(role='thought', content=msg.formatted['thought']))
257
+ if msg.formatted['tool_type']:
258
+ tool_type = msg.formatted['tool_type']
259
+ steps.append(
260
+ dict(
261
+ role='tool',
262
+ content=msg.formatted['action'],
263
+ name=tool_type))
264
+ elif msg.sender != 'user':
265
+ feedback = dict(role='environment', content=msg.content)
266
+ if tool_type:
267
+ feedback['name'] = tool_type
268
+ steps.append(feedback)
269
+ return steps
270
+
271
+
272
+ class AsyncMathCoder(AsyncAgentForInternLM):
273
+
274
+ def __init__(
275
+ self,
276
+ llm: Union[BaseLLM, Dict],
277
+ interpreter: dict = dict(type=AsyncIPythonInterpreter),
278
+ template: Union[str, dict, List[dict]] = None,
279
+ memory: Dict = dict(type=Memory),
280
+ output_format: Dict = dict(
281
+ type=InterpreterParser,
282
+ template=
283
+ ('Integrate step-by-step reasoning and Python code to solve math problems '
284
+ 'using the following guidelines:\n'
285
+ '- Analyze the question and write jupyter code to solve the problem;\n'
286
+ r"- Present the final result in LaTeX using a '\boxed{{}}' without any "
287
+ 'units. \n')),
288
+ aggregator: Dict = dict(type=InternLMToolAggregator),
289
+ action_hooks: List = [dict(type=InternLMActionProcessor)],
290
+ finish_condition: Callable[
291
+ [AgentMessage],
292
+ bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL,
293
+ max_turn: int = 6,
294
+ **kwargs,
295
+ ):
296
+ kwargs.pop('plugins', None)
297
+ super().__init__(
298
+ llm=llm,
299
+ interpreter=interpreter,
300
+ template=template,
301
+ memory=memory,
302
+ output_format=output_format,
303
+ aggregator=aggregator,
304
+ action_hooks=action_hooks,
305
+ finish_condition=finish_condition,
306
+ max_turn=max_turn,
307
+ **kwargs)
308
+
309
+ async def forward(self, message: AgentMessage, session_id=0, **kwargs):
310
+ try:
311
+ return await super().forward(message, session_id, **kwargs)
312
+ finally:
313
+ interpreter = next(
314
+ iter(self.interpreter_executor.actions.values()))
315
+ if interpreter.name == 'AsyncIPythonInterpreter':
316
+ await interpreter.close_session(session_id)
lagent/distributed/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .http_serve import AgentAPIServer, AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer
2
+ from .ray_serve import AgentRayActor, AsyncAgentRayActor
3
+
4
+ __all__ = [
5
+ 'AsyncAgentRayActor', 'AgentRayActor', 'HTTPAgentServer',
6
+ 'HTTPAgentClient', 'AsyncHTTPAgentServer', 'AsyncHTTPAgentClient',
7
+ 'AgentAPIServer'
8
+ ]
lagent/distributed/http_serve/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from .api_server import AsyncHTTPAgentClient, AsyncHTTPAgentServer, HTTPAgentClient, HTTPAgentServer
2
+ from .app import AgentAPIServer
3
+
4
+ __all__ = [
5
+ 'HTTPAgentServer', 'HTTPAgentClient', 'AsyncHTTPAgentClient',
6
+ 'AsyncHTTPAgentServer', 'AgentAPIServer'
7
+ ]
lagent/distributed/http_serve/api_server.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ import time
6
+
7
+ import aiohttp
8
+ import requests
9
+
10
+ from lagent.schema import AgentMessage
11
+
12
+
13
+ class HTTPAgentClient:
14
+
15
+ def __init__(self, host='127.0.0.1', port=8090, timeout=None):
16
+ self.host = host
17
+ self.port = port
18
+ self.timeout = timeout
19
+
20
+ @property
21
+ def is_alive(self):
22
+ try:
23
+ resp = requests.get(
24
+ f'http://{self.host}:{self.port}/health_check',
25
+ timeout=self.timeout)
26
+ return resp.status_code == 200
27
+ except:
28
+ return False
29
+
30
+ def __call__(self, *message, session_id: int = 0, **kwargs):
31
+ response = requests.post(
32
+ f'http://{self.host}:{self.port}/chat_completion',
33
+ json={
34
+ 'message': [
35
+ m if isinstance(m, str) else m.model_dump()
36
+ for m in message
37
+ ],
38
+ 'session_id': session_id,
39
+ **kwargs,
40
+ },
41
+ headers={'Content-Type': 'application/json'},
42
+ timeout=self.timeout)
43
+ resp = response.json()
44
+ if response.status_code != 200:
45
+ return resp
46
+ return AgentMessage.model_validate(resp)
47
+
48
+ def state_dict(self, session_id: int = 0):
49
+ resp = requests.get(
50
+ f'http://{self.host}:{self.port}/memory/{session_id}',
51
+ timeout=self.timeout)
52
+ return resp.json()
53
+
54
+
55
+ class HTTPAgentServer(HTTPAgentClient):
56
+
57
+ def __init__(self, gpu_id, config, host='127.0.0.1', port=8090):
58
+ super().__init__(host, port)
59
+ self.gpu_id = gpu_id
60
+ self.config = config
61
+ self.start_server()
62
+
63
+ def start_server(self):
64
+ # set CUDA_VISIBLE_DEVICES in subprocess
65
+ env = os.environ.copy()
66
+ env['CUDA_VISIBLE_DEVICES'] = self.gpu_id
67
+ cmds = [
68
+ sys.executable, 'lagent/distributed/http_serve/app.py', '--host',
69
+ self.host, '--port',
70
+ str(self.port), '--config',
71
+ json.dumps(self.config)
72
+ ]
73
+ self.process = subprocess.Popen(
74
+ cmds,
75
+ env=env,
76
+ stdout=subprocess.PIPE,
77
+ stderr=subprocess.STDOUT,
78
+ text=True)
79
+
80
+ while True:
81
+ output = self.process.stdout.readline()
82
+ if not output: # 如果读到 EOF,跳出循环
83
+ break
84
+ sys.stdout.write(output) # 打印到标准输出
85
+ sys.stdout.flush()
86
+ if 'Uvicorn running on' in output: # 根据实际输出调整
87
+ break
88
+ time.sleep(0.1)
89
+
90
+ def shutdown(self):
91
+ self.process.terminate()
92
+ self.process.wait()
93
+
94
+
95
+ class AsyncHTTPAgentMixin:
96
+
97
+ async def __call__(self, *message, session_id: int = 0, **kwargs):
98
+ async with aiohttp.ClientSession(
99
+ timeout=aiohttp.ClientTimeout(self.timeout)) as session:
100
+ async with session.post(
101
+ f'http://{self.host}:{self.port}/chat_completion',
102
+ json={
103
+ 'message': [
104
+ m if isinstance(m, str) else m.model_dump()
105
+ for m in message
106
+ ],
107
+ 'session_id': session_id,
108
+ **kwargs,
109
+ },
110
+ headers={'Content-Type': 'application/json'},
111
+ ) as response:
112
+ resp = await response.json()
113
+ if response.status != 200:
114
+ return resp
115
+ return AgentMessage.model_validate(resp)
116
+
117
+
118
+ class AsyncHTTPAgentClient(AsyncHTTPAgentMixin, HTTPAgentClient):
119
+ pass
120
+
121
+
122
+ class AsyncHTTPAgentServer(AsyncHTTPAgentMixin, HTTPAgentServer):
123
+ pass
lagent/distributed/http_serve/app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import time
5
+
6
+ import uvicorn
7
+ from fastapi import FastAPI, HTTPException
8
+ from fastapi.middleware.cors import CORSMiddleware
9
+ from fastapi.requests import Request
10
+
11
+ from lagent.schema import AgentMessage
12
+ from lagent.utils import load_class_from_string
13
+
14
+
15
+ class AgentAPIServer:
16
+
17
+ def __init__(self,
18
+ config: dict,
19
+ host: str = '127.0.0.1',
20
+ port: int = 8090):
21
+ self.app = FastAPI(docs_url='/')
22
+ self.app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=['*'],
25
+ allow_credentials=True,
26
+ allow_methods=['*'],
27
+ allow_headers=['*'],
28
+ )
29
+ cls_name = config.pop('type')
30
+ python_path = config.pop('python_path', None)
31
+ cls_name = load_class_from_string(cls_name, python_path) if isinstance(
32
+ cls_name, str) else cls_name
33
+ self.agent = cls_name(**config)
34
+ self.setup_routes()
35
+ self.run(host, port)
36
+
37
+ def setup_routes(self):
38
+
39
+ def heartbeat():
40
+ return {'status': 'success', 'timestamp': time.time()}
41
+
42
+ async def process_message(request: Request):
43
+ try:
44
+ body = await request.json()
45
+ message = [
46
+ m if isinstance(m, str) else AgentMessage.model_validate(m)
47
+ for m in body.pop('message')
48
+ ]
49
+ result = await self.agent(*message, **body)
50
+ return result
51
+ except Exception as e:
52
+ logging.error(f'Error processing message: {str(e)}')
53
+ raise HTTPException(
54
+ status_code=500, detail='Internal Server Error')
55
+
56
+ def get_memory(session_id: int = 0):
57
+ try:
58
+ result = self.agent.state_dict(session_id)
59
+ return result
60
+ except KeyError:
61
+ raise HTTPException(
62
+ status_code=404, detail="Session ID not found")
63
+ except Exception as e:
64
+ logging.error(f'Error processing message: {str(e)}')
65
+ raise HTTPException(
66
+ status_code=500, detail='Internal Server Error')
67
+
68
+ self.app.add_api_route('/health_check', heartbeat, methods=['GET'])
69
+ self.app.add_api_route(
70
+ '/chat_completion', process_message, methods=['POST'])
71
+ self.app.add_api_route(
72
+ '/memory/{session_id}', get_memory, methods=['GET'])
73
+
74
+ def run(self, host='127.0.0.1', port=8090):
75
+ logging.info(f'Starting server at {host}:{port}')
76
+ uvicorn.run(self.app, host=host, port=port)
77
+
78
+
79
+ def parse_args():
80
+ parser = argparse.ArgumentParser(description='Async Agent API Server')
81
+ parser.add_argument('--host', type=str, default='127.0.0.1')
82
+ parser.add_argument('--port', type=int, default=8090)
83
+ parser.add_argument(
84
+ '--config',
85
+ type=json.loads,
86
+ required=True,
87
+ help='JSON configuration for the agent')
88
+ args = parser.parse_args()
89
+
90
+ return args
91
+
92
+
93
+ if __name__ == '__main__':
94
+ logging.basicConfig(level=logging.INFO)
95
+ args = parse_args()
96
+ AgentAPIServer(args.config, host=args.host, port=args.port)
lagent/distributed/ray_serve/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .ray_warpper import AgentRayActor, AsyncAgentRayActor
2
+
3
+ __all__ = ['AsyncAgentRayActor', 'AgentRayActor']
lagent/distributed/ray_serve/ray_warpper.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import sys
3
+ from typing import Dict
4
+
5
+ import ray
6
+
7
+ from lagent.schema import AgentMessage
8
+ from lagent.utils import load_class_from_string
9
+
10
+
11
+ class AsyncAgentRayActor:
12
+
13
+ def __init__(
14
+ self,
15
+ config: Dict,
16
+ num_gpus: int,
17
+ ):
18
+ cls_name = config.pop('type')
19
+ python_path = config.pop('python_path', None)
20
+ cls_name = load_class_from_string(cls_name, python_path) if isinstance(
21
+ cls_name, str) else cls_name
22
+ AsyncAgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
23
+ self.agent_actor = AsyncAgentActor.remote(**config)
24
+
25
+ async def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
26
+ response = await self.agent_actor.__call__.remote(
27
+ *message, session_id=session_id, **kwargs)
28
+ return response
29
+
30
+
31
+ class AgentRayActor:
32
+
33
+ def __init__(
34
+ self,
35
+ config: Dict,
36
+ num_gpus: int,
37
+ ):
38
+ cls_name = config.pop('type')
39
+ python_path = config.pop('python_path', None)
40
+ cls_name = load_class_from_string(cls_name, python_path) if isinstance(
41
+ cls_name, str) else cls_name
42
+ AgentActor = ray.remote(num_gpus=num_gpus)(cls_name)
43
+ self.agent_actor = AgentActor.remote(**config)
44
+
45
+ def __call__(self, *message: AgentMessage, session_id=0, **kwargs):
46
+ response = self.agent_actor.__call__.remote(
47
+ *message, session_id=session_id, **kwargs)
48
+ return ray.get(response)
lagent/hooks/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from .action_preprocessor import ActionPreprocessor, InternLMActionProcessor
2
+ from .hook import Hook, RemovableHandle
3
+ from .logger import MessageLogger
4
+
5
+ __all__ = [
6
+ 'Hook', 'RemovableHandle', 'ActionPreprocessor', 'InternLMActionProcessor',
7
+ 'MessageLogger'
8
+ ]
lagent/hooks/action_preprocessor.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall
4
+ from .hook import Hook
5
+
6
+
7
+ class ActionPreprocessor(Hook):
8
+ """The ActionPreprocessor is a hook that preprocesses the action message
9
+ and postprocesses the action return message.
10
+
11
+ """
12
+
13
+ def before_action(self, executor, message, session_id):
14
+ assert isinstance(message.formatted, FunctionCall) or (
15
+ isinstance(message.formatted, dict) and 'name' in message.content
16
+ and 'parameters' in message.formatted) or (
17
+ 'action' in message.formatted
18
+ and 'parameters' in message.formatted['action']
19
+ and 'name' in message.formatted['action'])
20
+ if isinstance(message.formatted, dict):
21
+ name = message.formatted.get('name',
22
+ message.formatted['action']['name'])
23
+ parameters = message.formatted.get(
24
+ 'parameters', message.formatted['action']['parameters'])
25
+ else:
26
+ name = message.formatted.name
27
+ parameters = message.formatted.parameters
28
+ message.content = dict(name=name, parameters=parameters)
29
+ return message
30
+
31
+ def after_action(self, executor, message, session_id):
32
+ action_return = message.content
33
+ if isinstance(action_return, ActionReturn):
34
+ if action_return.state == ActionStatusCode.SUCCESS:
35
+ response = action_return.format_result()
36
+ else:
37
+ response = action_return.errmsg
38
+ else:
39
+ response = action_return
40
+ message.content = response
41
+ return message
42
+
43
+
44
+ class InternLMActionProcessor(ActionPreprocessor):
45
+
46
+ def __init__(self, code_parameter: str = 'command'):
47
+ self.code_parameter = code_parameter
48
+
49
+ def before_action(self, executor, message, session_id):
50
+ message = deepcopy(message)
51
+ assert isinstance(message.formatted, dict) and set(
52
+ message.formatted).issuperset(
53
+ {'tool_type', 'thought', 'action', 'status'})
54
+ if isinstance(message.formatted['action'], str):
55
+ # encapsulate code interpreter arguments
56
+ action_name = next(iter(executor.actions))
57
+ parameters = {self.code_parameter: message.formatted['action']}
58
+ if action_name in ['AsyncIPythonInterpreter']:
59
+ parameters['session_id'] = session_id
60
+ message.formatted['action'] = dict(
61
+ name=action_name, parameters=parameters)
62
+ return super().before_action(executor, message, session_id)
lagent/hooks/hook.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import count
2
+ from typing import Tuple
3
+
4
+ from lagent.schema import AgentMessage
5
+
6
+
7
+ class Hook:
8
+
9
+ def before_agent(
10
+ self,
11
+ agent,
12
+ message: Tuple[AgentMessage],
13
+ session_id: int,
14
+ ):
15
+ pass
16
+
17
+ def after_agent(
18
+ self,
19
+ agent,
20
+ message: AgentMessage,
21
+ session_id: int,
22
+ ):
23
+ pass
24
+
25
+ def before_action(
26
+ self,
27
+ executor,
28
+ message: AgentMessage,
29
+ session_id: int,
30
+ ):
31
+ pass
32
+
33
+ def after_action(
34
+ self,
35
+ executor,
36
+ message: AgentMessage,
37
+ session_id: int,
38
+ ):
39
+ pass
40
+
41
+
42
+ class RemovableHandle:
43
+ _id_iter = count(0)
44
+
45
+ def __init__(self, hooks_dict):
46
+ self.hooks_dict = hooks_dict
47
+ self.id = next(self._id_iter)
48
+
49
+ def remove(self):
50
+ del self.hooks_dict[self.id]
lagent/hooks/logger.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Optional
3
+
4
+ from termcolor import COLORS, colored
5
+
6
+ from lagent.utils import get_logger
7
+ from .hook import Hook
8
+
9
+
10
+ class MessageLogger(Hook):
11
+
12
+ def __init__(self, name: str = 'lagent'):
13
+ self.logger = get_logger(
14
+ name, 'info', '%(asctime)s %(levelname)8s %(name)8s - %(message)s')
15
+ self.sender2color = {}
16
+
17
+ def before_agent(self, agent, messages, session_id):
18
+ for message in messages:
19
+ self._process_message(message, session_id)
20
+
21
+ def after_agent(self, agent, message, session_id):
22
+ self._process_message(message, session_id)
23
+
24
+ def before_action(self, executor, message, session_id):
25
+ self._process_message(message, session_id)
26
+
27
+ def after_action(self, executor, message, session_id):
28
+ self._process_message(message, session_id)
29
+
30
+ def _process_message(self, message, session_id):
31
+ sender = message.sender
32
+ color = self.sender2color.setdefault(sender,
33
+ random.choice(list(COLORS)))
34
+ self.logger.info(
35
+ colored(
36
+ f'session id: {session_id}, message sender: {sender}\n'
37
+ f'{message.content}', color))
lagent/llms/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_api import AsyncBaseAPILLM, BaseAPILLM
2
+ from .base_llm import AsyncBaseLLM, BaseLLM
3
+ from .huggingface import HFTransformer, HFTransformerCasualLM, HFTransformerChat
4
+ from .lmdeploy_wrapper import (AsyncLMDeployClient, AsyncLMDeployPipeline,
5
+ AsyncLMDeployServer, LMDeployClient,
6
+ LMDeployPipeline, LMDeployServer)
7
+ from .meta_template import INTERNLM2_META
8
+ from .openai import GPTAPI, AsyncGPTAPI
9
+ from .sensenova import SensenovaAPI
10
+ from .vllm_wrapper import AsyncVllmModel, VllmModel
11
+
12
+ __all__ = [
13
+ 'AsyncBaseLLM',
14
+ 'BaseLLM',
15
+ 'AsyncBaseAPILLM',
16
+ 'BaseAPILLM',
17
+ 'AsyncGPTAPI',
18
+ 'GPTAPI',
19
+ 'LMDeployClient',
20
+ 'AsyncLMDeployClient',
21
+ 'LMDeployPipeline',
22
+ 'AsyncLMDeployPipeline',
23
+ 'LMDeployServer',
24
+ 'AsyncLMDeployServer',
25
+ 'HFTransformer',
26
+ 'HFTransformerCasualLM',
27
+ 'INTERNLM2_META',
28
+ 'HFTransformerChat',
29
+ 'VllmModel',
30
+ 'AsyncVllmModel',
31
+ 'SensenovaAPI',
32
+ ]
lagent/llms/base_api.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+ from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
5
+
6
+
7
+ class APITemplateParser:
8
+ """Intermidate prompt template parser, specifically for API models.
9
+
10
+ Args:
11
+ meta_template (Dict): The meta template for the model.
12
+ """
13
+
14
+ def __init__(self, meta_template: Optional[Dict] = None):
15
+ self.meta_template = meta_template
16
+ # Check meta template
17
+ if meta_template:
18
+ assert isinstance(meta_template, list)
19
+ self.roles: Dict[str, dict] = dict() # maps role name to config
20
+ for item in meta_template:
21
+ assert isinstance(item, dict)
22
+ assert item['role'] not in self.roles, \
23
+ 'role in meta prompt must be unique!'
24
+ self.roles[item['role']] = item.copy()
25
+
26
+ def __call__(self, dialog: List[Union[str, List]]):
27
+ """Parse the intermidate prompt template, and wrap it with meta
28
+ template if applicable. When the meta template is set and the input is
29
+ a list, the return value will be a list containing the full
30
+ conversation history. Each item looks like:
31
+
32
+ .. code-block:: python
33
+
34
+ {'role': 'user', 'content': '...'}).
35
+
36
+ Args:
37
+ dialog (List[str or list]): An intermidate prompt
38
+ template (potentially before being wrapped by meta template).
39
+
40
+ Returns:
41
+ List[str or list]: The finalized prompt or a conversation.
42
+ """
43
+ assert isinstance(dialog, (str, list))
44
+ if isinstance(dialog, str):
45
+ return dialog
46
+ if self.meta_template:
47
+
48
+ prompt = list()
49
+ # Whether to keep generating the prompt
50
+ generate = True
51
+ for i, item in enumerate(dialog):
52
+ if not generate:
53
+ break
54
+ if isinstance(item, str):
55
+ if item.strip():
56
+ # TODO: logger
57
+ warnings.warn('Non-empty string in prompt template '
58
+ 'will be ignored in API models.')
59
+ else:
60
+ api_prompts = self._prompt2api(item)
61
+ prompt.append(api_prompts)
62
+
63
+ # merge the consecutive prompts assigned to the same role
64
+ new_prompt = list([prompt[0]])
65
+ last_role = prompt[0]['role']
66
+ for item in prompt[1:]:
67
+ if item['role'] == last_role:
68
+ new_prompt[-1]['content'] += '\n' + item['content']
69
+ else:
70
+ last_role = item['role']
71
+ new_prompt.append(item)
72
+ prompt = new_prompt
73
+
74
+ else:
75
+ # in case the model does not have any meta template
76
+ prompt = ''
77
+ last_sep = ''
78
+ for item in dialog:
79
+ if isinstance(item, str):
80
+ if item:
81
+ prompt += last_sep + item
82
+ elif item.get('content', ''):
83
+ prompt += last_sep + item.get('content', '')
84
+ last_sep = '\n'
85
+ return prompt
86
+
87
+ def _prompt2api(self, prompts: Union[List, str]) -> Tuple[str, bool]:
88
+ """Convert the prompts to a API-style prompts, given an updated
89
+ role_dict.
90
+
91
+ Args:
92
+ prompts (Union[List, str]): The prompts to be converted.
93
+ role_dict (Dict[str, Dict]): The updated role dict.
94
+ for_gen (bool): If True, the prompts will be converted for
95
+ generation tasks. The conversion stops before the first
96
+ role whose "generate" is set to True.
97
+
98
+ Returns:
99
+ Tuple[str, bool]: The converted string, and whether the follow-up
100
+ conversion should be proceeded.
101
+ """
102
+ if isinstance(prompts, str):
103
+ return prompts
104
+ elif isinstance(prompts, dict):
105
+ api_role = self._role2api_role(prompts)
106
+ return api_role
107
+
108
+ res = []
109
+ for prompt in prompts:
110
+ if isinstance(prompt, str):
111
+ raise TypeError('Mixing str without explicit role is not '
112
+ 'allowed in API models!')
113
+ else:
114
+ api_role = self._role2api_role(prompt)
115
+ res.append(api_role)
116
+ return res
117
+
118
+ def _role2api_role(self, role_prompt: Dict) -> Tuple[str, bool]:
119
+ merged_prompt = self.roles[role_prompt['role']]
120
+ if merged_prompt.get('fallback_role'):
121
+ merged_prompt = self.roles[self.roles[
122
+ merged_prompt['fallback_role']]]
123
+ res = role_prompt.copy()
124
+ res['role'] = merged_prompt['api_role']
125
+ res['content'] = merged_prompt.get('begin', '')
126
+ res['content'] += role_prompt.get('content', '')
127
+ res['content'] += merged_prompt.get('end', '')
128
+ return res
129
+
130
+
131
+ class BaseAPILLM(BaseLLM):
132
+ """Base class for API model wrapper.
133
+
134
+ Args:
135
+ model_type (str): The type of model.
136
+ retry (int): Number of retires if the API call fails. Defaults to 2.
137
+ meta_template (Dict, optional): The model's meta prompt
138
+ template if needed, in case the requirement of injecting or
139
+ wrapping of any meta instructions.
140
+ """
141
+
142
+ is_api: bool = True
143
+
144
+ def __init__(self,
145
+ model_type: str,
146
+ retry: int = 2,
147
+ template_parser: 'APITemplateParser' = APITemplateParser,
148
+ meta_template: Optional[Dict] = None,
149
+ *,
150
+ max_new_tokens: int = 512,
151
+ top_p: float = 0.8,
152
+ top_k: int = 40,
153
+ temperature: float = 0.8,
154
+ repetition_penalty: float = 0.0,
155
+ stop_words: Union[List[str], str] = None):
156
+ self.model_type = model_type
157
+ self.meta_template = meta_template
158
+ self.retry = retry
159
+ if template_parser:
160
+ self.template_parser = template_parser(meta_template)
161
+
162
+ if isinstance(stop_words, str):
163
+ stop_words = [stop_words]
164
+ self.gen_params = dict(
165
+ max_new_tokens=max_new_tokens,
166
+ top_p=top_p,
167
+ top_k=top_k,
168
+ temperature=temperature,
169
+ repetition_penalty=repetition_penalty,
170
+ stop_words=stop_words,
171
+ skip_special_tokens=False)
172
+
173
+
174
+ class AsyncBaseAPILLM(AsyncLLMMixin, BaseAPILLM):
175
+ pass
lagent/llms/base_llm.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import copy
2
+ from typing import Dict, List, Optional, Tuple, Union
3
+
4
+
5
+ class LMTemplateParser:
6
+ """Intermidate prompt template parser, specifically for language models.
7
+
8
+ Args:
9
+ meta_template (list of dict, optional): The meta template for the
10
+ model.
11
+ """
12
+
13
+ def __init__(self, meta_template: Optional[List[Dict]] = None):
14
+ self.meta_template = meta_template
15
+ if meta_template:
16
+ assert isinstance(meta_template, list)
17
+ self.roles: Dict[str, dict] = dict() # maps role name to config
18
+ for item in meta_template:
19
+ assert isinstance(item, dict)
20
+ assert item['role'] not in self.roles, \
21
+ 'role in meta prompt must be unique!'
22
+ self.roles[item['role']] = item.copy()
23
+
24
+ def __call__(self, dialog) -> str:
25
+ """Parse a prompt template, and wrap it with meta template if
26
+ applicable.
27
+
28
+ Args:
29
+ dialog (List[str or PromptList]): A prompt
30
+ template (potentially before being wrapped by meta template).
31
+
32
+ Returns:
33
+ str: The final string.
34
+ """
35
+ assert isinstance(dialog, (str, list))
36
+ if isinstance(dialog, str):
37
+ return dialog
38
+ if self.meta_template:
39
+
40
+ prompt = ''
41
+ for index, item in enumerate(dialog):
42
+ if isinstance(item, str):
43
+ prompt += item
44
+ else:
45
+ new_str = self._prompt2str(item, index == len(dialog) - 1)
46
+ prompt += new_str
47
+ else:
48
+ # in case the model does not have any meta template
49
+ prompt = ''
50
+ last_sep = ''
51
+ for item in dialog:
52
+ if isinstance(item, str):
53
+ if item:
54
+ prompt += last_sep + item
55
+ elif item.get('content', ''):
56
+ prompt += last_sep + item.get('prompt', '')
57
+ last_sep = '\n'
58
+ return prompt
59
+
60
+ def _format_begin(self, role_cfg, message):
61
+ name = message.get('name', None)
62
+ if name is not None:
63
+ begin = role_cfg['begin'].get('with_name', '')
64
+ if name in role_cfg['begin'].get('name', {}):
65
+ begin = begin.format(name=role_cfg['begin']['name'][name])
66
+ else:
67
+ begin = begin.format(name=name)
68
+ else:
69
+ if isinstance(role_cfg.get('begin', ''), str):
70
+ begin = role_cfg.get('begin', '')
71
+ elif isinstance(role_cfg['begin'], dict):
72
+ begin = role_cfg['begin'].get('without_name', '')
73
+ return begin
74
+
75
+ def _prompt2str(self,
76
+ prompt: Union[str, Dict],
77
+ last: bool = False) -> Tuple[str, bool]:
78
+ if isinstance(prompt, str):
79
+ return prompt
80
+ merged_prompt = self.roles.get(prompt['role'])
81
+
82
+ if merged_prompt.get('fallback_role'):
83
+ merged_prompt = self.roles.get(merged_prompt['fallback_role'])
84
+ begin = self._format_begin(merged_prompt, prompt)
85
+ res = begin
86
+ if last and merged_prompt.get('generate', False):
87
+ res += prompt.get('content', '')
88
+ return res
89
+ res += prompt.get('content', '') + merged_prompt.get('end', '')
90
+ if last and merged_prompt['role'] != 'assistant':
91
+ res += self._format_begin(self.roles['assistant'], {})
92
+ return res
93
+ return res
94
+
95
+
96
+ class BaseLLM:
97
+ """Base class for model wrapper.
98
+
99
+ Args:
100
+ path (str): The path to the model.
101
+ max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
102
+ to 512.
103
+ tokenizer_only (bool): If True, only the tokenizer will be initialized.
104
+ Defaults to False.
105
+ meta_template (list of dict, optional): The model's meta prompt
106
+ template if needed, in case the requirement of injecting or
107
+ wrapping of any meta instructions.
108
+ """
109
+
110
+ def __init__(self,
111
+ path: str,
112
+ tokenizer_only: bool = False,
113
+ template_parser: 'LMTemplateParser' = LMTemplateParser,
114
+ meta_template: Optional[List[Dict]] = None,
115
+ *,
116
+ max_new_tokens: int = 512,
117
+ top_p: float = 0.8,
118
+ top_k: float = 40,
119
+ temperature: float = 0.8,
120
+ repetition_penalty: float = 1.0,
121
+ stop_words: Union[List[str], str] = None):
122
+ self.path = path
123
+ self.tokenizer_only = tokenizer_only
124
+ # meta template
125
+ self.template_parser = template_parser(meta_template)
126
+ self.eos_token_id = None
127
+ if meta_template and 'eos_token_id' in meta_template:
128
+ self.eos_token_id = meta_template['eos_token_id']
129
+
130
+ if isinstance(stop_words, str):
131
+ stop_words = [stop_words]
132
+ self.gen_params = dict(
133
+ max_new_tokens=max_new_tokens,
134
+ top_p=top_p,
135
+ top_k=top_k,
136
+ temperature=temperature,
137
+ repetition_penalty=repetition_penalty,
138
+ stop_words=stop_words)
139
+
140
+ def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
141
+ """Generate results given a str (or list of) inputs.
142
+
143
+ Args:
144
+ inputs (Union[str, List[str]]):
145
+ gen_params (dict): The input params for generation.
146
+
147
+ Returns:
148
+ Union[str, List[str]]: A (list of) generated strings.
149
+
150
+ eg.
151
+ batched = True
152
+ if isinstance(inputs, str):
153
+ inputs = [inputs]
154
+ batched = False
155
+ response = ['']
156
+ if batched:
157
+ return response
158
+ return response[0]
159
+ """
160
+ raise NotImplementedError
161
+
162
+ def stream_generate(self, inputs: str, **gen_params) -> List[str]:
163
+ """Generate results as streaming given a str inputs.
164
+
165
+ Args:
166
+ inputs (str):
167
+ gen_params (dict): The input params for generation.
168
+
169
+ Returns:
170
+ str: A generated string.
171
+ """
172
+ raise NotImplementedError
173
+
174
+ def chat(self,
175
+ inputs: Union[List[dict], List[List[dict]]],
176
+ session_ids: Union[int, List[int]] = None,
177
+ **gen_params):
178
+ """Generate completion from a list of templates.
179
+
180
+ Args:
181
+ inputs (Union[List[dict], List[List[dict]]]):
182
+ gen_params (dict): The input params for generation.
183
+ Returns:
184
+ """
185
+ if isinstance(inputs[0], list):
186
+ _inputs = list()
187
+ for msg in inputs:
188
+ _inputs.append(self.template_parser(msg))
189
+ else:
190
+ _inputs = self.template_parser(inputs)
191
+ return self.generate(_inputs, **gen_params)
192
+
193
+ def stream_chat(self, inputs: List[dict], **gen_params):
194
+ """Generate results as streaming given a list of templates.
195
+
196
+ Args:
197
+ inputs (Union[List[dict]):
198
+ gen_params (dict): The input params for generation.
199
+ Returns:
200
+ """
201
+ raise NotImplementedError
202
+
203
+ def tokenize(self, prompts: Union[str, List[str], List[dict],
204
+ List[List[dict]]]):
205
+ """Tokenize the input prompts.
206
+
207
+ Args:
208
+ prompts(str | List[str]): user's prompt, or a batch prompts
209
+
210
+ Returns:
211
+ Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
212
+ ids, ids' length and requested output length
213
+ """
214
+ raise NotImplementedError
215
+
216
+ def update_gen_params(self, **kwargs):
217
+ gen_params = copy(self.gen_params)
218
+ gen_params.update(kwargs)
219
+ return gen_params
220
+
221
+
222
+ class AsyncLLMMixin:
223
+
224
+ async def generate(self,
225
+ inputs: Union[str, List[str]],
226
+ session_ids: Union[int, List[int]] = None,
227
+ **gen_params) -> str:
228
+ """Generate results given a str (or list of) inputs.
229
+
230
+ Args:
231
+ inputs (Union[str, List[str]]):
232
+ gen_params (dict): The input params for generation.
233
+
234
+ Returns:
235
+ Union[str, List[str]]: A (list of) generated strings.
236
+
237
+ eg.
238
+ batched = True
239
+ if isinstance(inputs, str):
240
+ inputs = [inputs]
241
+ batched = False
242
+ response = ['']
243
+ if batched:
244
+ return response
245
+ return response[0]
246
+ """
247
+ raise NotImplementedError
248
+
249
+ async def stream_generate(self, inputs: str, **gen_params) -> List[str]:
250
+ """Generate results as streaming given a str inputs.
251
+
252
+ Args:
253
+ inputs (str):
254
+ gen_params (dict): The input params for generation.
255
+
256
+ Returns:
257
+ str: A generated string.
258
+ """
259
+ raise NotImplementedError
260
+
261
+ async def chat(self,
262
+ inputs: Union[List[dict], List[List[dict]]],
263
+ session_ids: Union[int, List[int]] = None,
264
+ **gen_params):
265
+ """Generate completion from a list of templates.
266
+
267
+ Args:
268
+ inputs (Union[List[dict], List[List[dict]]]):
269
+ gen_params (dict): The input params for generation.
270
+ Returns:
271
+ """
272
+ if isinstance(inputs[0], list):
273
+ _inputs = list()
274
+ for msg in inputs:
275
+ _inputs.append(self.template_parser(msg))
276
+ else:
277
+ _inputs = self.template_parser(inputs)
278
+ return await self.generate(_inputs, session_ids, **gen_params)
279
+
280
+ async def stream_chat(self, inputs: List[dict], **gen_params):
281
+ """Generate results as streaming given a list of templates.
282
+
283
+ Args:
284
+ inputs (Union[List[dict]):
285
+ gen_params (dict): The input params for generation.
286
+ Returns:
287
+ """
288
+ raise NotImplementedError
289
+
290
+ async def tokenize(self, prompts: Union[str, List[str], List[dict],
291
+ List[List[dict]]]):
292
+ """Tokenize the input prompts.
293
+
294
+ Args:
295
+ prompts(str | List[str]): user's prompt, or a batch prompts
296
+
297
+ Returns:
298
+ Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
299
+ ids, ids' length and requested output length
300
+ """
301
+ raise NotImplementedError
302
+
303
+
304
+ class AsyncBaseLLM(AsyncLLMMixin, BaseLLM):
305
+ pass
lagent/llms/huggingface.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ from lagent.schema import ModelStatusCode
6
+ from .base_api import APITemplateParser
7
+ from .base_llm import BaseLLM
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class HFTransformer(BaseLLM):
13
+ """Model wrapper around HuggingFace general models.
14
+
15
+ Adapted from Internlm (https://github.com/InternLM/InternLM/blob/main/
16
+ chat/web_demo.py)
17
+
18
+ Args:
19
+ path (str): The name or path to HuggingFace's model.
20
+ tokenizer_path (str): The path to the tokenizer. Defaults to None.
21
+ tokenizer_kwargs (dict): Keyword arguments for the tokenizer.
22
+ Defaults to {}.
23
+ tokenizer_only (bool): If True, only the tokenizer will be initialized.
24
+ Defaults to False.
25
+ model_kwargs (dict): Keyword arguments for the model, used in loader.
26
+ Defaults to dict(device_map='auto').
27
+ meta_template (Dict, optional): The model's meta prompt
28
+ template if needed, in case the requirement of injecting or
29
+ wrapping of any meta instructions.
30
+ """
31
+
32
+ def __init__(self,
33
+ path: str,
34
+ tokenizer_path: Optional[str] = None,
35
+ tokenizer_kwargs: dict = dict(),
36
+ tokenizer_only: bool = False,
37
+ model_kwargs: dict = dict(device_map='auto'),
38
+ meta_template: Optional[Dict] = None,
39
+ stop_words_id: Union[List[int], int] = None,
40
+ **kwargs):
41
+ super().__init__(
42
+ path=path,
43
+ tokenizer_only=tokenizer_only,
44
+ meta_template=meta_template,
45
+ **kwargs)
46
+ if isinstance(stop_words_id, int):
47
+ stop_words_id = [stop_words_id]
48
+ self.gen_params.update(stop_words_id=stop_words_id)
49
+ if self.gen_params['stop_words'] is not None and \
50
+ self.gen_params['stop_words_id'] is not None:
51
+ logger.warning('Both stop_words and stop_words_id are specified,'
52
+ 'only stop_words_id will be used.')
53
+
54
+ self._load_tokenizer(
55
+ path=path,
56
+ tokenizer_path=tokenizer_path,
57
+ tokenizer_kwargs=tokenizer_kwargs)
58
+ if not tokenizer_only:
59
+ self._load_model(path=path, model_kwargs=model_kwargs)
60
+
61
+ from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList # noqa: E501
62
+ self.logits_processor = LogitsProcessorList()
63
+ self.stopping_criteria = StoppingCriteriaList()
64
+ self.prefix_allowed_tokens_fn = None
65
+
66
+ stop_words_id = []
67
+ if self.gen_params.get('stop_words_id'):
68
+ stop_words_id = self.gen_params.get('stop_words_id')
69
+ elif self.gen_params.get('stop_words'):
70
+ for sw in self.gen_params.get('stop_words'):
71
+ stop_words_id.append(self.tokenizer(sw)['input_ids'][-1])
72
+ self.additional_eos_token_id = stop_words_id
73
+
74
+ def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
75
+ tokenizer_kwargs: dict):
76
+ from transformers import AutoTokenizer
77
+ self.tokenizer = AutoTokenizer.from_pretrained(
78
+ tokenizer_path if tokenizer_path else path,
79
+ trust_remote_code=True,
80
+ **tokenizer_kwargs)
81
+
82
+ if self.tokenizer.pad_token_id is None:
83
+ if self.tokenizer.eos_token is not None:
84
+ logger.warning(
85
+ f'Using eos_token_id {self.tokenizer.eos_token} '
86
+ 'as pad_token_id.')
87
+ self.tokenizer.pad_token = self.tokenizer.eos_token
88
+ else:
89
+ from transformers.generation import GenerationConfig
90
+ self.gcfg = GenerationConfig.from_pretrained(path)
91
+
92
+ if self.gcfg.pad_token_id is not None:
93
+ logger.warning(
94
+ f'Using pad_token_id {self.gcfg.pad_token_id} '
95
+ 'as pad_token_id.')
96
+ self.tokenizer.pad_token_id = self.gcfg.pad_token_id
97
+ else:
98
+ raise ValueError(
99
+ 'pad_token_id is not set for this tokenizer. Try to '
100
+ 'set pad_token_id via passing '
101
+ '`pad_token_id={PAD_TOKEN_ID}` in model_cfg.')
102
+
103
+ def _load_model(self, path: str, model_kwargs: dict):
104
+ import torch
105
+ from transformers import AutoModel
106
+ model_kwargs.setdefault('torch_dtype', torch.float16)
107
+ self.model = AutoModel.from_pretrained(
108
+ path, trust_remote_code=True, **model_kwargs)
109
+ self.model.eval()
110
+
111
+ def tokenize(self, inputs: str):
112
+ assert isinstance(inputs, str)
113
+ inputs = self.tokenizer(
114
+ inputs, return_tensors='pt', return_length=True)
115
+ return inputs['input_ids'].tolist()
116
+
117
+ def generate(
118
+ self,
119
+ inputs: Union[str, List[str]],
120
+ do_sample: bool = True,
121
+ **kwargs,
122
+ ):
123
+ """Return the chat completions in non-stream mode.
124
+
125
+ Args:
126
+ inputs (Union[str, List[str]]): input texts to be completed.
127
+ do_sample (bool): do sampling if enabled
128
+ Returns:
129
+ (a list of/batched) text/chat completion
130
+ """
131
+ for status, chunk, _ in self.stream_generate(inputs, do_sample,
132
+ **kwargs):
133
+ response = chunk
134
+ return response
135
+
136
+ def stream_generate(
137
+ self,
138
+ inputs: List[str],
139
+ do_sample: bool = True,
140
+ **kwargs,
141
+ ):
142
+ """Return the chat completions in stream mode.
143
+
144
+ Args:
145
+ inputs (Union[str, List[str]]): input texts to be completed.
146
+ do_sample (bool): do sampling if enabled
147
+ Returns:
148
+ tuple(Status, str, int): status, text/chat completion,
149
+ generated token number
150
+ """
151
+ import torch
152
+ from torch import nn
153
+ with torch.no_grad():
154
+ batched = True
155
+ if isinstance(inputs, str):
156
+ inputs = [inputs]
157
+ batched = False
158
+ inputs = self.tokenizer(
159
+ inputs, padding=True, return_tensors='pt', return_length=True)
160
+ input_length = inputs['length']
161
+ for k, v in inputs.items():
162
+ inputs[k] = v.cuda()
163
+ input_ids = inputs['input_ids']
164
+ attention_mask = inputs['attention_mask']
165
+ batch_size = input_ids.shape[0]
166
+ input_ids_seq_length = input_ids.shape[-1]
167
+ generation_config = self.model.generation_config
168
+ generation_config = copy.deepcopy(generation_config)
169
+ new_gen_params = self.update_gen_params(**kwargs)
170
+ generation_config.update(**new_gen_params)
171
+ generation_config.update(**kwargs)
172
+ model_kwargs = generation_config.to_dict()
173
+ model_kwargs['attention_mask'] = attention_mask
174
+ _, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
175
+ generation_config.bos_token_id,
176
+ generation_config.eos_token_id,
177
+ )
178
+ if eos_token_id is None:
179
+ if self.gcfg.eos_token_id is not None:
180
+ eos_token_id = self.gcfg.eos_token_id
181
+ else:
182
+ eos_token_id = []
183
+ if isinstance(eos_token_id, int):
184
+ eos_token_id = [eos_token_id]
185
+ if self.additional_eos_token_id is not None:
186
+ eos_token_id.extend(self.additional_eos_token_id)
187
+ eos_token_id_tensor = torch.tensor(eos_token_id).to(
188
+ input_ids.device) if eos_token_id is not None else None
189
+ generation_config.max_length = (
190
+ generation_config.max_new_tokens + input_ids_seq_length)
191
+ # Set generation parameters if not already defined
192
+ logits_processor = self.logits_processor
193
+ stopping_criteria = self.stopping_criteria
194
+
195
+ logits_processor = self.model._get_logits_processor(
196
+ generation_config=generation_config,
197
+ input_ids_seq_length=input_ids_seq_length,
198
+ encoder_input_ids=input_ids,
199
+ prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
200
+ logits_processor=logits_processor,
201
+ )
202
+
203
+ stopping_criteria = self.model._get_stopping_criteria(
204
+ generation_config=generation_config,
205
+ stopping_criteria=stopping_criteria)
206
+ logits_warper = self.model._get_logits_warper(generation_config)
207
+
208
+ unfinished_sequences = input_ids.new(batch_size).fill_(1)
209
+ scores = None
210
+ while True:
211
+ model_inputs = self.model.prepare_inputs_for_generation(
212
+ input_ids, **model_kwargs)
213
+ # forward pass to get next token
214
+ outputs = self.model(
215
+ **model_inputs,
216
+ return_dict=True,
217
+ output_attentions=False,
218
+ output_hidden_states=False,
219
+ )
220
+
221
+ next_token_logits = outputs.logits[:, -1, :]
222
+
223
+ # pre-process distribution
224
+ next_token_scores = logits_processor(input_ids,
225
+ next_token_logits)
226
+ next_token_scores = logits_warper(input_ids, next_token_scores)
227
+
228
+ # sample
229
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
230
+ if do_sample:
231
+ next_tokens = torch.multinomial(
232
+ probs, num_samples=1).squeeze(1)
233
+ else:
234
+ next_tokens = torch.argmax(probs, dim=-1)
235
+
236
+ # update generated ids, model inputs,
237
+ # and length for next step
238
+ input_ids = torch.cat([input_ids, next_tokens[:, None]],
239
+ dim=-1)
240
+ model_kwargs = self.model._update_model_kwargs_for_generation( # noqa: E501
241
+ outputs,
242
+ model_kwargs,
243
+ is_encoder_decoder=False)
244
+ unfinished_sequences = unfinished_sequences.mul(
245
+ next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(
246
+ eos_token_id_tensor.unsqueeze(1)).prod(dim=0))
247
+ output_token_ids = input_ids.cpu().tolist()
248
+ for i in range(len(output_token_ids)):
249
+ output_token_ids[i] = output_token_ids[i][:][
250
+ input_length[i]:]
251
+ # Find the first occurrence of
252
+ # an EOS token in the sequence
253
+ first_eos_idx = next(
254
+ (idx
255
+ for idx, token_id in enumerate(output_token_ids[i])
256
+ if token_id in eos_token_id), None)
257
+ # If an EOS token is found, only the previous
258
+ # part of it is retained
259
+ if first_eos_idx is not None:
260
+ output_token_ids[i] = output_token_ids[
261
+ i][:first_eos_idx]
262
+
263
+ response = self.tokenizer.batch_decode(output_token_ids)
264
+ # print(response)
265
+ if not batched:
266
+ response = response[0]
267
+ yield ModelStatusCode.STREAM_ING, response, None
268
+ # stop when each sentence is finished,
269
+ # or if we exceed the maximum length
270
+ if (unfinished_sequences.max() == 0
271
+ or stopping_criteria(input_ids, scores)):
272
+ break
273
+ yield ModelStatusCode.END, response, None
274
+
275
+ def stream_chat(
276
+ self,
277
+ inputs: List[dict],
278
+ do_sample: bool = True,
279
+ **kwargs,
280
+ ):
281
+ """Return the chat completions in stream mode.
282
+
283
+ Args:
284
+ inputs (List[dict]): input messages to be completed.
285
+ do_sample (bool): do sampling if enabled
286
+ Returns:
287
+ the text/chat completion
288
+ """
289
+ prompt = self.template_parser(inputs)
290
+ yield from self.stream_generate(prompt, do_sample, **kwargs)
291
+
292
+
293
+ class HFTransformerCasualLM(HFTransformer):
294
+
295
+ def _load_model(self, path: str, model_kwargs: dict):
296
+ import torch
297
+ from transformers import AutoModelForCausalLM
298
+ model_kwargs.setdefault('torch_dtype', torch.float16)
299
+ self.model = AutoModelForCausalLM.from_pretrained(
300
+ path, trust_remote_code=True, **model_kwargs)
301
+ self.model.eval()
302
+
303
+
304
+ class HFTransformerChat(HFTransformerCasualLM):
305
+
306
+ def __init__(self, template_parser=APITemplateParser, **kwargs):
307
+ super().__init__(template_parser=template_parser, **kwargs)
308
+
309
+ def chat(self,
310
+ inputs: Union[List[dict], List[List[dict]]],
311
+ do_sample: bool = True,
312
+ **kwargs):
313
+ """Return the chat completions in stream mode.
314
+
315
+ Args:
316
+ inputs (Union[List[dict], List[List[dict]]]): input messages to be completed.
317
+ do_sample (bool): do sampling if enabled
318
+ Returns:
319
+ the text/chat completion
320
+ """
321
+ # handle batch inference with vanilla for loop
322
+ if isinstance(inputs[0], list):
323
+ resps = []
324
+ for input in inputs:
325
+ resps.append(self.chat(input, do_sample, **kwargs))
326
+ return resps
327
+ prompt = self.template_parser(inputs)
328
+ query = prompt[-1]['content']
329
+ history = prompt[:-1]
330
+ try:
331
+ response, history = self.model.chat(
332
+ self.tokenizer, query, history=history)
333
+ except Exception as e:
334
+ # handle over-length input error
335
+ logger.warning(str(e))
336
+ response = ''
337
+ return response
lagent/llms/lmdeploy_wrapper.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import copy
3
+ import logging
4
+ from dataclasses import asdict
5
+ from typing import List, Optional, Union
6
+
7
+ import aiohttp
8
+
9
+ from lagent.llms.base_llm import AsyncLLMMixin, BaseLLM
10
+ from lagent.schema import ModelStatusCode
11
+ from lagent.utils.util import filter_suffix
12
+
13
+
14
+ class TritonClient(BaseLLM):
15
+ """TritonClient is a wrapper of TritonClient for LLM.
16
+
17
+ Args:
18
+ tritonserver_addr (str): the address in format "ip:port" of
19
+ triton inference server
20
+ model_name (str): the name of the model
21
+ session_len (int): the context size
22
+ max_tokens (int): the expected generated token numbers
23
+ """
24
+
25
+ def __init__(self,
26
+ tritonserver_addr: str,
27
+ model_name: str,
28
+ session_len: int = 32768,
29
+ log_level: str = 'WARNING',
30
+ **kwargs):
31
+ super().__init__(path=None, **kwargs)
32
+ try:
33
+ from lmdeploy.serve.turbomind.chatbot import Chatbot, StatusCode
34
+ except Exception as e:
35
+ logging.error(f'{e}')
36
+ raise RuntimeError('DO NOT use turbomind.chatbot since it has '
37
+ 'been removed by lmdeploy since v0.5.2')
38
+ self.state_map = {
39
+ StatusCode.TRITON_STREAM_END: ModelStatusCode.END,
40
+ StatusCode.TRITON_SERVER_ERR: ModelStatusCode.SERVER_ERR,
41
+ StatusCode.TRITON_SESSION_CLOSED: ModelStatusCode.SESSION_CLOSED,
42
+ StatusCode.TRITON_STREAM_ING: ModelStatusCode.STREAM_ING,
43
+ StatusCode.TRITON_SESSION_OUT_OF_LIMIT:
44
+ ModelStatusCode.SESSION_OUT_OF_LIMIT,
45
+ StatusCode.TRITON_SESSION_INVALID_ARG:
46
+ ModelStatusCode.SESSION_INVALID_ARG,
47
+ StatusCode.TRITON_SESSION_READY: ModelStatusCode.SESSION_READY
48
+ }
49
+ self.chatbot = Chatbot(
50
+ tritonserver_addr=tritonserver_addr,
51
+ model_name=model_name,
52
+ session_len=session_len,
53
+ log_level=log_level,
54
+ **kwargs)
55
+
56
+ def generate(self,
57
+ inputs: Union[str, List[str]],
58
+ session_id: int = 2967,
59
+ request_id: str = '',
60
+ sequence_start: bool = True,
61
+ sequence_end: bool = True,
62
+ skip_special_tokens: bool = False,
63
+ **kwargs):
64
+ """Start a new round conversation of a session. Return the chat
65
+ completions in non-stream mode.
66
+
67
+ Args:
68
+ inputs (str, List[str]): user's prompt(s) in this round
69
+ session_id (int): the identical id of a session
70
+ request_id (str): the identical id of this round conversation
71
+ sequence_start (bool): start flag of a session
72
+ sequence_end (bool): end flag of a session
73
+ skip_special_tokens (bool): Whether or not to remove special tokens
74
+ in the decoding. Default to be False.
75
+ Returns:
76
+ (a list of/batched) text/chat completion
77
+ """
78
+ from lmdeploy.serve.turbomind.chatbot import Session, get_logger
79
+ if isinstance(inputs, str):
80
+ inputs = [inputs]
81
+ prompt = inputs
82
+
83
+ assert isinstance(session_id, int), \
84
+ f'INT session id is required, but got {type(session_id)}'
85
+
86
+ self.chatbot.cfg = self._update_gen_params(**kwargs)
87
+ max_new_tokens = self.chatbot.cfg.max_new_tokens
88
+
89
+ logger = get_logger('service.ft', log_level=self.chatbot.log_level)
90
+ logger.info(f'session {session_id}, request_id {request_id}, '
91
+ f'max_out_len {max_new_tokens}')
92
+
93
+ if self.chatbot._session is None:
94
+ sequence_start = True
95
+ self.chatbot._session = Session(session_id=session_id)
96
+ elif self.chatbot._session.status == 0:
97
+ logger.error(f'session {session_id} has been ended. Please set '
98
+ f'`sequence_start` be True if you want to restart it')
99
+ return ''
100
+
101
+ self.chatbot._session.status = 1
102
+ self.chatbot._session.request_id = request_id
103
+ self.chatbot._session.response = ''
104
+
105
+ status, res, _ = None, '', 0
106
+ for status, res, _ in self.chatbot._stream_infer(
107
+ self.chatbot._session,
108
+ prompt,
109
+ max_new_tokens,
110
+ sequence_start,
111
+ sequence_end,
112
+ skip_special_tokens=skip_special_tokens):
113
+ status = self.state_map.get(status)
114
+ if status < ModelStatusCode.END:
115
+ return ''
116
+ elif status == ModelStatusCode.END:
117
+ self.chatbot._session.histories = (
118
+ self.chatbot._session.histories +
119
+ self.chatbot._session.prompt +
120
+ self.chatbot._session.response)
121
+ # remove stop_words
122
+ res = filter_suffix(res, self.gen_params.get('stop_words'))
123
+ return res
124
+
125
+ def stream_chat(self,
126
+ inputs: List[dict],
127
+ session_id: int = 2967,
128
+ request_id: str = '',
129
+ sequence_start: bool = True,
130
+ sequence_end: bool = True,
131
+ skip_special_tokens: bool = False,
132
+ **kwargs):
133
+ """Start a new round conversation of a session. Return the chat
134
+ completions in stream mode.
135
+
136
+ Args:
137
+ session_id (int): the identical id of a session
138
+ inputs (List[dict]): user's inputs in this round conversation
139
+ request_id (str): the identical id of this round conversation
140
+ sequence_start (bool): start flag of a session
141
+ sequence_end (bool): end flag of a session
142
+ skip_special_tokens (bool): Whether or not to remove special tokens
143
+ in the decoding. Default to be False.
144
+ Returns:
145
+ tuple(Status, str, int): status, text/chat completion,
146
+ generated token number
147
+ """
148
+ from lmdeploy.serve.turbomind.chatbot import Session, get_logger
149
+ assert isinstance(session_id, int), \
150
+ f'INT session id is required, but got {type(session_id)}'
151
+
152
+ self.chatbot.cfg = self._update_gen_params(**kwargs)
153
+ max_new_tokens = self.chatbot.cfg.max_new_tokens
154
+
155
+ logger = get_logger('service.ft', log_level=self.chatbot.log_level)
156
+ logger.info(f'session {session_id}, request_id {request_id}, '
157
+ f'max_out_len {max_new_tokens}')
158
+
159
+ if self.chatbot._session is None:
160
+ sequence_start = True
161
+ self.chatbot._session = Session(session_id=session_id)
162
+ elif self.chatbot._session.status == 0:
163
+ logger.error(f'session {session_id} has been ended. Please set '
164
+ f'`sequence_start` be True if you want to restart it')
165
+ return ModelStatusCode.SESSION_CLOSED, '', 0
166
+
167
+ self.chatbot._session.status = 1
168
+ self.chatbot._session.request_id = request_id
169
+ self.chatbot._session.response = ''
170
+
171
+ prompt = self.template_parser(inputs)
172
+ status, res, _ = None, '', 0
173
+ for status, res, _ in self.chatbot._stream_infer(
174
+ self.chatbot._session,
175
+ prompt,
176
+ max_new_tokens,
177
+ sequence_start,
178
+ sequence_end,
179
+ skip_special_tokens=skip_special_tokens):
180
+ status = self.state_map.get(status)
181
+ # The stop symbol also appears in the output of the last STREAM_ING state.
182
+ res = filter_suffix(res, self.gen_params.get('stop_words'))
183
+ if status < ModelStatusCode.END:
184
+ return status, res, _
185
+ elif status == ModelStatusCode.END: # remove stop_words
186
+ self.chatbot._session.histories = (
187
+ self.chatbot._session.histories +
188
+ self.chatbot._session.prompt +
189
+ self.chatbot._session.response)
190
+ yield status, res, _
191
+ break
192
+ else:
193
+ yield status, res, _
194
+
195
+ def _update_gen_params(self, **kwargs):
196
+ import mmengine
197
+ new_gen_params = self.update_gen_params(**kwargs)
198
+ self.gen_params['stop_words'] = new_gen_params.pop('stop_words')
199
+ stop_words = self.chatbot._stop_words(
200
+ self.gen_params.get('stop_words'))
201
+ cfg = mmengine.Config(
202
+ dict(
203
+ session_len=self.chatbot.model.session_len,
204
+ stop_words=stop_words,
205
+ bad_words=self.chatbot.cfg.bad_words,
206
+ **new_gen_params))
207
+ return cfg
208
+
209
+
210
+ class LMDeployPipeline(BaseLLM):
211
+ """
212
+
213
+ Args:
214
+ path (str): The path to the model.
215
+ It could be one of the following options:
216
+ - i) A local directory path of a turbomind model which is
217
+ converted by `lmdeploy convert` command or download
218
+ from ii) and iii).
219
+ - ii) The model_id of a lmdeploy-quantized model hosted
220
+ inside a model repo on huggingface.co, such as
221
+ "InternLM/internlm-chat-20b-4bit",
222
+ "lmdeploy/llama2-chat-70b-4bit", etc.
223
+ - iii) The model_id of a model hosted inside a model repo
224
+ on huggingface.co, such as "internlm/internlm-chat-7b",
225
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
226
+ and so on.
227
+ model_name (str): needed when model_path is a pytorch model on
228
+ huggingface.co, such as "internlm-chat-7b",
229
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
230
+ tp (int): tensor parallel
231
+ pipeline_cfg (dict): config of pipeline
232
+ """
233
+
234
+ def __init__(self,
235
+ path: str,
236
+ model_name: Optional[str] = None,
237
+ tp: int = 1,
238
+ pipeline_cfg=dict(),
239
+ **kwargs):
240
+ import lmdeploy
241
+ from lmdeploy import ChatTemplateConfig, TurbomindEngineConfig, pipeline, version_info
242
+
243
+ self.str_version = lmdeploy.__version__
244
+ self.version = version_info
245
+ self.do_sample = kwargs.pop('do_sample', None)
246
+ if self.do_sample is not None and self.version < (0, 6, 0):
247
+ raise RuntimeError(
248
+ '`do_sample` parameter is not supported by lmdeploy until '
249
+ f'v0.6.0, but currently using lmdeloy {self.str_version}')
250
+ super().__init__(path=path, **kwargs)
251
+ backend_config = copy.deepcopy(pipeline_cfg)
252
+ backend_config.update(tp=tp)
253
+ backend_config = {
254
+ k: v
255
+ for k, v in backend_config.items()
256
+ if hasattr(TurbomindEngineConfig, k)
257
+ }
258
+ backend_config = TurbomindEngineConfig(**backend_config)
259
+ chat_template_config = ChatTemplateConfig(
260
+ model_name=model_name) if model_name else None
261
+ self.model = pipeline(
262
+ model_path=self.path,
263
+ backend_config=backend_config,
264
+ chat_template_config=chat_template_config,
265
+ log_level='WARNING')
266
+
267
+ def generate(self,
268
+ inputs: Union[str, List[str]],
269
+ do_preprocess: bool = None,
270
+ skip_special_tokens: bool = False,
271
+ return_dict: bool = False,
272
+ **kwargs):
273
+ """Return the chat completions in non-stream mode.
274
+
275
+ Args:
276
+ inputs (Union[str, List[str]]): input texts to be completed.
277
+ do_preprocess (bool): whether pre-process the messages. Default to
278
+ True, which means chat_template will be applied.
279
+ skip_special_tokens (bool): Whether or not to remove special tokens
280
+ in the decoding. Default to be False.
281
+ Returns:
282
+ (a list of/batched) text/chat completion
283
+ """
284
+ from lmdeploy.messages import GenerationConfig
285
+ batched = True
286
+ if isinstance(inputs, str):
287
+ inputs = [inputs]
288
+ batched = False
289
+ prompt = inputs
290
+ do_sample = kwargs.pop('do_sample', None)
291
+ gen_params = self.update_gen_params(**kwargs)
292
+
293
+ if do_sample is None:
294
+ do_sample = self.do_sample
295
+ if do_sample is not None and self.version < (0, 6, 0):
296
+ raise RuntimeError(
297
+ '`do_sample` parameter is not supported by lmdeploy until '
298
+ f'v0.6.0, but currently using lmdeloy {self.str_version}')
299
+ if self.version >= (0, 6, 0):
300
+ if do_sample is None:
301
+ do_sample = gen_params['top_k'] > 1 or gen_params[
302
+ 'temperature'] > 0
303
+ gen_params.update(do_sample=do_sample)
304
+
305
+ gen_config = GenerationConfig(
306
+ skip_special_tokens=skip_special_tokens, **gen_params)
307
+ response = self.model.batch_infer(
308
+ prompt, gen_config=gen_config, do_preprocess=do_preprocess)
309
+ texts = [resp.text for resp in response]
310
+ # remove stop_words
311
+ texts = filter_suffix(texts, self.gen_params.get('stop_words'))
312
+ for resp, text in zip(response, texts):
313
+ resp.text = text
314
+ if batched:
315
+ return [asdict(resp)
316
+ for resp in response] if return_dict else texts
317
+ return asdict(response[0]) if return_dict else texts[0]
318
+
319
+
320
+ class LMDeployServer(BaseLLM):
321
+ """
322
+
323
+ Args:
324
+ path (str): The path to the model.
325
+ It could be one of the following options:
326
+ - i) A local directory path of a turbomind model which is
327
+ converted by `lmdeploy convert` command or download from
328
+ ii) and iii).
329
+ - ii) The model_id of a lmdeploy-quantized model hosted
330
+ inside a model repo on huggingface.co, such as
331
+ "InternLM/internlm-chat-20b-4bit",
332
+ "lmdeploy/llama2-chat-70b-4bit", etc.
333
+ - iii) The model_id of a model hosted inside a model repo
334
+ on huggingface.co, such as "internlm/internlm-chat-7b",
335
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
336
+ and so on.
337
+ model_name (str): needed when model_path is a pytorch model on
338
+ huggingface.co, such as "internlm-chat-7b",
339
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
340
+ server_name (str): host ip for serving
341
+ server_port (int): server port
342
+ tp (int): tensor parallel
343
+ log_level (str): set log level whose value among
344
+ [CRITICAL, ERROR, WARNING, INFO, DEBUG]
345
+ """
346
+
347
+ def __init__(self,
348
+ path: str,
349
+ model_name: Optional[str] = None,
350
+ server_name: str = '0.0.0.0',
351
+ server_port: int = 23333,
352
+ tp: int = 1,
353
+ log_level: str = 'WARNING',
354
+ serve_cfg=dict(),
355
+ **kwargs):
356
+ super().__init__(path=path, **kwargs)
357
+ self.model_name = model_name
358
+ # TODO get_logger issue in multi processing
359
+ import lmdeploy
360
+ self.client = lmdeploy.serve(
361
+ model_path=self.path,
362
+ model_name=model_name,
363
+ server_name=server_name,
364
+ server_port=server_port,
365
+ tp=tp,
366
+ log_level=log_level,
367
+ **serve_cfg)
368
+
369
+ def generate(self,
370
+ inputs: Union[str, List[str]],
371
+ session_id: int = 2967,
372
+ sequence_start: bool = True,
373
+ sequence_end: bool = True,
374
+ ignore_eos: bool = False,
375
+ skip_special_tokens: Optional[bool] = False,
376
+ timeout: int = 30,
377
+ **kwargs) -> List[str]:
378
+ """Start a new round conversation of a session. Return the chat
379
+ completions in non-stream mode.
380
+
381
+ Args:
382
+ inputs (str, List[str]): user's prompt(s) in this round
383
+ session_id (int): the identical id of a session
384
+ sequence_start (bool): start flag of a session
385
+ sequence_end (bool): end flag of a session
386
+ ignore_eos (bool): indicator for ignoring eos
387
+ skip_special_tokens (bool): Whether or not to remove special tokens
388
+ in the decoding. Default to be False.
389
+ timeout (int): max time to wait for response
390
+ Returns:
391
+ (a list of/batched) text/chat completion
392
+ """
393
+
394
+ batched = True
395
+ if isinstance(inputs, str):
396
+ inputs = [inputs]
397
+ batched = False
398
+
399
+ gen_params = self.update_gen_params(**kwargs)
400
+ max_new_tokens = gen_params.pop('max_new_tokens')
401
+ gen_params.update(max_tokens=max_new_tokens)
402
+
403
+ resp = [''] * len(inputs)
404
+ for text in self.client.completions_v1(
405
+ self.model_name,
406
+ inputs,
407
+ session_id=session_id,
408
+ sequence_start=sequence_start,
409
+ sequence_end=sequence_end,
410
+ stream=False,
411
+ ignore_eos=ignore_eos,
412
+ skip_special_tokens=skip_special_tokens,
413
+ timeout=timeout,
414
+ **gen_params):
415
+ resp = [
416
+ resp[i] + item['text']
417
+ for i, item in enumerate(text['choices'])
418
+ ]
419
+ # remove stop_words
420
+ resp = filter_suffix(resp, self.gen_params.get('stop_words'))
421
+ if not batched:
422
+ return resp[0]
423
+ return resp
424
+
425
+ def stream_chat(self,
426
+ inputs: List[dict],
427
+ session_id=0,
428
+ sequence_start: bool = True,
429
+ sequence_end: bool = True,
430
+ stream: bool = True,
431
+ ignore_eos: bool = False,
432
+ skip_special_tokens: Optional[bool] = False,
433
+ timeout: int = 30,
434
+ **kwargs):
435
+ """Start a new round conversation of a session. Return the chat
436
+ completions in stream mode.
437
+
438
+ Args:
439
+ session_id (int): the identical id of a session
440
+ inputs (List[dict]): user's inputs in this round conversation
441
+ sequence_start (bool): start flag of a session
442
+ sequence_end (bool): end flag of a session
443
+ stream (bool): return in a streaming format if enabled
444
+ ignore_eos (bool): indicator for ignoring eos
445
+ skip_special_tokens (bool): Whether or not to remove special tokens
446
+ in the decoding. Default to be False.
447
+ timeout (int): max time to wait for response
448
+ Returns:
449
+ tuple(Status, str, int): status, text/chat completion,
450
+ generated token number
451
+ """
452
+ gen_params = self.update_gen_params(**kwargs)
453
+ max_new_tokens = gen_params.pop('max_new_tokens')
454
+ gen_params.update(max_tokens=max_new_tokens)
455
+ prompt = self.template_parser(inputs)
456
+
457
+ resp = ''
458
+ finished = False
459
+ stop_words = self.gen_params.get('stop_words')
460
+ for text in self.client.completions_v1(
461
+ self.model_name,
462
+ prompt,
463
+ session_id=session_id,
464
+ sequence_start=sequence_start,
465
+ sequence_end=sequence_end,
466
+ stream=stream,
467
+ ignore_eos=ignore_eos,
468
+ skip_special_tokens=skip_special_tokens,
469
+ timeout=timeout,
470
+ **gen_params):
471
+ resp += text['choices'][0]['text']
472
+ if not resp:
473
+ continue
474
+ # remove stop_words
475
+ for sw in stop_words:
476
+ if sw in resp:
477
+ resp = filter_suffix(resp, stop_words)
478
+ finished = True
479
+ break
480
+ yield ModelStatusCode.STREAM_ING, resp, None
481
+ if finished:
482
+ break
483
+ yield ModelStatusCode.END, resp, None
484
+
485
+
486
+ class LMDeployClient(LMDeployServer):
487
+ """
488
+
489
+ Args:
490
+ url (str): communicating address 'http://<ip>:<port>' of
491
+ api_server
492
+ model_name (str): needed when model_path is a pytorch model on
493
+ huggingface.co, such as "internlm-chat-7b",
494
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
495
+ """
496
+
497
+ def __init__(self, url: str, model_name: str, **kwargs):
498
+ BaseLLM.__init__(self, path=url, **kwargs)
499
+ from lmdeploy.serve.openai.api_client import APIClient
500
+ self.client = APIClient(url)
501
+ self.model_name = model_name
502
+
503
+
504
+ class AsyncLMDeployPipeline(AsyncLLMMixin, LMDeployPipeline):
505
+ """
506
+
507
+ Args:
508
+ path (str): The path to the model.
509
+ It could be one of the following options:
510
+ - i) A local directory path of a turbomind model which is
511
+ converted by `lmdeploy convert` command or download
512
+ from ii) and iii).
513
+ - ii) The model_id of a lmdeploy-quantized model hosted
514
+ inside a model repo on huggingface.co, such as
515
+ "InternLM/internlm-chat-20b-4bit",
516
+ "lmdeploy/llama2-chat-70b-4bit", etc.
517
+ - iii) The model_id of a model hosted inside a model repo
518
+ on huggingface.co, such as "internlm/internlm-chat-7b",
519
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
520
+ and so on.
521
+ model_name (str): needed when model_path is a pytorch model on
522
+ huggingface.co, such as "internlm-chat-7b",
523
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
524
+ tp (int): tensor parallel
525
+ pipeline_cfg (dict): config of pipeline
526
+ """
527
+
528
+ async def generate(self,
529
+ inputs: Union[str, List[str]],
530
+ session_ids: Union[int, List[int]] = None,
531
+ do_preprocess: bool = None,
532
+ skip_special_tokens: bool = False,
533
+ return_dict: bool = False,
534
+ **kwargs):
535
+ """Return the chat completions in non-stream mode.
536
+
537
+ Args:
538
+ inputs (Union[str, List[str]]): input texts to be completed.
539
+ do_preprocess (bool): whether pre-process the messages. Default to
540
+ True, which means chat_template will be applied.
541
+ skip_special_tokens (bool): Whether or not to remove special tokens
542
+ in the decoding. Default to be False.
543
+ Returns:
544
+ (a list of/batched) text/chat completion
545
+ """
546
+ from lmdeploy.messages import GenerationConfig, Response
547
+
548
+ batched = True
549
+ if isinstance(inputs, str):
550
+ inputs = [inputs]
551
+ batched = False
552
+ if session_ids is None:
553
+ session_ids = list(range(len(inputs)))
554
+ elif isinstance(session_ids, (int, str)):
555
+ session_ids = [session_ids]
556
+ assert len(inputs) == len(session_ids)
557
+
558
+ prompt = inputs
559
+ gen_params = self.update_gen_params(**kwargs)
560
+ gen_config = GenerationConfig(
561
+ skip_special_tokens=skip_special_tokens, **gen_params)
562
+
563
+ async def _inner_generate(uid, text):
564
+ resp = Response('', 0, 0, uid)
565
+ async for out in self.model.generate(
566
+ text,
567
+ uid,
568
+ gen_config,
569
+ stream_response=True,
570
+ sequence_start=True,
571
+ sequence_end=True,
572
+ do_preprocess=do_preprocess,
573
+ **kwargs):
574
+ resp.text += out.response
575
+ resp.generate_token_len = out.generate_token_len
576
+ resp.input_token_len = out.input_token_len
577
+ resp.finish_reason = out.finish_reason
578
+ if out.token_ids:
579
+ resp.token_ids.extend(out.token_ids)
580
+ if out.logprobs:
581
+ if resp.logprobs is None:
582
+ resp.logprobs = []
583
+ resp.logprobs.extend(out.logprobs)
584
+ return resp
585
+
586
+ response = await asyncio.gather(*[
587
+ _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
588
+ ])
589
+ texts = [resp.text for resp in response]
590
+ # remove stop_words
591
+ texts = filter_suffix(texts, self.gen_params.get('stop_words'))
592
+ for resp, text in zip(response, texts):
593
+ resp.text = text
594
+ if batched:
595
+ return [asdict(resp)
596
+ for resp in response] if return_dict else texts
597
+ return asdict(response[0]) if return_dict else texts[0]
598
+
599
+
600
+ class AsyncLMDeployServer(AsyncLLMMixin, LMDeployServer):
601
+ """
602
+
603
+ Args:
604
+ path (str): The path to the model.
605
+ It could be one of the following options:
606
+ - i) A local directory path of a turbomind model which is
607
+ converted by `lmdeploy convert` command or download from
608
+ ii) and iii).
609
+ - ii) The model_id of a lmdeploy-quantized model hosted
610
+ inside a model repo on huggingface.co, such as
611
+ "InternLM/internlm-chat-20b-4bit",
612
+ "lmdeploy/llama2-chat-70b-4bit", etc.
613
+ - iii) The model_id of a model hosted inside a model repo
614
+ on huggingface.co, such as "internlm/internlm-chat-7b",
615
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
616
+ and so on.
617
+ model_name (str): needed when model_path is a pytorch model on
618
+ huggingface.co, such as "internlm-chat-7b",
619
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
620
+ server_name (str): host ip for serving
621
+ server_port (int): server port
622
+ tp (int): tensor parallel
623
+ log_level (str): set log level whose value among
624
+ [CRITICAL, ERROR, WARNING, INFO, DEBUG]
625
+ """
626
+
627
+ async def generate(
628
+ self,
629
+ inputs: Union[str, List[str]],
630
+ session_ids: Union[int, List[int]] = None,
631
+ sequence_start: bool = True,
632
+ sequence_end: bool = True,
633
+ ignore_eos: bool = False,
634
+ skip_special_tokens: Optional[bool] = False,
635
+ timeout: int = 30,
636
+ **kwargs,
637
+ ):
638
+ """Start a new round conversation of a session. Return the chat
639
+ completions in non-stream mode.
640
+
641
+ Args:
642
+ inputs (str, List[str]): user's prompt(s) in this round
643
+ session_ids (int, List[int]): session id(s)
644
+ sequence_start (bool): start flag of a session
645
+ sequence_end (bool): end flag of a session
646
+ ignore_eos (bool): indicator for ignoring eos
647
+ skip_special_tokens (bool): Whether or not to remove special tokens
648
+ in the decoding. Default to be False.
649
+ timeout (int): max time to wait for response
650
+ Returns:
651
+ (a list of/batched) text/chat completion
652
+ """
653
+ from lmdeploy.serve.openai.api_client import json_loads
654
+
655
+ batched = True
656
+ if isinstance(inputs, str):
657
+ inputs = [inputs]
658
+ batched = False
659
+
660
+ gen_params = self.update_gen_params(**kwargs)
661
+ max_new_tokens = gen_params.pop('max_new_tokens')
662
+ gen_params.update(max_tokens=max_new_tokens)
663
+
664
+ responses = [''] * len(inputs)
665
+ pload = dict(
666
+ model=self.model_name,
667
+ prompt=inputs,
668
+ sequence_start=sequence_start,
669
+ sequence_end=sequence_end,
670
+ stream=False,
671
+ ignore_eos=ignore_eos,
672
+ skip_special_tokens=skip_special_tokens,
673
+ timeout=timeout,
674
+ **gen_params)
675
+ async with aiohttp.ClientSession(
676
+ timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
677
+ async with session.post(
678
+ self.client.completions_v1_url,
679
+ headers=self.client.headers,
680
+ json=pload) as resp:
681
+ async for chunk in resp.content:
682
+ if chunk:
683
+ decoded = chunk.decode('utf-8')
684
+ output = json_loads(decoded)
685
+ responses = [
686
+ response + item['text'] for response, item in zip(
687
+ responses, output['choices'])
688
+ ]
689
+ # remove stop_words
690
+ responses = filter_suffix(responses, self.gen_params.get('stop_words'))
691
+ if not batched:
692
+ return responses[0]
693
+ return responses
694
+
695
+ async def stream_chat(
696
+ self,
697
+ inputs: List[dict],
698
+ session_id: int = None,
699
+ sequence_start: bool = True,
700
+ sequence_end: bool = True,
701
+ stream: bool = True,
702
+ ignore_eos: bool = False,
703
+ skip_special_tokens: Optional[bool] = False,
704
+ timeout: int = 30,
705
+ **kwargs,
706
+ ):
707
+ """Start a new round conversation of a session. Return the chat
708
+ completions in stream mode.
709
+
710
+ Args:
711
+ inputs (List[dict]): user's inputs in this round conversation
712
+ session_id (int): session id
713
+ sequence_start (bool): start flag of a session
714
+ sequence_end (bool): end flag of a session
715
+ stream (bool): return in a streaming format if enabled
716
+ ignore_eos (bool): indicator for ignoring eos
717
+ skip_special_tokens (bool): Whether or not to remove special tokens
718
+ in the decoding. Default to be False.
719
+ timeout (int): max time to wait for response
720
+ Returns:
721
+ tuple(Status, str, int): status, text/chat completion,
722
+ generated token number
723
+ """
724
+ from lmdeploy.serve.openai.api_client import json_loads
725
+
726
+ gen_params = self.update_gen_params(**kwargs)
727
+ max_new_tokens = gen_params.pop('max_new_tokens')
728
+ gen_params.update(max_tokens=max_new_tokens)
729
+ prompt = self.template_parser(inputs)
730
+
731
+ response = ''
732
+ finished = False
733
+ stop_words = self.gen_params.get('stop_words')
734
+
735
+ pload = dict(
736
+ model=self.model_name,
737
+ prompt=prompt,
738
+ sequence_start=sequence_start,
739
+ sequence_end=sequence_end,
740
+ stream=stream,
741
+ ignore_eos=ignore_eos,
742
+ skip_special_tokens=skip_special_tokens,
743
+ timeout=timeout,
744
+ **gen_params)
745
+ async with aiohttp.ClientSession(
746
+ timeout=aiohttp.ClientTimeout(3 * 3600)) as session:
747
+ async with session.post(
748
+ self.client.completions_v1_url,
749
+ headers=self.client.headers,
750
+ json=pload) as resp:
751
+ async for chunk in resp.content:
752
+ if chunk:
753
+ decoded = chunk.decode('utf-8')
754
+ if not decoded.strip() or decoded.rstrip(
755
+ ) == 'data: [DONE]':
756
+ continue
757
+ if decoded[:6] == 'data: ':
758
+ decoded = decoded[6:]
759
+ output = json_loads(decoded)
760
+ response += output['choices'][0]['text']
761
+ if not response:
762
+ continue
763
+ # remove stop_words
764
+ for sw in stop_words:
765
+ if sw in response:
766
+ response = filter_suffix(response, stop_words)
767
+ finished = True
768
+ break
769
+ yield ModelStatusCode.STREAM_ING, response, None
770
+ if finished:
771
+ break
772
+ yield ModelStatusCode.END, response, None
773
+
774
+
775
+ class AsyncLMDeployClient(AsyncLMDeployServer):
776
+ """
777
+
778
+ Args:
779
+ url (str): communicating address 'http://<ip>:<port>' of
780
+ api_server
781
+ model_name (str): needed when model_path is a pytorch model on
782
+ huggingface.co, such as "internlm-chat-7b",
783
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
784
+ """
785
+
786
+ def __init__(self, url: str, model_name: str, **kwargs):
787
+ BaseLLM.__init__(self, path=url, **kwargs)
788
+ from lmdeploy.serve.openai.api_client import APIClient
789
+ self.client = APIClient(url)
790
+ self.model_name = model_name
lagent/llms/meta_template.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ INTERNLM2_META = [
2
+ dict(
3
+ role='system',
4
+ begin=dict(
5
+ with_name='<|im_start|>system name={name}\n',
6
+ without_name='<|im_start|>system\n',
7
+ name={
8
+ 'interpreter': '<|interpreter|>',
9
+ 'plugin': '<|plugin|>',
10
+ }),
11
+ end='<|im_end|>\n',
12
+ ),
13
+ dict(
14
+ role='user',
15
+ begin=dict(
16
+ with_name='<|im_start|>user name={name}\n',
17
+ without_name='<|im_start|>user\n',
18
+ ),
19
+ end='<|im_end|>\n'),
20
+ dict(
21
+ role='assistant',
22
+ begin=dict(
23
+ with_name='<|im_start|>assistant name={name}\n',
24
+ without_name='<|im_start|>assistant\n',
25
+ name={
26
+ 'interpreter': '<|interpreter|>',
27
+ 'plugin': '<|plugin|>',
28
+ }),
29
+ end='<|im_end|>\n'),
30
+ dict(
31
+ role='environment',
32
+ begin=dict(
33
+ with_name='<|im_start|>environment name={name}\n',
34
+ without_name='<|im_start|>environment\n',
35
+ name={
36
+ 'interpreter': '<|interpreter|>',
37
+ 'plugin': '<|plugin|>',
38
+ }),
39
+ end='<|im_end|>\n'),
40
+ ]
lagent/llms/openai.py ADDED
@@ -0,0 +1,924 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import time
5
+ import traceback
6
+ import warnings
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from logging import getLogger
9
+ from threading import Lock
10
+ from typing import AsyncGenerator, Dict, List, Optional, Union
11
+
12
+ import aiohttp
13
+ import requests
14
+
15
+ from ..schema import ModelStatusCode
16
+ from ..utils import filter_suffix
17
+ from .base_api import AsyncBaseAPILLM, BaseAPILLM
18
+
19
+ warnings.simplefilter('default')
20
+
21
+ OPENAI_API_BASE = 'https://api.openai.com/v1/chat/completions'
22
+
23
+
24
+ class GPTAPI(BaseAPILLM):
25
+ """Model wrapper around OpenAI's models.
26
+
27
+ Args:
28
+ model_type (str): The name of OpenAI's model.
29
+ retry (int): Number of retires if the API call fails. Defaults to 2.
30
+ key (str or List[str]): OpenAI key(s). In particular, when it
31
+ is set to "ENV", the key will be fetched from the environment
32
+ variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
33
+ list, the keys will be used in round-robin manner. Defaults to
34
+ 'ENV'.
35
+ org (str or List[str], optional): OpenAI organization(s). If not
36
+ specified, OpenAI uses the default organization bound to each API
37
+ key. If specified, the orgs will be posted with each request in
38
+ round-robin manner. Defaults to None.
39
+ meta_template (Dict, optional): The model's meta prompt
40
+ template if needed, in case the requirement of injecting or
41
+ wrapping of any meta instructions.
42
+ api_base (str): The base url of OpenAI's API. Defaults to
43
+ 'https://api.openai.com/v1/chat/completions'.
44
+ gen_params: Default generation configuration which could be overridden
45
+ on the fly of generation.
46
+ """
47
+
48
+ is_api: bool = True
49
+
50
+ def __init__(self,
51
+ model_type: str = 'gpt-3.5-turbo',
52
+ retry: int = 2,
53
+ json_mode: bool = False,
54
+ key: Union[str, List[str]] = 'ENV',
55
+ org: Optional[Union[str, List[str]]] = None,
56
+ meta_template: Optional[Dict] = [
57
+ dict(role='system', api_role='system'),
58
+ dict(role='user', api_role='user'),
59
+ dict(role='assistant', api_role='assistant'),
60
+ dict(role='environment', api_role='system')
61
+ ],
62
+ api_base: str = OPENAI_API_BASE,
63
+ proxies: Optional[Dict] = None,
64
+ **gen_params):
65
+ if 'top_k' in gen_params:
66
+ warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
67
+ DeprecationWarning)
68
+ gen_params.pop('top_k')
69
+ super().__init__(
70
+ model_type=model_type,
71
+ meta_template=meta_template,
72
+ retry=retry,
73
+ **gen_params)
74
+ self.gen_params.pop('top_k')
75
+ self.logger = getLogger(__name__)
76
+
77
+ if isinstance(key, str):
78
+ self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
79
+ else:
80
+ self.keys = key
81
+
82
+ # record invalid keys and skip them when requesting API
83
+ # - keys have insufficient_quota
84
+ self.invalid_keys = set()
85
+
86
+ self.key_ctr = 0
87
+ if isinstance(org, str):
88
+ self.orgs = [org]
89
+ else:
90
+ self.orgs = org
91
+ self.org_ctr = 0
92
+ self.url = api_base
93
+ self.model_type = model_type
94
+ self.proxies = proxies
95
+ self.json_mode = json_mode
96
+
97
+ def chat(
98
+ self,
99
+ inputs: Union[List[dict], List[List[dict]]],
100
+ **gen_params,
101
+ ) -> Union[str, List[str]]:
102
+ """Generate responses given the contexts.
103
+
104
+ Args:
105
+ inputs (Union[List[dict], List[List[dict]]]): a list of messages
106
+ or list of lists of messages
107
+ gen_params: additional generation configuration
108
+
109
+ Returns:
110
+ Union[str, List[str]]: generated string(s)
111
+ """
112
+ assert isinstance(inputs, list)
113
+ if 'max_tokens' in gen_params:
114
+ raise NotImplementedError('unsupported parameter: max_tokens')
115
+ gen_params = {**self.gen_params, **gen_params}
116
+ with ThreadPoolExecutor(max_workers=20) as executor:
117
+ tasks = [
118
+ executor.submit(self._chat,
119
+ self.template_parser._prompt2api(messages),
120
+ **gen_params)
121
+ for messages in (
122
+ [inputs] if isinstance(inputs[0], dict) else inputs)
123
+ ]
124
+ ret = [task.result() for task in tasks]
125
+ return ret[0] if isinstance(inputs[0], dict) else ret
126
+
127
+ def stream_chat(
128
+ self,
129
+ inputs: List[dict],
130
+ **gen_params,
131
+ ):
132
+ """Generate responses given the contexts.
133
+
134
+ Args:
135
+ inputs (List[dict]): a list of messages
136
+ gen_params: additional generation configuration
137
+
138
+ Returns:
139
+ str: generated string
140
+ """
141
+ assert isinstance(inputs, list)
142
+ if 'max_tokens' in gen_params:
143
+ raise NotImplementedError('unsupported parameter: max_tokens')
144
+ gen_params = self.update_gen_params(**gen_params)
145
+ gen_params['stream'] = True
146
+
147
+ resp = ''
148
+ finished = False
149
+ stop_words = gen_params.get('stop_words')
150
+ if stop_words is None:
151
+ stop_words = []
152
+ # mapping to role that openai supports
153
+ messages = self.template_parser._prompt2api(inputs)
154
+ for text in self._stream_chat(messages, **gen_params):
155
+ if self.model_type.lower().startswith('qwen'):
156
+ resp = text
157
+ else:
158
+ resp += text
159
+ if not resp:
160
+ continue
161
+ # remove stop_words
162
+ for sw in stop_words:
163
+ if sw in resp:
164
+ resp = filter_suffix(resp, stop_words)
165
+ finished = True
166
+ break
167
+ yield ModelStatusCode.STREAM_ING, resp, None
168
+ if finished:
169
+ break
170
+ yield ModelStatusCode.END, resp, None
171
+
172
+ def _chat(self, messages: List[dict], **gen_params) -> str:
173
+ """Generate completion from a list of templates.
174
+
175
+ Args:
176
+ messages (List[dict]): a list of prompt dictionaries
177
+ gen_params: additional generation configuration
178
+
179
+ Returns:
180
+ str: The generated string.
181
+ """
182
+ assert isinstance(messages, list)
183
+
184
+ header, data = self.generate_request_data(
185
+ model_type=self.model_type,
186
+ messages=messages,
187
+ gen_params=gen_params,
188
+ json_mode=self.json_mode)
189
+
190
+ max_num_retries, errmsg = 0, ''
191
+ while max_num_retries < self.retry:
192
+ with Lock():
193
+ if len(self.invalid_keys) == len(self.keys):
194
+ raise RuntimeError('All keys have insufficient quota.')
195
+
196
+ # find the next valid key
197
+ while True:
198
+ self.key_ctr += 1
199
+ if self.key_ctr == len(self.keys):
200
+ self.key_ctr = 0
201
+
202
+ if self.keys[self.key_ctr] not in self.invalid_keys:
203
+ break
204
+
205
+ key = self.keys[self.key_ctr]
206
+ header['Authorization'] = f'Bearer {key}'
207
+
208
+ if self.orgs:
209
+ with Lock():
210
+ self.org_ctr += 1
211
+ if self.org_ctr == len(self.orgs):
212
+ self.org_ctr = 0
213
+ header['OpenAI-Organization'] = self.orgs[self.org_ctr]
214
+
215
+ response = dict()
216
+ try:
217
+ raw_response = requests.post(
218
+ self.url,
219
+ headers=header,
220
+ data=json.dumps(data),
221
+ proxies=self.proxies)
222
+ response = raw_response.json()
223
+ return response['choices'][0]['message']['content'].strip()
224
+ except requests.ConnectionError:
225
+ errmsg = 'Got connection error ' + str(traceback.format_exc())
226
+ self.logger.error(errmsg)
227
+ continue
228
+ except requests.JSONDecodeError:
229
+ errmsg = 'JsonDecode error, got ' + str(raw_response.content)
230
+ self.logger.error(errmsg)
231
+ continue
232
+ except KeyError:
233
+ if 'error' in response:
234
+ if response['error']['code'] == 'rate_limit_exceeded':
235
+ time.sleep(1)
236
+ continue
237
+ elif response['error']['code'] == 'insufficient_quota':
238
+ self.invalid_keys.add(key)
239
+ self.logger.warn(f'insufficient_quota key: {key}')
240
+ continue
241
+
242
+ errmsg = 'Find error message in response: ' + str(
243
+ response['error'])
244
+ self.logger.error(errmsg)
245
+ except Exception as error:
246
+ errmsg = str(error) + '\n' + str(traceback.format_exc())
247
+ self.logger.error(errmsg)
248
+ max_num_retries += 1
249
+
250
+ raise RuntimeError('Calling OpenAI failed after retrying for '
251
+ f'{max_num_retries} times. Check the logs for '
252
+ f'details. errmsg: {errmsg}')
253
+
254
+ def _stream_chat(self, messages: List[dict], **gen_params) -> str:
255
+ """Generate completion from a list of templates.
256
+
257
+ Args:
258
+ messages (List[dict]): a list of prompt dictionaries
259
+ gen_params: additional generation configuration
260
+
261
+ Returns:
262
+ str: The generated string.
263
+ """
264
+
265
+ def streaming(raw_response):
266
+ for chunk in raw_response.iter_lines(
267
+ chunk_size=8192, decode_unicode=False, delimiter=b'\n'):
268
+ if chunk:
269
+ decoded = chunk.decode('utf-8')
270
+ if decoded.startswith('data: [DONE]'):
271
+ return
272
+ if decoded[:5] == 'data:':
273
+ decoded = decoded[5:]
274
+ if decoded[0] == ' ':
275
+ decoded = decoded[1:]
276
+ else:
277
+ print(decoded)
278
+ continue
279
+ try:
280
+ response = json.loads(decoded)
281
+ if 'code' in response and response['code'] == -20003:
282
+ # Context exceeds maximum length
283
+ yield ''
284
+ return
285
+ if self.model_type.lower().startswith('qwen'):
286
+ choice = response['output']['choices'][0]
287
+ yield choice['message']['content']
288
+ if choice['finish_reason'] == 'stop':
289
+ return
290
+ else:
291
+ choice = response['choices'][0]
292
+ if choice['finish_reason'] == 'stop':
293
+ return
294
+ yield choice['delta'].get('content', '')
295
+ except Exception as exc:
296
+ msg = f'response {decoded} lead to exception of {str(exc)}'
297
+ self.logger.error(msg)
298
+ raise Exception(msg) from exc
299
+
300
+ assert isinstance(messages, list)
301
+
302
+ header, data = self.generate_request_data(
303
+ model_type=self.model_type,
304
+ messages=messages,
305
+ gen_params=gen_params,
306
+ json_mode=self.json_mode)
307
+
308
+ max_num_retries, errmsg = 0, ''
309
+ while max_num_retries < self.retry:
310
+ if len(self.invalid_keys) == len(self.keys):
311
+ raise RuntimeError('All keys have insufficient quota.')
312
+
313
+ # find the next valid key
314
+ while True:
315
+ self.key_ctr += 1
316
+ if self.key_ctr == len(self.keys):
317
+ self.key_ctr = 0
318
+
319
+ if self.keys[self.key_ctr] not in self.invalid_keys:
320
+ break
321
+
322
+ key = self.keys[self.key_ctr]
323
+ header['Authorization'] = f'Bearer {key}'
324
+
325
+ if self.orgs:
326
+ self.org_ctr += 1
327
+ if self.org_ctr == len(self.orgs):
328
+ self.org_ctr = 0
329
+ header['OpenAI-Organization'] = self.orgs[self.org_ctr]
330
+
331
+ response = dict()
332
+ try:
333
+ raw_response = requests.post(
334
+ self.url,
335
+ headers=header,
336
+ data=json.dumps(data),
337
+ proxies=self.proxies)
338
+ return streaming(raw_response)
339
+ except requests.ConnectionError:
340
+ errmsg = 'Got connection error ' + str(traceback.format_exc())
341
+ self.logger.error(errmsg)
342
+ continue
343
+ except requests.JSONDecodeError:
344
+ errmsg = 'JsonDecode error, got ' + str(raw_response.content)
345
+ self.logger.error(errmsg)
346
+ continue
347
+ except KeyError:
348
+ if 'error' in response:
349
+ if response['error']['code'] == 'rate_limit_exceeded':
350
+ time.sleep(1)
351
+ continue
352
+ elif response['error']['code'] == 'insufficient_quota':
353
+ self.invalid_keys.add(key)
354
+ self.logger.warn(f'insufficient_quota key: {key}')
355
+ continue
356
+
357
+ errmsg = 'Find error message in response: ' + str(
358
+ response['error'])
359
+ self.logger.error(errmsg)
360
+ except Exception as error:
361
+ errmsg = str(error) + '\n' + str(traceback.format_exc())
362
+ self.logger.error(errmsg)
363
+ max_num_retries += 1
364
+
365
+ raise RuntimeError('Calling OpenAI failed after retrying for '
366
+ f'{max_num_retries} times. Check the logs for '
367
+ f'details. errmsg: {errmsg}')
368
+
369
+ def generate_request_data(self,
370
+ model_type,
371
+ messages,
372
+ gen_params,
373
+ json_mode=False):
374
+ """
375
+ Generates the request data for different model types.
376
+
377
+ Args:
378
+ model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
379
+ messages (list): The list of messages to be sent to the model.
380
+ gen_params (dict): The generation parameters.
381
+ json_mode (bool): Flag to determine if the response format should be JSON.
382
+
383
+ Returns:
384
+ tuple: A tuple containing the header and the request data.
385
+ """
386
+ # Copy generation parameters to avoid modifying the original dictionary
387
+ gen_params = gen_params.copy()
388
+
389
+ # Hold out 100 tokens due to potential errors in token calculation
390
+ max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
391
+ if max_tokens <= 0:
392
+ return '', ''
393
+
394
+ # Initialize the header
395
+ header = {
396
+ 'content-type': 'application/json',
397
+ }
398
+
399
+ # Common parameters processing
400
+ gen_params['max_tokens'] = max_tokens
401
+ if 'stop_words' in gen_params:
402
+ gen_params['stop'] = gen_params.pop('stop_words')
403
+ if 'repetition_penalty' in gen_params:
404
+ gen_params['frequency_penalty'] = gen_params.pop(
405
+ 'repetition_penalty')
406
+
407
+ # Model-specific processing
408
+ data = {}
409
+ if model_type.lower().startswith('gpt'):
410
+ if 'top_k' in gen_params:
411
+ warnings.warn(
412
+ '`top_k` parameter is deprecated in OpenAI APIs.',
413
+ DeprecationWarning)
414
+ gen_params.pop('top_k')
415
+ gen_params.pop('skip_special_tokens', None)
416
+ gen_params.pop('session_id', None)
417
+ data = {
418
+ 'model': model_type,
419
+ 'messages': messages,
420
+ 'n': 1,
421
+ **gen_params
422
+ }
423
+ if json_mode:
424
+ data['response_format'] = {'type': 'json_object'}
425
+ elif model_type.lower().startswith('internlm'):
426
+ data = {
427
+ 'model': model_type,
428
+ 'messages': messages,
429
+ 'n': 1,
430
+ **gen_params
431
+ }
432
+ if json_mode:
433
+ data['response_format'] = {'type': 'json_object'}
434
+ elif model_type.lower().startswith('qwen'):
435
+ header['X-DashScope-SSE'] = 'enable'
436
+ gen_params.pop('skip_special_tokens', None)
437
+ gen_params.pop('session_id', None)
438
+ if 'frequency_penalty' in gen_params:
439
+ gen_params['repetition_penalty'] = gen_params.pop(
440
+ 'frequency_penalty')
441
+ gen_params['result_format'] = 'message'
442
+ data = {
443
+ 'model': model_type,
444
+ 'input': {
445
+ 'messages': messages
446
+ },
447
+ 'parameters': {
448
+ **gen_params
449
+ }
450
+ }
451
+ else:
452
+ raise NotImplementedError(
453
+ f'Model type {model_type} is not supported')
454
+
455
+ return header, data
456
+
457
+ def tokenize(self, prompt: str) -> list:
458
+ """Tokenize the input prompt.
459
+
460
+ Args:
461
+ prompt (str): Input string.
462
+
463
+ Returns:
464
+ list: token ids
465
+ """
466
+ import tiktoken
467
+ self.tiktoken = tiktoken
468
+ enc = self.tiktoken.encoding_for_model(self.model_type)
469
+ return enc.encode(prompt)
470
+
471
+
472
+ class AsyncGPTAPI(AsyncBaseAPILLM):
473
+ """Model wrapper around OpenAI's models.
474
+
475
+ Args:
476
+ model_type (str): The name of OpenAI's model.
477
+ retry (int): Number of retires if the API call fails. Defaults to 2.
478
+ key (str or List[str]): OpenAI key(s). In particular, when it
479
+ is set to "ENV", the key will be fetched from the environment
480
+ variable $OPENAI_API_KEY, as how openai defaults to be. If it's a
481
+ list, the keys will be used in round-robin manner. Defaults to
482
+ 'ENV'.
483
+ org (str or List[str], optional): OpenAI organization(s). If not
484
+ specified, OpenAI uses the default organization bound to each API
485
+ key. If specified, the orgs will be posted with each request in
486
+ round-robin manner. Defaults to None.
487
+ meta_template (Dict, optional): The model's meta prompt
488
+ template if needed, in case the requirement of injecting or
489
+ wrapping of any meta instructions.
490
+ api_base (str): The base url of OpenAI's API. Defaults to
491
+ 'https://api.openai.com/v1/chat/completions'.
492
+ gen_params: Default generation configuration which could be overridden
493
+ on the fly of generation.
494
+ """
495
+
496
+ is_api: bool = True
497
+
498
+ def __init__(self,
499
+ model_type: str = 'gpt-3.5-turbo',
500
+ retry: int = 2,
501
+ json_mode: bool = False,
502
+ key: Union[str, List[str]] = 'ENV',
503
+ org: Optional[Union[str, List[str]]] = None,
504
+ meta_template: Optional[Dict] = [
505
+ dict(role='system', api_role='system'),
506
+ dict(role='user', api_role='user'),
507
+ dict(role='assistant', api_role='assistant')
508
+ ],
509
+ api_base: str = OPENAI_API_BASE,
510
+ proxies: Optional[Dict] = None,
511
+ **gen_params):
512
+ if 'top_k' in gen_params:
513
+ warnings.warn('`top_k` parameter is deprecated in OpenAI APIs.',
514
+ DeprecationWarning)
515
+ gen_params.pop('top_k')
516
+ super().__init__(
517
+ model_type=model_type,
518
+ meta_template=meta_template,
519
+ retry=retry,
520
+ **gen_params)
521
+ self.gen_params.pop('top_k')
522
+ self.logger = getLogger(__name__)
523
+
524
+ if isinstance(key, str):
525
+ self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key]
526
+ else:
527
+ self.keys = key
528
+
529
+ # record invalid keys and skip them when requesting API
530
+ # - keys have insufficient_quota
531
+ self.invalid_keys = set()
532
+
533
+ self.key_ctr = 0
534
+ if isinstance(org, str):
535
+ self.orgs = [org]
536
+ else:
537
+ self.orgs = org
538
+ self.org_ctr = 0
539
+ self.url = api_base
540
+ self.model_type = model_type
541
+ self.proxies = proxies or {}
542
+ self.json_mode = json_mode
543
+
544
+ async def chat(
545
+ self,
546
+ inputs: Union[List[dict], List[List[dict]]],
547
+ session_ids: Union[int, List[int]] = None,
548
+ **gen_params,
549
+ ) -> Union[str, List[str]]:
550
+ """Generate responses given the contexts.
551
+
552
+ Args:
553
+ inputs (Union[List[dict], List[List[dict]]]): a list of messages
554
+ or list of lists of messages
555
+ gen_params: additional generation configuration
556
+
557
+ Returns:
558
+ Union[str, List[str]]: generated string(s)
559
+ """
560
+ assert isinstance(inputs, list)
561
+ if 'max_tokens' in gen_params:
562
+ raise NotImplementedError('unsupported parameter: max_tokens')
563
+ gen_params = {**self.gen_params, **gen_params}
564
+ tasks = [
565
+ self._chat(messages, **gen_params) for messages in (
566
+ [inputs] if isinstance(inputs[0], dict) else inputs)
567
+ ]
568
+ ret = await asyncio.gather(*tasks)
569
+ return ret[0] if isinstance(inputs[0], dict) else ret
570
+
571
+ async def stream_chat(
572
+ self,
573
+ inputs: List[dict],
574
+ **gen_params,
575
+ ):
576
+ """Generate responses given the contexts.
577
+
578
+ Args:
579
+ inputs (List[dict]): a list of messages
580
+ gen_params: additional generation configuration
581
+
582
+ Returns:
583
+ str: generated string
584
+ """
585
+ assert isinstance(inputs, list)
586
+ if 'max_tokens' in gen_params:
587
+ raise NotImplementedError('unsupported parameter: max_tokens')
588
+ gen_params = self.update_gen_params(**gen_params)
589
+ gen_params['stream'] = True
590
+
591
+ resp = ''
592
+ finished = False
593
+ stop_words = gen_params.get('stop_words')
594
+ if stop_words is None:
595
+ stop_words = []
596
+ # mapping to role that openai supports
597
+ messages = self.template_parser._prompt2api(inputs)
598
+ async for text in self._stream_chat(messages, **gen_params):
599
+ if self.model_type.lower().startswith('qwen'):
600
+ resp = text
601
+ else:
602
+ resp += text
603
+ if not resp:
604
+ continue
605
+ # remove stop_words
606
+ for sw in stop_words:
607
+ if sw in resp:
608
+ resp = filter_suffix(resp, stop_words)
609
+ finished = True
610
+ break
611
+ yield ModelStatusCode.STREAM_ING, resp, None
612
+ if finished:
613
+ break
614
+ yield ModelStatusCode.END, resp, None
615
+
616
+ async def _chat(self, messages: List[dict], **gen_params) -> str:
617
+ """Generate completion from a list of templates.
618
+
619
+ Args:
620
+ messages (List[dict]): a list of prompt dictionaries
621
+ gen_params: additional generation configuration
622
+
623
+ Returns:
624
+ str: The generated string.
625
+ """
626
+ assert isinstance(messages, list)
627
+
628
+ header, data = self.generate_request_data(
629
+ model_type=self.model_type,
630
+ messages=messages,
631
+ gen_params=gen_params,
632
+ json_mode=self.json_mode)
633
+
634
+ max_num_retries, errmsg = 0, ''
635
+ while max_num_retries < self.retry:
636
+ if len(self.invalid_keys) == len(self.keys):
637
+ raise RuntimeError('All keys have insufficient quota.')
638
+
639
+ # find the next valid key
640
+ while True:
641
+ self.key_ctr += 1
642
+ if self.key_ctr == len(self.keys):
643
+ self.key_ctr = 0
644
+
645
+ if self.keys[self.key_ctr] not in self.invalid_keys:
646
+ break
647
+
648
+ key = self.keys[self.key_ctr]
649
+ header['Authorization'] = f'Bearer {key}'
650
+
651
+ if self.orgs:
652
+ self.org_ctr += 1
653
+ if self.org_ctr == len(self.orgs):
654
+ self.org_ctr = 0
655
+ header['OpenAI-Organization'] = self.orgs[self.org_ctr]
656
+
657
+ response = dict()
658
+ try:
659
+ async with aiohttp.ClientSession() as session:
660
+ async with session.post(
661
+ self.url,
662
+ headers=header,
663
+ json=data,
664
+ proxy=self.proxies.get(
665
+ 'https', self.proxies.get('http'))) as resp:
666
+ response = await resp.json()
667
+ return response['choices'][0]['message'][
668
+ 'content'].strip()
669
+ except aiohttp.ClientConnectionError:
670
+ errmsg = 'Got connection error ' + str(traceback.format_exc())
671
+ self.logger.error(errmsg)
672
+ continue
673
+ except aiohttp.ClientResponseError as e:
674
+ errmsg = 'Response error, got ' + str(e)
675
+ self.logger.error(errmsg)
676
+ continue
677
+ except json.JSONDecodeError:
678
+ errmsg = 'JsonDecode error, got ' + (await resp.text(
679
+ errors='replace'))
680
+ self.logger.error(errmsg)
681
+ continue
682
+ except KeyError:
683
+ if 'error' in response:
684
+ if response['error']['code'] == 'rate_limit_exceeded':
685
+ time.sleep(1)
686
+ continue
687
+ elif response['error']['code'] == 'insufficient_quota':
688
+ self.invalid_keys.add(key)
689
+ self.logger.warn(f'insufficient_quota key: {key}')
690
+ continue
691
+
692
+ errmsg = 'Find error message in response: ' + str(
693
+ response['error'])
694
+ self.logger.error(errmsg)
695
+ except Exception as error:
696
+ errmsg = str(error) + '\n' + str(traceback.format_exc())
697
+ self.logger.error(errmsg)
698
+ max_num_retries += 1
699
+
700
+ raise RuntimeError('Calling OpenAI failed after retrying for '
701
+ f'{max_num_retries} times. Check the logs for '
702
+ f'details. errmsg: {errmsg}')
703
+
704
+ async def _stream_chat(self, messages: List[dict],
705
+ **gen_params) -> AsyncGenerator[str, None]:
706
+ """Generate completion from a list of templates.
707
+
708
+ Args:
709
+ messages (List[dict]): a list of prompt dictionaries
710
+ gen_params: additional generation configuration
711
+
712
+ Returns:
713
+ str: The generated string.
714
+ """
715
+
716
+ async def streaming(raw_response):
717
+ async for chunk in raw_response.content:
718
+ if chunk:
719
+ decoded = chunk.decode('utf-8')
720
+ if decoded.startswith('data: [DONE]'):
721
+ return
722
+ if decoded[:5] == 'data:':
723
+ decoded = decoded[5:]
724
+ if decoded[0] == ' ':
725
+ decoded = decoded[1:]
726
+ else:
727
+ print(decoded)
728
+ continue
729
+ try:
730
+ response = json.loads(decoded)
731
+ if 'code' in response and response['code'] == -20003:
732
+ # Context exceeds maximum length
733
+ yield ''
734
+ return
735
+ if self.model_type.lower().startswith('qwen'):
736
+ choice = response['output']['choices'][0]
737
+ yield choice['message']['content']
738
+ if choice['finish_reason'] == 'stop':
739
+ return
740
+ else:
741
+ choice = response['choices'][0]
742
+ if choice['finish_reason'] == 'stop':
743
+ return
744
+ yield choice['delta'].get('content', '')
745
+ except Exception as exc:
746
+ msg = f'response {decoded} lead to exception of {str(exc)}'
747
+ self.logger.error(msg)
748
+ raise Exception(msg) from exc
749
+
750
+ assert isinstance(messages, list)
751
+
752
+ header, data = self.generate_request_data(
753
+ model_type=self.model_type,
754
+ messages=messages,
755
+ gen_params=gen_params,
756
+ json_mode=self.json_mode)
757
+
758
+ max_num_retries, errmsg = 0, ''
759
+ while max_num_retries < self.retry:
760
+ if len(self.invalid_keys) == len(self.keys):
761
+ raise RuntimeError('All keys have insufficient quota.')
762
+
763
+ # find the next valid key
764
+ while True:
765
+ self.key_ctr += 1
766
+ if self.key_ctr == len(self.keys):
767
+ self.key_ctr = 0
768
+
769
+ if self.keys[self.key_ctr] not in self.invalid_keys:
770
+ break
771
+
772
+ key = self.keys[self.key_ctr]
773
+ header['Authorization'] = f'Bearer {key}'
774
+
775
+ if self.orgs:
776
+ self.org_ctr += 1
777
+ if self.org_ctr == len(self.orgs):
778
+ self.org_ctr = 0
779
+ header['OpenAI-Organization'] = self.orgs[self.org_ctr]
780
+
781
+ response = dict()
782
+ try:
783
+ async with aiohttp.ClientSession() as session:
784
+ async with session.post(
785
+ self.url,
786
+ headers=header,
787
+ json=data,
788
+ proxy=self.proxies.get(
789
+ 'https',
790
+ self.proxies.get('http'))) as raw_response:
791
+ async for msg in streaming(raw_response):
792
+ yield msg
793
+ return
794
+ except aiohttp.ClientConnectionError:
795
+ errmsg = 'Got connection error ' + str(traceback.format_exc())
796
+ self.logger.error(errmsg)
797
+ continue
798
+ except aiohttp.ClientResponseError as e:
799
+ errmsg = 'Response error, got ' + str(e)
800
+ self.logger.error(errmsg)
801
+ continue
802
+ except KeyError:
803
+ if 'error' in response:
804
+ if response['error']['code'] == 'rate_limit_exceeded':
805
+ time.sleep(1)
806
+ continue
807
+ elif response['error']['code'] == 'insufficient_quota':
808
+ self.invalid_keys.add(key)
809
+ self.logger.warn(f'insufficient_quota key: {key}')
810
+ continue
811
+
812
+ errmsg = 'Find error message in response: ' + str(
813
+ response['error'])
814
+ self.logger.error(errmsg)
815
+ except Exception as error:
816
+ errmsg = str(error) + '\n' + str(traceback.format_exc())
817
+ self.logger.error(errmsg)
818
+ max_num_retries += 1
819
+
820
+ raise RuntimeError('Calling OpenAI failed after retrying for '
821
+ f'{max_num_retries} times. Check the logs for '
822
+ f'details. errmsg: {errmsg}')
823
+
824
+ def generate_request_data(self,
825
+ model_type,
826
+ messages,
827
+ gen_params,
828
+ json_mode=False):
829
+ """
830
+ Generates the request data for different model types.
831
+
832
+ Args:
833
+ model_type (str): The type of the model (e.g., 'gpt', 'internlm', 'qwen').
834
+ messages (list): The list of messages to be sent to the model.
835
+ gen_params (dict): The generation parameters.
836
+ json_mode (bool): Flag to determine if the response format should be JSON.
837
+
838
+ Returns:
839
+ tuple: A tuple containing the header and the request data.
840
+ """
841
+ # Copy generation parameters to avoid modifying the original dictionary
842
+ gen_params = gen_params.copy()
843
+
844
+ # Hold out 100 tokens due to potential errors in token calculation
845
+ max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
846
+ if max_tokens <= 0:
847
+ return '', ''
848
+
849
+ # Initialize the header
850
+ header = {
851
+ 'content-type': 'application/json',
852
+ }
853
+
854
+ # Common parameters processing
855
+ gen_params['max_tokens'] = max_tokens
856
+ if 'stop_words' in gen_params:
857
+ gen_params['stop'] = gen_params.pop('stop_words')
858
+ if 'repetition_penalty' in gen_params:
859
+ gen_params['frequency_penalty'] = gen_params.pop(
860
+ 'repetition_penalty')
861
+
862
+ # Model-specific processing
863
+ data = {}
864
+ if model_type.lower().startswith('gpt'):
865
+ if 'top_k' in gen_params:
866
+ warnings.warn(
867
+ '`top_k` parameter is deprecated in OpenAI APIs.',
868
+ DeprecationWarning)
869
+ gen_params.pop('top_k')
870
+ gen_params.pop('skip_special_tokens', None)
871
+ gen_params.pop('session_id', None)
872
+ data = {
873
+ 'model': model_type,
874
+ 'messages': messages,
875
+ 'n': 1,
876
+ **gen_params
877
+ }
878
+ if json_mode:
879
+ data['response_format'] = {'type': 'json_object'}
880
+ elif model_type.lower().startswith('internlm'):
881
+ data = {
882
+ 'model': model_type,
883
+ 'messages': messages,
884
+ 'n': 1,
885
+ **gen_params
886
+ }
887
+ if json_mode:
888
+ data['response_format'] = {'type': 'json_object'}
889
+ elif model_type.lower().startswith('qwen'):
890
+ header['X-DashScope-SSE'] = 'enable'
891
+ gen_params.pop('skip_special_tokens', None)
892
+ gen_params.pop('session_id', None)
893
+ if 'frequency_penalty' in gen_params:
894
+ gen_params['repetition_penalty'] = gen_params.pop(
895
+ 'frequency_penalty')
896
+ gen_params['result_format'] = 'message'
897
+ data = {
898
+ 'model': model_type,
899
+ 'input': {
900
+ 'messages': messages
901
+ },
902
+ 'parameters': {
903
+ **gen_params
904
+ }
905
+ }
906
+ else:
907
+ raise NotImplementedError(
908
+ f'Model type {model_type} is not supported')
909
+
910
+ return header, data
911
+
912
+ def tokenize(self, prompt: str) -> list:
913
+ """Tokenize the input prompt.
914
+
915
+ Args:
916
+ prompt (str): Input string.
917
+
918
+ Returns:
919
+ list: token ids
920
+ """
921
+ import tiktoken
922
+ self.tiktoken = tiktoken
923
+ enc = self.tiktoken.encoding_for_model(self.model_type)
924
+ return enc.encode(prompt)
lagent/llms/sensenova.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import time
4
+ import warnings
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from logging import getLogger
7
+ from threading import Lock
8
+ from typing import Dict, Generator, List, Optional, Tuple, Union
9
+
10
+ import requests
11
+
12
+ from lagent.schema import ModelStatusCode
13
+ from lagent.utils.util import filter_suffix
14
+ from .base_api import BaseAPILLM
15
+
16
+ warnings.simplefilter('default')
17
+
18
+ SENSENOVA_API_BASE = 'https://api.sensenova.cn/v1/llm/chat-completions'
19
+
20
+ sensechat_models = {'SenseChat-5': 131072, 'SenseChat-5-Cantonese': 32768}
21
+
22
+
23
+ class SensenovaAPI(BaseAPILLM):
24
+ """Model wrapper around SenseTime's models.
25
+
26
+ Args:
27
+ model_type (str): The name of SenseTime's model.
28
+ retry (int): Number of retires if the API call fails. Defaults to 2.
29
+ key (str or List[str]): SenseTime key(s). In particular, when it
30
+ is set to "ENV", the key will be fetched from the environment
31
+ variable $SENSENOVA_API_KEY. If it's a list, the keys will be
32
+ used in round-robin manner. Defaults to 'ENV'.
33
+ meta_template (Dict, optional): The model's meta prompt
34
+ template if needed, in case the requirement of injecting or
35
+ wrapping of any meta instructions.
36
+ sensenova_api_base (str): The base url of SenseTime's API. Defaults to
37
+ 'https://api.sensenova.cn/v1/llm/chat-completions'.
38
+ gen_params: Default generation configuration which could be overridden
39
+ on the fly of generation.
40
+ """
41
+
42
+ is_api: bool = True
43
+
44
+ def __init__(
45
+ self,
46
+ model_type: str = 'SenseChat-5-Cantonese',
47
+ retry: int = 2,
48
+ json_mode: bool = False,
49
+ key: Union[str, List[str]] = 'ENV',
50
+ meta_template: Optional[Dict] = [
51
+ dict(role='system', api_role='system'),
52
+ dict(role='user', api_role='user'),
53
+ dict(role='assistant', api_role='assistant'),
54
+ dict(role='environment', api_role='system'),
55
+ ],
56
+ sensenova_api_base: str = SENSENOVA_API_BASE,
57
+ proxies: Optional[Dict] = None,
58
+ **gen_params,
59
+ ):
60
+
61
+ super().__init__(
62
+ model_type=model_type,
63
+ meta_template=meta_template,
64
+ retry=retry,
65
+ **gen_params,
66
+ )
67
+ self.logger = getLogger(__name__)
68
+
69
+ if isinstance(key, str):
70
+ # First, apply for SenseNova's ak and sk from SenseTime staff
71
+ # Then, generated SENSENOVA_API_KEY using lagent.utils.gen_key.auto_gen_jwt_token(ak, sk)
72
+ self.keys = [
73
+ os.getenv('SENSENOVA_API_KEY') if key == 'ENV' else key
74
+ ]
75
+ else:
76
+ self.keys = key
77
+
78
+ # record invalid keys and skip them when requesting API
79
+ # - keys have insufficient_quota
80
+ self.invalid_keys = set()
81
+
82
+ self.key_ctr = 0
83
+ self.url = sensenova_api_base
84
+ self.model_type = model_type
85
+ self.proxies = proxies
86
+ self.json_mode = json_mode
87
+
88
+ def chat(
89
+ self,
90
+ inputs: Union[List[dict], List[List[dict]]],
91
+ **gen_params,
92
+ ) -> Union[str, List[str]]:
93
+ """Generate responses given the contexts.
94
+
95
+ Args:
96
+ inputs (Union[List[dict], List[List[dict]]]): a list of messages
97
+ or list of lists of messages
98
+ gen_params: additional generation configuration
99
+
100
+ Returns:
101
+ Union[str, List[str]]: generated string(s)
102
+ """
103
+ assert isinstance(inputs, list)
104
+ if 'max_tokens' in gen_params:
105
+ raise NotImplementedError('unsupported parameter: max_tokens')
106
+ gen_params = {**self.gen_params, **gen_params}
107
+ with ThreadPoolExecutor(max_workers=20) as executor:
108
+ tasks = [
109
+ executor.submit(self._chat,
110
+ self.template_parser._prompt2api(messages),
111
+ **gen_params)
112
+ for messages in (
113
+ [inputs] if isinstance(inputs[0], dict) else inputs)
114
+ ]
115
+ ret = [task.result() for task in tasks]
116
+ return ret[0] if isinstance(inputs[0], dict) else ret
117
+
118
+ def stream_chat(
119
+ self,
120
+ inputs: List[dict],
121
+ **gen_params,
122
+ ) -> Generator[Tuple[ModelStatusCode, str, Optional[str]], None, None]:
123
+ """Generate responses given the contexts.
124
+
125
+ Args:
126
+ inputs (List[dict]): a list of messages
127
+ gen_params: additional generation configuration
128
+
129
+ Yields:
130
+ Tuple[ModelStatusCode, str, Optional[str]]: Status code, generated string, and optional metadata
131
+ """
132
+ assert isinstance(inputs, list)
133
+ if 'max_tokens' in gen_params:
134
+ raise NotImplementedError('unsupported parameter: max_tokens')
135
+ gen_params = self.update_gen_params(**gen_params)
136
+ gen_params['stream'] = True
137
+
138
+ resp = ''
139
+ finished = False
140
+ stop_words = gen_params.get('stop_words') or []
141
+ messages = self.template_parser._prompt2api(inputs)
142
+ for text in self._stream_chat(messages, **gen_params):
143
+ # TODO 测试 resp = text 还是 resp += text
144
+ resp += text
145
+ if not resp:
146
+ continue
147
+ # remove stop_words
148
+ for sw in stop_words:
149
+ if sw in resp:
150
+ resp = filter_suffix(resp, stop_words)
151
+ finished = True
152
+ break
153
+ yield ModelStatusCode.STREAM_ING, resp, None
154
+ if finished:
155
+ break
156
+ yield ModelStatusCode.END, resp, None
157
+
158
+ def _chat(self, messages: List[dict], **gen_params) -> str:
159
+ """Generate completion from a list of templates.
160
+
161
+ Args:
162
+ messages (List[dict]): a list of prompt dictionaries
163
+ gen_params: additional generation configuration
164
+
165
+ Returns:
166
+ str: The generated string.
167
+ """
168
+ assert isinstance(messages, list)
169
+
170
+ header, data = self.generate_request_data(
171
+ model_type=self.model_type,
172
+ messages=messages,
173
+ gen_params=gen_params,
174
+ json_mode=self.json_mode,
175
+ )
176
+
177
+ max_num_retries = 0
178
+ while max_num_retries < self.retry:
179
+ self._wait()
180
+
181
+ with Lock():
182
+ if len(self.invalid_keys) == len(self.keys):
183
+ raise RuntimeError('All keys have insufficient quota.')
184
+
185
+ # find the next valid key
186
+ while True:
187
+ self.key_ctr += 1
188
+ if self.key_ctr == len(self.keys):
189
+ self.key_ctr = 0
190
+
191
+ if self.keys[self.key_ctr] not in self.invalid_keys:
192
+ break
193
+
194
+ key = self.keys[self.key_ctr]
195
+ header['Authorization'] = f'Bearer {key}'
196
+
197
+ response = dict()
198
+ try:
199
+ raw_response = requests.post(
200
+ self.url,
201
+ headers=header,
202
+ data=json.dumps(data),
203
+ proxies=self.proxies,
204
+ )
205
+ response = raw_response.json()
206
+ return response['choices'][0]['message']['content'].strip()
207
+ except requests.ConnectionError:
208
+ print('Got connection error, retrying...')
209
+ continue
210
+ except requests.JSONDecodeError:
211
+ print('JsonDecode error, got', str(raw_response.content))
212
+ continue
213
+ except KeyError:
214
+ if 'error' in response:
215
+ if response['error']['code'] == 'rate_limit_exceeded':
216
+ time.sleep(1)
217
+ continue
218
+ elif response['error']['code'] == 'insufficient_quota':
219
+ self.invalid_keys.add(key)
220
+ self.logger.warn(f'insufficient_quota key: {key}')
221
+ continue
222
+
223
+ print('Find error message in response: ',
224
+ str(response['error']))
225
+ except Exception as error:
226
+ print(str(error))
227
+ max_num_retries += 1
228
+
229
+ raise RuntimeError('Calling SenseTime failed after retrying for '
230
+ f'{max_num_retries} times. Check the logs for '
231
+ 'details.')
232
+
233
+ def _stream_chat(self, messages: List[dict], **gen_params) -> str:
234
+ """Generate completion from a list of templates.
235
+
236
+ Args:
237
+ messages (List[dict]): a list of prompt dictionaries
238
+ gen_params: additional generation configuration
239
+
240
+ Returns:
241
+ str: The generated string.
242
+ """
243
+
244
+ def streaming(raw_response):
245
+ for chunk in raw_response.iter_lines():
246
+ if chunk:
247
+ try:
248
+ decoded_chunk = chunk.decode('utf-8')
249
+ # print(f"Decoded chunk: {decoded_chunk}")
250
+
251
+ if decoded_chunk == 'data:[DONE]':
252
+ # print("Stream ended")
253
+ break
254
+
255
+ if decoded_chunk.startswith('data:'):
256
+ json_str = decoded_chunk[5:]
257
+ chunk_data = json.loads(json_str)
258
+
259
+ if 'data' in chunk_data and 'choices' in chunk_data[
260
+ 'data']:
261
+ choice = chunk_data['data']['choices'][0]
262
+ if 'delta' in choice:
263
+ content = choice['delta']
264
+ yield content
265
+ else:
266
+ print(f'Unexpected format: {decoded_chunk}')
267
+
268
+ except json.JSONDecodeError as e:
269
+ print(f'JSON parsing error: {e}')
270
+ except Exception as e:
271
+ print(
272
+ f'An error occurred while processing the chunk: {e}'
273
+ )
274
+
275
+ assert isinstance(messages, list)
276
+
277
+ header, data = self.generate_request_data(
278
+ model_type=self.model_type,
279
+ messages=messages,
280
+ gen_params=gen_params,
281
+ json_mode=self.json_mode,
282
+ )
283
+
284
+ max_num_retries = 0
285
+ while max_num_retries < self.retry:
286
+ if len(self.invalid_keys) == len(self.keys):
287
+ raise RuntimeError('All keys have insufficient quota.')
288
+
289
+ # find the next valid key
290
+ while True:
291
+ self.key_ctr += 1
292
+ if self.key_ctr == len(self.keys):
293
+ self.key_ctr = 0
294
+
295
+ if self.keys[self.key_ctr] not in self.invalid_keys:
296
+ break
297
+
298
+ key = self.keys[self.key_ctr]
299
+ header['Authorization'] = f'Bearer {key}'
300
+
301
+ response = dict()
302
+ try:
303
+ raw_response = requests.post(
304
+ self.url,
305
+ headers=header,
306
+ data=json.dumps(data),
307
+ proxies=self.proxies,
308
+ )
309
+ return streaming(raw_response)
310
+ except requests.ConnectionError:
311
+ print('Got connection error, retrying...')
312
+ continue
313
+ except requests.JSONDecodeError:
314
+ print('JsonDecode error, got', str(raw_response.content))
315
+ continue
316
+ except KeyError:
317
+ if 'error' in response:
318
+ if response['error']['code'] == 'rate_limit_exceeded':
319
+ time.sleep(1)
320
+ continue
321
+ elif response['error']['code'] == 'insufficient_quota':
322
+ self.invalid_keys.add(key)
323
+ self.logger.warn(f'insufficient_quota key: {key}')
324
+ continue
325
+
326
+ print('Find error message in response: ',
327
+ str(response['error']))
328
+ except Exception as error:
329
+ print(str(error))
330
+ max_num_retries += 1
331
+
332
+ raise RuntimeError('Calling SenseTime failed after retrying for '
333
+ f'{max_num_retries} times. Check the logs for '
334
+ 'details.')
335
+
336
+ def generate_request_data(self,
337
+ model_type,
338
+ messages,
339
+ gen_params,
340
+ json_mode=False):
341
+ """
342
+ Generates the request data for different model types.
343
+
344
+ Args:
345
+ model_type (str): The type of the model (e.g., 'sense').
346
+ messages (list): The list of messages to be sent to the model.
347
+ gen_params (dict): The generation parameters.
348
+ json_mode (bool): Flag to determine if the response format should be JSON.
349
+
350
+ Returns:
351
+ tuple: A tuple containing the header and the request data.
352
+ """
353
+ # Copy generation parameters to avoid modifying the original dictionary
354
+ gen_params = gen_params.copy()
355
+
356
+ # Hold out 100 tokens due to potential errors in token calculation
357
+ max_tokens = min(gen_params.pop('max_new_tokens'), 4096)
358
+ if max_tokens <= 0:
359
+ return '', ''
360
+
361
+ # Initialize the header
362
+ header = {
363
+ 'content-type': 'application/json',
364
+ }
365
+
366
+ # Common parameters processing
367
+ gen_params['max_tokens'] = max_tokens
368
+ if 'stop_words' in gen_params:
369
+ gen_params['stop'] = gen_params.pop('stop_words')
370
+ if 'repetition_penalty' in gen_params:
371
+ gen_params['frequency_penalty'] = gen_params.pop(
372
+ 'repetition_penalty')
373
+
374
+ # Model-specific processing
375
+ data = {}
376
+ if model_type.lower().startswith('sense'):
377
+ gen_params.pop('skip_special_tokens', None)
378
+ gen_params.pop('session_id', None)
379
+ data = {
380
+ 'model': model_type,
381
+ 'messages': messages,
382
+ 'n': 1,
383
+ **gen_params
384
+ }
385
+ if json_mode:
386
+ data['response_format'] = {'type': 'json_object'}
387
+ else:
388
+ raise NotImplementedError(
389
+ f'Model type {model_type} is not supported')
390
+
391
+ return header, data
392
+
393
+ def tokenize(self, prompt: str) -> list:
394
+ """Tokenize the input prompt.
395
+
396
+ Args:
397
+ prompt (str): Input string.
398
+
399
+ Returns:
400
+ list: token ids
401
+ """
402
+ import tiktoken
403
+
404
+ self.tiktoken = tiktoken
405
+ enc = self.tiktoken.encoding_for_model('gpt-4o')
406
+ return enc.encode(prompt)
lagent/llms/vllm_wrapper.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from typing import List, Union
3
+
4
+ from lagent.llms.base_llm import AsyncBaseLLM, BaseLLM
5
+ from lagent.utils.util import filter_suffix
6
+
7
+
8
+ def asdict_completion(output):
9
+ return {
10
+ key: getattr(output, key)
11
+ for key in [
12
+ 'text', 'token_ids', 'cumulative_logprob', 'logprobs',
13
+ 'finish_reason', 'stop_reason'
14
+ ]
15
+ }
16
+
17
+
18
+ class VllmModel(BaseLLM):
19
+ """
20
+ A wrapper of vLLM model.
21
+
22
+ Args:
23
+ path (str): The path to the model.
24
+ It could be one of the following options:
25
+ - i) A local directory path of a huggingface model.
26
+ - ii) The model_id of a model hosted inside a model repo
27
+ on huggingface.co, such as "internlm/internlm-chat-7b",
28
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
29
+ and so on.
30
+ tp (int): tensor parallel
31
+ vllm_cfg (dict): Other kwargs for vllm model initialization.
32
+ """
33
+
34
+ def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
35
+
36
+ super().__init__(path=path, **kwargs)
37
+ from vllm import LLM
38
+ self.model = LLM(
39
+ model=self.path,
40
+ trust_remote_code=True,
41
+ tensor_parallel_size=tp,
42
+ **vllm_cfg)
43
+
44
+ def generate(self,
45
+ inputs: Union[str, List[str]],
46
+ do_preprocess: bool = None,
47
+ skip_special_tokens: bool = False,
48
+ return_dict: bool = False,
49
+ **kwargs):
50
+ """Return the chat completions in non-stream mode.
51
+
52
+ Args:
53
+ inputs (Union[str, List[str]]): input texts to be completed.
54
+ do_preprocess (bool): whether pre-process the messages. Default to
55
+ True, which means chat_template will be applied.
56
+ skip_special_tokens (bool): Whether or not to remove special tokens
57
+ in the decoding. Default to be False.
58
+ Returns:
59
+ (a list of/batched) text/chat completion
60
+ """
61
+ from vllm import SamplingParams
62
+
63
+ batched = True
64
+ if isinstance(inputs, str):
65
+ inputs = [inputs]
66
+ batched = False
67
+ prompt = inputs
68
+ gen_params = self.update_gen_params(**kwargs)
69
+ max_new_tokens = gen_params.pop('max_new_tokens')
70
+ stop_words = gen_params.pop('stop_words')
71
+
72
+ sampling_config = SamplingParams(
73
+ skip_special_tokens=skip_special_tokens,
74
+ max_tokens=max_new_tokens,
75
+ stop=stop_words,
76
+ **gen_params)
77
+ response = self.model.generate(prompt, sampling_params=sampling_config)
78
+ texts = [resp.outputs[0].text for resp in response]
79
+ # remove stop_words
80
+ texts = filter_suffix(texts, self.gen_params.get('stop_words'))
81
+ for resp, text in zip(response, texts):
82
+ resp.outputs[0].text = text
83
+ if batched:
84
+ return [asdict_completion(resp.outputs[0])
85
+ for resp in response] if return_dict else texts
86
+ return asdict_completion(
87
+ response[0].outputs[0]) if return_dict else texts[0]
88
+
89
+
90
+ class AsyncVllmModel(AsyncBaseLLM):
91
+ """
92
+ A asynchronous wrapper of vLLM model.
93
+
94
+ Args:
95
+ path (str): The path to the model.
96
+ It could be one of the following options:
97
+ - i) A local directory path of a huggingface model.
98
+ - ii) The model_id of a model hosted inside a model repo
99
+ on huggingface.co, such as "internlm/internlm-chat-7b",
100
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
101
+ and so on.
102
+ tp (int): tensor parallel
103
+ vllm_cfg (dict): Other kwargs for vllm model initialization.
104
+ """
105
+
106
+ def __init__(self, path: str, tp: int = 1, vllm_cfg=dict(), **kwargs):
107
+ super().__init__(path=path, **kwargs)
108
+ from vllm import AsyncEngineArgs, AsyncLLMEngine
109
+
110
+ engine_args = AsyncEngineArgs(
111
+ model=self.path,
112
+ trust_remote_code=True,
113
+ tensor_parallel_size=tp,
114
+ **vllm_cfg)
115
+ self.model = AsyncLLMEngine.from_engine_args(engine_args)
116
+
117
+ async def generate(self,
118
+ inputs: Union[str, List[str]],
119
+ session_ids: Union[int, List[int]] = None,
120
+ do_preprocess: bool = None,
121
+ skip_special_tokens: bool = False,
122
+ return_dict: bool = False,
123
+ **kwargs):
124
+ """Return the chat completions in non-stream mode.
125
+
126
+ Args:
127
+ inputs (Union[str, List[str]]): input texts to be completed.
128
+ do_preprocess (bool): whether pre-process the messages. Default to
129
+ True, which means chat_template will be applied.
130
+ skip_special_tokens (bool): Whether or not to remove special tokens
131
+ in the decoding. Default to be False.
132
+ Returns:
133
+ (a list of/batched) text/chat completion
134
+ """
135
+ from vllm import SamplingParams
136
+
137
+ batched = True
138
+ if isinstance(inputs, str):
139
+ inputs = [inputs]
140
+ batched = False
141
+ if session_ids is None:
142
+ session_ids = list(range(len(inputs)))
143
+ elif isinstance(session_ids, (int, str)):
144
+ session_ids = [session_ids]
145
+ assert len(inputs) == len(session_ids)
146
+
147
+ prompt = inputs
148
+ gen_params = self.update_gen_params(**kwargs)
149
+ max_new_tokens = gen_params.pop('max_new_tokens')
150
+ stop_words = gen_params.pop('stop_words')
151
+
152
+ sampling_config = SamplingParams(
153
+ skip_special_tokens=skip_special_tokens,
154
+ max_tokens=max_new_tokens,
155
+ stop=stop_words,
156
+ **gen_params)
157
+
158
+ async def _inner_generate(uid, text):
159
+ resp, generator = '', self.model.generate(
160
+ text, sampling_params=sampling_config, request_id=uid)
161
+ async for out in generator:
162
+ resp = out.outputs[0]
163
+ return resp
164
+
165
+ response = await asyncio.gather(*[
166
+ _inner_generate(sid, inp) for sid, inp in zip(session_ids, prompt)
167
+ ])
168
+ texts = [resp.text for resp in response]
169
+ # remove stop_words
170
+ texts = filter_suffix(texts, self.gen_params.get('stop_words'))
171
+ for resp, text in zip(response, texts):
172
+ resp.text = text
173
+ if batched:
174
+ return [asdict_completion(resp)
175
+ for resp in response] if return_dict else texts
176
+ return asdict_completion(response[0]) if return_dict else texts[0]
lagent/memory/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .base_memory import Memory
2
+ from .manager import MemoryManager
3
+
4
+ __all__ = ['Memory', 'MemoryManager']
lagent/memory/base_memory.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Union
2
+
3
+ from lagent.schema import AgentMessage
4
+
5
+
6
+ class Memory:
7
+
8
+ def __init__(self, recent_n=None) -> None:
9
+ self.memory: List[AgentMessage] = []
10
+ self.recent_n = recent_n
11
+
12
+ def get_memory(
13
+ self,
14
+ recent_n: Optional[int] = None,
15
+ filter_func: Optional[Callable[[int, dict], bool]] = None,
16
+ ) -> list:
17
+ recent_n = recent_n or self.recent_n
18
+ if recent_n is not None:
19
+ memory = self.memory[-recent_n:]
20
+ else:
21
+ memory = self.memory
22
+ if filter_func is not None:
23
+ memory = [m for i, m in enumerate(memory) if filter_func(i, m)]
24
+ return memory
25
+
26
+ def add(self, memories: Union[List[Dict], Dict, None]) -> None:
27
+ for memory in memories if isinstance(memories,
28
+ (list, tuple)) else [memories]:
29
+ if isinstance(memory, str):
30
+ memory = AgentMessage(sender='user', content=memory)
31
+ if isinstance(memory, AgentMessage):
32
+ self.memory.append(memory)
33
+
34
+ def delete(self, index: Union[List, int]) -> None:
35
+ if isinstance(index, int):
36
+ del self.memory[index]
37
+ else:
38
+ for i in index:
39
+ del self.memory[i]
40
+
41
+ def load(
42
+ self,
43
+ memories: Union[str, Dict, List],
44
+ overwrite: bool = True,
45
+ ) -> None:
46
+ if overwrite:
47
+ self.memory = []
48
+ if isinstance(memories, dict):
49
+ self.memory.append(AgentMessage(**memories))
50
+ elif isinstance(memories, list):
51
+ for m in memories:
52
+ self.memory.append(AgentMessage(**m))
53
+ else:
54
+ raise TypeError(f'{type(memories)} is not supported')
55
+
56
+ def save(self) -> List[dict]:
57
+ memory = []
58
+ for m in self.memory:
59
+ memory.append(m.model_dump())
60
+ return memory
lagent/memory/manager.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from ..utils import create_object
4
+ from .base_memory import Memory
5
+
6
+
7
+ class MemoryManager:
8
+
9
+ def __init__(self, cfg) -> None:
10
+ self.cfg = cfg
11
+ self.memory_map: Dict[str, Memory] = {}
12
+
13
+ def create_instance(self, session_id):
14
+ self.memory_map[session_id] = create_object(self.cfg)
15
+
16
+ def get_memory(self, session_id=0, **kwargs) -> list:
17
+ return self.memory_map[session_id].get_memory(**kwargs)
18
+
19
+ def add(self, memory, session_id=0, **kwargs) -> None:
20
+ if session_id not in self.memory_map:
21
+ self.create_instance(session_id)
22
+ self.memory_map[session_id].add(memory, **kwargs)
23
+
24
+ def get(self, session_id=0) -> Memory:
25
+ return self.memory_map.get(session_id, None)
26
+
27
+ def reset(self, session_id=0) -> None:
28
+ if session_id in self.memory_map:
29
+ del self.memory_map[session_id]
lagent/prompts/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .parsers import * # noqa
2
+ from .prompt_template import PromptTemplate
3
+
4
+ __all__ = ['PromptTemplate']