DABstep / baseline /run.py
martinigoyanes's picture
initial commit
883eeae
#!/usr/bin/env python3
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from openinference.instrumentation.smolagents import SmolagentsInstrumentor
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
endpoint = "http://0.0.0.0:6006/v1/traces"
trace_provider = TracerProvider()
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
import argparse
import json
import logging
import os
import threading
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import datasets
import pandas as pd
from data_agents_benchmark.utils import evaluate
from huggingface_hub import hf_hub_download
from smolagents import CodeAgent, LiteLLMModel
from smolagents.utils import console
from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log, retry_if_exception_type, wait_exponential, wait_random
from tqdm import tqdm
import litellm
class TqdmLoggingHandler(logging.Handler):
def emit(self, record):
tqdm.write(self.format(record))
logging.basicConfig(level=logging.WARNING, handlers=[TqdmLoggingHandler()])
logger = logging.getLogger(__name__)
append_answer_lock = threading.Lock()
append_console_output_lock = threading.Lock()
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--concurrency", type=int, default=4)
parser.add_argument("--model-id", type=str, default="huggingface/meta-llama/Meta-Llama-3.1-70B-Instruct")
parser.add_argument("--max-tasks", type=int, default=-1)
parser.add_argument("--api-base", type=str, default=None)
parser.add_argument("--api-key", type=str, default=None)
parser.add_argument("--split", type=str, default="default", choices=["default", "dev"])
parser.add_argument("--timestamp", type=str, default=None)
return parser.parse_args()
def download_context(base_dir: str) -> str:
ctx_files = [
"data/context/acquirer_countries.csv",
"data/context/payments.csv",
"data/context/merchant_category_codes.csv",
"data/context/fees.json",
"data/context/merchant_data.json",
"data/context/manual.md",
"data/context/payments-readme.md"
]
repo_id = "adyen/data-agents-benchmark"
for f in ctx_files:
hf_hub_download(repo_id, repo_type="dataset", filename=f, local_dir=base_dir, force_download=True)
return os.path.join(base_dir, Path(ctx_files[0]).parent)
def get_tasks_to_run(data, total: int, base_filename: Path):
import json
f = base_filename.parent / f"{base_filename.stem}_answers.jsonl"
done = set()
if f.exists():
with open(f, encoding="utf-8") as fh:
done = {json.loads(line)["task_id"] for line in fh if line.strip()}
return [data[i] for i in range(total) if str(data[i]["task_id"]) not in done]
def append_answer(entry: dict, jsonl_file: Path) -> None:
jsonl_file.parent.mkdir(parents=True, exist_ok=True)
with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp:
fp.write(json.dumps(entry) + "\n")
def append_console_output(captured_text: str, txt_file: Path) -> None:
txt_file.parent.mkdir(parents=True, exist_ok=True)
with append_console_output_lock, open(txt_file, "a", encoding="utf-8") as fp:
fp.write(captured_text + "\n")
class LiteLLMModelWithBackOff(LiteLLMModel):
@retry(
stop=stop_after_attempt(450),
wait=wait_exponential(min=1, max=120, exp_base=2, multiplier=1) + wait_random(0, 5),
before_sleep=before_sleep_log(logger, logging.WARNING),
retry=retry_if_exception_type((
litellm.Timeout,
litellm.RateLimitError,
litellm.APIConnectionError,
litellm.InternalServerError
))
)
def __call__(self, *args, **kwargs):
return super().__call__(*args, **kwargs)
def create_code_agent(model_id: str, api_base=None, api_key=None, max_steps=10):
agent = CodeAgent(
tools=[],
model=LiteLLMModelWithBackOff(model_id=model_id, api_base=api_base, api_key=api_key),
additional_authorized_imports=["numpy", "pandas", "json", "csv", "glob", "markdown", "os"],
max_steps=max_steps,
)
def read_only_open(*a, **kw):
if (len(a) > 1 and isinstance(a[1], str) and a[1] != 'r') or kw.get('mode', 'r') != 'r':
raise Exception("Only mode='r' allowed for the function open")
return open(*a, **kw)
agent.python_executor.static_tools.update({"open": read_only_open})
return agent
def run_single_task(
task: dict,
model_id: str,
api_base: str,
api_key: str,
ctx_path: str,
base_filename: Path, # base file path WITHOUT suffix changes
is_dev_data: bool
):
prompt = f"""You are an expert data analyst and you will answer factoid questions by referencing files in the data directory: `{ctx_path}`
Don't forget to reference any documentation in the data dir before answering a question.
Here is the question you need to answer: {task['question']}
Here are the guidelines you MUST follow when answering the question above: {task['guidelines']}
Before answering the question, reference any documentation in the data dir and leverage its information in your reasoning / planning.
"""
agent = create_code_agent(model_id, api_base, api_key)
with console.capture() as capture:
answer = agent.run(prompt)
logger.warning(f"Task id: {task['task_id']}\tQuestion: {task['question']} Answer: {answer}\n{'=' * 50}")
answer_dict = {"task_id": str(task["task_id"]), "agent_answer": str(answer)}
answers_file = base_filename / "answers.jsonl"
logs_file = base_filename / "logs.txt"
if is_dev_data:
scores = evaluate(agent_answers=pd.DataFrame([answer_dict]), tasks_with_gt=pd.DataFrame([task]))
entry = {**answer_dict, "answer": task["answer"], "score": scores[0]["score"], "level": scores[0]["level"]}
append_answer(entry, answers_file)
else:
append_answer(answer_dict, answers_file)
append_console_output(capture.get(), logs_file)
def main():
args = parse_args()
logger.warning(f"Starting run with arguments: {args}")
ctx_path = download_context(str(Path().resolve()))
# We'll create a base filename with no special suffix, e.g. claude_123456789
runs_dir = Path().resolve() / "runs"
runs_dir.mkdir(parents=True, exist_ok=True)
timestamp = time.time() if not args.timestamp else args.timestamp
base_filename = runs_dir / f"{args.model_id.replace('/', '_').replace('.', '_')}/{args.split}/{int(timestamp)}"
# Load dataset with user-chosen split
data = datasets.load_dataset("adyen/data-agents-benchmark", name="tasks", split=args.split, download_mode='force_redownload')
total = len(data) if args.max_tasks < 0 else min(len(data), args.max_tasks)
tasks_to_run = get_tasks_to_run(data, total, base_filename)
with ThreadPoolExecutor(max_workers=args.concurrency) as exe:
futures = [
exe.submit(run_single_task, task, args.model_id, args.api_base, args.api_key, ctx_path, base_filename, (args.split == "dev"))
for task in tasks_to_run
]
for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"):
f.result()
logger.warning("All tasks processed.")
if __name__ == "__main__":
main()