File size: 11,843 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
"""
This file contains the Environment class, which prepares the environment for the research agent to run in.
"""

import json
import os
import sys
import subprocess
import shutil
import copy
import time
import fnmatch
import signal
from traceback import format_exception
from multiprocessing import active_children
from dacite import from_dict

from .low_level_actions import LOW_LEVEL_ACTIONS
from .high_level_actions import HIGH_LEVEL_ACTIONS
from .p2m_actions import P2M_ACTIONS
from .schema import Step, Trace, EnvException, TooLongPromptError, LLMError, EnhancedJSONEncoder 
from .prepare_task import prepare_task, get_task_info

class TimeoutException(Exception): pass

class Environment:
    def __init__(self, args):

        self._args = args
        self._log_dir = os.path.join(args.log_dir, "env_log")
        self._setup_log_dir()

        self._research_problem = args.research_problem
        self._work_dir = args.work_dir
        self._read_only_files = []

        self._initialize_env() # set up work dir and log dir

        self._action_infos =  {t.name: t for t in LOW_LEVEL_ACTIONS + HIGH_LEVEL_ACTIONS + P2M_ACTIONS}

        self._static_kwargs_for_tools = {
            "device": args.device,
            "python": args.python,
            "work_dir": self.work_dir,
            "args": args,
            "read_only_files": self.read_only_files,
            "research_problem": self.research_problem,
        }
        self._trace = self._initialize_trace()
        self._start_time = time.time()

    ############################## getters ########################################
    @property
    def user(self):
        return self._user

    @property
    def research_problem(self):
        return self._research_problem

    @property
    def log_dir(self):
        return self._log_dir

    @property
    def work_dir(self):
        return self._work_dir
    
    @property
    def read_only_files(self):
        return self._read_only_files

    @property
    def action_infos(self):
        return self._action_infos
    
    @property
    def args(self):
        return self._args

    @property
    def static_kwargs_for_tools(self):
        return self._static_kwargs_for_tools
    
    @property
    def trace(self):
        return copy.deepcopy(self._trace)

    @property
    def start_time(self):
        return self._start_time
    
    ############################## internal functions ########################################
    
    def _setup_log_dir(self):
        # set up log dir
        if os.path.exists(self.args.log_dir):
            print("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(self.log_dir)

        if os.path.exists(os.path.join(self.log_dir, "tool_logs")):
            print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "tool_logs")))
            # raise ValueError("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(os.path.join(self.log_dir, "tool_logs"))

        if os.path.exists(os.path.join(self.log_dir, "traces")):
            print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "traces")))
            # raise ValueError("log_dir {} already exists".format(self.log_dir))
        else:
            os.makedirs(os.path.join(self.log_dir, "traces"))

    def _initialize_env(self):
        os.makedirs(os.path.join(self.work_dir), exist_ok=True)
        # set up read only files
        can_modify_files = '*'
        size = 0
        self._read_only_files = []
        for path, subdirs, files in os.walk(os.path.join(self.work_dir)):
            relpath = os.path.relpath(path, self.work_dir)
            # filter out the files that are read only
            filenames = [os.path.join(relpath, filename) for filename in files]
            for not_ignore in can_modify_files:
                ignore_filenames = [n for n in filenames if not fnmatch.fnmatch(n, not_ignore)]
                self.read_only_files.extend(ignore_filenames)
            for f in files:
                size += os.path.getsize(os.path.join(path, f))
                
        # try save this task to a benchmark folder
        os.makedirs(os.path.join(self.log_dir), exist_ok=True)
        if size / 1e6 < 10:
            # save if the size is smaller than 10MB
            shutil.copytree(self.work_dir, os.path.join(self.log_dir, "env"))
        os.makedirs(os.path.join(self.log_dir, "scripts"), exist_ok=True)
        with open(os.path.join(self.log_dir, "scripts", "research_problem.txt"), "w") as f:
            f.write(self.research_problem)
        with open(os.path.join(self.log_dir, "scripts", "read_only_files.txt"), "w") as f:
            f.write("\n".join(self.read_only_files))

        # init backup folder and remove all content if it exists
        if os.path.exists(os.path.join(self.work_dir, "backup")):
            shutil.rmtree(os.path.join(self.work_dir, "backup"))
        os.mkdir(os.path.join(self.work_dir, "backup"))

        # restore data if resuming
        if self.args.resume:
            shutil.rmtree(self.work_dir)
            resume_dir = os.path.join(self.log_dir, "traces" , f"step_{self.args.resume_step}_files")
            print("Restoring workspace ing from {}".format(resume_dir))
            shutil.copytree(resume_dir, self.work_dir, symlinks=True)


    def _initialize_trace(self):
        if self.args.resume:
            print("Restoring trace from {}".format(self.args.resume))
            prev_trace = from_dict(data_class=Trace, data=json.load(open(os.path.join(self.args.resume, "env_log","trace.json"), "r")))
            print("Resetting trace to step {}".format(self.args.resume_step))
            steps = prev_trace.steps[:self.args.resume_step+1]
            t = steps[-1].timestamp
            low_level_steps = [s for s in prev_trace.low_level_steps if s.timestamp < t]
            trace = Trace(
                steps=steps,
                low_level_steps=low_level_steps,
                action_infos=self.action_infos,
                task_description=self.research_problem,
            )
        else:   
            trace = Trace(
            steps=[],
            low_level_steps=[],
            action_infos=self.action_infos,
            task_description=self.research_problem,
        )
        return trace
    
    def __enter__(self):
        # set time out
        def signal_handler(signum, frame):
            raise TimeoutException("Timed out!")
        signal.signal(signal.SIGALRM, signal_handler)
        signal.alarm(self.args.max_time)
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):  
        # save error message
        active = active_children()
        print(f'Active Children: {len(active)}')
        # terminate all active children
        for child in active:
            child.terminate()
        # block until all children have closed
        for child in active:
            child.join()
        # report active children
        active = active_children()
        print(f'Active Children: {len(active)}')
            
        if traceback is not None:
            print("Error message saved in error.txt")
            open(os.path.join(self.log_dir, "error.txt"), "w").write(''.join(format_exception(exc_type, exc_value, traceback)))
        open(os.path.join(self.log_dir, "overall_time.txt"), "w").write(str(time.time() - self.start_time))
            
    ################################# public functions ########################################

    def is_final(self):
        """Check if the task has reached a final state, either by reaching the maximum steps or time, or because the agent has submitted a final answer. """
        
        curr_step = len(self.trace.steps)
        # check if any step is final answer
        any_final_answer = any([s.action.name == "Final Answer" for s in self.trace.steps])
        return curr_step >= self.args.max_steps or any_final_answer or time.time() - self.start_time > self.args.max_time

    def execute(self, action):
        """Execute an action and return the observation."""
        
        trace = self._trace

        curr_step = len(trace.steps)
        action_name = action.name
        action_input = action.args

        if action_name == "Final Answer":
            observation = "end"

        elif self.is_final():
            observation = "The environment has shut down because the maximum number of steps or time has been reached. Please submit your final answer."

        elif action_name not in list(self.action_infos.keys()):
            actions = ", ".join(self.action_infos.keys())
            observation = f"Invalid action: {action_name}. Action did not execute. Please use one of the following actions:\n{actions}"

        else:
            # execute the action and get the observation
            log_file = os.path.join(os.path.join(self.log_dir, "tool_logs") , f"step_{curr_step}_tool_log.log")
            usage = ",\n            ".join([f"{k}: [{v}]" for k, v in self.action_infos[action_name].usage.items()])
            usage = f"""{{
            {usage}
}}"""
            invalid_action_error = f"The action input for {action_name} needs to be a valid json with proper entries. You may have missed the comma between entries. Please use the correct format and try again:\n{usage}"

            if isinstance(action_input, dict):
                try:
                    observation = self.action_infos[action_name].function(**action_input, log_file=log_file, trace=trace, **self.static_kwargs_for_tools)
                except TooLongPromptError:
                    observation="EnvError: too long input for the tool"
                except LLMError as e:
                    observation = "LLMError: " + e.message
                except EnvException as e:
                    observation = "EnvError: " + e.message
                except TypeError as e:
                    print("Step: ", curr_step, file=sys.stderr)
                    print(e, file=sys.stderr)
                    print(action_input, file=sys.stderr)
                    observation = "EnvError: " + invalid_action_error
                except TimeoutException as e:
                    raise e
                except Exception as e:
                    # should not happen
                    print("Step: ", curr_step, file=sys.stderr)
                    print(e, file=sys.stderr)
                    if "Connection aborted." in str(e):
                        raise Exception("Connection aborted for crfm")
                    observation = f"EnvError: Error executing {action_name}."
            else:
                observation = invalid_action_error


        step_time = time.time()

        trace.steps.append(Step(action, observation, step_time))

        self.save(curr_step)
        return observation

    def save(self, curr_step):
        """ Save the trace and snapshot of the workspace folder """     
        with open(os.path.join(self.log_dir, f"trace.json"), "w") as f:
            json.dump(self.trace, f, indent=4, cls=EnhancedJSONEncoder)

        ##### save a snapshot of the current step
        save_folder = os.path.join(self.log_dir, "traces", f"step_{curr_step}_files")
        if os.path.exists(save_folder):
            shutil.rmtree(save_folder)

        shutil.copytree(self.work_dir, save_folder, symlinks=True)

    ############## for logging convenience ##############

    @property
    def low_level_actions(self):
        return list(filter(lambda x: x.is_primitive, self.action_infos.values()))

    @property
    def high_level_actions(self):
        return list(filter(lambda x: not x.is_primitive, self.action_infos.values()))

    def print_action(self, entries):
        return "".join([ k + ": " + v for k,v in  entries.items()])