|
""" |
|
Usage: python mteb_meta.py path_to_results_folder |
|
|
|
Creates evaluation results metadata for the model card. |
|
E.g. |
|
--- |
|
tags: |
|
- mteb |
|
model-index: |
|
- name: SGPT-5.8B-weightedmean-msmarco-specb-bitfit |
|
results: |
|
- task: |
|
type: classification |
|
dataset: |
|
type: mteb/banking77 |
|
name: MTEB Banking77 |
|
config: default |
|
split: test |
|
revision: 44fa15921b4c889113cc5df03dd4901b49161ab7 |
|
metrics: |
|
- type: accuracy |
|
value: 84.49350649350649 |
|
--- |
|
""" |
|
|
|
import json |
|
import logging |
|
import os |
|
import sys |
|
|
|
from mteb import MTEB |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
results_folder = sys.argv[1].strip("/") |
|
model_name = results_folder.split("/")[-1] |
|
|
|
all_results = {} |
|
|
|
for file_name in os.listdir(results_folder): |
|
if not file_name.endswith(".json"): |
|
logger.info(f"Skipping non-json {file_name}") |
|
continue |
|
with open(os.path.join(results_folder, file_name), "r", encoding="utf-8") as f: |
|
results = json.load(f) |
|
all_results = {**all_results, **{file_name.replace(".json", ""): results}} |
|
|
|
MARKER = "---" |
|
TAGS = "tags:" |
|
MTEB_TAG = "- mteb" |
|
HEADER = "model-index:" |
|
MODEL = f"- name: {model_name}" |
|
RES = " results:" |
|
|
|
META_STRING = "\n".join([MARKER, TAGS, MTEB_TAG, HEADER, MODEL, RES]) |
|
|
|
|
|
ONE_TASK = " - task:\n type: {}\n dataset:\n type: {}\n name: {}\n config: {}\n split: {}\n revision: {}\n metrics:" |
|
ONE_METRIC = " - type: {}\n value: {}" |
|
SKIP_KEYS = ["std", "evaluation_time", "main_score", "threshold"] |
|
|
|
for ds_name, res_dict in sorted(all_results.items()): |
|
mteb_desc = ( |
|
MTEB(tasks=[ds_name.replace("CQADupstackRetrieval", "CQADupstackAndroidRetrieval")]) |
|
.tasks[0] |
|
.description |
|
) |
|
hf_hub_name = mteb_desc.get("hf_hub_name", mteb_desc.get("beir_name")) |
|
if "CQADupstack" in ds_name: |
|
hf_hub_name = "BeIR/cqadupstack" |
|
mteb_type = mteb_desc["type"] |
|
revision = res_dict.get("dataset_revision") |
|
split = "test" |
|
if ds_name == "MSMARCO": |
|
split = "dev" if "dev" in res_dict else "validation" |
|
if split not in res_dict: |
|
logger.info(f"Skipping {ds_name} as split {split} not present.") |
|
continue |
|
res_dict = res_dict.get(split) |
|
for lang in mteb_desc["eval_langs"]: |
|
mteb_name = f"MTEB {ds_name}" |
|
mteb_name += f" ({lang})" if len(mteb_desc["eval_langs"]) > 1 else "" |
|
|
|
test_result_lang = res_dict.get(lang) if len(mteb_desc["eval_langs"]) > 1 else res_dict |
|
|
|
if test_result_lang is None: |
|
continue |
|
META_STRING += "\n" + ONE_TASK.format( |
|
mteb_type, |
|
hf_hub_name, |
|
mteb_name, |
|
lang if len(mteb_desc["eval_langs"]) > 1 else "default", |
|
split, |
|
revision |
|
) |
|
for (metric, score) in test_result_lang.items(): |
|
if not isinstance(score, dict): |
|
score = {metric: score} |
|
for sub_metric, sub_score in score.items(): |
|
if any([x in sub_metric for x in SKIP_KEYS]): |
|
continue |
|
META_STRING += "\n" + ONE_METRIC.format( |
|
f"{metric}_{sub_metric}" if metric != sub_metric else metric, |
|
|
|
|
|
|
|
|
|
|
|
sub_score * 100, |
|
) |
|
|
|
META_STRING += "\n" + MARKER |
|
if os.path.exists("./mteb_metadata.md"): |
|
logger.warning("Overwriting mteb_metadata.md") |
|
with open(f"./mteb_metadata.md", "w") as f: |
|
f.write(META_STRING) |