Tristan Thrush commited on
Commit
9bb22fc
·
1 Parent(s): 7f35e51

start of select any metric feature

Browse files
Files changed (2) hide show
  1. app.py +56 -3
  2. requirements.txt +1 -0
app.py CHANGED
@@ -4,9 +4,11 @@ from pathlib import Path
4
 
5
  import pandas as pd
6
  import streamlit as st
7
- from datasets import get_dataset_config_names
8
  from dotenv import load_dotenv
9
  from huggingface_hub import list_datasets
 
 
10
 
11
  from utils import (get_compatible_models, get_key, get_metadata, http_get,
12
  http_post)
@@ -30,8 +32,50 @@ TASK_TO_ID = {
30
  "summarization": 8,
31
  }
32
 
 
 
 
 
 
 
 
 
 
33
  supported_tasks = list(TASK_TO_ID.keys())
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  ###########
37
  ### APP ###
@@ -242,7 +286,16 @@ with st.expander("Advanced configuration"):
242
  with st.form(key="form"):
243
 
244
  compatible_models = get_compatible_models(selected_task, selected_dataset)
245
-
 
 
 
 
 
 
 
 
 
246
  selected_models = st.multiselect("Select the models you wish to evaluate", compatible_models)
247
  print("Selected models:", selected_models)
248
  submit_button = st.form_submit_button("Make submission")
@@ -264,7 +317,7 @@ with st.form(key="form"):
264
  "disk_size_gb": 150,
265
  },
266
  "evaluation": {
267
- "metrics": [],
268
  "models": selected_models,
269
  },
270
  },
 
4
 
5
  import pandas as pd
6
  import streamlit as st
7
+ from datasets import get_dataset_config_names, list_metrics, load_metric
8
  from dotenv import load_dotenv
9
  from huggingface_hub import list_datasets
10
+ from tqdm import tqdm
11
+ import inspect
12
 
13
  from utils import (get_compatible_models, get_key, get_metadata, http_get,
14
  http_post)
 
32
  "summarization": 8,
33
  }
34
 
35
+ TASK_TO_DEFAULT_METRICS = {
36
+ "binary_classification": ["f1", "precision", "recall", "auc", "accuracy"],
37
+ "multi_class_classification": ["f1_micro", "f1_macro", "f1_weighted", "precision_macro", "precision_micro", "precision_weighted", "recall_macro", "recall_micro", "recall_weighted", "accuracy"],
38
+ "entity_extraction": ["precision", "recall", "f1", "accuracy"],
39
+ "extractive_question_answering": [],
40
+ "translation": ["sacrebleu", "gen_len"],
41
+ "summarization": ["rouge1", "rouge2", "rougeL", "rougeLsum", "gen_len"],
42
+ }
43
+
44
  supported_tasks = list(TASK_TO_ID.keys())
45
 
46
+ @st.cache
47
+ def get_supported_metrics():
48
+ metrics = list_metrics()
49
+ supported_metrics = {}
50
+ for metric in tqdm(metrics):
51
+ try:
52
+ metric_func = load_metric(metric)
53
+ except Exception as e:
54
+ print(e)
55
+ print("Skipping the following metric, which cannot load:", metric)
56
+
57
+ argspec = inspect.getfullargspec(metric_func.compute)
58
+ if (
59
+ "references" in argspec.kwonlyargs
60
+ and "predictions" in argspec.kwonlyargs
61
+ ):
62
+ # We require that "references" and "predictions" are arguments
63
+ # to the metric function. We also require that the other arguments
64
+ # besides "references" and "predictions" have defaults and so do not
65
+ # need to be specified explicitly.
66
+ defaults = True
67
+ for key, value in argspec.kwonlydefaults.items():
68
+ if key not in ("references", "predictions"):
69
+ if value is None:
70
+ defaults = False
71
+ break
72
+
73
+ if defaults:
74
+ supported_metrics[metric] = argspec.kwonlydefaults
75
+ return supported_metrics
76
+
77
+ supported_metrics = get_supported_metrics()
78
+
79
 
80
  ###########
81
  ### APP ###
 
286
  with st.form(key="form"):
287
 
288
  compatible_models = get_compatible_models(selected_task, selected_dataset)
289
+ st.markdown("The following metrics will be computed")
290
+ html_string = " ".join(["<div style=\"padding-right:5px;padding-left:5px;padding-top:5px;padding-bottom:5px;float:left\"><div style=\"background-color:#D3D3D3;border-radius:5px;display:inline-block;padding-right:5px;padding-left:5px;color:white\">" + metric + "</div></div>" for metric in TASK_TO_DEFAULT_METRICS[selected_task]])
291
+ st.markdown(html_string, unsafe_allow_html=True)
292
+ selected_metrics = st.multiselect(
293
+ "(Optional) Select additional metrics",
294
+ list(set(supported_metrics.keys()) - set(TASK_TO_DEFAULT_METRICS[selected_task])),
295
+ )
296
+ for metric_name in selected_metrics:
297
+ argument_string = ", ".join(["-".join(key, value) for key, value in supported_metrics[metric].items()])
298
+ st.info(f"Note! The arguments for {metric_name} are: {argument_string}")
299
  selected_models = st.multiselect("Select the models you wish to evaluate", compatible_models)
300
  print("Selected models:", selected_models)
301
  submit_button = st.form_submit_button("Make submission")
 
317
  "disk_size_gb": 150,
318
  },
319
  "evaluation": {
320
+ "metrics": selected_metrics,
321
  "models": selected_models,
322
  },
323
  },
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  huggingface-hub==0.4.0
2
  python-dotenv
3
  streamlit==1.2.0
 
4
  py7zr
 
1
  huggingface-hub==0.4.0
2
  python-dotenv
3
  streamlit==1.2.0
4
+ datasets
5
  py7zr