lukasgarbas commited on
Commit
73d9a01
1 Parent(s): 0b6543b

add gradio app

Browse files
Files changed (2) hide show
  1. app.py +274 -0
  2. utils.py +205 -0
app.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import disable_caching, load_dataset
3
+ from transformer_ranker import TransformerRanker, prepare_popular_models
4
+ import traceback
5
+
6
+ from utils import (
7
+ DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER,
8
+ EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded,
9
+ compute_ratio, ensure_one_lm_selected, get_dataset_info
10
+ )
11
+
12
+ disable_caching()
13
+
14
+ THEME = "pseudolab/huggingface-korea-theme"
15
+ DEFAULT_SAMPLES = 1000
16
+ MAX_SAMPLES = 5000
17
+ LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large')
18
+
19
+ # Add a tiny model for demonstration on CPU
20
+ LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS))
21
+ LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased")
22
+
23
+ # Preselect some small models
24
+ DEFAULT_MODELS = [
25
+ "prajjwal1/bert-tiny", "google/electra-small-discriminator",
26
+ "distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2"
27
+ ]
28
+
29
+
30
+ with gr.Blocks(css=CSS, theme=THEME) as demo:
31
+
32
+ ########## STEP 1: Load the Dataset ##########
33
+
34
+ gr.Markdown(HEADLINE)
35
+
36
+ gr.Markdown("## Step 1: Load a Dataset")
37
+ with gr.Group():
38
+ dataset = gr.State(None)
39
+
40
+ dataset_name = gr.Textbox(
41
+ label="Enter the name of your dataset",
42
+ placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17",
43
+ max_lines=1,
44
+ )
45
+ select_dataset_button = gr.Button(
46
+ value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT
47
+ )
48
+
49
+ # Activate the "Load dataset" button if dataset was found
50
+ dataset_name.change(
51
+ check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button
52
+ )
53
+
54
+ gr.Markdown(
55
+ "*The number of samples that can be used in this demo is limited to save resources. "
56
+ "To run an estimate on the full dataset, check out the "
57
+ "[library](https://github.com/flairNLP/transformer-ranker).*"
58
+ )
59
+
60
+ ########## Step 1.1 Dataset preprocessing ##########
61
+
62
+ with gr.Accordion("Dataset settings", open=False) as dataset_config:
63
+ with gr.Row() as dataset_details:
64
+ dataset_name_label = gr.Label("", label="Dataset Name")
65
+ num_samples = gr.State(0)
66
+ num_samples_label = gr.Label("", label="Number of Samples")
67
+ num_samples.change(
68
+ lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
69
+ )
70
+
71
+ with gr.Row():
72
+ text_column = gr.Dropdown("", label="Text Column")
73
+ text_pair_column = gr.Dropdown("", label="Text Pair Column")
74
+
75
+ with gr.Row():
76
+ label_column = gr.Dropdown("", label="Label Column")
77
+ task_category = gr.Dropdown("", label="Task Type")
78
+
79
+ with gr.Group():
80
+ downsample_ratio = gr.State(0.0)
81
+ num_samples_to_use = gr.Slider(
82
+ 20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1
83
+ )
84
+ downsample_ratio_label = gr.Label("", label="Ratio of dataset to use")
85
+ downsample_ratio.change(
86
+ lambda x: f"{x:.1%}",
87
+ inputs=[downsample_ratio],
88
+ outputs=[downsample_ratio_label],
89
+ )
90
+
91
+ num_samples_to_use.change(
92
+ compute_ratio,
93
+ inputs=[num_samples_to_use, num_samples],
94
+ outputs=downsample_ratio,
95
+ )
96
+ num_samples.change(
97
+ compute_ratio,
98
+ inputs=[num_samples_to_use, num_samples],
99
+ outputs=downsample_ratio,
100
+ )
101
+
102
+ # Download the dataset and show details
103
+ def select_dataset(dataset_name):
104
+ try:
105
+ dataset = load_dataset(dataset_name, trust_remote_code=True)
106
+ dataset_info = get_dataset_info(dataset)
107
+ except ValueError:
108
+ gr.Warning("Dataset collections are not supported. Please use a single dataset.")
109
+
110
+ return (
111
+ gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT),
112
+ gr.Accordion(open=True),
113
+ dataset_name,
114
+ dataset,
115
+ *dataset_info
116
+ )
117
+
118
+ select_dataset_button.click(
119
+ select_dataset,
120
+ inputs=[dataset_name],
121
+ outputs=[
122
+ select_dataset_button,
123
+ dataset_config,
124
+ dataset_name_label,
125
+ dataset,
126
+ task_category,
127
+ text_column,
128
+ text_pair_column,
129
+ label_column,
130
+ num_samples,
131
+ ],
132
+ scroll_to_output=True,
133
+ )
134
+
135
+ ########## STEP 2 ##########
136
+
137
+ gr.Markdown("## Step 2: Select a List of Language Models")
138
+ with gr.Group():
139
+ model_options = [
140
+ (model_handle.split("/")[-1], model_handle)
141
+ for model_handle in LANGUAGE_MODELS
142
+ ]
143
+ models = gr.CheckboxGroup(
144
+ choices=model_options, label="Select Models", value=DEFAULT_MODELS
145
+ )
146
+
147
+ ########## STEP 3: Run Language Model Ranking ##########
148
+
149
+ gr.Markdown("## Step 3: Rank LMs")
150
+
151
+ with gr.Group():
152
+ with gr.Accordion("Advanced settings", open=False):
153
+ with gr.Row():
154
+ estimator = gr.Dropdown(
155
+ choices=["hscore", "logme", "knn"],
156
+ label="Transferability metric",
157
+ value="hscore",
158
+ )
159
+ layer_pooling_options = ["lastlayer", "layermean", "bestlayer"]
160
+ layer_pooling = gr.Dropdown(
161
+ choices=["lastlayer", "layermean", "bestlayer"],
162
+ label="Layer pooling",
163
+ value="layermean",
164
+ )
165
+ submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT)
166
+
167
+ # Make button active if the dataset is loaded
168
+ dataset.change(
169
+ check_dataset_is_loaded,
170
+ inputs=[dataset, text_column, label_column, task_category],
171
+ outputs=submit_button
172
+ )
173
+
174
+ label_column.change(
175
+ check_dataset_is_loaded,
176
+ inputs=[dataset, text_column, label_column, task_category],
177
+ outputs=submit_button
178
+ )
179
+
180
+ text_column.change(
181
+ check_dataset_is_loaded,
182
+ inputs=[dataset, text_column, label_column, task_category],
183
+ outputs=submit_button
184
+ )
185
+
186
+ def rank_models(
187
+ dataset,
188
+ downsample_ratio,
189
+ selected_models,
190
+ layer_pooling,
191
+ estimator,
192
+ text_column,
193
+ text_pair_column,
194
+ label_column,
195
+ task_category,
196
+ progress=gr.Progress(),
197
+ ):
198
+
199
+ if text_column == "-":
200
+ raise gr.Error("Text column is not set.")
201
+
202
+ if label_column == "-":
203
+ raise gr.Error("Label column is not set.")
204
+
205
+ if task_category == "-":
206
+ raise gr.Error(
207
+ "Task category is not set. The dataset must support classification or regression tasks."
208
+ )
209
+
210
+ if text_pair_column == "-":
211
+ text_pair_column = None
212
+
213
+ progress(0.0, "Starting")
214
+
215
+ with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
216
+ try:
217
+ ranker = TransformerRanker(
218
+ dataset,
219
+ dataset_downsample=downsample_ratio,
220
+ text_column=text_column,
221
+ text_pair_column=text_pair_column,
222
+ label_column=label_column,
223
+ task_category=task_category,
224
+ )
225
+
226
+ results = ranker.run(
227
+ models=selected_models,
228
+ layer_aggregator=layer_pooling,
229
+ estimator=estimator,
230
+ batch_size=64,
231
+ tracker=tracker,
232
+ )
233
+
234
+ sorted_results = sorted(
235
+ results._results.items(), key=lambda item: item[1], reverse=True
236
+ )
237
+ return [
238
+ (i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
239
+ ]
240
+ except Exception as e:
241
+ gr.Error("The dataset is not supported.")
242
+
243
+ gr.Markdown("## Results")
244
+ ranking_results = gr.Dataframe(
245
+ headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"]
246
+ )
247
+
248
+ submit_button.click(
249
+ rank_models,
250
+ inputs=[
251
+ dataset,
252
+ downsample_ratio,
253
+ models,
254
+ layer_pooling,
255
+ estimator,
256
+ text_column,
257
+ text_pair_column,
258
+ label_column,
259
+ task_category,
260
+ ],
261
+ outputs=ranking_results,
262
+ scroll_to_output=True,
263
+ )
264
+
265
+ gr.Markdown(
266
+ "*The results are ranked by their transferability score, with the most suitable model listed first. "
267
+ "This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*"
268
+ )
269
+
270
+ gr.Markdown(FOOTER)
271
+
272
+ if __name__ == "__main__":
273
+ demo.queue(default_concurrency_limit=3)
274
+ demo.launch(max_threads=6)
utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import concatenate_datasets
3
+ from huggingface_hub import HfApi
4
+ from huggingface_hub.errors import HFValidationError
5
+ from requests.exceptions import HTTPError
6
+ from transformer_ranker import Result
7
+ from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
8
+ from transformer_ranker.embedder import Embedder
9
+ import math
10
+
11
+ DISABLED_BUTTON_VARIANT = "huggingface"
12
+ ENABLED_BUTTON_VARIANT = "primary"
13
+
14
+ HEADLINE = """
15
+ <h1 align="center">TransformerRanker</h1>
16
+ <p align="center" style="max-width: 560px; margin: auto;">
17
+ A very simple library that helps you find the best-suited language model for your NLP task.
18
+ All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub.
19
+ TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!
20
+ </p>
21
+ <p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
22
+ <a href="https://github.com/flairNLP/transformer-ranker">
23
+ <img src="https://img.shields.io/github/stars/flairNLP/transformer-ranker?style=social&label=Repository" alt="GitHub Badge">
24
+ </a>
25
+ <a href="https://pypi.org/project/transformer-ranker/">
26
+ <img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="Package Badge">
27
+ </a>
28
+ <a href="https://github.com/flairNLP/transformer-ranker/blob/main/examples/01-walkthrough.md">
29
+ <img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="Tutorials Badge">
30
+ </a>
31
+ <img src="https://img.shields.io/badge/license-MIT-green?style=flat" alt="License: MIT">
32
+ </p>
33
+ <p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
34
+ """
35
+
36
+ FOOTER = """
37
+ **Note:** This demonstration currently runs on a CPU and is suited for smaller models only.
38
+ **Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas).
39
+ For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions).
40
+ """
41
+
42
+ CSS = """
43
+ .gradio-container{max-width: 800px !important}
44
+ a {color: #ff9d00;}
45
+ @media (prefers-color-scheme: dark) { a {color: #be185d;} }
46
+ """
47
+
48
+
49
+ hf_api = HfApi()
50
+
51
+
52
+ def check_dataset_exists(dataset_name):
53
+ """Update loading button if dataset can be found"""
54
+ try:
55
+ hf_api.dataset_info(dataset_name)
56
+ return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
57
+
58
+ except (HTTPError, HFValidationError):
59
+ return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT)
60
+
61
+ def check_dataset_is_loaded(dataset, text_column, label_column, task_category):
62
+ if dataset and text_column != "-" and label_column != "-" and task_category != "-":
63
+ return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
64
+ else:
65
+ return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT)
66
+
67
+
68
+ def get_dataset_info(dataset):
69
+ """Show information for dataset settings"""
70
+ joined_dataset = concatenate_datasets(list(dataset.values()))
71
+ datacleaner = DatasetCleaner()
72
+
73
+ try:
74
+ text_column = datacleaner._find_column(joined_dataset, "text column")
75
+ except ValueError:
76
+ gr.Warning("Text column can not be found. Select it in the dataset settings.")
77
+ text_column = "-"
78
+
79
+ try:
80
+ label_column = datacleaner._find_column(joined_dataset, "label column")
81
+ except ValueError:
82
+ gr.Warning("Label column can not be found. Select it in the dataset settings.")
83
+ label_column = "-"
84
+
85
+ task_category = "-"
86
+ if label_column != "-":
87
+ try:
88
+ # Find or set the task_category
89
+ task_category = datacleaner._find_task_category(joined_dataset, label_column)
90
+ except ValueError:
91
+ gr.Warning(
92
+ "Task category could not be determined. The dataset must support classification or regression tasks.",
93
+ )
94
+ pass
95
+
96
+ num_samples = len(joined_dataset)
97
+
98
+ return (
99
+ gr.update(
100
+ value=task_category,
101
+ choices=[str(t) for t in TaskCategory],
102
+ interactive=True,
103
+ ),
104
+ gr.update(
105
+ value=text_column, choices=joined_dataset.column_names, interactive=True
106
+ ),
107
+ gr.update(
108
+ value="-", choices=["-", *joined_dataset.column_names], interactive=True
109
+ ),
110
+ gr.update(
111
+ value=label_column, choices=joined_dataset.column_names, interactive=True
112
+ ),
113
+ num_samples,
114
+ )
115
+
116
+
117
+ def compute_ratio(num_samples_to_use, num_samples):
118
+ if num_samples > 0:
119
+ return num_samples_to_use / num_samples
120
+ else:
121
+ return 0.0
122
+
123
+
124
+ def ensure_one_lm_selected(checkbox_values, previous_values):
125
+ if not any(checkbox_values):
126
+ return previous_values
127
+ return checkbox_values
128
+
129
+
130
+ # Apply monkey patch to enable callbacks
131
+ _old_embed = Embedder.embed
132
+
133
+ def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
134
+ if embedder.tracker is not None:
135
+ embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))
136
+
137
+ return _old_embed(embedder, sentences, batch_size=batch_size, **kw)
138
+
139
+ Embedder.embed = _new_embed
140
+
141
+ _old_embed_batch = Embedder.embed_batch
142
+
143
+ def _new_embed_batch(embedder, *args, **kw):
144
+ r = _old_embed_batch(embedder, *args, **kw)
145
+ if embedder.tracker is not None:
146
+ embedder.tracker.update_batch_complete()
147
+ return r
148
+
149
+ Embedder.embed_batch = _new_embed_batch
150
+
151
+ _old_init = Embedder.__init__
152
+
153
+ def _new_init(embedder, *args, tracker=None, **kw):
154
+ _old_init(embedder, *args, **kw)
155
+ embedder.tracker = tracker
156
+
157
+ Embedder.__init__ = _new_init
158
+
159
+
160
+ class EmbeddingProgressTracker:
161
+ def __init__(self, *, progress, model_names):
162
+ self.model_names = model_names
163
+ self.progress_bar = progress
164
+
165
+ @property
166
+ def total(self):
167
+ return len(self.model_names)
168
+
169
+ def __enter__(self):
170
+ self.progress_bar = gr.Progress(track_tqdm=False)
171
+ self.current_model = -1
172
+ self.batches_complete = 0
173
+ self.batches_total = None
174
+ return self
175
+
176
+ def __exit__(self, typ, value, tb):
177
+ if typ is None:
178
+ self.progress_bar(1.0, desc="Done")
179
+ else:
180
+ self.progress_bar(1.0, desc="Error")
181
+
182
+ # Do not suppress any errors
183
+ return False
184
+
185
+ def update_num_batches(self, total):
186
+ self.current_model += 1
187
+ self.batches_complete = 0
188
+ self.batches_total = total
189
+ self.update_bar()
190
+
191
+ def update_batch_complete(self):
192
+ self.batches_complete += 1
193
+ self.update_bar()
194
+
195
+ def update_bar(self):
196
+ i = self.current_model
197
+
198
+ description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
199
+
200
+ progress = i / self.total
201
+ if self.batches_total is not None:
202
+ progress += (self.batches_complete / self.batches_total) / self.total
203
+
204
+ self.progress_bar(progress=progress, desc=description)
205
+