Lagent / lagent /hooks /action_preprocessor.py
yanyoyo
update
ec878fd
from copy import deepcopy
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall
from .hook import Hook
class ActionPreprocessor(Hook):
"""The ActionPreprocessor is a hook that preprocesses the action message
and postprocesses the action return message.
"""
def before_action(self, executor, message, session_id):
assert isinstance(message.formatted, FunctionCall) or (
isinstance(message.formatted, dict) and 'name' in message.content
and 'parameters' in message.formatted) or (
'action' in message.formatted
and 'parameters' in message.formatted['action']
and 'name' in message.formatted['action'])
if isinstance(message.formatted, dict):
name = message.formatted.get('name',
message.formatted['action']['name'])
parameters = message.formatted.get(
'parameters', message.formatted['action']['parameters'])
else:
name = message.formatted.name
parameters = message.formatted.parameters
message.content = dict(name=name, parameters=parameters)
return message
def after_action(self, executor, message, session_id):
action_return = message.content
if isinstance(action_return, ActionReturn):
if action_return.state == ActionStatusCode.SUCCESS:
response = action_return.format_result()
else:
response = action_return.errmsg
else:
response = action_return
message.content = response
return message
class InternLMActionProcessor(ActionPreprocessor):
def __init__(self, code_parameter: str = 'command'):
self.code_parameter = code_parameter
def before_action(self, executor, message, session_id):
message = deepcopy(message)
assert isinstance(message.formatted, dict) and set(
message.formatted).issuperset(
{'tool_type', 'thought', 'action', 'status'})
if isinstance(message.formatted['action'], str):
# encapsulate code interpreter arguments
action_name = next(iter(executor.actions))
parameters = {self.code_parameter: message.formatted['action']}
if action_name in ['AsyncIPythonInterpreter']:
parameters['session_id'] = session_id
message.formatted['action'] = dict(
name=action_name, parameters=parameters)
return super().before_action(executor, message, session_id)