Spaces:
Runtime error
Runtime error
init
Browse files- .gitignore +27 -0
- app.py +238 -0
- bmtools/__init__.py +2 -0
- bmtools/agent/BabyagiTools.py +317 -0
- bmtools/agent/__init__.py +0 -0
- bmtools/agent/apitool.py +102 -0
- bmtools/agent/executor.py +109 -0
- bmtools/agent/singletool.py +173 -0
- bmtools/agent/tools_controller.py +113 -0
- bmtools/agent/translator.py +116 -0
- bmtools/tools/__init__.py +15 -0
- bmtools/tools/film/__init__.py +6 -0
- bmtools/tools/film/douban/__init__.py +1 -0
- bmtools/tools/film/douban/api.py +234 -0
- bmtools/tools/film/douban/readme.md +3 -0
- bmtools/tools/film/douban/test.py +15 -0
- bmtools/tools/registry.py +38 -0
- bmtools/tools/retriever.py +45 -0
- bmtools/tools/serve.py +102 -0
- bmtools/tools/tool.py +87 -0
- bmtools/utils/__init__.py +0 -0
- bmtools/utils/logging.py +279 -0
- requirements.txt +23 -0
.gitignore
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Compiled python files
|
2 |
+
*.pyc
|
3 |
+
|
4 |
+
# Compiled C extensions
|
5 |
+
*.so
|
6 |
+
|
7 |
+
# Distribution / packaging
|
8 |
+
dist/
|
9 |
+
build/
|
10 |
+
*.egg-info/
|
11 |
+
|
12 |
+
# IDE / editor files
|
13 |
+
.idea/
|
14 |
+
*.swp
|
15 |
+
*~
|
16 |
+
|
17 |
+
# Virtual environment
|
18 |
+
venv/
|
19 |
+
env/
|
20 |
+
|
21 |
+
__pycache__/
|
22 |
+
.vscode/
|
23 |
+
.DS_Store
|
24 |
+
|
25 |
+
cache/
|
26 |
+
|
27 |
+
secret_keys_mine.sh
|
app.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys
|
3 |
+
# sys.path.append('./inference/')
|
4 |
+
import bmtools
|
5 |
+
from bmtools.agent.tools_controller import MTQuestionAnswerer, load_valid_tools
|
6 |
+
from bmtools.agent.singletool import STQuestionAnswerer
|
7 |
+
from langchain.schema import AgentFinish
|
8 |
+
import os
|
9 |
+
import requests
|
10 |
+
from threading import Thread
|
11 |
+
from multiprocessing import Process
|
12 |
+
import time
|
13 |
+
|
14 |
+
available_models = ["ChatGPT", "GPT-3.5"]
|
15 |
+
DEFAULTMODEL = "GPT-3.5"
|
16 |
+
|
17 |
+
tools_mappings = {
|
18 |
+
"klarna": "https://www.klarna.com/",
|
19 |
+
"chemical-prop": "http://127.0.0.1:8079/tools/chemical-prop/",
|
20 |
+
"wolframalpha": "http://127.0.0.1:8079/tools/wolframalpha/",
|
21 |
+
"weather": "http://127.0.0.1:8079/tools/weather/",
|
22 |
+
"douban-film": "http://127.0.0.1:8079/tools/douban-film/",
|
23 |
+
"wikipedia": "http://127.0.0.1:8079/tools/wikipedia/",
|
24 |
+
"office-ppt": "http://127.0.0.1:8079/tools/office-ppt/",
|
25 |
+
"bing_search": "http://127.0.0.1:8079/tools/bing_search/",
|
26 |
+
"map": "http://127.0.0.1:8079/tools/map/",
|
27 |
+
"stock": "http://127.0.0.1:8079/tools/stock/",
|
28 |
+
"baidu-translation": "http://127.0.0.1:8079/tools/baidu-translation/",
|
29 |
+
"nllb-translation": "http://127.0.0.1:8079/tools/nllb-translation/",
|
30 |
+
}
|
31 |
+
|
32 |
+
valid_tools_info = {}
|
33 |
+
all_tools_list = []
|
34 |
+
|
35 |
+
gr.close_all()
|
36 |
+
|
37 |
+
MAX_TURNS = 30
|
38 |
+
MAX_BOXES = MAX_TURNS * 2
|
39 |
+
|
40 |
+
return_msg = []
|
41 |
+
chat_history = ""
|
42 |
+
|
43 |
+
tool_server_flag = False
|
44 |
+
|
45 |
+
def run_tool_server():
|
46 |
+
def run_server():
|
47 |
+
server = bmtools.ToolServer()
|
48 |
+
# server.load_tool("chemical-prop")
|
49 |
+
server.load_tool("douban-film")
|
50 |
+
# server.load_tool("weather")
|
51 |
+
# server.load_tool("wikipedia")
|
52 |
+
# server.load_tool("wolframalpha")
|
53 |
+
# server.load_tool("bing_search")
|
54 |
+
# server.load_tool("office-ppt")
|
55 |
+
# server.load_tool("stock")
|
56 |
+
# server.load_tool("map")
|
57 |
+
# server.load_tool("nllb-translation")
|
58 |
+
# server.load_tool("baidu-translation")
|
59 |
+
# server.load_tool("tutorial")
|
60 |
+
server.serve()
|
61 |
+
# server = Thread(target=run_server)
|
62 |
+
server = Process(target=run_server)
|
63 |
+
server.start()
|
64 |
+
global tool_server_flag
|
65 |
+
tool_server_flag = True
|
66 |
+
|
67 |
+
def load_tools():
|
68 |
+
global valid_tools_info
|
69 |
+
global all_tools_list
|
70 |
+
valid_tools_info = load_valid_tools(tools_mappings)
|
71 |
+
all_tools_list = sorted(list(valid_tools_info.keys()))
|
72 |
+
return gr.update(choices=all_tools_list)
|
73 |
+
|
74 |
+
def set_environ(OPENAI_API_KEY: str,
|
75 |
+
WOLFRAMALPH_APP_ID: str = "",
|
76 |
+
WEATHER_API_KEYS: str = "",
|
77 |
+
BING_SUBSCRIPT_KEY: str = "",
|
78 |
+
ALPHA_VANTAGE_KEY: str = "",
|
79 |
+
BING_MAP_KEY: str = "",
|
80 |
+
BAIDU_TRANSLATE_KEY: str = "",
|
81 |
+
BAIDU_SECRET_KEY: str = "") -> str:
|
82 |
+
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
83 |
+
os.environ["WOLFRAMALPH_APP_ID"] = WOLFRAMALPH_APP_ID
|
84 |
+
os.environ["WEATHER_API_KEYS"] = WEATHER_API_KEYS
|
85 |
+
os.environ["BING_SUBSCRIPT_KEY"] = BING_SUBSCRIPT_KEY
|
86 |
+
os.environ["ALPHA_VANTAGE_KEY"] = ALPHA_VANTAGE_KEY
|
87 |
+
os.environ["BING_MAP_KEY"] = BING_MAP_KEY
|
88 |
+
os.environ["BAIDU_TRANSLATE_KEY"] = BAIDU_TRANSLATE_KEY
|
89 |
+
os.environ["BAIDU_SECRET_KEY"] = BAIDU_SECRET_KEY
|
90 |
+
if not tool_server_flag:
|
91 |
+
run_tool_server()
|
92 |
+
time.sleep(10)
|
93 |
+
return gr.update(value="OK!")
|
94 |
+
|
95 |
+
def show_avatar_imgs(tools_chosen):
|
96 |
+
if len(tools_chosen) == 0:
|
97 |
+
tools_chosen = list(valid_tools_info.keys())
|
98 |
+
img_template = '<a href="{}" style="float: left"> <img style="margin:5px" src="{}.png" width="24" height="24" alt="avatar" /> {} </a>'
|
99 |
+
imgs = [valid_tools_info[tool]['avatar'] for tool in tools_chosen if valid_tools_info[tool]['avatar'] != None]
|
100 |
+
imgs = ' '.join([img_template.format(img, img, tool ) for img, tool in zip(imgs, tools_chosen) ])
|
101 |
+
return [gr.update(value='<span class="">'+imgs+'</span>', visible=True), gr.update(visible=True)]
|
102 |
+
|
103 |
+
def answer_by_tools(question, tools_chosen, model_chosen):
|
104 |
+
global return_msg
|
105 |
+
return_msg += [(question, None), (None, '...')]
|
106 |
+
yield [gr.update(visible=True, value=return_msg), gr.update(), gr.update()]
|
107 |
+
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY', '')
|
108 |
+
|
109 |
+
if len(tools_chosen) == 0: # if there is no tools chosen, we use all todo (TODO: What if the pool is too large.)
|
110 |
+
tools_chosen = list(valid_tools_info.keys())
|
111 |
+
|
112 |
+
if len(tools_chosen) == 1:
|
113 |
+
answerer = STQuestionAnswerer(OPENAI_API_KEY.strip(), stream_output=True, llm=model_chosen)
|
114 |
+
agent_executor = answerer.load_tools(tools_chosen[0], valid_tools_info[tools_chosen[0]], prompt_type="react-with-tool-description", return_intermediate_steps=True)
|
115 |
+
else:
|
116 |
+
answerer = MTQuestionAnswerer(OPENAI_API_KEY.strip(), load_valid_tools({k: tools_mappings[k] for k in tools_chosen}), stream_output=True, llm=model_chosen)
|
117 |
+
|
118 |
+
agent_executor = answerer.build_runner()
|
119 |
+
|
120 |
+
global chat_history
|
121 |
+
chat_history += "Question: " + question + "\n"
|
122 |
+
question = chat_history
|
123 |
+
for inter in agent_executor(question):
|
124 |
+
if isinstance(inter, AgentFinish): continue
|
125 |
+
result_str = []
|
126 |
+
return_msg.pop()
|
127 |
+
if isinstance(inter, dict):
|
128 |
+
result_str.append("<font color=red>Answer:</font> {}".format(inter['output']))
|
129 |
+
chat_history += "Answer:" + inter['output'] + "\n"
|
130 |
+
result_str.append("...")
|
131 |
+
else:
|
132 |
+
not_observation = inter[0].log
|
133 |
+
if not not_observation.startswith('Thought:'):
|
134 |
+
not_observation = "Thought: " + not_observation
|
135 |
+
chat_history += not_observation
|
136 |
+
not_observation = not_observation.replace('Thought:', '<font color=green>Thought: </font>')
|
137 |
+
not_observation = not_observation.replace('Action:', '<font color=purple>Action: </font>')
|
138 |
+
not_observation = not_observation.replace('Action Input:', '<font color=purple>Action Input: </font>')
|
139 |
+
result_str.append("{}".format(not_observation))
|
140 |
+
result_str.append("<font color=blue>Action output:</font>\n{}".format(inter[1]))
|
141 |
+
chat_history += "\nAction output:" + inter[1] + "\n"
|
142 |
+
result_str.append("...")
|
143 |
+
return_msg += [(None, result) for result in result_str]
|
144 |
+
yield [gr.update(visible=True, value=return_msg), gr.update(), gr.update()]
|
145 |
+
return_msg.pop()
|
146 |
+
if return_msg[-1][1].startswith("<font color=red>Answer:</font> "):
|
147 |
+
return_msg[-1] = (return_msg[-1][0], return_msg[-1][1].replace("<font color=red>Answer:</font> ", "<font color=green>Final Answer:</font> "))
|
148 |
+
yield [gr.update(visible=True, value=return_msg), gr.update(visible=True), gr.update(visible=False)]
|
149 |
+
|
150 |
+
def retrieve(tools_search):
|
151 |
+
if tools_search == "":
|
152 |
+
return gr.update(choices=all_tools_list)
|
153 |
+
else:
|
154 |
+
url = "http://127.0.0.1:8079/retrieve"
|
155 |
+
param = {
|
156 |
+
"query": tools_search
|
157 |
+
}
|
158 |
+
response = requests.post(url, json=param)
|
159 |
+
result = response.json()
|
160 |
+
retrieved_tools = result["tools"]
|
161 |
+
return gr.update(choices=retrieved_tools)
|
162 |
+
|
163 |
+
def clear_history():
|
164 |
+
global return_msg
|
165 |
+
global chat_history
|
166 |
+
return_msg = []
|
167 |
+
chat_history = ""
|
168 |
+
yield gr.update(visible=True, value=return_msg)
|
169 |
+
|
170 |
+
with gr.Blocks() as demo:
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column(scale=14):
|
173 |
+
gr.Markdown("<h1 align='left'> BMTools </h1>")
|
174 |
+
with gr.Column(scale=1):
|
175 |
+
gr.Markdown('<img src="https://openbmb.cn/openbmb/img/head_logo.e9d9f3f.png" width="140">')
|
176 |
+
with gr.Row():
|
177 |
+
with gr.Column(scale=1):
|
178 |
+
OPENAI_API_KEY = gr.Textbox(label="OpenAI API KEY:", placeholder="sk-...", type="text")
|
179 |
+
# WOLFRAMALPH_APP_ID = gr.Textbox(label="WOLFRAMALPH APP ID:", type="text")
|
180 |
+
# WEATHER_API_KEYS = gr.Textbox(label="WEATHER API KEYS:", type="text")
|
181 |
+
# BING_SUBSCRIPT_KEY = gr.Textbox(label="BING SUBSCRIPT KEY:", type="text")
|
182 |
+
# ALPHA_VANTAGE_KEY = gr.Textbox(label="ALPHA VANTAGE KEY:", type="text")
|
183 |
+
# BING_MAP_KEY = gr.Textbox(label="BING MAP KEY:", type="text")
|
184 |
+
# BAIDU_TRANSLATE_KEY = gr.Textbox(label="BAIDU TRANSLATE KEY:", type="text")
|
185 |
+
# BAIDU_SECRET_KEY = gr.Textbox(label="BAIDU SECRET KEY:", type="text")
|
186 |
+
key_set_btn = gr.Button(value="Set")
|
187 |
+
|
188 |
+
with gr.Column(scale=4):
|
189 |
+
with gr.Row():
|
190 |
+
with gr.Column(scale=0.85):
|
191 |
+
txt = gr.Textbox(show_label=False, placeholder="Question here. Use Shift+Enter to add new line.", lines=1).style(container=False)
|
192 |
+
with gr.Column(scale=0.15, min_width=0):
|
193 |
+
buttonClear = gr.Button("Clear History")
|
194 |
+
buttonStop = gr.Button("Stop", visible=False)
|
195 |
+
|
196 |
+
chatbot = gr.Chatbot(show_label=False, visible=True).style(height=600)
|
197 |
+
|
198 |
+
with gr.Column(scale=1):
|
199 |
+
with gr.Column():
|
200 |
+
tools_search = gr.Textbox(
|
201 |
+
lines=1,
|
202 |
+
label="Tools Search",
|
203 |
+
info="Please input some text to search tools.",
|
204 |
+
)
|
205 |
+
buttonSearch = gr.Button("Clear")
|
206 |
+
tools_chosen = gr.CheckboxGroup(
|
207 |
+
choices=all_tools_list,
|
208 |
+
value=["chemical-prop"],
|
209 |
+
label="Tools provided",
|
210 |
+
info="Choose the tools to solve your question.",
|
211 |
+
)
|
212 |
+
model_chosen = gr.Dropdown(
|
213 |
+
list(available_models), value=DEFAULTMODEL, multiselect=False, label="Model provided", info="Choose the model to solve your question, Default means ChatGPT."
|
214 |
+
)
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
key_set_btn.click(fn=set_environ, inputs=[
|
219 |
+
OPENAI_API_KEY,
|
220 |
+
# WOLFRAMALPH_APP_ID,
|
221 |
+
# WEATHER_API_KEYS,
|
222 |
+
# BING_SUBSCRIPT_KEY,
|
223 |
+
# ALPHA_VANTAGE_KEY,
|
224 |
+
# BING_MAP_KEY,
|
225 |
+
# BAIDU_TRANSLATE_KEY,
|
226 |
+
# BAIDU_SECRET_KEY
|
227 |
+
], outputs=key_set_btn)
|
228 |
+
key_set_btn.click(fn=load_tools, outputs=tools_chosen)
|
229 |
+
|
230 |
+
tools_search.change(retrieve, tools_search, tools_chosen)
|
231 |
+
buttonSearch.click(lambda : [gr.update(value=""), gr.update(choices=all_tools_list)], [], [tools_search, tools_chosen])
|
232 |
+
|
233 |
+
txt.submit(lambda : [gr.update(value=''), gr.update(visible=False), gr.update(visible=True)], [], [txt, buttonClear, buttonStop])
|
234 |
+
inference_event = txt.submit(answer_by_tools, [txt, tools_chosen, model_chosen], [chatbot, buttonClear, buttonStop])
|
235 |
+
buttonStop.click(lambda : [gr.update(visible=True), gr.update(visible=False)], [], [buttonClear, buttonStop], cancels=[inference_event])
|
236 |
+
buttonClear.click(clear_history, [], chatbot)
|
237 |
+
|
238 |
+
demo.queue().launch(share=False, inbrowser=True, server_name="127.0.0.1", server_port=7001)
|
bmtools/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .tools.serve import ToolServer
|
2 |
+
from .utils.logging import get_logger
|
bmtools/agent/BabyagiTools.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import deque
|
2 |
+
from typing import Dict, List, Optional, Any
|
3 |
+
import re
|
4 |
+
|
5 |
+
from langchain import LLMChain, OpenAI, PromptTemplate, SerpAPIWrapper
|
6 |
+
from langchain.embeddings import OpenAIEmbeddings
|
7 |
+
from langchain.llms import BaseLLM
|
8 |
+
from langchain.vectorstores.base import VectorStore
|
9 |
+
from pydantic import BaseModel, Field
|
10 |
+
from langchain.chains.base import Chain
|
11 |
+
|
12 |
+
from langchain.vectorstores import FAISS
|
13 |
+
import faiss
|
14 |
+
from langchain.docstore import InMemoryDocstore
|
15 |
+
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
|
16 |
+
from bmtools.agent.executor import Executor, AgentExecutorWithTranslation
|
17 |
+
|
18 |
+
class ContextAwareAgent(ZeroShotAgent):
|
19 |
+
def get_full_inputs(
|
20 |
+
self, intermediate_steps, **kwargs: Any
|
21 |
+
) -> Dict[str, Any]:
|
22 |
+
"""Create the full inputs for the LLMChain from intermediate steps."""
|
23 |
+
thoughts = self._construct_scratchpad(intermediate_steps)
|
24 |
+
new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
|
25 |
+
full_inputs = {**kwargs, **new_inputs}
|
26 |
+
return full_inputs
|
27 |
+
|
28 |
+
def _construct_scratchpad(self, intermediate_steps):
|
29 |
+
"""Construct the scratchpad that lets the agent continue its thought process."""
|
30 |
+
thoughts = ""
|
31 |
+
# only modify the following line, [-2: ]
|
32 |
+
for action, observation in intermediate_steps[-2: ]:
|
33 |
+
thoughts += action.log
|
34 |
+
thoughts += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
35 |
+
if "is not a valid tool, try another one" in observation:
|
36 |
+
thoughts += "You should select another tool rather than the invalid one.\n"
|
37 |
+
return thoughts
|
38 |
+
|
39 |
+
class TaskCreationChain(LLMChain):
|
40 |
+
"""Chain to generates tasks."""
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
|
44 |
+
"""Get the response parser."""
|
45 |
+
task_creation_template = (
|
46 |
+
"You are an task creation AI that uses the result of an execution agent"
|
47 |
+
" to create new tasks with the following objective: {objective},"
|
48 |
+
" The last completed task has the result: {result}."
|
49 |
+
" This result was based on this task description: {task_description}."
|
50 |
+
" These are incomplete tasks: {incomplete_tasks}."
|
51 |
+
" Based on the result, create new tasks to be completed"
|
52 |
+
" by the AI system that do not overlap with incomplete tasks."
|
53 |
+
" For a simple objective, do not generate complex todo lists."
|
54 |
+
" Do not generate repetitive tasks (e.g., tasks that have already been completed)."
|
55 |
+
" If there is not futher task needed to complete the objective, return NO TASK."
|
56 |
+
" Now return the tasks as an array."
|
57 |
+
)
|
58 |
+
prompt = PromptTemplate(
|
59 |
+
template=task_creation_template,
|
60 |
+
input_variables=["result", "task_description", "incomplete_tasks", "objective"],
|
61 |
+
)
|
62 |
+
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
63 |
+
|
64 |
+
class InitialTaskCreationChain(LLMChain):
|
65 |
+
"""Chain to generates tasks."""
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
|
69 |
+
"""Get the response parser."""
|
70 |
+
task_creation_template = (
|
71 |
+
"You are a planner who is an expert at coming up with a todo list for a given objective. For a simple objective, do not generate a complex todo list. Generate the first (only one) task needed to do for this objective: {objective}"
|
72 |
+
)
|
73 |
+
prompt = PromptTemplate(
|
74 |
+
template=task_creation_template,
|
75 |
+
input_variables=["objective"],
|
76 |
+
)
|
77 |
+
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
78 |
+
|
79 |
+
class TaskPrioritizationChain(LLMChain):
|
80 |
+
"""Chain to prioritize tasks."""
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def from_llm(cls, llm: BaseLLM, verbose: bool = True) -> LLMChain:
|
84 |
+
"""Get the response parser."""
|
85 |
+
task_prioritization_template = (
|
86 |
+
"You are an task prioritization AI tasked with cleaning the formatting of and reprioritizing"
|
87 |
+
" the following tasks: {task_names}."
|
88 |
+
" Consider the ultimate objective of your team: {objective}."
|
89 |
+
" Do not make up any tasks, just reorganize the existing tasks."
|
90 |
+
" Do not remove any tasks. Return the result as a numbered list, like:"
|
91 |
+
" #. First task"
|
92 |
+
" #. Second task"
|
93 |
+
" Start the task list with number {next_task_id}. (e.g., 2. ***, 3. ***, etc.)"
|
94 |
+
)
|
95 |
+
prompt = PromptTemplate(
|
96 |
+
template=task_prioritization_template,
|
97 |
+
input_variables=["task_names", "next_task_id", "objective"],
|
98 |
+
)
|
99 |
+
return cls(prompt=prompt, llm=llm, verbose=verbose)
|
100 |
+
|
101 |
+
def get_next_task(task_creation_chain: LLMChain, result: Dict, task_description: str, task_list: List[str], objective: str) -> List[Dict]:
|
102 |
+
"""Get the next task."""
|
103 |
+
incomplete_tasks = ", ".join(task_list)
|
104 |
+
response = task_creation_chain.run(result=result, task_description=task_description, incomplete_tasks=incomplete_tasks, objective=objective)
|
105 |
+
# change the split method to re matching
|
106 |
+
# new_tasks = response.split('\n')
|
107 |
+
task_pattern = re.compile(r'\d+\. (.+?)\n')
|
108 |
+
new_tasks = task_pattern.findall(response)
|
109 |
+
|
110 |
+
return [{"task_name": task_name} for task_name in new_tasks if task_name.strip()]
|
111 |
+
|
112 |
+
def prioritize_tasks(task_prioritization_chain: LLMChain, this_task_id: int, task_list: List[Dict], objective: str) -> List[Dict]:
|
113 |
+
"""Prioritize tasks."""
|
114 |
+
task_names = [t["task_name"] for t in task_list]
|
115 |
+
next_task_id = int(this_task_id) + 1
|
116 |
+
response = task_prioritization_chain.run(task_names=task_names, next_task_id=next_task_id, objective=objective)
|
117 |
+
new_tasks = response.split('\n')
|
118 |
+
prioritized_task_list = []
|
119 |
+
for task_string in new_tasks:
|
120 |
+
if not task_string.strip():
|
121 |
+
continue
|
122 |
+
task_parts = task_string.strip().split(".", 1)
|
123 |
+
if len(task_parts) == 2:
|
124 |
+
task_id = task_parts[0].strip()
|
125 |
+
task_name = task_parts[1].strip()
|
126 |
+
prioritized_task_list.append({"task_id": task_id, "task_name": task_name})
|
127 |
+
return prioritized_task_list
|
128 |
+
|
129 |
+
def _get_top_tasks(vectorstore, query: str, k: int) -> List[str]:
|
130 |
+
"""Get the top k tasks based on the query."""
|
131 |
+
results = vectorstore.similarity_search_with_score(query, k=k)
|
132 |
+
if not results:
|
133 |
+
return []
|
134 |
+
sorted_results, _ = zip(*sorted(results, key=lambda x: x[1], reverse=True))
|
135 |
+
return [str(item.metadata['task']) for item in sorted_results]
|
136 |
+
|
137 |
+
def execute_task(vectorstore, execution_chain: LLMChain, objective: str, task: str, k: int = 5) -> str:
|
138 |
+
"""Execute a task."""
|
139 |
+
context = _get_top_tasks(vectorstore, query=objective, k=k)
|
140 |
+
return execution_chain.run(objective=objective, context=context, task=task)
|
141 |
+
|
142 |
+
class BabyAGI(Chain, BaseModel):
|
143 |
+
"""Controller model for the BabyAGI agent."""
|
144 |
+
|
145 |
+
task_list: deque = Field(default_factory=deque)
|
146 |
+
task_creation_chain: TaskCreationChain = Field(...)
|
147 |
+
task_prioritization_chain: TaskPrioritizationChain = Field(...)
|
148 |
+
initial_task_creation_chain: InitialTaskCreationChain = Field(...)
|
149 |
+
execution_chain: AgentExecutor = Field(...)
|
150 |
+
task_id_counter: int = Field(1)
|
151 |
+
vectorstore: VectorStore = Field(init=False)
|
152 |
+
max_iterations: Optional[int] = None
|
153 |
+
|
154 |
+
class Config:
|
155 |
+
"""Configuration for this pydantic object."""
|
156 |
+
arbitrary_types_allowed = True
|
157 |
+
|
158 |
+
def add_task(self, task: Dict):
|
159 |
+
self.task_list.append(task)
|
160 |
+
|
161 |
+
def print_task_list(self):
|
162 |
+
print("\033[95m\033[1m" + "\n*****TASK LIST*****\n" + "\033[0m\033[0m")
|
163 |
+
for t in self.task_list:
|
164 |
+
print(str(t["task_id"]) + ": " + t["task_name"])
|
165 |
+
|
166 |
+
def print_next_task(self, task: Dict):
|
167 |
+
print("\033[92m\033[1m" + "\n*****NEXT TASK*****\n" + "\033[0m\033[0m")
|
168 |
+
print(str(task["task_id"]) + ": " + task["task_name"])
|
169 |
+
|
170 |
+
def print_task_result(self, result: str):
|
171 |
+
print("\033[93m\033[1m" + "\n*****TASK RESULT*****\n" + "\033[0m\033[0m")
|
172 |
+
print(result)
|
173 |
+
|
174 |
+
@property
|
175 |
+
def input_keys(self) -> List[str]:
|
176 |
+
return ["objective"]
|
177 |
+
|
178 |
+
@property
|
179 |
+
def output_keys(self) -> List[str]:
|
180 |
+
return []
|
181 |
+
|
182 |
+
def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
183 |
+
"""Run the agent."""
|
184 |
+
# not an elegant implementation, but it works for the first task
|
185 |
+
objective = inputs['objective']
|
186 |
+
first_task = inputs.get("first_task", self.initial_task_creation_chain.run(objective=objective))# self.task_creation_chain.llm(initial_task_prompt))
|
187 |
+
|
188 |
+
self.add_task({"task_id": 1, "task_name": first_task})
|
189 |
+
num_iters = 0
|
190 |
+
while True:
|
191 |
+
if self.task_list:
|
192 |
+
self.print_task_list()
|
193 |
+
|
194 |
+
# Step 1: Pull the first task
|
195 |
+
task = self.task_list.popleft()
|
196 |
+
self.print_next_task(task)
|
197 |
+
|
198 |
+
# Step 2: Execute the task
|
199 |
+
result = execute_task(
|
200 |
+
self.vectorstore, self.execution_chain, objective, task["task_name"]
|
201 |
+
)
|
202 |
+
this_task_id = int(task["task_id"])
|
203 |
+
self.print_task_result(result)
|
204 |
+
|
205 |
+
# Step 3: Store the result in Pinecone
|
206 |
+
result_id = f"result_{task['task_id']}"
|
207 |
+
self.vectorstore.add_texts(
|
208 |
+
texts=[result],
|
209 |
+
metadatas=[{"task": task["task_name"]}],
|
210 |
+
ids=[result_id],
|
211 |
+
)
|
212 |
+
|
213 |
+
# Step 4: Create new tasks and reprioritize task list
|
214 |
+
new_tasks = get_next_task(
|
215 |
+
self.task_creation_chain, result, task["task_name"], [t["task_name"] for t in self.task_list], objective
|
216 |
+
)
|
217 |
+
for new_task in new_tasks:
|
218 |
+
self.task_id_counter += 1
|
219 |
+
new_task.update({"task_id": self.task_id_counter})
|
220 |
+
self.add_task(new_task)
|
221 |
+
|
222 |
+
if len(self.task_list) == 0:
|
223 |
+
print("\033[91m\033[1m" + "\n*****NO TASK, ABORTING*****\n" + "\033[0m\033[0m")
|
224 |
+
break
|
225 |
+
|
226 |
+
self.task_list = deque(
|
227 |
+
prioritize_tasks(
|
228 |
+
self.task_prioritization_chain, this_task_id, list(self.task_list), objective
|
229 |
+
)
|
230 |
+
)
|
231 |
+
num_iters += 1
|
232 |
+
if self.max_iterations is not None and num_iters == self.max_iterations:
|
233 |
+
print("\033[91m\033[1m" + "\n*****TASK ENDING*****\n" + "\033[0m\033[0m")
|
234 |
+
break
|
235 |
+
return {}
|
236 |
+
|
237 |
+
@classmethod
|
238 |
+
def from_llm(
|
239 |
+
cls,
|
240 |
+
llm: BaseLLM,
|
241 |
+
prompt = None,
|
242 |
+
verbose: bool = False,
|
243 |
+
tools = None,
|
244 |
+
stream_output = None,
|
245 |
+
**kwargs
|
246 |
+
) -> "BabyAGI":
|
247 |
+
embeddings_model = OpenAIEmbeddings()
|
248 |
+
embedding_size = 1536
|
249 |
+
index = faiss.IndexFlatL2(embedding_size)
|
250 |
+
vectorstore = FAISS(embeddings_model.embed_query, index, InMemoryDocstore({}), {})
|
251 |
+
|
252 |
+
task_creation_chain = TaskCreationChain.from_llm(
|
253 |
+
llm, verbose=verbose
|
254 |
+
)
|
255 |
+
initial_task_creation_chain = InitialTaskCreationChain.from_llm(
|
256 |
+
llm, verbose=verbose
|
257 |
+
)
|
258 |
+
task_prioritization_chain = TaskPrioritizationChain.from_llm(
|
259 |
+
llm, verbose=verbose
|
260 |
+
)
|
261 |
+
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
262 |
+
tool_names = [tool.name for tool in tools]
|
263 |
+
agent = ContextAwareAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
264 |
+
|
265 |
+
if stream_output:
|
266 |
+
agent_executor = Executor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
267 |
+
else:
|
268 |
+
agent_executor = AgentExecutorWithTranslation.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
269 |
+
|
270 |
+
return cls(
|
271 |
+
task_creation_chain=task_creation_chain,
|
272 |
+
task_prioritization_chain=task_prioritization_chain,
|
273 |
+
initial_task_creation_chain=initial_task_creation_chain,
|
274 |
+
execution_chain=agent_executor,
|
275 |
+
vectorstore=vectorstore,
|
276 |
+
**kwargs
|
277 |
+
)
|
278 |
+
|
279 |
+
if __name__ == "__main__":
|
280 |
+
todo_prompt = PromptTemplate.from_template("You are a planner who is an expert at coming up with a todo list for a given objective. For a simple objective, do not generate a complex todo list. Come up with a todo list for this objective: {objective}")
|
281 |
+
todo_chain = LLMChain(llm=OpenAI(temperature=0), prompt=todo_prompt)
|
282 |
+
search = SerpAPIWrapper()
|
283 |
+
tools = [
|
284 |
+
Tool(
|
285 |
+
name = "Search",
|
286 |
+
func=search.run,
|
287 |
+
description="useful for when you need to answer questions about current events"
|
288 |
+
),
|
289 |
+
Tool(
|
290 |
+
name = "TODO",
|
291 |
+
func=todo_chain.run,
|
292 |
+
description="useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!"
|
293 |
+
)
|
294 |
+
]
|
295 |
+
|
296 |
+
prefix = """You are an AI who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}."""
|
297 |
+
suffix = """Question: {task}
|
298 |
+
{agent_scratchpad}"""
|
299 |
+
prompt = ZeroShotAgent.create_prompt(
|
300 |
+
tools,
|
301 |
+
prefix=prefix,
|
302 |
+
suffix=suffix,
|
303 |
+
input_variables=["objective", "task", "context","agent_scratchpad"]
|
304 |
+
)
|
305 |
+
|
306 |
+
OBJECTIVE = "Write a weather report for SF today"
|
307 |
+
llm = OpenAI(temperature=0)
|
308 |
+
# Logging of LLMChains
|
309 |
+
verbose=False
|
310 |
+
# If None, will keep on going forever
|
311 |
+
max_iterations: Optional[int] = 10
|
312 |
+
baby_agi = BabyAGI.from_llm(
|
313 |
+
llm=llm,
|
314 |
+
verbose=verbose,
|
315 |
+
max_iterations=max_iterations
|
316 |
+
)
|
317 |
+
baby_agi({"objective": OBJECTIVE})
|
bmtools/agent/__init__.py
ADDED
File without changes
|
bmtools/agent/apitool.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Interface for tools."""
|
2 |
+
from inspect import signature
|
3 |
+
from typing import Any, Awaitable, Callable, Optional, Union
|
4 |
+
|
5 |
+
from langchain.agents import Tool as LangChainTool
|
6 |
+
from langchain.tools.base import BaseTool
|
7 |
+
import requests
|
8 |
+
import json
|
9 |
+
import http.client
|
10 |
+
http.client._MAXLINE = 655360
|
11 |
+
|
12 |
+
from bmtools import get_logger
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
|
16 |
+
class Tool(LangChainTool):
|
17 |
+
tool_logo_md: str = ""
|
18 |
+
|
19 |
+
class RequestTool(BaseTool):
|
20 |
+
"""Tool that takes in function or coroutine directly."""
|
21 |
+
|
22 |
+
description: str = ""
|
23 |
+
func: Callable[[str], str]
|
24 |
+
coroutine: Optional[Callable[[str], Awaitable[str]]] = None
|
25 |
+
max_output_len = 4000
|
26 |
+
tool_logo_md: str = ""
|
27 |
+
|
28 |
+
def _run(self, tool_input: str) -> str:
|
29 |
+
"""Use the tool."""
|
30 |
+
return self.func(tool_input)
|
31 |
+
|
32 |
+
async def _arun(self, tool_input: str) -> str:
|
33 |
+
"""Use the tool asynchronously."""
|
34 |
+
if self.coroutine:
|
35 |
+
return await self.coroutine(tool_input)
|
36 |
+
raise NotImplementedError("Tool does not support async")
|
37 |
+
|
38 |
+
|
39 |
+
def convert_prompt(self,params):
|
40 |
+
lines = "Your input should be a json: {{"
|
41 |
+
for p in params:
|
42 |
+
logger.debug(p)
|
43 |
+
optional = not p['required']
|
44 |
+
description = p.get('description', '')
|
45 |
+
if len(description) > 0:
|
46 |
+
description = "("+description+")"
|
47 |
+
|
48 |
+
lines += '"{name}" : {type}{desc},'.format(
|
49 |
+
name=p['name'],
|
50 |
+
type= p['schema']['type'],
|
51 |
+
optional=optional,
|
52 |
+
desc=description)
|
53 |
+
|
54 |
+
lines += "}}"
|
55 |
+
return lines
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
def __init__(self, root_url, func_url, method, request_info, **kwargs):
|
60 |
+
""" Store the function, description, and tool_name in a class to store the information
|
61 |
+
"""
|
62 |
+
url = root_url + func_url
|
63 |
+
|
64 |
+
def func(json_args):
|
65 |
+
if isinstance(json_args, str):
|
66 |
+
# json_args = json_args.replace("\'", "\"")
|
67 |
+
# print(json_args)
|
68 |
+
try:
|
69 |
+
json_args = json.loads(json_args)
|
70 |
+
except:
|
71 |
+
return "Your input can not be parsed as json, please use thought."
|
72 |
+
response = requests.get(url, json_args)
|
73 |
+
if response.status_code == 200:
|
74 |
+
message = response.text
|
75 |
+
else:
|
76 |
+
message = f"Error code {response.status_code}. You can try (1) Change your input (2) Call another function. (If the same error code is produced more than 4 times, please use Thought: I can not use these APIs, so I will stop. Final Answer: No Answer, please check the APIs.)"
|
77 |
+
|
78 |
+
message = message[:self.max_output_len] # TODO: not rigorous, to improve
|
79 |
+
return message
|
80 |
+
|
81 |
+
tool_name = func_url.replace("/", ".").strip(".")
|
82 |
+
|
83 |
+
if 'parameters' in request_info[method]:
|
84 |
+
str_doc = self.convert_prompt(request_info[method]['parameters'])
|
85 |
+
else:
|
86 |
+
str_doc = ''
|
87 |
+
|
88 |
+
|
89 |
+
description = f"- {tool_name}:\n" + \
|
90 |
+
request_info[method].get('summary', '').replace("{", "{{").replace("}", "}}") \
|
91 |
+
+ "," \
|
92 |
+
+ request_info[method].get('description','').replace("{", "{{").replace("}", "}}") \
|
93 |
+
+ str_doc \
|
94 |
+
+ f"The Action to trigger this API should be {tool_name}\n and the input parameters should be a json dict string. Pay attention to the type of parameters.\n"
|
95 |
+
|
96 |
+
logger.info("API Name: {}".format(tool_name))
|
97 |
+
logger.info("API Description: {}".format(description))
|
98 |
+
|
99 |
+
super(RequestTool, self).__init__(
|
100 |
+
name=tool_name, func=func, description=description, **kwargs
|
101 |
+
)
|
102 |
+
|
bmtools/agent/executor.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import types
|
2 |
+
from typing import Any, Dict, List, Tuple, Union
|
3 |
+
from langchain.agents import AgentExecutor
|
4 |
+
from langchain.input import get_color_mapping
|
5 |
+
from langchain.schema import AgentAction, AgentFinish
|
6 |
+
from bmtools.agent.translator import Translator
|
7 |
+
|
8 |
+
class AgentExecutorWithTranslation(AgentExecutor):
|
9 |
+
|
10 |
+
translator: Translator = Translator()
|
11 |
+
|
12 |
+
def prep_outputs(
|
13 |
+
self,
|
14 |
+
inputs: Dict[str, str],
|
15 |
+
outputs: Dict[str, str],
|
16 |
+
return_only_outputs: bool = False,
|
17 |
+
) -> Dict[str, str]:
|
18 |
+
try:
|
19 |
+
outputs = super().prep_outputs(inputs, outputs, return_only_outputs)
|
20 |
+
except ValueError as e:
|
21 |
+
return outputs
|
22 |
+
else:
|
23 |
+
if "input" in outputs:
|
24 |
+
outputs = self.translator(outputs)
|
25 |
+
return outputs
|
26 |
+
|
27 |
+
class Executor(AgentExecutorWithTranslation):
|
28 |
+
def _call(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
29 |
+
"""Run text through and get agent response."""
|
30 |
+
# Do any preparation necessary when receiving a new input.
|
31 |
+
self.agent.prepare_for_new_call()
|
32 |
+
# Construct a mapping of tool name to tool for easy lookup
|
33 |
+
name_to_tool_map = {tool.name: tool for tool in self.tools}
|
34 |
+
# We construct a mapping from each tool to a color, used for logging.
|
35 |
+
color_mapping = get_color_mapping(
|
36 |
+
[tool.name for tool in self.tools], excluded_colors=["green"]
|
37 |
+
)
|
38 |
+
intermediate_steps: List[Tuple[AgentAction, str]] = []
|
39 |
+
# Let's start tracking the iterations the agent has gone through
|
40 |
+
iterations = 0
|
41 |
+
# We now enter the agent loop (until it returns something).
|
42 |
+
while self._should_continue(iterations):
|
43 |
+
next_step_output = self._take_next_step(
|
44 |
+
name_to_tool_map, color_mapping, inputs, intermediate_steps
|
45 |
+
)
|
46 |
+
if isinstance(next_step_output, AgentFinish):
|
47 |
+
yield self._return(next_step_output, intermediate_steps)
|
48 |
+
return
|
49 |
+
|
50 |
+
agent_action = next_step_output[0]
|
51 |
+
tool_logo = None
|
52 |
+
for tool in self.tools:
|
53 |
+
if tool.name == agent_action.tool:
|
54 |
+
tool_logo = tool.tool_logo_md
|
55 |
+
if isinstance(next_step_output[1], types.GeneratorType):
|
56 |
+
logo = f"{tool_logo}" if tool_logo is not None else ""
|
57 |
+
yield (AgentAction("", agent_action.tool_input, agent_action.log), f"Further use other tool {logo} to answer the question.")
|
58 |
+
for output in next_step_output[1]:
|
59 |
+
yield output
|
60 |
+
next_step_output = (agent_action, output)
|
61 |
+
else:
|
62 |
+
for tool in self.tools:
|
63 |
+
if tool.name == agent_action.tool:
|
64 |
+
yield (AgentAction(tool_logo, agent_action.tool_input, agent_action.log), next_step_output[1])
|
65 |
+
|
66 |
+
intermediate_steps.append(next_step_output)
|
67 |
+
# See if tool should return directly
|
68 |
+
tool_return = self._get_tool_return(next_step_output)
|
69 |
+
if tool_return is not None:
|
70 |
+
yield self._return(tool_return, intermediate_steps)
|
71 |
+
return
|
72 |
+
iterations += 1
|
73 |
+
output = self.agent.return_stopped_response(
|
74 |
+
self.early_stopping_method, intermediate_steps, **inputs
|
75 |
+
)
|
76 |
+
yield self._return(output, intermediate_steps)
|
77 |
+
return
|
78 |
+
|
79 |
+
def __call__(
|
80 |
+
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
81 |
+
) -> Dict[str, Any]:
|
82 |
+
"""Run the logic of this chain and add to output if desired.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
inputs: Dictionary of inputs, or single input if chain expects
|
86 |
+
only one param.
|
87 |
+
return_only_outputs: boolean for whether to return only outputs in the
|
88 |
+
response. If True, only new keys generated by this chain will be
|
89 |
+
returned. If False, both input keys and new keys generated by this
|
90 |
+
chain will be returned. Defaults to False.
|
91 |
+
|
92 |
+
"""
|
93 |
+
inputs = self.prep_inputs(inputs)
|
94 |
+
self.callback_manager.on_chain_start(
|
95 |
+
{"name": self.__class__.__name__},
|
96 |
+
inputs,
|
97 |
+
verbose=self.verbose,
|
98 |
+
)
|
99 |
+
try:
|
100 |
+
for output in self._call(inputs):
|
101 |
+
if type(output) is dict:
|
102 |
+
output = self.prep_outputs(inputs, output, return_only_outputs)
|
103 |
+
yield output
|
104 |
+
except (KeyboardInterrupt, Exception) as e:
|
105 |
+
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
106 |
+
raise e
|
107 |
+
self.callback_manager.on_chain_end(output, verbose=self.verbose)
|
108 |
+
# return self.prep_outputs(inputs, output, return_only_outputs)
|
109 |
+
return output
|
bmtools/agent/singletool.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms import OpenAI
|
2 |
+
from langchain import OpenAI, LLMChain, PromptTemplate, SerpAPIWrapper
|
3 |
+
from langchain.agents import ZeroShotAgent, AgentExecutor, initialize_agent, Tool
|
4 |
+
import importlib
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import requests
|
8 |
+
import yaml
|
9 |
+
from bmtools.agent.apitool import RequestTool
|
10 |
+
from bmtools.agent.executor import Executor, AgentExecutorWithTranslation
|
11 |
+
from bmtools import get_logger
|
12 |
+
from bmtools.agent.BabyagiTools import BabyAGI
|
13 |
+
# from bmtools.models.customllm import CustomLLM
|
14 |
+
|
15 |
+
|
16 |
+
logger = get_logger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
def import_all_apis(tool_json):
|
20 |
+
'''import all apis that is a tool
|
21 |
+
'''
|
22 |
+
doc_url = tool_json['api']['url']
|
23 |
+
response = requests.get(doc_url)
|
24 |
+
|
25 |
+
logger.info("Doc string URL: {}".format(doc_url))
|
26 |
+
if doc_url.endswith('yaml') or doc_url.endswith('yml'):
|
27 |
+
plugin = yaml.safe_load(response.text)
|
28 |
+
else:
|
29 |
+
plugin = json.loads(response.text)
|
30 |
+
|
31 |
+
server_url = plugin['servers'][0]['url']
|
32 |
+
if server_url.startswith("/"):
|
33 |
+
server_url = "http://127.0.0.1:8079" + server_url
|
34 |
+
logger.info("server_url {}".format(server_url))
|
35 |
+
all_apis = []
|
36 |
+
for key in plugin['paths']:
|
37 |
+
value = plugin['paths'][key]
|
38 |
+
api = RequestTool(root_url=server_url, func_url=key, method='get', request_info=value)
|
39 |
+
all_apis.append(api)
|
40 |
+
return all_apis
|
41 |
+
|
42 |
+
def load_single_tools(tool_name, tool_url):
|
43 |
+
|
44 |
+
# tool_name, tool_url = "datasette", "https://datasette.io/"
|
45 |
+
# tool_name, tool_url = "klarna", "https://www.klarna.com/"
|
46 |
+
# tool_name, tool_url = 'chemical-prop', "http://127.0.0.1:8079/tools/chemical-prop/"
|
47 |
+
# tool_name, tool_url = 'douban-film', "http://127.0.0.1:8079/tools/douban-film/"
|
48 |
+
# tool_name, tool_url = 'weather', "http://127.0.0.1:8079/tools/weather/"
|
49 |
+
# tool_name, tool_url = 'wikipedia', "http://127.0.0.1:8079/tools/wikipedia/"
|
50 |
+
# tool_name, tool_url = 'wolframalpha', "http://127.0.0.1:8079/tools/wolframalpha/"
|
51 |
+
# tool_name, tool_url = 'klarna', "https://www.klarna.com/"
|
52 |
+
|
53 |
+
|
54 |
+
get_url = tool_url +".well-known/ai-plugin.json"
|
55 |
+
response = requests.get(get_url)
|
56 |
+
|
57 |
+
if response.status_code == 200:
|
58 |
+
tool_config_json = response.json()
|
59 |
+
else:
|
60 |
+
raise RuntimeError("Your URL of the tool is invalid.")
|
61 |
+
|
62 |
+
return tool_name, tool_config_json
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
class STQuestionAnswerer:
|
67 |
+
def __init__(self, openai_api_key = "", stream_output=False, llm='ChatGPT'):
|
68 |
+
if len(openai_api_key) < 3: # not valid key (TODO: more rigorous checking)
|
69 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
70 |
+
|
71 |
+
self.openai_api_key = openai_api_key
|
72 |
+
self.llm_model = llm
|
73 |
+
|
74 |
+
self.set_openai_api_key(openai_api_key)
|
75 |
+
self.stream_output = stream_output
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
def set_openai_api_key(self, key):
|
80 |
+
# self.llm = CustomLLM()
|
81 |
+
logger.info("Using {}".format(self.llm_model))
|
82 |
+
if self.llm_model == "GPT-3.5":
|
83 |
+
self.llm = OpenAI(temperature=0.0, openai_api_key=key) # use text-darvinci
|
84 |
+
elif self.llm_model == "ChatGPT":
|
85 |
+
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key) # use chatgpt
|
86 |
+
else:
|
87 |
+
raise RuntimeError("Your model is not available.")
|
88 |
+
|
89 |
+
def load_tools(self, name, meta_info, prompt_type="react-with-tool-description", return_intermediate_steps=True):
|
90 |
+
|
91 |
+
self.all_tools_map = {}
|
92 |
+
|
93 |
+
self.all_tools_map[name] = import_all_apis(meta_info)
|
94 |
+
|
95 |
+
logger.info("Tool [{}] has the following apis: {}".format(name, self.all_tools_map[name]))
|
96 |
+
|
97 |
+
if prompt_type == "zero-shot-react-description":
|
98 |
+
subagent = initialize_agent(self.all_tools_map[name], self.llm, agent="zero-shot-react-description", verbose=True, return_intermediate_steps=return_intermediate_steps)
|
99 |
+
elif prompt_type == "react-with-tool-description":
|
100 |
+
description_for_model = meta_info['description_for_model'].replace("{", "{{").replace("}", "}}").strip()
|
101 |
+
|
102 |
+
prefix = f"""Answer the following questions as best you can. General instructions are: {description_for_model}. Specifically, you have access to the following APIs:"""
|
103 |
+
suffix = """Begin! Remember: (1) Follow the format, i.e,\nThought:\nAction:\nAction Input:\nObservation:\nFinal Answer:\n (2) Provide as much as useful information in your Final Answer. (3) YOU MUST INCLUDE all relevant IMAGES in your Final Answer using format ![img](url), and include relevant links. (3) Do not make up anything, and if your Observation has no link, DO NOT hallucihate one. (4) If you have enough information, please use \nThought: I have got enough information\nFinal Answer: \n\nQuestion: {input}\n{agent_scratchpad}"""
|
104 |
+
prompt = ZeroShotAgent.create_prompt(
|
105 |
+
self.all_tools_map[name],
|
106 |
+
prefix=prefix,
|
107 |
+
suffix=suffix,
|
108 |
+
input_variables=["input", "agent_scratchpad"]
|
109 |
+
)
|
110 |
+
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
111 |
+
logger.info("Full prompt template: {}".format(prompt.template))
|
112 |
+
tool_names = [tool.name for tool in self.all_tools_map[name] ]
|
113 |
+
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
114 |
+
if self.stream_output:
|
115 |
+
agent_executor = Executor.from_agent_and_tools(agent=agent, tools=self.all_tools_map[name] , verbose=True, return_intermediate_steps=return_intermediate_steps)
|
116 |
+
else:
|
117 |
+
agent_executor = AgentExecutorWithTranslation.from_agent_and_tools(agent=agent, tools=self.all_tools_map[name], verbose=True, return_intermediate_steps=return_intermediate_steps)
|
118 |
+
return agent_executor
|
119 |
+
elif prompt_type == "babyagi":
|
120 |
+
# customllm = CustomLLM()
|
121 |
+
tool_str = "; ".join([t.name for t in self.all_tools_map[name]] + ["TODO"])
|
122 |
+
prefix = """You are an AI who performs one task based on the following objective: {objective}. Take into account these previously completed tasks: {context}.\n You have access to the following APIs:"""
|
123 |
+
suffix = """YOUR CONSTRAINTS: (1) YOU MUST follow this format:
|
124 |
+
\nThought:\nAction:\nAction Input: \n or \nThought:\nFinal Answer:\n (2) Do not make up anything, and if your Observation has no link, DO NOT hallucihate one. (3) The Action: MUST be one of the following: """ + tool_str + """\nQuestion: {task}\n Agent scratchpad (history actions): {agent_scratchpad}."""
|
125 |
+
|
126 |
+
# tool_str = "; ".join([t.name for t in self.all_tools_map[name]])
|
127 |
+
# todo_prompt = PromptTemplate.from_template("You are a planner who is an expert at coming up with a todo list for a given objective. For a simple objective, do not generate a complex todo list. Generate a todo list that can largely be completed by the following APIs: " + tool_str + ". Come up with a todo list for this objective: {objective}")
|
128 |
+
|
129 |
+
# # todo_chain = LLMChain(llm=self.llm, prompt=todo_prompt)
|
130 |
+
# todo_chain = LLMChain(llm=customllm, prompt=todo_prompt)
|
131 |
+
|
132 |
+
# todo_tool = Tool(
|
133 |
+
# name = "TODO",
|
134 |
+
# func=todo_chain.run,
|
135 |
+
# description="useful for when you need to come up with todo lists. Input: an objective to create a todo list for. Output: a todo list for that objective. Please be very clear what the objective is!"
|
136 |
+
# )
|
137 |
+
# self.all_tools_map[name].append(todo_tool)
|
138 |
+
|
139 |
+
prompt = ZeroShotAgent.create_prompt(
|
140 |
+
self.all_tools_map[name],
|
141 |
+
prefix=prefix,
|
142 |
+
suffix=suffix,
|
143 |
+
input_variables=["objective", "task", "context","agent_scratchpad"]
|
144 |
+
)
|
145 |
+
|
146 |
+
logger.info("Full prompt template: {}".format(prompt.template))
|
147 |
+
|
148 |
+
# specify the maximum number of iterations you want babyAGI to perform
|
149 |
+
max_iterations = 10
|
150 |
+
baby_agi = BabyAGI.from_llm(
|
151 |
+
llm=self.llm,
|
152 |
+
# llm=customllm,
|
153 |
+
prompt=prompt,
|
154 |
+
verbose=False,
|
155 |
+
tools=self.all_tools_map[name],
|
156 |
+
stream_output=self.stream_output,
|
157 |
+
return_intermediate_steps=return_intermediate_steps,
|
158 |
+
max_iterations=max_iterations,
|
159 |
+
)
|
160 |
+
|
161 |
+
return baby_agi
|
162 |
+
|
163 |
+
|
164 |
+
if __name__ == "__main__":
|
165 |
+
|
166 |
+
tools_name, tools_config = load_single_tools()
|
167 |
+
print(tools_name, tools_config)
|
168 |
+
|
169 |
+
qa = STQuestionAnswerer()
|
170 |
+
|
171 |
+
agent = qa.load_tools(tools_name, tools_config)
|
172 |
+
|
173 |
+
agent("Calc integral of sin(x)+2x^2+3x+1 from 0 to 1")
|
bmtools/agent/tools_controller.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms import OpenAI
|
2 |
+
from langchain import OpenAI, LLMChain
|
3 |
+
from langchain.agents import ZeroShotAgent, AgentExecutor
|
4 |
+
import importlib
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
import requests
|
8 |
+
import yaml
|
9 |
+
from bmtools.agent.apitool import Tool
|
10 |
+
from bmtools.agent.singletool import STQuestionAnswerer
|
11 |
+
from bmtools.agent.executor import Executor, AgentExecutorWithTranslation
|
12 |
+
from bmtools import get_logger
|
13 |
+
|
14 |
+
logger = get_logger(__name__)
|
15 |
+
|
16 |
+
def load_valid_tools(tools_mappings):
|
17 |
+
tools_to_config = {}
|
18 |
+
for key in tools_mappings:
|
19 |
+
get_url = tools_mappings[key]+".well-known/ai-plugin.json"
|
20 |
+
|
21 |
+
response = requests.get(get_url)
|
22 |
+
|
23 |
+
if response.status_code == 200:
|
24 |
+
tools_to_config[key] = response.json()
|
25 |
+
else:
|
26 |
+
logger.warning("Load tool {} error, status code {}".format(key, response.status_code))
|
27 |
+
|
28 |
+
return tools_to_config
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
class MTQuestionAnswerer:
|
33 |
+
"""Use multiple tools to answer a question. Basically pass a natural question to
|
34 |
+
"""
|
35 |
+
def __init__(self, openai_api_key, all_tools, stream_output=False, llm='ChatGPT'):
|
36 |
+
if len(openai_api_key) < 3: # not valid key (TODO: more rigorous checking)
|
37 |
+
openai_api_key = os.environ.get('OPENAI_API_KEY')
|
38 |
+
self.openai_api_key = openai_api_key
|
39 |
+
self.stream_output = stream_output
|
40 |
+
self.llm_model = llm
|
41 |
+
self.set_openai_api_key(openai_api_key)
|
42 |
+
self.load_tools(all_tools)
|
43 |
+
|
44 |
+
def set_openai_api_key(self, key):
|
45 |
+
logger.info("Using {}".format(self.llm_model))
|
46 |
+
if self.llm_model == "GPT-3.5":
|
47 |
+
self.llm = OpenAI(temperature=0.0, openai_api_key=key) # use text-darvinci
|
48 |
+
elif self.llm_model == "ChatGPT":
|
49 |
+
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0.0, openai_api_key=key) # use chatgpt
|
50 |
+
else:
|
51 |
+
raise RuntimeError("Your model is not available.")
|
52 |
+
|
53 |
+
def load_tools(self, all_tools):
|
54 |
+
logger.info("All tools: {}".format(all_tools))
|
55 |
+
self.all_tools_map = {}
|
56 |
+
self.tools_pool = []
|
57 |
+
for name in all_tools:
|
58 |
+
meta_info = all_tools[name]
|
59 |
+
|
60 |
+
question_answer = STQuestionAnswerer(self.openai_api_key, stream_output=self.stream_output, llm=self.llm_model)
|
61 |
+
subagent = question_answer.load_tools(name, meta_info, prompt_type="react-with-tool-description", return_intermediate_steps=False)
|
62 |
+
tool_logo_md = f'<img src="{meta_info["logo_url"]}" width="32" height="32" style="display:inline-block">'
|
63 |
+
for tool in subagent.tools:
|
64 |
+
tool.tool_logo_md = tool_logo_md
|
65 |
+
tool = Tool(
|
66 |
+
name=meta_info['name_for_model'],
|
67 |
+
description=meta_info['description_for_model'].replace("{", "{{").replace("}", "}}"),
|
68 |
+
func=subagent,
|
69 |
+
)
|
70 |
+
tool.tool_logo_md = tool_logo_md
|
71 |
+
self.tools_pool.append(tool)
|
72 |
+
|
73 |
+
def build_runner(self, ):
|
74 |
+
|
75 |
+
# 可以修改prompt来让模型表现更好,也可以修改tool的doc
|
76 |
+
prefix = """Answer the following questions as best you can. In this level, you are calling the tools in natural language format, since the tools are actually an intelligent agent like you, but they expert only in one area. Several things to remember. (1) Remember to follow the format of passing natural language as the Action Input. (2) DO NOT use your imagination, only use concrete information given by the tools. (3) If the observation contains images or urls which has useful information, YOU MUST INCLUDE ALL USEFUL IMAGES and links in your Answer and Final Answers using format ![img](url). BUT DO NOT provide any imaginary links. (4) The information in your Final Answer should include ALL the informations returned by the tools. (5) If a user's query is a language other than English, please translate it to English without tools, and translate it back to the source language in Final Answer. You have access to the following tools (Only use these tools we provide you):"""
|
77 |
+
suffix = """\nBegin! Remember to . \nQuestion: {input}\n{agent_scratchpad}"""
|
78 |
+
|
79 |
+
|
80 |
+
prompt = ZeroShotAgent.create_prompt(
|
81 |
+
self.tools_pool,
|
82 |
+
prefix=prefix,
|
83 |
+
suffix=suffix,
|
84 |
+
input_variables=["input", "agent_scratchpad"]
|
85 |
+
)
|
86 |
+
llm_chain = LLMChain(llm=self.llm, prompt=prompt)
|
87 |
+
logger.info("Full Prompt Template:\n {}".format(prompt.template))
|
88 |
+
tool_names = [tool.name for tool in self.tools_pool]
|
89 |
+
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
|
90 |
+
if self.stream_output:
|
91 |
+
agent_executor = Executor.from_agent_and_tools(agent=agent, tools=self.tools_pool, verbose=True, return_intermediate_steps=True)
|
92 |
+
else:
|
93 |
+
agent_executor = AgentExecutorWithTranslation.from_agent_and_tools(agent=agent, tools=self.tools_pool, verbose=True, return_intermediate_steps=True)
|
94 |
+
return agent_executor
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
tools_mappings = {
|
100 |
+
"klarna": "https://www.klarna.com/",
|
101 |
+
"chemical-prop": "http://127.0.0.1:8079/tools/chemical-prop/",
|
102 |
+
"wolframalpha": "http://127.0.0.1:8079/tools/wolframalpha/",
|
103 |
+
"weather": "http://127.0.0.1:8079/tools/weather/",
|
104 |
+
}
|
105 |
+
|
106 |
+
tools = load_valid_tools(tools_mappings)
|
107 |
+
|
108 |
+
qa = MTQuestionAnswerer(openai_api_key='', all_tools=tools)
|
109 |
+
|
110 |
+
agent = qa.build_runner()
|
111 |
+
|
112 |
+
agent("How many carbon elements are there in CH3COOH? How many people are there in China?")
|
113 |
+
|
bmtools/agent/translator.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.llms import OpenAI
|
2 |
+
from langchain.prompts import PromptTemplate
|
3 |
+
from langchain.chains import LLMChain
|
4 |
+
|
5 |
+
from bmtools.models import CustomLLM
|
6 |
+
|
7 |
+
import py3langid as langid
|
8 |
+
from iso639 import languages
|
9 |
+
|
10 |
+
from typing import Dict
|
11 |
+
from copy import deepcopy
|
12 |
+
import os
|
13 |
+
|
14 |
+
def detect_lang(text: str):
|
15 |
+
lang_code = langid.classify(text)[0]
|
16 |
+
lang_name = languages.get(part1=lang_code[:2]).name
|
17 |
+
return lang_name
|
18 |
+
|
19 |
+
class Translator:
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
openai_api_key: str = None,
|
23 |
+
model_name: str = "gpt-3.5-turbo"):
|
24 |
+
llm = self.create_openai_model(openai_api_key, model_name)
|
25 |
+
prompt = self.create_prompt()
|
26 |
+
self.chain = LLMChain(llm=llm, prompt=prompt)
|
27 |
+
|
28 |
+
def __call__(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
29 |
+
question = inputs["input"]
|
30 |
+
answer = inputs["output"]
|
31 |
+
|
32 |
+
src_lang = detect_lang(answer)
|
33 |
+
tgt_lang = detect_lang(question)
|
34 |
+
|
35 |
+
if src_lang != tgt_lang:
|
36 |
+
translated_answer = self.chain.run(text=answer, language=tgt_lang)
|
37 |
+
outputs = deepcopy(inputs)
|
38 |
+
outputs["output"] = translated_answer
|
39 |
+
return outputs
|
40 |
+
else:
|
41 |
+
return inputs
|
42 |
+
|
43 |
+
def create_openai_model(self, openai_api_key: str, model_name: str) -> OpenAI:
|
44 |
+
# if openai_api_key is None:
|
45 |
+
# openai_api_key = os.environ.get('OPENAI_API_KEY')
|
46 |
+
# llm = OpenAI(model_name=model_name,
|
47 |
+
# temperature=0.0,
|
48 |
+
# openai_api_key=openai_api_key)
|
49 |
+
llm = CustomLLM()
|
50 |
+
return llm
|
51 |
+
|
52 |
+
def create_prompt(self) -> PromptTemplate:
|
53 |
+
template = """
|
54 |
+
Translate to {language}: {text} =>
|
55 |
+
"""
|
56 |
+
prompt = PromptTemplate(
|
57 |
+
input_variables=["text", "language"],
|
58 |
+
template=template
|
59 |
+
)
|
60 |
+
return prompt
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
lang = {
|
64 |
+
"zh": {
|
65 |
+
"question": "帮我介绍下《深海》这部电影",
|
66 |
+
"answer": "《深海》是一部中国大陆的动画、奇幻电影,由田晓鹏导演,苏鑫、王亭文、滕奎兴等人主演。剧情简介是在大海的最深处,藏着所有秘密。一位现代少女(参宿)误入梦幻的 深海世界,却因此邂逅了一段独特的生命旅程。![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
67 |
+
},
|
68 |
+
"ja": {
|
69 |
+
"question": "映画「深海」について教えてください",
|
70 |
+
"answer": "「深海」は、中国本土のアニメーションおよびファンタジー映画で、Tian Xiaopeng が監督し、Su Xin、Wang Tingwen、Teng Kuixing などが出演しています。 あらすじは、海の最深部にはすべての秘密が隠されているというもの。 夢のような深海の世界に迷い込んだ現代少女(さんすけ)は、それをきっかけに独特の人生の旅に出くわす。 ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
71 |
+
},
|
72 |
+
"ko": {
|
73 |
+
"question": "영화 딥씨에 대해 알려주세요",
|
74 |
+
"answer": "\"Deep Sea\"는 Tian Xiaopeng 감독, Su Xin, Wang Tingwen, Teng Kuixing 등이 출연한 중국 본토의 애니메이션 및 판타지 영화입니다. 시놉시스는 바다 가장 깊은 곳에 모든 비밀이 숨겨져 있다는 것입니다. 현대 소녀(산스케)는 꿈 같은 심해 세계로 방황하지만 그것 때문에 독특한 삶의 여정을 만난다. ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
75 |
+
},
|
76 |
+
"en": {
|
77 |
+
"question": "Tell me about the movie '深海'",
|
78 |
+
"answer": "\"Deep Sea\" is an animation and fantasy film in mainland China, directed by Tian Xiaopeng, starring Su Xin, Wang Tingwen, Teng Kuixing and others. The synopsis is that in the deepest part of the sea, all secrets are hidden. A modern girl (Sansuke) strays into the dreamy deep sea world, but encounters a unique journey of life because of it. ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
79 |
+
},
|
80 |
+
"de": {
|
81 |
+
"question": "Erzähl mir von dem Film '深海'",
|
82 |
+
"answer": "\"Deep Sea\" ist ein Animations- und Fantasyfilm in Festlandchina unter der Regie von Tian Xiaopeng mit Su Xin, Wang Tingwen, Teng Kuixing und anderen in den Hauptrollen. Die Zusammenfassung ist, dass im tiefsten Teil des Meeres alle Geheimnisse verborgen sind. Ein modernes Mädchen (Sansuke) verirrt sich in die verträumte Tiefseewelt, trifft dabei aber auf eine einzigartige Lebensreise. ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
83 |
+
},
|
84 |
+
"fr": {
|
85 |
+
"question": "Parlez-moi du film 'Deep Sea'",
|
86 |
+
"answer": "\"Deep Sea\" est un film d'animation et fantastique en Chine continentale, réalisé par Tian Xiaopeng, avec Su Xin, Wang Tingwen, Teng Kuixing et d'autres. Le synopsis est que dans la partie la plus profonde de la mer, tous les secrets sont cachés. Une fille moderne (Sansuke) s'égare dans le monde onirique des profondeurs marines, mais rencontre un voyage de vie unique à cause de cela. ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
87 |
+
},
|
88 |
+
"ru": {
|
89 |
+
"question": "Расскажите о фильме 'Глубокое море'",
|
90 |
+
"answer": "«Глубокое море» — это анимационный и фэнтезийный фильм в материковом Китае, снятый Тянь Сяопином, в главных ролях Су Синь, Ван Тинвэнь, Тэн Куйсин и другие. Суть в том, что в самой глубокой части моря скрыты все секреты. Современная девушка (Сансукэ) заблудилась в мечтательном глубоководном мире, но из-за этого столкнулась с уникальным жизненным путешествием. ![img](https://img3.doubanio.com/view/photo/s_ratio_poster/public/p2635450820.webp)",
|
91 |
+
},
|
92 |
+
}
|
93 |
+
|
94 |
+
translator = Translator()
|
95 |
+
for source in lang:
|
96 |
+
for target in lang:
|
97 |
+
print(source, "=>", target, end=":\t")
|
98 |
+
question = lang[target]["question"]
|
99 |
+
answer = lang[source]["answer"]
|
100 |
+
inputs = {
|
101 |
+
"input": question,
|
102 |
+
"output": answer
|
103 |
+
}
|
104 |
+
|
105 |
+
result = translator(inputs)
|
106 |
+
translated_answer = result["output"]
|
107 |
+
|
108 |
+
if detect_lang(question) == detect_lang(translated_answer) == languages.get(part1=target).name:
|
109 |
+
print("Y")
|
110 |
+
else:
|
111 |
+
print("N")
|
112 |
+
print("====================")
|
113 |
+
print("Question:\t", detect_lang(question), " - ", question)
|
114 |
+
print("Answer:\t", detect_lang(answer), " - ", answer)
|
115 |
+
print("Translated Anser:\t", detect_lang(translated_answer), " - ", translated_answer)
|
116 |
+
print("====================")
|
bmtools/tools/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from . import chemical
|
2 |
+
from . import film
|
3 |
+
# from . import kg
|
4 |
+
# from . import stock
|
5 |
+
# from . import weather
|
6 |
+
# from . import wikipedia
|
7 |
+
# from . import wolframalpha
|
8 |
+
# from . import office
|
9 |
+
# from . import bing_search
|
10 |
+
# from . import translation
|
11 |
+
# from . import tutorial
|
12 |
+
|
13 |
+
from .tool import Tool
|
14 |
+
from .registry import register
|
15 |
+
from .serve import ToolServer
|
bmtools/tools/film/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..registry import register
|
2 |
+
|
3 |
+
@register("douban-film")
|
4 |
+
def douban_film():
|
5 |
+
from .douban import build_tool
|
6 |
+
return build_tool
|
bmtools/tools/film/douban/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .api import build_tool
|
bmtools/tools/film/douban/api.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from lxml import etree
|
3 |
+
import pandas as pd
|
4 |
+
from translate import Translator
|
5 |
+
import re
|
6 |
+
from ...tool import Tool
|
7 |
+
|
8 |
+
|
9 |
+
def build_tool(config) -> Tool:
|
10 |
+
tool = Tool(
|
11 |
+
"Film Search Plugin",
|
12 |
+
"search for up-to-date film information.",
|
13 |
+
name_for_model="Film Search",
|
14 |
+
description_for_model="Plugin for search for up-to-date film information.",
|
15 |
+
logo_url="https://your-app-url.com/.well-known/logo.png",
|
16 |
+
contact_email="[email protected]",
|
17 |
+
legal_info_url="[email protected]"
|
18 |
+
)
|
19 |
+
|
20 |
+
def fetch_page(url : str):
|
21 |
+
"""get_name(url: str) print html text of url
|
22 |
+
"""
|
23 |
+
headers = {
|
24 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) '
|
25 |
+
'Chrome/108.0.0.0 Safari/537.36'}
|
26 |
+
s = requests.session()
|
27 |
+
s.keep_alive = False
|
28 |
+
response = s.get(url, headers=headers, verify=False)
|
29 |
+
|
30 |
+
return response
|
31 |
+
|
32 |
+
def parse_coming_page():
|
33 |
+
"""parse_coming_page() prints the details of the all coming films, including date, title, cate, region, wantWatchPeopleNum, link
|
34 |
+
"""
|
35 |
+
# 获取即将上映的电影列表
|
36 |
+
url = 'https://movie.douban.com/coming'
|
37 |
+
response = fetch_page(url)
|
38 |
+
|
39 |
+
df_filmsComing = pd.DataFrame(columns=["date", "title", "cate", "region", "wantWatchPeopleNum", 'link'])
|
40 |
+
|
41 |
+
parser = etree.HTMLParser(encoding='utf-8')
|
42 |
+
tree = etree.HTML(response.text, parser=parser)
|
43 |
+
|
44 |
+
movies_table_path = '//*[@id="content"]/div/div[1]/table/tbody'
|
45 |
+
movies_table = tree.xpath(movies_table_path)
|
46 |
+
for filmChild in movies_table[0].iter('tr'):
|
47 |
+
filmTime = filmChild.xpath('td[1]/text()')[0].strip()
|
48 |
+
filmName = filmChild.xpath('td[2]/a/text()')[0]
|
49 |
+
filmType = filmChild.xpath('td[3]/text()')[0].strip()
|
50 |
+
filmRegion = filmChild.xpath('td[4]/text()')[0].strip()
|
51 |
+
filmWantWatching = filmChild.xpath('td[5]/text()')[0].strip()
|
52 |
+
filmLink = filmChild.xpath('td[2]/a/@href')[0]
|
53 |
+
df_filmsComing.loc[len(df_filmsComing.index)] = [
|
54 |
+
filmTime, filmName, filmType, filmRegion, filmWantWatching, filmLink
|
55 |
+
]
|
56 |
+
return df_filmsComing
|
57 |
+
|
58 |
+
def parse_nowplaying_page():
|
59 |
+
"""parse_nowplaying_page() prints the details of the all playing films now, including title, score, region, director, actors, link
|
60 |
+
"""
|
61 |
+
# 获取正在上映的电影列表
|
62 |
+
url = 'https://movie.douban.com/cinema/nowplaying/beijing/'
|
63 |
+
response = fetch_page(url)
|
64 |
+
df_filmsNowPlaying = pd.DataFrame(columns=["title", "score", "region", "director", "actors", 'link'])
|
65 |
+
|
66 |
+
parser = etree.HTMLParser(encoding='utf-8')
|
67 |
+
tree = etree.HTML(response.text, parser=parser)
|
68 |
+
|
69 |
+
movies_table_path = './/div[@id="nowplaying"]/div[2]/ul'
|
70 |
+
movies_table = tree.xpath(movies_table_path)
|
71 |
+
for filmChild in movies_table[0]:
|
72 |
+
filmName = filmChild.xpath('@data-title')[0]
|
73 |
+
filmScore = filmChild.xpath('@data-score')[0]
|
74 |
+
filmRegion = filmChild.xpath('@data-region')[0]
|
75 |
+
filmDirector = filmChild.xpath('@data-director')[0]
|
76 |
+
filmActors = filmChild.xpath('@data-actors')[0]
|
77 |
+
filmLink = filmChild.xpath('ul/li[1]/a/@href')[0]
|
78 |
+
df_filmsNowPlaying.loc[len(df_filmsNowPlaying.index)] = [
|
79 |
+
filmName, filmScore, filmRegion, filmDirector, filmActors, filmLink
|
80 |
+
]
|
81 |
+
return df_filmsNowPlaying
|
82 |
+
|
83 |
+
def parse_detail_page(response):
|
84 |
+
"""parse_detail_page(response) get information from response.text
|
85 |
+
"""
|
86 |
+
parser = etree.HTMLParser(encoding='utf-8')
|
87 |
+
tree = etree.HTML(response.text, parser=parser)
|
88 |
+
info_path = './/div[@class="subject clearfix"]/div[2]'
|
89 |
+
|
90 |
+
director = tree.xpath(f'{info_path}/span[1]/span[2]/a/text()')[0]
|
91 |
+
|
92 |
+
actors = []
|
93 |
+
actors_spans = tree.xpath(f'{info_path}/span[3]/span[2]')[0]
|
94 |
+
for actors_span in actors_spans:
|
95 |
+
actors.append(actors_span.text)
|
96 |
+
actors = '、'.join(actors[:3])
|
97 |
+
|
98 |
+
types = []
|
99 |
+
spans = tree.xpath(f'{info_path}')[0]
|
100 |
+
for span in spans.iter('span'):
|
101 |
+
if 'property' in span.attrib and span.attrib['property']=='v:genre':
|
102 |
+
types.append(span.text)
|
103 |
+
types = '、'.join(types)
|
104 |
+
|
105 |
+
for span in spans:
|
106 |
+
if span.text=='制片国家/地区:':
|
107 |
+
region = span.tail.strip()
|
108 |
+
break
|
109 |
+
Synopsis = tree.xpath('.//div[@class="related-info"]/div/span')[0].text.strip()
|
110 |
+
detail = f'是一部{region}的{types}电影,由{director}导演,{actors}等人主演.\n剧情简介:{Synopsis}'
|
111 |
+
return detail
|
112 |
+
|
113 |
+
@tool.get("/coming_out_filter")
|
114 |
+
def coming_out_filter(args : str):
|
115 |
+
"""coming_out_filter(args: str) prints the details of the filtered [outNum] coming films now according to region, cate and outNum.
|
116 |
+
args is a list like 'str1, str2, str3, str4'
|
117 |
+
str1 represents Production country or region. If you cannot find a region, str1 is 全部
|
118 |
+
str2 represents movie's category. If you cannot find a category, str2 is 全部
|
119 |
+
str3 can be a integer number that agent want to get. If you cannot find a number, str2 is 100. If the found movie's num is less than str2, Final Answer only print [the found movie's num] movies.
|
120 |
+
str4 can be a True or False that refluct whether agent want the result sorted by people number which look forward to the movie.
|
121 |
+
Final answer should be complete.
|
122 |
+
|
123 |
+
This is an example:
|
124 |
+
Thought: I need to find the upcoming Chinese drama movies and the top 2 most wanted movies
|
125 |
+
Action: coming_out_filter
|
126 |
+
Action Input: {"args" : "中国, 剧情, 2, True"}
|
127 |
+
Observation: {"date":{"23":"04月28日","50":"07月"},"title":{"23":"长空之王","50":"热烈"},"cate":{"23":"剧情 / 动作","50":"剧情 / 喜剧"},"region":{"23":"中国大陆","50":"中国大陆"},"wantWatchPeopleNum":{"23":"39303人","50":"26831人"}}
|
128 |
+
Thought: I now know the top 2 upcoming Chinese drama movies
|
129 |
+
Final Answer: 即将上映的中国剧情电影有2部:长空之王、热烈,大家最想看的前2部分别是:长空之王、热烈。
|
130 |
+
"""
|
131 |
+
args = re.findall(r'\b\w+\b', args)
|
132 |
+
region = args[0]
|
133 |
+
if region=='全部':
|
134 |
+
region = ''
|
135 |
+
cate = args[1]
|
136 |
+
if cate=='全部':
|
137 |
+
cate = ''
|
138 |
+
outNum = int(args[2])
|
139 |
+
WantSort = True if args[3]=='True' else False
|
140 |
+
|
141 |
+
df = parse_coming_page()
|
142 |
+
df_recon = pd.DataFrame.copy(df, deep=True)
|
143 |
+
|
144 |
+
# 即将上映的某类型电影,根据想看人数、地区、类型进行筛选
|
145 |
+
df_recon['wantWatchPeopleNum'] = df_recon['wantWatchPeopleNum'].apply(lambda x: int(x.replace('人', '')))
|
146 |
+
|
147 |
+
df_recon = df_recon[df_recon['cate'].str.contains(cate)]
|
148 |
+
df_recon = df_recon[df_recon['region'].str.contains(region)]
|
149 |
+
|
150 |
+
# 最后根据想看人数降序排列
|
151 |
+
if WantSort:
|
152 |
+
df_recon.sort_values(by="wantWatchPeopleNum" , inplace=True, ascending = not WantSort)
|
153 |
+
outDf = df_recon[:outNum]
|
154 |
+
return df.loc[outDf.index, 'date':'wantWatchPeopleNum']
|
155 |
+
|
156 |
+
|
157 |
+
@tool.get("/now_playing_out_filter")
|
158 |
+
def now_playing_out_filter(args : str):
|
159 |
+
"""NowPlayingOutFilter(args: str) prints the details of the filtered [outNum] playing films now according to region, scoreSort
|
160 |
+
args is a list like 'str1, str2, str3'
|
161 |
+
str1 can be '中国','日本' or other Production country or region. If you cannot find a region, str1 is 全部
|
162 |
+
str2 can be a integer number that agent want to get. If you cannot find a number, str2 is 100. If the found movie's num is less than str2, Final Answer only print [the found movie's num] movies.
|
163 |
+
str3 can be a True or False that refluct whether agent want the result sorted by score.
|
164 |
+
Final answer should be complete.
|
165 |
+
|
166 |
+
This is an example:
|
167 |
+
Input: 您知道现在有正在上映中国的电影吗?请输出3部
|
168 |
+
Thought: I need to find the currently playing movies with the highest scores
|
169 |
+
Action: now_playing_out_filter
|
170 |
+
Action Input: {"args" : "全部, 3, True"}
|
171 |
+
Observation: {"title":{"34":"切腹","53":"吉赛尔","31":"小森林 夏秋篇"},"score":{"34":"9.4","53":"9.2","31":"9.0"},"region":{"34":"日本","53":"西德","31":"日本"},"director":{"34":"小林正树","53":"Hugo Niebeling","31":"森淳一"},"actors":{"34":"仲代达矢 / 石浜朗 / 岩下志麻","53":"卡拉·弗拉奇 / 埃里克·布鲁恩 / Bruce Marks","31":"桥本爱 / 三浦贵大 / 松冈茉优"}}
|
172 |
+
Thought: I now know the currently playing movies with the highest scores
|
173 |
+
Final Answer: 现在上映的评分最高的3部电影是:切腹、吉赛尔、小森林 夏秋篇
|
174 |
+
|
175 |
+
"""
|
176 |
+
args = re.findall(r'\b\w+\b', args)
|
177 |
+
region = args[0]
|
178 |
+
if region=='全部':
|
179 |
+
region = ''
|
180 |
+
outNum = int(args[1])
|
181 |
+
scoreSort = True if args[2]=='True' else False
|
182 |
+
|
183 |
+
df = parse_nowplaying_page()
|
184 |
+
|
185 |
+
df_recon = pd.DataFrame.copy(df, deep=True)
|
186 |
+
|
187 |
+
df_recon['score'] = df_recon['score'].apply(lambda x: float(x))
|
188 |
+
|
189 |
+
# 正在上映的某类型电影,根据地区进行筛选
|
190 |
+
df_recon = df_recon[df_recon['region'].str.contains(region)]
|
191 |
+
|
192 |
+
# 最后根据评分降序排列
|
193 |
+
if scoreSort:
|
194 |
+
df_recon.sort_values(by="score" , inplace=True, ascending = not scoreSort)
|
195 |
+
outDf = df_recon[:outNum]
|
196 |
+
return df.loc[outDf.index, 'title':'actors']
|
197 |
+
|
198 |
+
@tool.get("/print_detail")
|
199 |
+
def print_detail(args : str):
|
200 |
+
"""parsing_detail_page(args) prints the details of a movie, giving its name.
|
201 |
+
args is a list like 'str1'
|
202 |
+
str1 is target movie's name.
|
203 |
+
step1: apply function parse_coming_page and parse_nowplaying_page and get all movie's links and other infomation.
|
204 |
+
step2: get the target movie's link from df_coming or df_nowplaying
|
205 |
+
step3: get detail from step2's link
|
206 |
+
|
207 |
+
This is an example:
|
208 |
+
Input: "电影流浪地球2怎么样?"
|
209 |
+
Thought: I need to find the movie's information
|
210 |
+
Action: print_detail
|
211 |
+
Action Input: {"args" : "流浪地球2"}
|
212 |
+
Observation: "是一部中国大陆的科幻、冒险、灾难电影,由郭帆导演,吴京、刘德华、李雪健等人主演.\n剧情简介:太阳即将毁灭,人类在地球表面建造出巨大的推进器,寻找新的家园。然而宇宙之路危机四伏,为了拯救地球,流浪地球时代的年轻人再次挺身而出,展开争分夺秒的生死之战。"
|
213 |
+
Thought: I now know the final answer
|
214 |
+
Final Answer: 流浪地球2是一部中国大陆的科幻、冒险、灾难电影,由郭帆导演,吴京、刘德华、李雪健等人主演,剧情简介是太阳即将毁灭,人类在地球表面建造出巨大的推进器,寻找新的家园,然而宇宙之路危机四伏,为了拯救地球,流浪地球时代的年轻人再次挺身而出,
|
215 |
+
|
216 |
+
"""
|
217 |
+
args = re.findall(r'\b\w+\b', args)
|
218 |
+
filmName = args[0]
|
219 |
+
|
220 |
+
df_coming = parse_coming_page()
|
221 |
+
df_nowplaying = parse_nowplaying_page()
|
222 |
+
|
223 |
+
if filmName in list(df_coming['title']):
|
224 |
+
df = df_coming
|
225 |
+
url = df[df['title']==filmName]['link'].values[0]
|
226 |
+
response = fetch_page(url)
|
227 |
+
detail = parse_detail_page(response)
|
228 |
+
elif filmName in list(df_nowplaying['title']):
|
229 |
+
df = df_nowplaying
|
230 |
+
url = df[df['title']==filmName]['link'].values[0]
|
231 |
+
response = fetch_page(url)
|
232 |
+
detail = parse_detail_page(response)
|
233 |
+
return f'{filmName}{detail}'
|
234 |
+
return tool
|
bmtools/tools/film/douban/readme.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Douban Film Search
|
2 |
+
|
3 |
+
Contributor: [Jing Yi](https://github.com/yijing16)
|
bmtools/tools/film/douban/test.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from bmtools.agent.singletool import load_single_tools, STQuestionAnswerer
|
3 |
+
|
4 |
+
tool_name, tool_url = 'douban', "http://127.0.0.1:8079/tools/douban-film/"
|
5 |
+
tools_name, tools_config = load_single_tools(tool_name, tool_url)
|
6 |
+
# tools_name, tools_config = load_single_tools()
|
7 |
+
print(tools_name, tools_config)
|
8 |
+
|
9 |
+
qa = STQuestionAnswerer()
|
10 |
+
|
11 |
+
agent = qa.load_tools(tools_name, tools_config)
|
12 |
+
|
13 |
+
agent("有哪些即将上映的中国喜剧电影?哪些是大家最想看的前5部?")
|
14 |
+
agent("想去电影院看一些国产电影,有评分高的吗?输出3部")
|
15 |
+
agent("帮我介绍下《深海》这部电影")
|
bmtools/tools/registry.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .tool import Tool
|
2 |
+
from typing import Dict, Callable, Any, List
|
3 |
+
|
4 |
+
ToolBuilder = Callable[[Any], Tool]
|
5 |
+
FuncToolBuilder = Callable[[], ToolBuilder]
|
6 |
+
|
7 |
+
|
8 |
+
class ToolsRegistry:
|
9 |
+
def __init__(self) -> None:
|
10 |
+
self.tools : Dict[str, FuncToolBuilder] = {}
|
11 |
+
|
12 |
+
def register(self, tool_name : str, tool : FuncToolBuilder):
|
13 |
+
print(f"will register {tool_name}")
|
14 |
+
self.tools[tool_name] = tool
|
15 |
+
|
16 |
+
def build(self, tool_name, config) -> Tool:
|
17 |
+
ret = self.tools[tool_name]()(config)
|
18 |
+
if isinstance(ret, Tool):
|
19 |
+
return ret
|
20 |
+
raise ValueError("Tool builder {} did not return a Tool instance".format(tool_name))
|
21 |
+
|
22 |
+
def list_tools(self) -> List[str]:
|
23 |
+
return list(self.tools.keys())
|
24 |
+
|
25 |
+
tools_registry = ToolsRegistry()
|
26 |
+
|
27 |
+
def register(tool_name):
|
28 |
+
def decorator(tool : FuncToolBuilder):
|
29 |
+
tools_registry.register(tool_name, tool)
|
30 |
+
return tool
|
31 |
+
return decorator
|
32 |
+
|
33 |
+
def build_tool(tool_name : str, config : Any) -> Tool:
|
34 |
+
print(f"will build {tool_name}")
|
35 |
+
return tools_registry.build(tool_name, config)
|
36 |
+
|
37 |
+
def list_tools() -> List[str]:
|
38 |
+
return tools_registry.list_tools()
|
bmtools/tools/retriever.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.embeddings import OpenAIEmbeddings
|
2 |
+
# from bmtools.models import CustomEmbedding
|
3 |
+
from typing import List, Dict
|
4 |
+
from queue import PriorityQueue
|
5 |
+
import os
|
6 |
+
|
7 |
+
class Retriever:
|
8 |
+
def __init__(self,
|
9 |
+
openai_api_key: str = None,
|
10 |
+
model: str = "text-embedding-ada-002"):
|
11 |
+
if openai_api_key is None:
|
12 |
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
13 |
+
self.embed = OpenAIEmbeddings(openai_api_key=openai_api_key, model=model)
|
14 |
+
# self.embed = CustomEmbedding()
|
15 |
+
self.documents = dict()
|
16 |
+
|
17 |
+
def add_tool(self, tool_name: str, api_info: Dict) -> None:
|
18 |
+
if tool_name in self.documents:
|
19 |
+
return
|
20 |
+
document = api_info["name_for_model"] + ". " + api_info["description_for_model"]
|
21 |
+
document_embedding = self.embed.embed_documents([document])
|
22 |
+
self.documents[tool_name] = {
|
23 |
+
"document": document,
|
24 |
+
"embedding": document_embedding[0]
|
25 |
+
}
|
26 |
+
|
27 |
+
def query(self, query: str, topk: int = 3) -> List[str]:
|
28 |
+
query_embedding = self.embed.embed_query(query)
|
29 |
+
|
30 |
+
queue = PriorityQueue()
|
31 |
+
for tool_name, tool_info in self.documents.items():
|
32 |
+
tool_embedding = tool_info["embedding"]
|
33 |
+
tool_sim = self.similarity(query_embedding, tool_embedding)
|
34 |
+
queue.put([-tool_sim, tool_name])
|
35 |
+
|
36 |
+
result = []
|
37 |
+
for i in range(min(topk, len(queue.queue))):
|
38 |
+
tool = queue.get()
|
39 |
+
result.append(tool[1])
|
40 |
+
|
41 |
+
return result
|
42 |
+
|
43 |
+
def similarity(self, query: List[float], document: List[float]) -> float:
|
44 |
+
return sum([i * j for i, j in zip(query, document)])
|
45 |
+
|
bmtools/tools/serve.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastapi
|
2 |
+
import uvicorn
|
3 |
+
from .registry import build_tool, list_tools
|
4 |
+
from .retriever import Retriever
|
5 |
+
from typing import List
|
6 |
+
from pydantic import BaseModel
|
7 |
+
|
8 |
+
class RetrieveRequest(BaseModel):
|
9 |
+
query: str
|
10 |
+
topk: int = 3
|
11 |
+
|
12 |
+
def _bind_tool_server(t : "ToolServer"):
|
13 |
+
""" Add property API to ToolServer.
|
14 |
+
t.api is a FastAPI object
|
15 |
+
"""
|
16 |
+
|
17 |
+
@t.api.get("/")
|
18 |
+
def health():
|
19 |
+
return {
|
20 |
+
"status": "ok",
|
21 |
+
}
|
22 |
+
|
23 |
+
@t.api.get("/list")
|
24 |
+
def get_tools_list():
|
25 |
+
return {
|
26 |
+
"tools": t.list_tools(),
|
27 |
+
}
|
28 |
+
|
29 |
+
@t.api.get("/loaded")
|
30 |
+
def get_loaded_tools():
|
31 |
+
return {
|
32 |
+
"tools": list(t.loaded_tools),
|
33 |
+
}
|
34 |
+
|
35 |
+
@t.api.get("/.well-known/ai-plugin.json", include_in_schema=False)
|
36 |
+
def get_api_info():
|
37 |
+
return {
|
38 |
+
"schema_version": "v1",
|
39 |
+
"name_for_human": "BMTools",
|
40 |
+
"name_for_model": "BMTools",
|
41 |
+
"description_for_human": "tools to big models",
|
42 |
+
"description_for_model": "tools to big models",
|
43 |
+
"auth": {
|
44 |
+
"type": "none",
|
45 |
+
},
|
46 |
+
"api": {
|
47 |
+
"type": "openapi",
|
48 |
+
"url": "/openapi.json",
|
49 |
+
"is_user_authenticated": False,
|
50 |
+
},
|
51 |
+
"logo_url": None,
|
52 |
+
"contact_email": "",
|
53 |
+
"legal_info_url": "",
|
54 |
+
}
|
55 |
+
|
56 |
+
@t.api.post("/retrieve")
|
57 |
+
def retrieve(request: RetrieveRequest):
|
58 |
+
tool_list = t.retrieve(request.query, request.topk)
|
59 |
+
return {
|
60 |
+
"tools": tool_list
|
61 |
+
}
|
62 |
+
|
63 |
+
class ToolServer:
|
64 |
+
""" This class host your own API backend.
|
65 |
+
"""
|
66 |
+
def __init__(self) -> None:
|
67 |
+
# define the root API server
|
68 |
+
self.api = fastapi.FastAPI(
|
69 |
+
title="BMTools",
|
70 |
+
description="Tools for bigmodels",
|
71 |
+
)
|
72 |
+
self.loaded_tools = dict()
|
73 |
+
self.retriever = Retriever()
|
74 |
+
_bind_tool_server(self)
|
75 |
+
|
76 |
+
def load_tool(self, name : str, config = {}):
|
77 |
+
if self.is_loaded(name):
|
78 |
+
raise ValueError(f"Tool {name} is already loaded")
|
79 |
+
try:
|
80 |
+
tool = build_tool(name, config)
|
81 |
+
except BaseException as e:
|
82 |
+
print(f"Cannot load tool {name}: {repr(e)}")
|
83 |
+
return
|
84 |
+
# tool = build_tool(name, config)
|
85 |
+
self.loaded_tools[name] = tool.api_info
|
86 |
+
self.retriever.add_tool(name, tool.api_info)
|
87 |
+
|
88 |
+
# mount sub API server to the root API server, thus can mount all urls of sub API server to /tools/{name} route
|
89 |
+
self.api.mount(f"/tools/{name}", tool, name)
|
90 |
+
return
|
91 |
+
|
92 |
+
def is_loaded(self, name : str):
|
93 |
+
return name in self.loaded_tools
|
94 |
+
|
95 |
+
def serve(self, host : str = "0.0.0.0", port : int = 8079):
|
96 |
+
uvicorn.run(self.api, host=host, port=port)
|
97 |
+
|
98 |
+
def list_tools(self) -> List[str]:
|
99 |
+
return list_tools()
|
100 |
+
|
101 |
+
def retrieve(self, query: str, topk: int = 3) -> List[str]:
|
102 |
+
return self.retriever.query(query, topk)
|
bmtools/tools/tool.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fastapi
|
2 |
+
from typing import Optional
|
3 |
+
import copy
|
4 |
+
from starlette.middleware.sessions import SessionMiddleware
|
5 |
+
from fastapi import Request
|
6 |
+
|
7 |
+
class Tool(fastapi.FastAPI):
|
8 |
+
""" Tool is inherited from FastAPI class, thus:
|
9 |
+
1. It can act as a server
|
10 |
+
2. It has get method, you can use Tool.get method to bind a function to an url
|
11 |
+
3. It can be easily mounted to another server
|
12 |
+
4. It has a list of sub-routes, each route is a function
|
13 |
+
|
14 |
+
Diagram:
|
15 |
+
Root API server (ToolServer object)
|
16 |
+
│
|
17 |
+
├───── "./weather": Tool object
|
18 |
+
│ ├── "./get_weather_today": function_get_weather(location: str) -> str
|
19 |
+
│ ├── "./get_weather_forecast": function_get_weather_forcast(location: str, day_offset: int) -> str
|
20 |
+
│ └── "...more routes"
|
21 |
+
├───── "./wikidata": Tool object
|
22 |
+
│ ├── "... more routes"
|
23 |
+
└───── "... more routes"
|
24 |
+
"""
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
tool_name : str,
|
28 |
+
description : str,
|
29 |
+
name_for_human : Optional[str] = None,
|
30 |
+
name_for_model : Optional[str] = None,
|
31 |
+
description_for_human : Optional[str] = None,
|
32 |
+
description_for_model : Optional[str] = None,
|
33 |
+
logo_url : Optional[str] = None,
|
34 |
+
author_github : Optional[str] = None,
|
35 |
+
contact_email : str = "",
|
36 |
+
legal_info_url : str = "",
|
37 |
+
version : str = "0.1.0",
|
38 |
+
):
|
39 |
+
super().__init__(
|
40 |
+
title=tool_name,
|
41 |
+
description=description,
|
42 |
+
version=version,
|
43 |
+
)
|
44 |
+
|
45 |
+
if name_for_human is None:
|
46 |
+
name_for_human = tool_name
|
47 |
+
if name_for_model is None:
|
48 |
+
name_for_model = name_for_human
|
49 |
+
if description_for_human is None:
|
50 |
+
description_for_human = description
|
51 |
+
if description_for_model is None:
|
52 |
+
description_for_model = description_for_human
|
53 |
+
|
54 |
+
self.api_info = {
|
55 |
+
"schema_version": "v1",
|
56 |
+
"name_for_human": name_for_human,
|
57 |
+
"name_for_model": name_for_model,
|
58 |
+
"description_for_human": description_for_human,
|
59 |
+
"description_for_model": description_for_model,
|
60 |
+
"auth": {
|
61 |
+
"type": "none",
|
62 |
+
},
|
63 |
+
"api": {
|
64 |
+
"type": "openapi",
|
65 |
+
"url": "/openapi.json",
|
66 |
+
"is_user_authenticated": False,
|
67 |
+
},
|
68 |
+
"author_github": author_github,
|
69 |
+
"logo_url": logo_url,
|
70 |
+
"contact_email": contact_email,
|
71 |
+
"legal_info_url": legal_info_url,
|
72 |
+
}
|
73 |
+
|
74 |
+
@self.get("/.well-known/ai-plugin.json", include_in_schema=False)
|
75 |
+
def get_api_info(request : fastapi.Request):
|
76 |
+
openapi_path = str(request.url).replace("/.well-known/ai-plugin.json", "/openapi.json")
|
77 |
+
info = copy.deepcopy(self.api_info)
|
78 |
+
info["api"]["url"] = str(openapi_path)
|
79 |
+
return info
|
80 |
+
|
81 |
+
|
82 |
+
self.add_middleware(
|
83 |
+
SessionMiddleware,
|
84 |
+
secret_key=tool_name,
|
85 |
+
session_cookie="session_{}".format(tool_name.replace(" ", "_")),
|
86 |
+
)
|
87 |
+
|
bmtools/utils/__init__.py
ADDED
File without changes
|
bmtools/utils/logging.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2020 Optuna, Hugging Face
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# BMTools copied from Huggingface Transformers
|
17 |
+
""" Logging utilities."""
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import threading
|
23 |
+
from logging import CRITICAL # NOQA
|
24 |
+
from logging import DEBUG # NOQA
|
25 |
+
from logging import ERROR # NOQA
|
26 |
+
from logging import FATAL # NOQA
|
27 |
+
from logging import INFO # NOQA
|
28 |
+
from logging import NOTSET # NOQA
|
29 |
+
from logging import WARN # NOQA
|
30 |
+
from logging import WARNING # NOQA
|
31 |
+
from typing import Optional
|
32 |
+
|
33 |
+
|
34 |
+
_lock = threading.Lock()
|
35 |
+
_default_handler: Optional[logging.Handler] = None
|
36 |
+
|
37 |
+
log_levels = {
|
38 |
+
"debug": logging.DEBUG,
|
39 |
+
"info": logging.INFO,
|
40 |
+
"warning": logging.WARNING,
|
41 |
+
"error": logging.ERROR,
|
42 |
+
"critical": logging.CRITICAL,
|
43 |
+
}
|
44 |
+
|
45 |
+
_default_log_level = logging.INFO
|
46 |
+
|
47 |
+
|
48 |
+
def _get_default_logging_level():
|
49 |
+
"""
|
50 |
+
If BMTOOLS_VERBOSITY env var is set to one of the valid choices return that as the new default level. If it is
|
51 |
+
not - fall back to ``_default_log_level``
|
52 |
+
"""
|
53 |
+
env_level_str = os.getenv("BMTOOLS_VERBOSITY", None)
|
54 |
+
if env_level_str:
|
55 |
+
if env_level_str in log_levels:
|
56 |
+
return log_levels[env_level_str]
|
57 |
+
else:
|
58 |
+
logging.getLogger().warning(
|
59 |
+
f"Unknown option BMTOOLS_VERBOSITY={env_level_str}, "
|
60 |
+
f"has to be one of: { ', '.join(log_levels.keys()) }"
|
61 |
+
)
|
62 |
+
return _default_log_level
|
63 |
+
|
64 |
+
|
65 |
+
def _get_library_name() -> str:
|
66 |
+
|
67 |
+
return __name__.split(".")[0]
|
68 |
+
|
69 |
+
|
70 |
+
def _get_library_root_logger() -> logging.Logger:
|
71 |
+
|
72 |
+
return logging.getLogger(_get_library_name())
|
73 |
+
|
74 |
+
|
75 |
+
def _configure_library_root_logger() -> None:
|
76 |
+
|
77 |
+
global _default_handler
|
78 |
+
|
79 |
+
with _lock:
|
80 |
+
if _default_handler:
|
81 |
+
# This library has already configured the library root logger.
|
82 |
+
return
|
83 |
+
_default_handler = logging.StreamHandler() # Set sys.stderr as stream.
|
84 |
+
_default_handler.flush = sys.stderr.flush
|
85 |
+
formatter = logging.Formatter(
|
86 |
+
"\033[1;31m[%(levelname)s|(BMTools)%(module)s:%(lineno)d]%(asctime)s >> \033[0m %(message)s")
|
87 |
+
_default_handler.setFormatter(formatter)
|
88 |
+
|
89 |
+
# Apply our default configuration to the library root logger.
|
90 |
+
library_root_logger = _get_library_root_logger()
|
91 |
+
library_root_logger.addHandler(_default_handler)
|
92 |
+
library_root_logger.setLevel(_get_default_logging_level())
|
93 |
+
|
94 |
+
|
95 |
+
library_root_logger.propagate = False
|
96 |
+
|
97 |
+
|
98 |
+
def _reset_library_root_logger() -> None:
|
99 |
+
|
100 |
+
global _default_handler
|
101 |
+
|
102 |
+
with _lock:
|
103 |
+
if not _default_handler:
|
104 |
+
return
|
105 |
+
|
106 |
+
library_root_logger = _get_library_root_logger()
|
107 |
+
library_root_logger.removeHandler(_default_handler)
|
108 |
+
library_root_logger.setLevel(logging.NOTSET)
|
109 |
+
_default_handler = None
|
110 |
+
|
111 |
+
|
112 |
+
def get_log_levels_dict():
|
113 |
+
return log_levels
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
def get_verbosity() -> int:
|
118 |
+
"""
|
119 |
+
Return the current level for the 🤗 Transformers's root logger as an int.
|
120 |
+
Returns:
|
121 |
+
:obj:`int`: The logging level.
|
122 |
+
<Tip>
|
123 |
+
🤗 Transformers has following logging levels:
|
124 |
+
- 50: ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
|
125 |
+
- 40: ``transformers.logging.ERROR``
|
126 |
+
- 30: ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
|
127 |
+
- 20: ``transformers.logging.INFO``
|
128 |
+
- 10: ``transformers.logging.DEBUG``
|
129 |
+
</Tip>"""
|
130 |
+
|
131 |
+
_configure_library_root_logger()
|
132 |
+
return _get_library_root_logger().getEffectiveLevel()
|
133 |
+
|
134 |
+
|
135 |
+
def set_verbosity(verbosity: int) -> None:
|
136 |
+
"""
|
137 |
+
Set the verbosity level for the 🤗 Transformers's root logger.
|
138 |
+
Args:
|
139 |
+
verbosity (:obj:`int`):
|
140 |
+
Logging level, e.g., one of:
|
141 |
+
- ``transformers.logging.CRITICAL`` or ``transformers.logging.FATAL``
|
142 |
+
- ``transformers.logging.ERROR``
|
143 |
+
- ``transformers.logging.WARNING`` or ``transformers.logging.WARN``
|
144 |
+
- ``transformers.logging.INFO``
|
145 |
+
- ``transformers.logging.DEBUG``
|
146 |
+
"""
|
147 |
+
|
148 |
+
_configure_library_root_logger()
|
149 |
+
_get_library_root_logger().setLevel(verbosity)
|
150 |
+
|
151 |
+
|
152 |
+
def set_verbosity_info():
|
153 |
+
"""Set the verbosity to the ``INFO`` level."""
|
154 |
+
return set_verbosity(INFO)
|
155 |
+
|
156 |
+
|
157 |
+
def set_verbosity_warning():
|
158 |
+
"""Set the verbosity to the ``WARNING`` level."""
|
159 |
+
return set_verbosity(WARNING)
|
160 |
+
|
161 |
+
|
162 |
+
def set_verbosity_debug():
|
163 |
+
"""Set the verbosity to the ``DEBUG`` level."""
|
164 |
+
return set_verbosity(DEBUG)
|
165 |
+
|
166 |
+
|
167 |
+
def set_verbosity_error():
|
168 |
+
"""Set the verbosity to the ``ERROR`` level."""
|
169 |
+
return set_verbosity(ERROR)
|
170 |
+
|
171 |
+
|
172 |
+
def disable_default_handler() -> None:
|
173 |
+
"""Disable the default handler of the HuggingFace Transformers's root logger."""
|
174 |
+
|
175 |
+
_configure_library_root_logger()
|
176 |
+
|
177 |
+
assert _default_handler is not None
|
178 |
+
_get_library_root_logger().removeHandler(_default_handler)
|
179 |
+
|
180 |
+
|
181 |
+
def enable_default_handler() -> None:
|
182 |
+
"""Enable the default handler of the HuggingFace Transformers's root logger."""
|
183 |
+
|
184 |
+
_configure_library_root_logger()
|
185 |
+
|
186 |
+
assert _default_handler is not None
|
187 |
+
_get_library_root_logger().addHandler(_default_handler)
|
188 |
+
|
189 |
+
|
190 |
+
def add_handler(handler: logging.Handler) -> None:
|
191 |
+
"""adds a handler to the HuggingFace Transformers's root logger."""
|
192 |
+
|
193 |
+
_configure_library_root_logger()
|
194 |
+
|
195 |
+
assert handler is not None
|
196 |
+
_get_library_root_logger().addHandler(handler)
|
197 |
+
|
198 |
+
|
199 |
+
def remove_handler(handler: logging.Handler) -> None:
|
200 |
+
"""removes given handler from the HuggingFace Transformers's root logger."""
|
201 |
+
|
202 |
+
_configure_library_root_logger()
|
203 |
+
|
204 |
+
assert handler is not None and handler not in _get_library_root_logger().handlers
|
205 |
+
_get_library_root_logger().removeHandler(handler)
|
206 |
+
|
207 |
+
|
208 |
+
def disable_propagation() -> None:
|
209 |
+
"""
|
210 |
+
Disable propagation of the library log outputs. Note that log propagation is disabled by default.
|
211 |
+
"""
|
212 |
+
|
213 |
+
_configure_library_root_logger()
|
214 |
+
_get_library_root_logger().propagate = False
|
215 |
+
|
216 |
+
|
217 |
+
def enable_propagation() -> None:
|
218 |
+
"""
|
219 |
+
Enable propagation of the library log outputs. Please disable the HuggingFace Transformers's default handler to
|
220 |
+
prevent double logging if the root logger has been configured.
|
221 |
+
"""
|
222 |
+
|
223 |
+
_configure_library_root_logger()
|
224 |
+
_get_library_root_logger().propagate = True
|
225 |
+
|
226 |
+
|
227 |
+
def enable_explicit_format() -> None:
|
228 |
+
"""
|
229 |
+
Enable explicit formatting for every HuggingFace Transformers's logger. The explicit formatter is as follows:
|
230 |
+
```
|
231 |
+
[LEVELNAME|FILENAME|LINE NUMBER] TIME >> MESSAGE
|
232 |
+
```
|
233 |
+
All handlers currently bound to the root logger are affected by this method.
|
234 |
+
"""
|
235 |
+
handlers = _get_library_root_logger().handlers
|
236 |
+
|
237 |
+
for handler in handlers:
|
238 |
+
formatter = logging.Formatter("[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s")
|
239 |
+
handler.setFormatter(formatter)
|
240 |
+
|
241 |
+
|
242 |
+
def reset_format() -> None:
|
243 |
+
"""
|
244 |
+
Resets the formatting for HuggingFace Transformers's loggers.
|
245 |
+
All handlers currently bound to the root logger are affected by this method.
|
246 |
+
"""
|
247 |
+
handlers = _get_library_root_logger().handlers
|
248 |
+
|
249 |
+
for handler in handlers:
|
250 |
+
handler.setFormatter(None)
|
251 |
+
|
252 |
+
|
253 |
+
def warning_advice(self, *args, **kwargs):
|
254 |
+
"""
|
255 |
+
This method is identical to ``logger.warning()``, but if env var TRANSFORMERS_NO_ADVISORY_WARNINGS=1 is set, this
|
256 |
+
warning will not be printed
|
257 |
+
"""
|
258 |
+
no_advisory_warnings = os.getenv("TRANSFORMERS_NO_ADVISORY_WARNINGS", False)
|
259 |
+
if no_advisory_warnings:
|
260 |
+
return
|
261 |
+
self.warning(*args, **kwargs)
|
262 |
+
|
263 |
+
|
264 |
+
logging.Logger.warning_advice = warning_advice
|
265 |
+
|
266 |
+
|
267 |
+
def get_logger(name: Optional[str] = None, verbosity='info') -> logging.Logger:
|
268 |
+
"""
|
269 |
+
Return a logger with the specified name.
|
270 |
+
This function is not supposed to be directly accessed unless you are writing a custom transformers module.
|
271 |
+
"""
|
272 |
+
|
273 |
+
if name is None:
|
274 |
+
name = _get_library_name()
|
275 |
+
|
276 |
+
_configure_library_root_logger()
|
277 |
+
logger = logging.getLogger(name)
|
278 |
+
logger.setLevel(log_levels[verbosity])
|
279 |
+
return logger
|
requirements.txt
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas
|
2 |
+
SPARQLWrapper
|
3 |
+
Pillow
|
4 |
+
oss2
|
5 |
+
regex
|
6 |
+
matplotlib
|
7 |
+
tabulate
|
8 |
+
python-pptx
|
9 |
+
replicate
|
10 |
+
bs4
|
11 |
+
langchain
|
12 |
+
pandasql
|
13 |
+
SQLAlchemy==1.4.46
|
14 |
+
openai
|
15 |
+
gradio
|
16 |
+
fastapi_sessions
|
17 |
+
translate
|
18 |
+
socksio
|
19 |
+
py3langid
|
20 |
+
iso-639
|
21 |
+
transformers
|
22 |
+
cchardet
|
23 |
+
faiss-cpu
|