from unittest import TestCase, mock from lagent.actions import ActionExecutor from lagent.actions.llm_qa import LLMQA from lagent.actions.serper_search import SerperSearch from lagent.agents.rewoo import ReWOO, ReWOOProtocol from lagent.schema import ActionReturn, ActionStatusCode class TestReWOO(TestCase): @mock.patch.object(SerperSearch, 'run') @mock.patch.object(LLMQA, 'run') @mock.patch.object(ReWOOProtocol, 'parse_worker') def test_normal_chat(self, mock_parse_worker_func, mock_qa_func, mock_search_func): mock_model = mock.Mock() mock_model.generate_from_template.return_value = 'LLM response' mock_parse_worker_func.return_value = (['Thought1', 'Thought2' ], ['LLMQA', 'SerperSearch'], ['abc', 'abc']) search_return = ActionReturn(args=None) search_return.state = ActionStatusCode.SUCCESS search_return.result = dict(text='search_return') mock_search_func.return_value = search_return qa_return = ActionReturn(args=None) qa_return.state = ActionStatusCode.SUCCESS qa_return.result = dict(text='qa_return') mock_qa_func.return_value = qa_return chatbot = ReWOO( llm=mock_model, action_executor=ActionExecutor(actions=[ LLMQA(mock_model), SerperSearch(api_key=''), ])) agent_return = chatbot.chat('abc') self.assertEqual(agent_return.response, 'LLM response') def test_parse_worker(self): prompt = ReWOOProtocol() message = """ Plan: a. #E1 = tool1["a"] #E2 = tool2["b"] """ try: thoughts, actions, actions_input = prompt.parse_worker(message) except Exception as e: self.assertEqual( 'Each Plan should only correspond to only ONE action', str(e)) else: self.assertFalse( True, 'it should raise exception when the format is incorrect') message = """ Plan: a. #E1 = tool1("a") Plan: b. #E2 = tool2["b"] """ try: thoughts, actions, actions_input = prompt.parse_worker(message) except Exception as e: self.assertIsInstance(e, BaseException) else: self.assertFalse( True, 'it should raise exception when the format is incorrect') message = """ Plan: a. #E1 = tool1["a"] Plan: b. #E2 = tool2["b"] """ try: thoughts, actions, actions_input = prompt.parse_worker(message) except Exception: self.assertFalse( True, 'it should not raise exception when the format is correct') self.assertEqual(thoughts, ['a.', 'b.']) self.assertEqual(actions, ['tool1', 'tool2']) self.assertEqual(actions_input, ['"a"', '"b"'])