File size: 1,575 Bytes
ac819bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import json
from unittest import TestCase, mock

from lagent.actions import GoogleSearch
from lagent.schema import ActionStatusCode


class TestGoogleSearch(TestCase):

    @mock.patch.object(GoogleSearch, '_search')
    def test_search_tool(self, mock_search_func):
        mock_response = (200, json.load('tests/data/search.json'))
        mock_search_func.return_value = mock_response
        search_tool = GoogleSearch(api_key='abc')
        tool_return = search_tool.run("What's the capital of China?")
        self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
        self.assertDictEqual(tool_return.result, dict(text="['Beijing']"))

    @mock.patch.object(GoogleSearch, '_search')
    def test_api_error(self, mock_search_func):
        mock_response = (403, {'message': 'bad requests'})
        mock_search_func.return_value = mock_response
        search_tool = GoogleSearch(api_key='abc')
        tool_return = search_tool.run("What's the capital of China?")
        self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
        self.assertEqual(tool_return.errmsg, str(403))

    @mock.patch.object(GoogleSearch, '_search')
    def test_http_error(self, mock_search_func):
        mock_response = (-1, 'HTTPSConnectionPool')
        mock_search_func.return_value = mock_response
        search_tool = GoogleSearch(api_key='abc')
        tool_return = search_tool.run("What's the capital of China?")
        self.assertEqual(tool_return.state, ActionStatusCode.HTTP_ERROR)
        self.assertEqual(tool_return.errmsg, 'HTTPSConnectionPool')