|
from unittest import TestCase, mock |
|
|
|
from lagent.actions import ActionExecutor |
|
from lagent.actions.llm_qa import LLMQA |
|
from lagent.actions.serper_search import SerperSearch |
|
from lagent.agents.rewoo import ReWOO, ReWOOProtocol |
|
from lagent.schema import ActionReturn, ActionStatusCode |
|
|
|
|
|
class TestReWOO(TestCase): |
|
|
|
@mock.patch.object(SerperSearch, 'run') |
|
@mock.patch.object(LLMQA, 'run') |
|
@mock.patch.object(ReWOOProtocol, 'parse_worker') |
|
def test_normal_chat(self, mock_parse_worker_func, mock_qa_func, |
|
mock_search_func): |
|
mock_model = mock.Mock() |
|
mock_model.generate_from_template.return_value = 'LLM response' |
|
|
|
mock_parse_worker_func.return_value = (['Thought1', 'Thought2' |
|
], ['LLMQA', 'SerperSearch'], |
|
['abc', 'abc']) |
|
|
|
search_return = ActionReturn(args=None) |
|
search_return.state = ActionStatusCode.SUCCESS |
|
search_return.result = dict(text='search_return') |
|
mock_search_func.return_value = search_return |
|
|
|
qa_return = ActionReturn(args=None) |
|
qa_return.state = ActionStatusCode.SUCCESS |
|
qa_return.result = dict(text='qa_return') |
|
mock_qa_func.return_value = qa_return |
|
|
|
chatbot = ReWOO( |
|
llm=mock_model, |
|
action_executor=ActionExecutor(actions=[ |
|
LLMQA(mock_model), |
|
SerperSearch(api_key=''), |
|
])) |
|
agent_return = chatbot.chat('abc') |
|
self.assertEqual(agent_return.response, 'LLM response') |
|
|
|
def test_parse_worker(self): |
|
prompt = ReWOOProtocol() |
|
message = """ |
|
Plan: a. |
|
#E1 = tool1["a"] |
|
#E2 = tool2["b"] |
|
""" |
|
try: |
|
thoughts, actions, actions_input = prompt.parse_worker(message) |
|
except Exception as e: |
|
self.assertEqual( |
|
'Each Plan should only correspond to only ONE action', str(e)) |
|
else: |
|
self.assertFalse( |
|
True, 'it should raise exception when the format is incorrect') |
|
|
|
message = """ |
|
Plan: a. |
|
#E1 = tool1("a") |
|
Plan: b. |
|
#E2 = tool2["b"] |
|
""" |
|
try: |
|
thoughts, actions, actions_input = prompt.parse_worker(message) |
|
except Exception as e: |
|
self.assertIsInstance(e, BaseException) |
|
else: |
|
self.assertFalse( |
|
True, 'it should raise exception when the format is incorrect') |
|
|
|
message = """ |
|
Plan: a. |
|
#E1 = tool1["a"] |
|
Plan: b. |
|
#E2 = tool2["b"] |
|
""" |
|
try: |
|
thoughts, actions, actions_input = prompt.parse_worker(message) |
|
except Exception: |
|
self.assertFalse( |
|
True, |
|
'it should not raise exception when the format is correct') |
|
self.assertEqual(thoughts, ['a.', 'b.']) |
|
self.assertEqual(actions, ['tool1', 'tool2']) |
|
self.assertEqual(actions_input, ['"a"', '"b"']) |
|
|