Spaces:
Running
Running
File size: 1,575 Bytes
ac819bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import json
from unittest import TestCase, mock
from lagent.actions import GoogleSearch
from lagent.schema import ActionStatusCode
class TestGoogleSearch(TestCase):
@mock.patch.object(GoogleSearch, '_search')
def test_search_tool(self, mock_search_func):
mock_response = (200, json.load('tests/data/search.json'))
mock_search_func.return_value = mock_response
search_tool = GoogleSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.SUCCESS)
self.assertDictEqual(tool_return.result, dict(text="['Beijing']"))
@mock.patch.object(GoogleSearch, '_search')
def test_api_error(self, mock_search_func):
mock_response = (403, {'message': 'bad requests'})
mock_search_func.return_value = mock_response
search_tool = GoogleSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.API_ERROR)
self.assertEqual(tool_return.errmsg, str(403))
@mock.patch.object(GoogleSearch, '_search')
def test_http_error(self, mock_search_func):
mock_response = (-1, 'HTTPSConnectionPool')
mock_search_func.return_value = mock_response
search_tool = GoogleSearch(api_key='abc')
tool_return = search_tool.run("What's the capital of China?")
self.assertEqual(tool_return.state, ActionStatusCode.HTTP_ERROR)
self.assertEqual(tool_return.errmsg, 'HTTPSConnectionPool')
|