Spaces:
Running
Running
lukasgarbas
commited on
Commit
•
73d9a01
1
Parent(s):
0b6543b
add gradio app
Browse files
app.py
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from datasets import disable_caching, load_dataset
|
3 |
+
from transformer_ranker import TransformerRanker, prepare_popular_models
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
from utils import (
|
7 |
+
DISABLED_BUTTON_VARIANT, ENABLED_BUTTON_VARIANT, CSS, HEADLINE, FOOTER,
|
8 |
+
EmbeddingProgressTracker, check_dataset_exists, check_dataset_is_loaded,
|
9 |
+
compute_ratio, ensure_one_lm_selected, get_dataset_info
|
10 |
+
)
|
11 |
+
|
12 |
+
disable_caching()
|
13 |
+
|
14 |
+
THEME = "pseudolab/huggingface-korea-theme"
|
15 |
+
DEFAULT_SAMPLES = 1000
|
16 |
+
MAX_SAMPLES = 5000
|
17 |
+
LANGUAGE_MODELS = prepare_popular_models('base') + prepare_popular_models('large')
|
18 |
+
|
19 |
+
# Add a tiny model for demonstration on CPU
|
20 |
+
LANGUAGE_MODELS = ['prajjwal1/bert-tiny'] + list(dict.fromkeys(LANGUAGE_MODELS))
|
21 |
+
LANGUAGE_MODELS.insert(LANGUAGE_MODELS.index("bert-base-cased") + 1, "bert-base-uncased")
|
22 |
+
|
23 |
+
# Preselect some small models
|
24 |
+
DEFAULT_MODELS = [
|
25 |
+
"prajjwal1/bert-tiny", "google/electra-small-discriminator",
|
26 |
+
"distilbert-base-cased", "sentence-transformers/all-MiniLM-L12-v2"
|
27 |
+
]
|
28 |
+
|
29 |
+
|
30 |
+
with gr.Blocks(css=CSS, theme=THEME) as demo:
|
31 |
+
|
32 |
+
########## STEP 1: Load the Dataset ##########
|
33 |
+
|
34 |
+
gr.Markdown(HEADLINE)
|
35 |
+
|
36 |
+
gr.Markdown("## Step 1: Load a Dataset")
|
37 |
+
with gr.Group():
|
38 |
+
dataset = gr.State(None)
|
39 |
+
|
40 |
+
dataset_name = gr.Textbox(
|
41 |
+
label="Enter the name of your dataset",
|
42 |
+
placeholder="Examples: trec, ag_news, sst2, conll2003, leondz/wnut_17",
|
43 |
+
max_lines=1,
|
44 |
+
)
|
45 |
+
select_dataset_button = gr.Button(
|
46 |
+
value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT
|
47 |
+
)
|
48 |
+
|
49 |
+
# Activate the "Load dataset" button if dataset was found
|
50 |
+
dataset_name.change(
|
51 |
+
check_dataset_exists, inputs=dataset_name, outputs=select_dataset_button
|
52 |
+
)
|
53 |
+
|
54 |
+
gr.Markdown(
|
55 |
+
"*The number of samples that can be used in this demo is limited to save resources. "
|
56 |
+
"To run an estimate on the full dataset, check out the "
|
57 |
+
"[library](https://github.com/flairNLP/transformer-ranker).*"
|
58 |
+
)
|
59 |
+
|
60 |
+
########## Step 1.1 Dataset preprocessing ##########
|
61 |
+
|
62 |
+
with gr.Accordion("Dataset settings", open=False) as dataset_config:
|
63 |
+
with gr.Row() as dataset_details:
|
64 |
+
dataset_name_label = gr.Label("", label="Dataset Name")
|
65 |
+
num_samples = gr.State(0)
|
66 |
+
num_samples_label = gr.Label("", label="Number of Samples")
|
67 |
+
num_samples.change(
|
68 |
+
lambda x: str(x), inputs=[num_samples], outputs=[num_samples_label]
|
69 |
+
)
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
text_column = gr.Dropdown("", label="Text Column")
|
73 |
+
text_pair_column = gr.Dropdown("", label="Text Pair Column")
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
label_column = gr.Dropdown("", label="Label Column")
|
77 |
+
task_category = gr.Dropdown("", label="Task Type")
|
78 |
+
|
79 |
+
with gr.Group():
|
80 |
+
downsample_ratio = gr.State(0.0)
|
81 |
+
num_samples_to_use = gr.Slider(
|
82 |
+
20, MAX_SAMPLES, label="Samples to use", value=DEFAULT_SAMPLES, step=1
|
83 |
+
)
|
84 |
+
downsample_ratio_label = gr.Label("", label="Ratio of dataset to use")
|
85 |
+
downsample_ratio.change(
|
86 |
+
lambda x: f"{x:.1%}",
|
87 |
+
inputs=[downsample_ratio],
|
88 |
+
outputs=[downsample_ratio_label],
|
89 |
+
)
|
90 |
+
|
91 |
+
num_samples_to_use.change(
|
92 |
+
compute_ratio,
|
93 |
+
inputs=[num_samples_to_use, num_samples],
|
94 |
+
outputs=downsample_ratio,
|
95 |
+
)
|
96 |
+
num_samples.change(
|
97 |
+
compute_ratio,
|
98 |
+
inputs=[num_samples_to_use, num_samples],
|
99 |
+
outputs=downsample_ratio,
|
100 |
+
)
|
101 |
+
|
102 |
+
# Download the dataset and show details
|
103 |
+
def select_dataset(dataset_name):
|
104 |
+
try:
|
105 |
+
dataset = load_dataset(dataset_name, trust_remote_code=True)
|
106 |
+
dataset_info = get_dataset_info(dataset)
|
107 |
+
except ValueError:
|
108 |
+
gr.Warning("Dataset collections are not supported. Please use a single dataset.")
|
109 |
+
|
110 |
+
return (
|
111 |
+
gr.update(value="Loaded", interactive=False, variant=DISABLED_BUTTON_VARIANT),
|
112 |
+
gr.Accordion(open=True),
|
113 |
+
dataset_name,
|
114 |
+
dataset,
|
115 |
+
*dataset_info
|
116 |
+
)
|
117 |
+
|
118 |
+
select_dataset_button.click(
|
119 |
+
select_dataset,
|
120 |
+
inputs=[dataset_name],
|
121 |
+
outputs=[
|
122 |
+
select_dataset_button,
|
123 |
+
dataset_config,
|
124 |
+
dataset_name_label,
|
125 |
+
dataset,
|
126 |
+
task_category,
|
127 |
+
text_column,
|
128 |
+
text_pair_column,
|
129 |
+
label_column,
|
130 |
+
num_samples,
|
131 |
+
],
|
132 |
+
scroll_to_output=True,
|
133 |
+
)
|
134 |
+
|
135 |
+
########## STEP 2 ##########
|
136 |
+
|
137 |
+
gr.Markdown("## Step 2: Select a List of Language Models")
|
138 |
+
with gr.Group():
|
139 |
+
model_options = [
|
140 |
+
(model_handle.split("/")[-1], model_handle)
|
141 |
+
for model_handle in LANGUAGE_MODELS
|
142 |
+
]
|
143 |
+
models = gr.CheckboxGroup(
|
144 |
+
choices=model_options, label="Select Models", value=DEFAULT_MODELS
|
145 |
+
)
|
146 |
+
|
147 |
+
########## STEP 3: Run Language Model Ranking ##########
|
148 |
+
|
149 |
+
gr.Markdown("## Step 3: Rank LMs")
|
150 |
+
|
151 |
+
with gr.Group():
|
152 |
+
with gr.Accordion("Advanced settings", open=False):
|
153 |
+
with gr.Row():
|
154 |
+
estimator = gr.Dropdown(
|
155 |
+
choices=["hscore", "logme", "knn"],
|
156 |
+
label="Transferability metric",
|
157 |
+
value="hscore",
|
158 |
+
)
|
159 |
+
layer_pooling_options = ["lastlayer", "layermean", "bestlayer"]
|
160 |
+
layer_pooling = gr.Dropdown(
|
161 |
+
choices=["lastlayer", "layermean", "bestlayer"],
|
162 |
+
label="Layer pooling",
|
163 |
+
value="layermean",
|
164 |
+
)
|
165 |
+
submit_button = gr.Button("Run Ranking", interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
166 |
+
|
167 |
+
# Make button active if the dataset is loaded
|
168 |
+
dataset.change(
|
169 |
+
check_dataset_is_loaded,
|
170 |
+
inputs=[dataset, text_column, label_column, task_category],
|
171 |
+
outputs=submit_button
|
172 |
+
)
|
173 |
+
|
174 |
+
label_column.change(
|
175 |
+
check_dataset_is_loaded,
|
176 |
+
inputs=[dataset, text_column, label_column, task_category],
|
177 |
+
outputs=submit_button
|
178 |
+
)
|
179 |
+
|
180 |
+
text_column.change(
|
181 |
+
check_dataset_is_loaded,
|
182 |
+
inputs=[dataset, text_column, label_column, task_category],
|
183 |
+
outputs=submit_button
|
184 |
+
)
|
185 |
+
|
186 |
+
def rank_models(
|
187 |
+
dataset,
|
188 |
+
downsample_ratio,
|
189 |
+
selected_models,
|
190 |
+
layer_pooling,
|
191 |
+
estimator,
|
192 |
+
text_column,
|
193 |
+
text_pair_column,
|
194 |
+
label_column,
|
195 |
+
task_category,
|
196 |
+
progress=gr.Progress(),
|
197 |
+
):
|
198 |
+
|
199 |
+
if text_column == "-":
|
200 |
+
raise gr.Error("Text column is not set.")
|
201 |
+
|
202 |
+
if label_column == "-":
|
203 |
+
raise gr.Error("Label column is not set.")
|
204 |
+
|
205 |
+
if task_category == "-":
|
206 |
+
raise gr.Error(
|
207 |
+
"Task category is not set. The dataset must support classification or regression tasks."
|
208 |
+
)
|
209 |
+
|
210 |
+
if text_pair_column == "-":
|
211 |
+
text_pair_column = None
|
212 |
+
|
213 |
+
progress(0.0, "Starting")
|
214 |
+
|
215 |
+
with EmbeddingProgressTracker(progress=progress, model_names=selected_models) as tracker:
|
216 |
+
try:
|
217 |
+
ranker = TransformerRanker(
|
218 |
+
dataset,
|
219 |
+
dataset_downsample=downsample_ratio,
|
220 |
+
text_column=text_column,
|
221 |
+
text_pair_column=text_pair_column,
|
222 |
+
label_column=label_column,
|
223 |
+
task_category=task_category,
|
224 |
+
)
|
225 |
+
|
226 |
+
results = ranker.run(
|
227 |
+
models=selected_models,
|
228 |
+
layer_aggregator=layer_pooling,
|
229 |
+
estimator=estimator,
|
230 |
+
batch_size=64,
|
231 |
+
tracker=tracker,
|
232 |
+
)
|
233 |
+
|
234 |
+
sorted_results = sorted(
|
235 |
+
results._results.items(), key=lambda item: item[1], reverse=True
|
236 |
+
)
|
237 |
+
return [
|
238 |
+
(i + 1, model, score) for i, (model, score) in enumerate(sorted_results)
|
239 |
+
]
|
240 |
+
except Exception as e:
|
241 |
+
gr.Error("The dataset is not supported.")
|
242 |
+
|
243 |
+
gr.Markdown("## Results")
|
244 |
+
ranking_results = gr.Dataframe(
|
245 |
+
headers=["Rank", "Model", "Score"], datatype=["number", "str", "number"]
|
246 |
+
)
|
247 |
+
|
248 |
+
submit_button.click(
|
249 |
+
rank_models,
|
250 |
+
inputs=[
|
251 |
+
dataset,
|
252 |
+
downsample_ratio,
|
253 |
+
models,
|
254 |
+
layer_pooling,
|
255 |
+
estimator,
|
256 |
+
text_column,
|
257 |
+
text_pair_column,
|
258 |
+
label_column,
|
259 |
+
task_category,
|
260 |
+
],
|
261 |
+
outputs=ranking_results,
|
262 |
+
scroll_to_output=True,
|
263 |
+
)
|
264 |
+
|
265 |
+
gr.Markdown(
|
266 |
+
"*The results are ranked by their transferability score, with the most suitable model listed first. "
|
267 |
+
"This ranking allows focusing on the higher-ranked models for further exploration and fine-tuning.*"
|
268 |
+
)
|
269 |
+
|
270 |
+
gr.Markdown(FOOTER)
|
271 |
+
|
272 |
+
if __name__ == "__main__":
|
273 |
+
demo.queue(default_concurrency_limit=3)
|
274 |
+
demo.launch(max_threads=6)
|
utils.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from datasets import concatenate_datasets
|
3 |
+
from huggingface_hub import HfApi
|
4 |
+
from huggingface_hub.errors import HFValidationError
|
5 |
+
from requests.exceptions import HTTPError
|
6 |
+
from transformer_ranker import Result
|
7 |
+
from transformer_ranker.datacleaner import DatasetCleaner, TaskCategory
|
8 |
+
from transformer_ranker.embedder import Embedder
|
9 |
+
import math
|
10 |
+
|
11 |
+
DISABLED_BUTTON_VARIANT = "huggingface"
|
12 |
+
ENABLED_BUTTON_VARIANT = "primary"
|
13 |
+
|
14 |
+
HEADLINE = """
|
15 |
+
<h1 align="center">TransformerRanker</h1>
|
16 |
+
<p align="center" style="max-width: 560px; margin: auto;">
|
17 |
+
A very simple library that helps you find the best-suited language model for your NLP task.
|
18 |
+
All you need to do is to select a dataset and a list of pre-trained language models (LMs) from the 🤗 HuggingFace Hub.
|
19 |
+
TransformerRanker will quickly estimate which of these LMs will perform best on the given dataset!
|
20 |
+
</p>
|
21 |
+
<p align="center" style="font-weight: bold; margin-top: 20px; display: flex; justify-content: center; gap: 10px;">
|
22 |
+
<a href="https://github.com/flairNLP/transformer-ranker">
|
23 |
+
<img src="https://img.shields.io/github/stars/flairNLP/transformer-ranker?style=social&label=Repository" alt="GitHub Badge">
|
24 |
+
</a>
|
25 |
+
<a href="https://pypi.org/project/transformer-ranker/">
|
26 |
+
<img src="https://img.shields.io/badge/Package-orange?style=flat&logo=python" alt="Package Badge">
|
27 |
+
</a>
|
28 |
+
<a href="https://github.com/flairNLP/transformer-ranker/blob/main/examples/01-walkthrough.md">
|
29 |
+
<img src="https://img.shields.io/badge/Tutorials-blue?style=flat&logo=readthedocs&logoColor=white" alt="Tutorials Badge">
|
30 |
+
</a>
|
31 |
+
<img src="https://img.shields.io/badge/license-MIT-green?style=flat" alt="License: MIT">
|
32 |
+
</p>
|
33 |
+
<p align="center">Developed at <a href="https://www.informatik.hu-berlin.de/en/forschung-en/gebiete/ml-en/">Humboldt University of Berlin</a>.</p>
|
34 |
+
"""
|
35 |
+
|
36 |
+
FOOTER = """
|
37 |
+
**Note:** This demonstration currently runs on a CPU and is suited for smaller models only.
|
38 |
+
**Developers:** [@plonerma](https://huggingface.co/plonerma) and [@lukasgarbas](https://huggingface.co/lukasgarbas).
|
39 |
+
For feedback, suggestions, or contributions, reach out via GitHub or leave a message in the [discussions](https://huggingface.co/spaces/lukasgarbas/transformer-ranker/discussions).
|
40 |
+
"""
|
41 |
+
|
42 |
+
CSS = """
|
43 |
+
.gradio-container{max-width: 800px !important}
|
44 |
+
a {color: #ff9d00;}
|
45 |
+
@media (prefers-color-scheme: dark) { a {color: #be185d;} }
|
46 |
+
"""
|
47 |
+
|
48 |
+
|
49 |
+
hf_api = HfApi()
|
50 |
+
|
51 |
+
|
52 |
+
def check_dataset_exists(dataset_name):
|
53 |
+
"""Update loading button if dataset can be found"""
|
54 |
+
try:
|
55 |
+
hf_api.dataset_info(dataset_name)
|
56 |
+
return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
|
57 |
+
|
58 |
+
except (HTTPError, HFValidationError):
|
59 |
+
return gr.update(value="Load dataset", interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
60 |
+
|
61 |
+
def check_dataset_is_loaded(dataset, text_column, label_column, task_category):
|
62 |
+
if dataset and text_column != "-" and label_column != "-" and task_category != "-":
|
63 |
+
return gr.update(interactive=True, variant=ENABLED_BUTTON_VARIANT)
|
64 |
+
else:
|
65 |
+
return gr.update(interactive=False, variant=DISABLED_BUTTON_VARIANT)
|
66 |
+
|
67 |
+
|
68 |
+
def get_dataset_info(dataset):
|
69 |
+
"""Show information for dataset settings"""
|
70 |
+
joined_dataset = concatenate_datasets(list(dataset.values()))
|
71 |
+
datacleaner = DatasetCleaner()
|
72 |
+
|
73 |
+
try:
|
74 |
+
text_column = datacleaner._find_column(joined_dataset, "text column")
|
75 |
+
except ValueError:
|
76 |
+
gr.Warning("Text column can not be found. Select it in the dataset settings.")
|
77 |
+
text_column = "-"
|
78 |
+
|
79 |
+
try:
|
80 |
+
label_column = datacleaner._find_column(joined_dataset, "label column")
|
81 |
+
except ValueError:
|
82 |
+
gr.Warning("Label column can not be found. Select it in the dataset settings.")
|
83 |
+
label_column = "-"
|
84 |
+
|
85 |
+
task_category = "-"
|
86 |
+
if label_column != "-":
|
87 |
+
try:
|
88 |
+
# Find or set the task_category
|
89 |
+
task_category = datacleaner._find_task_category(joined_dataset, label_column)
|
90 |
+
except ValueError:
|
91 |
+
gr.Warning(
|
92 |
+
"Task category could not be determined. The dataset must support classification or regression tasks.",
|
93 |
+
)
|
94 |
+
pass
|
95 |
+
|
96 |
+
num_samples = len(joined_dataset)
|
97 |
+
|
98 |
+
return (
|
99 |
+
gr.update(
|
100 |
+
value=task_category,
|
101 |
+
choices=[str(t) for t in TaskCategory],
|
102 |
+
interactive=True,
|
103 |
+
),
|
104 |
+
gr.update(
|
105 |
+
value=text_column, choices=joined_dataset.column_names, interactive=True
|
106 |
+
),
|
107 |
+
gr.update(
|
108 |
+
value="-", choices=["-", *joined_dataset.column_names], interactive=True
|
109 |
+
),
|
110 |
+
gr.update(
|
111 |
+
value=label_column, choices=joined_dataset.column_names, interactive=True
|
112 |
+
),
|
113 |
+
num_samples,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def compute_ratio(num_samples_to_use, num_samples):
|
118 |
+
if num_samples > 0:
|
119 |
+
return num_samples_to_use / num_samples
|
120 |
+
else:
|
121 |
+
return 0.0
|
122 |
+
|
123 |
+
|
124 |
+
def ensure_one_lm_selected(checkbox_values, previous_values):
|
125 |
+
if not any(checkbox_values):
|
126 |
+
return previous_values
|
127 |
+
return checkbox_values
|
128 |
+
|
129 |
+
|
130 |
+
# Apply monkey patch to enable callbacks
|
131 |
+
_old_embed = Embedder.embed
|
132 |
+
|
133 |
+
def _new_embed(embedder, sentences, batch_size: int = 32, **kw):
|
134 |
+
if embedder.tracker is not None:
|
135 |
+
embedder.tracker.update_num_batches(math.ceil(len(sentences) / batch_size))
|
136 |
+
|
137 |
+
return _old_embed(embedder, sentences, batch_size=batch_size, **kw)
|
138 |
+
|
139 |
+
Embedder.embed = _new_embed
|
140 |
+
|
141 |
+
_old_embed_batch = Embedder.embed_batch
|
142 |
+
|
143 |
+
def _new_embed_batch(embedder, *args, **kw):
|
144 |
+
r = _old_embed_batch(embedder, *args, **kw)
|
145 |
+
if embedder.tracker is not None:
|
146 |
+
embedder.tracker.update_batch_complete()
|
147 |
+
return r
|
148 |
+
|
149 |
+
Embedder.embed_batch = _new_embed_batch
|
150 |
+
|
151 |
+
_old_init = Embedder.__init__
|
152 |
+
|
153 |
+
def _new_init(embedder, *args, tracker=None, **kw):
|
154 |
+
_old_init(embedder, *args, **kw)
|
155 |
+
embedder.tracker = tracker
|
156 |
+
|
157 |
+
Embedder.__init__ = _new_init
|
158 |
+
|
159 |
+
|
160 |
+
class EmbeddingProgressTracker:
|
161 |
+
def __init__(self, *, progress, model_names):
|
162 |
+
self.model_names = model_names
|
163 |
+
self.progress_bar = progress
|
164 |
+
|
165 |
+
@property
|
166 |
+
def total(self):
|
167 |
+
return len(self.model_names)
|
168 |
+
|
169 |
+
def __enter__(self):
|
170 |
+
self.progress_bar = gr.Progress(track_tqdm=False)
|
171 |
+
self.current_model = -1
|
172 |
+
self.batches_complete = 0
|
173 |
+
self.batches_total = None
|
174 |
+
return self
|
175 |
+
|
176 |
+
def __exit__(self, typ, value, tb):
|
177 |
+
if typ is None:
|
178 |
+
self.progress_bar(1.0, desc="Done")
|
179 |
+
else:
|
180 |
+
self.progress_bar(1.0, desc="Error")
|
181 |
+
|
182 |
+
# Do not suppress any errors
|
183 |
+
return False
|
184 |
+
|
185 |
+
def update_num_batches(self, total):
|
186 |
+
self.current_model += 1
|
187 |
+
self.batches_complete = 0
|
188 |
+
self.batches_total = total
|
189 |
+
self.update_bar()
|
190 |
+
|
191 |
+
def update_batch_complete(self):
|
192 |
+
self.batches_complete += 1
|
193 |
+
self.update_bar()
|
194 |
+
|
195 |
+
def update_bar(self):
|
196 |
+
i = self.current_model
|
197 |
+
|
198 |
+
description = f"Running {self.model_names[i]} ({i + 1} / {self.total})"
|
199 |
+
|
200 |
+
progress = i / self.total
|
201 |
+
if self.batches_total is not None:
|
202 |
+
progress += (self.batches_complete / self.batches_total) / self.total
|
203 |
+
|
204 |
+
self.progress_bar(progress=progress, desc=description)
|
205 |
+
|