|
from copy import deepcopy |
|
|
|
from lagent.schema import ActionReturn, ActionStatusCode, FunctionCall |
|
from .hook import Hook |
|
|
|
|
|
class ActionPreprocessor(Hook): |
|
"""The ActionPreprocessor is a hook that preprocesses the action message |
|
and postprocesses the action return message. |
|
|
|
""" |
|
|
|
def before_action(self, executor, message, session_id): |
|
assert isinstance(message.formatted, FunctionCall) or ( |
|
isinstance(message.formatted, dict) and 'name' in message.content |
|
and 'parameters' in message.formatted) or ( |
|
'action' in message.formatted |
|
and 'parameters' in message.formatted['action'] |
|
and 'name' in message.formatted['action']) |
|
if isinstance(message.formatted, dict): |
|
name = message.formatted.get('name', |
|
message.formatted['action']['name']) |
|
parameters = message.formatted.get( |
|
'parameters', message.formatted['action']['parameters']) |
|
else: |
|
name = message.formatted.name |
|
parameters = message.formatted.parameters |
|
message.content = dict(name=name, parameters=parameters) |
|
return message |
|
|
|
def after_action(self, executor, message, session_id): |
|
action_return = message.content |
|
if isinstance(action_return, ActionReturn): |
|
if action_return.state == ActionStatusCode.SUCCESS: |
|
response = action_return.format_result() |
|
else: |
|
response = action_return.errmsg |
|
else: |
|
response = action_return |
|
message.content = response |
|
return message |
|
|
|
|
|
class InternLMActionProcessor(ActionPreprocessor): |
|
|
|
def __init__(self, code_parameter: str = 'command'): |
|
self.code_parameter = code_parameter |
|
|
|
def before_action(self, executor, message, session_id): |
|
message = deepcopy(message) |
|
assert isinstance(message.formatted, dict) and set( |
|
message.formatted).issuperset( |
|
{'tool_type', 'thought', 'action', 'status'}) |
|
if isinstance(message.formatted['action'], str): |
|
|
|
action_name = next(iter(executor.actions)) |
|
parameters = {self.code_parameter: message.formatted['action']} |
|
if action_name in ['AsyncIPythonInterpreter']: |
|
parameters['session_id'] = session_id |
|
message.formatted['action'] = dict( |
|
name=action_name, parameters=parameters) |
|
return super().before_action(executor, message, session_id) |
|
|