|
|
|
|
|
from opentelemetry import trace |
|
from opentelemetry.sdk.trace import TracerProvider |
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor |
|
|
|
from openinference.instrumentation.smolagents import SmolagentsInstrumentor |
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter |
|
from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor |
|
|
|
endpoint = "http://0.0.0.0:6006/v1/traces" |
|
trace_provider = TracerProvider() |
|
trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint))) |
|
|
|
SmolagentsInstrumentor().instrument(tracer_provider=trace_provider) |
|
|
|
import argparse |
|
import json |
|
import logging |
|
import os |
|
import threading |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from pathlib import Path |
|
|
|
import datasets |
|
import pandas as pd |
|
from data_agents_benchmark.utils import evaluate |
|
from huggingface_hub import hf_hub_download |
|
from smolagents import CodeAgent, LiteLLMModel |
|
from smolagents.utils import console |
|
from tenacity import retry, stop_after_attempt, wait_fixed, before_sleep_log, retry_if_exception_type, wait_exponential, wait_random |
|
from tqdm import tqdm |
|
import litellm |
|
|
|
|
|
class TqdmLoggingHandler(logging.Handler): |
|
def emit(self, record): |
|
tqdm.write(self.format(record)) |
|
|
|
logging.basicConfig(level=logging.WARNING, handlers=[TqdmLoggingHandler()]) |
|
logger = logging.getLogger(__name__) |
|
|
|
append_answer_lock = threading.Lock() |
|
append_console_output_lock = threading.Lock() |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--concurrency", type=int, default=4) |
|
parser.add_argument("--model-id", type=str, default="huggingface/meta-llama/Meta-Llama-3.1-70B-Instruct") |
|
parser.add_argument("--max-tasks", type=int, default=-1) |
|
parser.add_argument("--api-base", type=str, default=None) |
|
parser.add_argument("--api-key", type=str, default=None) |
|
parser.add_argument("--split", type=str, default="default", choices=["default", "dev"]) |
|
parser.add_argument("--timestamp", type=str, default=None) |
|
return parser.parse_args() |
|
|
|
|
|
def download_context(base_dir: str) -> str: |
|
ctx_files = [ |
|
"data/context/acquirer_countries.csv", |
|
"data/context/payments.csv", |
|
"data/context/merchant_category_codes.csv", |
|
"data/context/fees.json", |
|
"data/context/merchant_data.json", |
|
"data/context/manual.md", |
|
"data/context/payments-readme.md" |
|
] |
|
repo_id = "adyen/data-agents-benchmark" |
|
for f in ctx_files: |
|
hf_hub_download(repo_id, repo_type="dataset", filename=f, local_dir=base_dir, force_download=True) |
|
return os.path.join(base_dir, Path(ctx_files[0]).parent) |
|
|
|
|
|
def get_tasks_to_run(data, total: int, base_filename: Path): |
|
import json |
|
f = base_filename.parent / f"{base_filename.stem}_answers.jsonl" |
|
done = set() |
|
if f.exists(): |
|
with open(f, encoding="utf-8") as fh: |
|
done = {json.loads(line)["task_id"] for line in fh if line.strip()} |
|
return [data[i] for i in range(total) if str(data[i]["task_id"]) not in done] |
|
|
|
def append_answer(entry: dict, jsonl_file: Path) -> None: |
|
jsonl_file.parent.mkdir(parents=True, exist_ok=True) |
|
with append_answer_lock, open(jsonl_file, "a", encoding="utf-8") as fp: |
|
fp.write(json.dumps(entry) + "\n") |
|
|
|
|
|
def append_console_output(captured_text: str, txt_file: Path) -> None: |
|
txt_file.parent.mkdir(parents=True, exist_ok=True) |
|
with append_console_output_lock, open(txt_file, "a", encoding="utf-8") as fp: |
|
fp.write(captured_text + "\n") |
|
|
|
|
|
class LiteLLMModelWithBackOff(LiteLLMModel): |
|
@retry( |
|
stop=stop_after_attempt(450), |
|
wait=wait_exponential(min=1, max=120, exp_base=2, multiplier=1) + wait_random(0, 5), |
|
before_sleep=before_sleep_log(logger, logging.WARNING), |
|
retry=retry_if_exception_type(( |
|
litellm.Timeout, |
|
litellm.RateLimitError, |
|
litellm.APIConnectionError, |
|
litellm.InternalServerError |
|
)) |
|
) |
|
def __call__(self, *args, **kwargs): |
|
return super().__call__(*args, **kwargs) |
|
|
|
def create_code_agent(model_id: str, api_base=None, api_key=None, max_steps=10): |
|
agent = CodeAgent( |
|
tools=[], |
|
model=LiteLLMModelWithBackOff(model_id=model_id, api_base=api_base, api_key=api_key), |
|
additional_authorized_imports=["numpy", "pandas", "json", "csv", "glob", "markdown", "os"], |
|
max_steps=max_steps, |
|
) |
|
def read_only_open(*a, **kw): |
|
if (len(a) > 1 and isinstance(a[1], str) and a[1] != 'r') or kw.get('mode', 'r') != 'r': |
|
raise Exception("Only mode='r' allowed for the function open") |
|
return open(*a, **kw) |
|
|
|
agent.python_executor.static_tools.update({"open": read_only_open}) |
|
return agent |
|
|
|
def run_single_task( |
|
task: dict, |
|
model_id: str, |
|
api_base: str, |
|
api_key: str, |
|
ctx_path: str, |
|
base_filename: Path, |
|
is_dev_data: bool |
|
): |
|
prompt = f"""You are an expert data analyst and you will answer factoid questions by referencing files in the data directory: `{ctx_path}` |
|
Don't forget to reference any documentation in the data dir before answering a question. |
|
|
|
Here is the question you need to answer: {task['question']} |
|
|
|
Here are the guidelines you MUST follow when answering the question above: {task['guidelines']} |
|
|
|
Before answering the question, reference any documentation in the data dir and leverage its information in your reasoning / planning. |
|
""" |
|
|
|
agent = create_code_agent(model_id, api_base, api_key) |
|
with console.capture() as capture: |
|
answer = agent.run(prompt) |
|
|
|
|
|
logger.warning(f"Task id: {task['task_id']}\tQuestion: {task['question']} Answer: {answer}\n{'=' * 50}") |
|
|
|
answer_dict = {"task_id": str(task["task_id"]), "agent_answer": str(answer)} |
|
answers_file = base_filename / "answers.jsonl" |
|
logs_file = base_filename / "logs.txt" |
|
|
|
if is_dev_data: |
|
scores = evaluate(agent_answers=pd.DataFrame([answer_dict]), tasks_with_gt=pd.DataFrame([task])) |
|
entry = {**answer_dict, "answer": task["answer"], "score": scores[0]["score"], "level": scores[0]["level"]} |
|
append_answer(entry, answers_file) |
|
else: |
|
append_answer(answer_dict, answers_file) |
|
append_console_output(capture.get(), logs_file) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
logger.warning(f"Starting run with arguments: {args}") |
|
|
|
ctx_path = download_context(str(Path().resolve())) |
|
|
|
|
|
runs_dir = Path().resolve() / "runs" |
|
runs_dir.mkdir(parents=True, exist_ok=True) |
|
timestamp = time.time() if not args.timestamp else args.timestamp |
|
base_filename = runs_dir / f"{args.model_id.replace('/', '_').replace('.', '_')}/{args.split}/{int(timestamp)}" |
|
|
|
|
|
data = datasets.load_dataset("adyen/data-agents-benchmark", name="tasks", split=args.split, download_mode='force_redownload') |
|
total = len(data) if args.max_tasks < 0 else min(len(data), args.max_tasks) |
|
|
|
tasks_to_run = get_tasks_to_run(data, total, base_filename) |
|
with ThreadPoolExecutor(max_workers=args.concurrency) as exe: |
|
futures = [ |
|
exe.submit(run_single_task, task, args.model_id, args.api_base, args.api_key, ctx_path, base_filename, (args.split == "dev")) |
|
for task in tasks_to_run |
|
] |
|
for f in tqdm(as_completed(futures), total=len(tasks_to_run), desc="Processing tasks"): |
|
f.result() |
|
|
|
logger.warning("All tasks processed.") |
|
|
|
if __name__ == "__main__": |
|
main() |