File size: 3,051 Bytes
cfc816f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from unittest import TestCase, mock

from lagent.actions import ActionExecutor
from lagent.actions.llm_qa import LLMQA
from lagent.actions.serper_search import SerperSearch
from lagent.agents.rewoo import ReWOO, ReWOOProtocol
from lagent.schema import ActionReturn, ActionStatusCode


class TestReWOO(TestCase):

    @mock.patch.object(SerperSearch, 'run')
    @mock.patch.object(LLMQA, 'run')
    @mock.patch.object(ReWOOProtocol, 'parse_worker')
    def test_normal_chat(self, mock_parse_worker_func, mock_qa_func,
                         mock_search_func):
        mock_model = mock.Mock()
        mock_model.generate_from_template.return_value = 'LLM response'

        mock_parse_worker_func.return_value = (['Thought1', 'Thought2'
                                                ], ['LLMQA', 'SerperSearch'],
                                               ['abc', 'abc'])

        search_return = ActionReturn(args=None)
        search_return.state = ActionStatusCode.SUCCESS
        search_return.result = dict(text='search_return')
        mock_search_func.return_value = search_return

        qa_return = ActionReturn(args=None)
        qa_return.state = ActionStatusCode.SUCCESS
        qa_return.result = dict(text='qa_return')
        mock_qa_func.return_value = qa_return

        chatbot = ReWOO(
            llm=mock_model,
            action_executor=ActionExecutor(actions=[
                LLMQA(mock_model),
                SerperSearch(api_key=''),
            ]))
        agent_return = chatbot.chat('abc')
        self.assertEqual(agent_return.response, 'LLM response')

    def test_parse_worker(self):
        prompt = ReWOOProtocol()
        message = """
        Plan: a.
        #E1 = tool1["a"]
        #E2 = tool2["b"]
        """
        try:
            thoughts, actions, actions_input = prompt.parse_worker(message)
        except Exception as e:
            self.assertEqual(
                'Each Plan should only correspond to only ONE action', str(e))
        else:
            self.assertFalse(
                True, 'it should raise exception when the format is incorrect')

        message = """
        Plan: a.
        #E1 = tool1("a")
        Plan: b.
        #E2 = tool2["b"]
        """
        try:
            thoughts, actions, actions_input = prompt.parse_worker(message)
        except Exception as e:
            self.assertIsInstance(e, BaseException)
        else:
            self.assertFalse(
                True, 'it should raise exception when the format is incorrect')

        message = """
        Plan: a.
        #E1 = tool1["a"]
        Plan: b.
        #E2 = tool2["b"]
        """
        try:
            thoughts, actions, actions_input = prompt.parse_worker(message)
        except Exception:
            self.assertFalse(
                True,
                'it should not raise exception when the format is correct')
        self.assertEqual(thoughts, ['a.', 'b.'])
        self.assertEqual(actions, ['tool1', 'tool2'])
        self.assertEqual(actions_input, ['"a"', '"b"'])