File size: 3,051 Bytes
cfc816f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
from unittest import TestCase, mock
from lagent.actions import ActionExecutor
from lagent.actions.llm_qa import LLMQA
from lagent.actions.serper_search import SerperSearch
from lagent.agents.rewoo import ReWOO, ReWOOProtocol
from lagent.schema import ActionReturn, ActionStatusCode
class TestReWOO(TestCase):
@mock.patch.object(SerperSearch, 'run')
@mock.patch.object(LLMQA, 'run')
@mock.patch.object(ReWOOProtocol, 'parse_worker')
def test_normal_chat(self, mock_parse_worker_func, mock_qa_func,
mock_search_func):
mock_model = mock.Mock()
mock_model.generate_from_template.return_value = 'LLM response'
mock_parse_worker_func.return_value = (['Thought1', 'Thought2'
], ['LLMQA', 'SerperSearch'],
['abc', 'abc'])
search_return = ActionReturn(args=None)
search_return.state = ActionStatusCode.SUCCESS
search_return.result = dict(text='search_return')
mock_search_func.return_value = search_return
qa_return = ActionReturn(args=None)
qa_return.state = ActionStatusCode.SUCCESS
qa_return.result = dict(text='qa_return')
mock_qa_func.return_value = qa_return
chatbot = ReWOO(
llm=mock_model,
action_executor=ActionExecutor(actions=[
LLMQA(mock_model),
SerperSearch(api_key=''),
]))
agent_return = chatbot.chat('abc')
self.assertEqual(agent_return.response, 'LLM response')
def test_parse_worker(self):
prompt = ReWOOProtocol()
message = """
Plan: a.
#E1 = tool1["a"]
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception as e:
self.assertEqual(
'Each Plan should only correspond to only ONE action', str(e))
else:
self.assertFalse(
True, 'it should raise exception when the format is incorrect')
message = """
Plan: a.
#E1 = tool1("a")
Plan: b.
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception as e:
self.assertIsInstance(e, BaseException)
else:
self.assertFalse(
True, 'it should raise exception when the format is incorrect')
message = """
Plan: a.
#E1 = tool1["a"]
Plan: b.
#E2 = tool2["b"]
"""
try:
thoughts, actions, actions_input = prompt.parse_worker(message)
except Exception:
self.assertFalse(
True,
'it should not raise exception when the format is correct')
self.assertEqual(thoughts, ['a.', 'b.'])
self.assertEqual(actions, ['tool1', 'tool2'])
self.assertEqual(actions_input, ['"a"', '"b"'])
|