Spaces:
Running
Running
""" | |
This file contains the Environment class, which prepares the environment for the research agent to run in. | |
""" | |
import json | |
import os | |
import sys | |
import subprocess | |
import shutil | |
import copy | |
import time | |
import fnmatch | |
import signal | |
from traceback import format_exception | |
from multiprocessing import active_children | |
from dacite import from_dict | |
from .low_level_actions import LOW_LEVEL_ACTIONS | |
from .high_level_actions import HIGH_LEVEL_ACTIONS | |
from .p2m_actions import P2M_ACTIONS | |
from .schema import Step, Trace, EnvException, TooLongPromptError, LLMError, EnhancedJSONEncoder | |
from .prepare_task import prepare_task, get_task_info | |
class TimeoutException(Exception): pass | |
class Environment: | |
def __init__(self, args): | |
self._args = args | |
self._log_dir = os.path.join(args.log_dir, "env_log") | |
self._setup_log_dir() | |
self._research_problem = args.research_problem | |
self._work_dir = args.work_dir | |
self._read_only_files = [] | |
self._initialize_env() # set up work dir and log dir | |
self._action_infos = {t.name: t for t in LOW_LEVEL_ACTIONS + HIGH_LEVEL_ACTIONS + P2M_ACTIONS} | |
self._static_kwargs_for_tools = { | |
"device": args.device, | |
"python": args.python, | |
"work_dir": self.work_dir, | |
"args": args, | |
"read_only_files": self.read_only_files, | |
"research_problem": self.research_problem, | |
} | |
self._trace = self._initialize_trace() | |
self._start_time = time.time() | |
############################## getters ######################################## | |
def user(self): | |
return self._user | |
def research_problem(self): | |
return self._research_problem | |
def log_dir(self): | |
return self._log_dir | |
def work_dir(self): | |
return self._work_dir | |
def read_only_files(self): | |
return self._read_only_files | |
def action_infos(self): | |
return self._action_infos | |
def args(self): | |
return self._args | |
def static_kwargs_for_tools(self): | |
return self._static_kwargs_for_tools | |
def trace(self): | |
return copy.deepcopy(self._trace) | |
def start_time(self): | |
return self._start_time | |
############################## internal functions ######################################## | |
def _setup_log_dir(self): | |
# set up log dir | |
if os.path.exists(self.args.log_dir): | |
print("log_dir {} already exists".format(self.log_dir)) | |
else: | |
os.makedirs(self.log_dir) | |
if os.path.exists(os.path.join(self.log_dir, "tool_logs")): | |
print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "tool_logs"))) | |
# raise ValueError("log_dir {} already exists".format(self.log_dir)) | |
else: | |
os.makedirs(os.path.join(self.log_dir, "tool_logs")) | |
if os.path.exists(os.path.join(self.log_dir, "traces")): | |
print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "traces"))) | |
# raise ValueError("log_dir {} already exists".format(self.log_dir)) | |
else: | |
os.makedirs(os.path.join(self.log_dir, "traces")) | |
def _initialize_env(self): | |
os.makedirs(os.path.join(self.work_dir), exist_ok=True) | |
# set up read only files | |
can_modify_files = '*' | |
size = 0 | |
self._read_only_files = [] | |
for path, subdirs, files in os.walk(os.path.join(self.work_dir)): | |
relpath = os.path.relpath(path, self.work_dir) | |
# filter out the files that are read only | |
filenames = [os.path.join(relpath, filename) for filename in files] | |
for not_ignore in can_modify_files: | |
ignore_filenames = [n for n in filenames if not fnmatch.fnmatch(n, not_ignore)] | |
self.read_only_files.extend(ignore_filenames) | |
for f in files: | |
size += os.path.getsize(os.path.join(path, f)) | |
# try save this task to a benchmark folder | |
os.makedirs(os.path.join(self.log_dir), exist_ok=True) | |
if size / 1e6 < 10: | |
# save if the size is smaller than 10MB | |
shutil.copytree(self.work_dir, os.path.join(self.log_dir, "env")) | |
os.makedirs(os.path.join(self.log_dir, "scripts"), exist_ok=True) | |
with open(os.path.join(self.log_dir, "scripts", "research_problem.txt"), "w") as f: | |
f.write(self.research_problem) | |
with open(os.path.join(self.log_dir, "scripts", "read_only_files.txt"), "w") as f: | |
f.write("\n".join(self.read_only_files)) | |
# init backup folder and remove all content if it exists | |
if os.path.exists(os.path.join(self.work_dir, "backup")): | |
shutil.rmtree(os.path.join(self.work_dir, "backup")) | |
os.mkdir(os.path.join(self.work_dir, "backup")) | |
# restore data if resuming | |
if self.args.resume: | |
shutil.rmtree(self.work_dir) | |
resume_dir = os.path.join(self.log_dir, "traces" , f"step_{self.args.resume_step}_files") | |
print("Restoring workspace ing from {}".format(resume_dir)) | |
shutil.copytree(resume_dir, self.work_dir, symlinks=True) | |
def _initialize_trace(self): | |
if self.args.resume: | |
print("Restoring trace from {}".format(self.args.resume)) | |
prev_trace = from_dict(data_class=Trace, data=json.load(open(os.path.join(self.args.resume, "env_log","trace.json"), "r"))) | |
print("Resetting trace to step {}".format(self.args.resume_step)) | |
steps = prev_trace.steps[:self.args.resume_step+1] | |
t = steps[-1].timestamp | |
low_level_steps = [s for s in prev_trace.low_level_steps if s.timestamp < t] | |
trace = Trace( | |
steps=steps, | |
low_level_steps=low_level_steps, | |
action_infos=self.action_infos, | |
task_description=self.research_problem, | |
) | |
else: | |
trace = Trace( | |
steps=[], | |
low_level_steps=[], | |
action_infos=self.action_infos, | |
task_description=self.research_problem, | |
) | |
return trace | |
def __enter__(self): | |
# set time out | |
def signal_handler(signum, frame): | |
raise TimeoutException("Timed out!") | |
signal.signal(signal.SIGALRM, signal_handler) | |
signal.alarm(self.args.max_time) | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
# save error message | |
active = active_children() | |
print(f'Active Children: {len(active)}') | |
# terminate all active children | |
for child in active: | |
child.terminate() | |
# block until all children have closed | |
for child in active: | |
child.join() | |
# report active children | |
active = active_children() | |
print(f'Active Children: {len(active)}') | |
if traceback is not None: | |
print("Error message saved in error.txt") | |
open(os.path.join(self.log_dir, "error.txt"), "w").write(''.join(format_exception(exc_type, exc_value, traceback))) | |
open(os.path.join(self.log_dir, "overall_time.txt"), "w").write(str(time.time() - self.start_time)) | |
################################# public functions ######################################## | |
def is_final(self): | |
"""Check if the task has reached a final state, either by reaching the maximum steps or time, or because the agent has submitted a final answer. """ | |
curr_step = len(self.trace.steps) | |
# check if any step is final answer | |
any_final_answer = any([s.action.name == "Final Answer" for s in self.trace.steps]) | |
return curr_step >= self.args.max_steps or any_final_answer or time.time() - self.start_time > self.args.max_time | |
def execute(self, action): | |
"""Execute an action and return the observation.""" | |
trace = self._trace | |
curr_step = len(trace.steps) | |
action_name = action.name | |
action_input = action.args | |
if action_name == "Final Answer": | |
observation = "end" | |
elif self.is_final(): | |
observation = "The environment has shut down because the maximum number of steps or time has been reached. Please submit your final answer." | |
elif action_name not in list(self.action_infos.keys()): | |
actions = ", ".join(self.action_infos.keys()) | |
observation = f"Invalid action: {action_name}. Action did not execute. Please use one of the following actions:\n{actions}" | |
else: | |
# execute the action and get the observation | |
log_file = os.path.join(os.path.join(self.log_dir, "tool_logs") , f"step_{curr_step}_tool_log.log") | |
usage = ",\n ".join([f"{k}: [{v}]" for k, v in self.action_infos[action_name].usage.items()]) | |
usage = f"""{{ | |
{usage} | |
}}""" | |
invalid_action_error = f"The action input for {action_name} needs to be a valid json with proper entries. You may have missed the comma between entries. Please use the correct format and try again:\n{usage}" | |
if isinstance(action_input, dict): | |
try: | |
observation = self.action_infos[action_name].function(**action_input, log_file=log_file, trace=trace, **self.static_kwargs_for_tools) | |
except TooLongPromptError: | |
observation="EnvError: too long input for the tool" | |
except LLMError as e: | |
observation = "LLMError: " + e.message | |
except EnvException as e: | |
observation = "EnvError: " + e.message | |
except TypeError as e: | |
print("Step: ", curr_step, file=sys.stderr) | |
print(e, file=sys.stderr) | |
print(action_input, file=sys.stderr) | |
observation = "EnvError: " + invalid_action_error | |
except TimeoutException as e: | |
raise e | |
except Exception as e: | |
# should not happen | |
print("Step: ", curr_step, file=sys.stderr) | |
print(e, file=sys.stderr) | |
if "Connection aborted." in str(e): | |
raise Exception("Connection aborted for crfm") | |
observation = f"EnvError: Error executing {action_name}." | |
else: | |
observation = invalid_action_error | |
step_time = time.time() | |
trace.steps.append(Step(action, observation, step_time)) | |
self.save(curr_step) | |
return observation | |
def save(self, curr_step): | |
""" Save the trace and snapshot of the workspace folder """ | |
with open(os.path.join(self.log_dir, f"trace.json"), "w") as f: | |
json.dump(self.trace, f, indent=4, cls=EnhancedJSONEncoder) | |
##### save a snapshot of the current step | |
save_folder = os.path.join(self.log_dir, "traces", f"step_{curr_step}_files") | |
if os.path.exists(save_folder): | |
shutil.rmtree(save_folder) | |
shutil.copytree(self.work_dir, save_folder, symlinks=True) | |
############## for logging convenience ############## | |
def low_level_actions(self): | |
return list(filter(lambda x: x.is_primitive, self.action_infos.values())) | |
def high_level_actions(self): | |
return list(filter(lambda x: not x.is_primitive, self.action_infos.values())) | |
def print_action(self, entries): | |
return "".join([ k + ": " + v for k,v in entries.items()]) | |