File size: 2,605 Bytes
ec878fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from copy import deepcopy
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall
from .hook import Hook
class ActionPreprocessor(Hook):
"""The ActionPreprocessor is a hook that preprocesses the action message
and postprocesses the action return message.
"""
def before_action(self, executor, message, session_id):
assert isinstance(message.formatted, FunctionCall) or (
isinstance(message.formatted, dict) and 'name' in message.content
and 'parameters' in message.formatted) or (
'action' in message.formatted
and 'parameters' in message.formatted['action']
and 'name' in message.formatted['action'])
if isinstance(message.formatted, dict):
name = message.formatted.get('name',
message.formatted['action']['name'])
parameters = message.formatted.get(
'parameters', message.formatted['action']['parameters'])
else:
name = message.formatted.name
parameters = message.formatted.parameters
message.content = dict(name=name, parameters=parameters)
return message
def after_action(self, executor, message, session_id):
action_return = message.content
if isinstance(action_return, ActionReturn):
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
else:
response = action_return
message.content = response
return message
class InternLMActionProcessor(ActionPreprocessor):
def __init__(self, code_parameter: str = 'command'):
self.code_parameter = code_parameter
def before_action(self, executor, message, session_id):
message = deepcopy(message)
assert isinstance(message.formatted, dict) and set(
message.formatted).issuperset(
{'tool_type', 'thought', 'action', 'status'})
if isinstance(message.formatted['action'], str):
# encapsulate code interpreter arguments
action_name = next(iter(executor.actions))
parameters = {self.code_parameter: message.formatted['action']}
if action_name in ['AsyncIPythonInterpreter']:
parameters['session_id'] = session_id
message.formatted['action'] = dict(
name=action_name, parameters=parameters)
return super().before_action(executor, message, session_id)
|