Spaces:
Running
Running
File size: 6,767 Bytes
85e3d20 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import os
import anthropic
from pathlib import Path
import re
import sys
import json
import dataclasses
from dataclasses import dataclass
from typing import List, Dict
from importlib import util
import argparse
import importlib
import matplotlib.pyplot as plt
# from .LLM import complete_text_gpt4, complete_text_claude
from .environment import get_task_info
class EnhancedJSONEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
#if it is a function, use its string name
elif hasattr(o, '__call__'):
return o.__name__
return super().default(o)
def oom_error(path):
log = path.replace("trace.json", "../log")
main_log = path.replace("trace.json", "../agent_log/main_log")
message = "CUDA out of memory"
return (message in open(log, "r").read()) or (message in open(main_log, "r").read())
def connection_error(path):
log = path.replace("trace.json", "../log")
main_log = path.replace("trace.json", "../agent_log/main_log")
bad = ["You exceeded your current quota, please check your plan and billing details.", "Error: 'text-similarity-ada-001'", "Error: 'text-embedding-ada-001'"]
return ("Connection aborted" in open(log, "r").read()) or (any([b in open(main_log, "r").read() for b in bad]))
def error(path):
return os.path.exists(os.path.join(path.replace("trace.json", ""), "error.txt")) or not os.path.exists(os.path.join(path.replace("trace.json", ""), "overall_time.txt"))
def json_error(path):
main_log = path.replace("trace.json", "../agent_log/main_log")
return open(main_log, "r").read().count("JSONDecodeError") > 2
def long_prompt_error(path):
main_log = path.replace("trace.json", "../agent_log/main_log")
return "EnvError: too long input for the tool" in open(main_log, "r").read()
@dataclass
class EvaluationResult:
path: str
summary: str
rubric_questions: Dict[str, str]
score: List[float]
score_steps: List[float]
submitted_final_answer: bool
final_score: float
total_time: float
error: str
extra: Dict[str, bool]
def run_eval(log_folder, benchmark_folder_name, eval_intermediate=False):
results = {}
for subdir, dirs, files in os.walk(log_folder):
for file in files:
if file == 'trace.json':
result = EvaluationResult(
path=os.path.join(subdir, file),
summary="",
rubric_questions={},
score=[],
score_steps=[],
final_score = -1,
submitted_final_answer = False,
total_time = 0,
error = "",
extra = {}
)
try:
with open(os.path.join(subdir, file)) as f:
data = json.load(f)
except:
continue
num_steps = len(data['steps'])
for step in range(len(data['steps'])):
if data['steps'][step]["action"]["name"] == "Final Answer":
result.submitted_final_answer = True
num_steps_eval = 50
step_list = range(num_steps)
if num_steps_eval >= len(step_list):
subsampled_list = step_list
else:
step = num_steps // num_steps_eval
subsampled_list = step_list[::step][:num_steps_eval]
if eval_intermediate:
for step in subsampled_list:
eval_step_score = 0
try:
folder_path = os.path.join(subdir, f'traces/step_{step}_files')
if os.path.exists(folder_path):
print(folder_path)
module = importlib.import_module(f'MLAgentBench.benchmarks.{benchmark_folder_name}.scripts.eval')
eval_step_score = module.get_score(folder_path)
result.score.append(eval_step_score)
except Exception as e:
print(e)
result.score.append(eval_step_score)
result.score_steps = list(subsampled_list)
folder_path = os.path.join(subdir, 'traces/step_final_files')
try:
if os.path.exists(folder_path):
module = importlib.import_module(f'MLAgentBench.benchmarks.{benchmark_folder_name}.scripts.eval')
eval_final_score = module.get_score(folder_path)
result.score.append(eval_final_score)
result.final_score = eval_final_score
print(eval_final_score)
except Exception as e:
print(e)
pass
if os.path.exists(os.path.join(subdir, "error.txt")):
result.error = open(os.path.join(subdir, "error.txt")).read()
if os.path.exists(os.path.join(subdir, "overall_time.txt")):
result.total_time = float(open(os.path.join(subdir, "overall_time.txt")).read())
print(result.total_time)
result.extra = {
"oom_error": oom_error(os.path.join(subdir, file)),
"connection_error": connection_error(os.path.join(subdir, file)),
"error": error(os.path.join(subdir, file)),
"json_error": json_error(os.path.join(subdir, file)),
"long_prompt_error": long_prompt_error(os.path.join(subdir, file)),
}
results[os.path.join(subdir, file)] = result
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--log-folder", type=str, default="logs")
parser.add_argument("--task", type=str, default="cifar10_training")
parser.add_argument("--output-file", type=str, default="results.json")
parser.add_argument("--eval-intermediate", action="store_true")
args = parser.parse_args()
benchmark_folder_name = get_task_info(args.task)[0]
results = run_eval(args.log_folder, benchmark_folder_name, eval_intermediate = args.eval_intermediate)
json.dump(results, open(args.output_file, "w"), indent=4, cls=EnhancedJSONEncoder)
|