|
import asyncio |
|
import json |
|
import time |
|
|
|
from datasets import load_dataset |
|
|
|
from lagent.agents.stream import PLUGIN_CN, AsyncAgentForInternLM, AsyncMathCoder, get_plugin_prompt |
|
from lagent.llms import INTERNLM2_META |
|
from lagent.llms.lmdeploy_wrapper import AsyncLMDeployPipeline |
|
from lagent.prompts.parsers import PluginParser |
|
|
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
model = AsyncLMDeployPipeline( |
|
path='internlm/internlm2_5-7b-chat', |
|
meta_template=INTERNLM2_META, |
|
model_name='internlm-chat', |
|
tp=1, |
|
top_k=1, |
|
temperature=1.0, |
|
stop_words=['<|im_end|>', '<|action_end|>'], |
|
max_new_tokens=1024, |
|
) |
|
|
|
|
|
print('-' * 80, 'interpreter', '-' * 80) |
|
|
|
ds = load_dataset('lighteval/MATH', split='test') |
|
problems = [item['problem'] for item in ds.select(range(0, 5000, 2))] |
|
|
|
coder = AsyncMathCoder( |
|
llm=model, |
|
interpreter=dict( |
|
type='lagent.actions.AsyncIPythonInterpreter', max_kernels=300), |
|
max_turn=11) |
|
tic = time.time() |
|
coros = [coder(query, session_id=i) for i, query in enumerate(problems)] |
|
res = loop.run_until_complete(asyncio.gather(*coros)) |
|
|
|
print('-' * 120) |
|
print(f'time elapsed: {time.time() - tic}') |
|
|
|
with open('./tmp_1.json', 'w') as f: |
|
json.dump([coder.get_steps(i) for i in range(len(res))], |
|
f, |
|
ensure_ascii=False, |
|
indent=4) |
|
|
|
|
|
print('-' * 80, 'plugin', '-' * 80) |
|
plugins = [dict(type='lagent.actions.AsyncArxivSearch')] |
|
agent = AsyncAgentForInternLM( |
|
llm=model, |
|
plugins=plugins, |
|
output_format=dict( |
|
type=PluginParser, |
|
template=PLUGIN_CN, |
|
prompt=get_plugin_prompt(plugins))) |
|
|
|
tic = time.time() |
|
coros = [ |
|
agent(query, session_id=i) |
|
for i, query in enumerate(['LLM智能体方向的最新论文有哪些?'] * 50) |
|
] |
|
res = loop.run_until_complete(asyncio.gather(*coros)) |
|
|
|
print('-' * 120) |
|
print(f'time elapsed: {time.time() - tic}') |
|
|