Initial conversion
Browse filesAssumes macOS. Does not push anywhere.
- .gitignore +6 -0
- app.py +250 -0
- requirements.txt +4 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
exported/
|
2 |
+
.ipynb_checkpoints/
|
3 |
+
.vscode/
|
4 |
+
__pycache__/
|
5 |
+
Untitled.ipynb
|
6 |
+
test.py
|
app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from huggingface_hub import hf_hub_download, HfApi
|
6 |
+
from coremltools import ComputeUnit
|
7 |
+
from transformers.onnx.utils import get_preprocessor
|
8 |
+
|
9 |
+
from exporters.coreml import export
|
10 |
+
from exporters.coreml.features import FeaturesManager
|
11 |
+
from exporters.coreml.validate import validate_model_outputs
|
12 |
+
|
13 |
+
compute_units_mapping = {
|
14 |
+
"All": ComputeUnit.ALL,
|
15 |
+
"CPU": ComputeUnit.CPU_ONLY,
|
16 |
+
"CPU + GPU": ComputeUnit.CPU_AND_GPU,
|
17 |
+
"CPU + NE": ComputeUnit.CPU_AND_NE,
|
18 |
+
}
|
19 |
+
compute_units_labels = list(compute_units_mapping.keys())
|
20 |
+
|
21 |
+
framework_mapping = {
|
22 |
+
"PyTorch": "pt",
|
23 |
+
"TensorFlow": "tf",
|
24 |
+
}
|
25 |
+
framework_labels = list(framework_mapping.keys())
|
26 |
+
|
27 |
+
precision_mapping = {
|
28 |
+
"Float32": "float32",
|
29 |
+
"Float16 quantization": "float16",
|
30 |
+
}
|
31 |
+
precision_labels = list(precision_mapping.keys())
|
32 |
+
|
33 |
+
tolerance_mapping = {
|
34 |
+
"Model default": None,
|
35 |
+
"1e-2": 1e-2,
|
36 |
+
"1e-3": 1e-3,
|
37 |
+
"1e-4": 1e-4,
|
38 |
+
}
|
39 |
+
tolerance_labels = list(tolerance_mapping.keys())
|
40 |
+
|
41 |
+
def error_str(error, title="Error"):
|
42 |
+
return f"""#### {title}
|
43 |
+
{error}""" if error else ""
|
44 |
+
|
45 |
+
def url_to_model_id(model_id_str):
|
46 |
+
if not model_id_str.startswith("https://huggingface.co/"): return model_id_str
|
47 |
+
return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1]
|
48 |
+
|
49 |
+
def supported_frameworks(model_id):
|
50 |
+
"""
|
51 |
+
Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
|
52 |
+
Only PyTorch and Tensorflow are supported.
|
53 |
+
"""
|
54 |
+
api = HfApi()
|
55 |
+
model_info = api.model_info(model_id)
|
56 |
+
tags = model_info.tags
|
57 |
+
frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
|
58 |
+
return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks])
|
59 |
+
|
60 |
+
def on_model_change(model):
|
61 |
+
model = url_to_model_id(model)
|
62 |
+
tasks = None
|
63 |
+
error = None
|
64 |
+
|
65 |
+
try:
|
66 |
+
config_file = hf_hub_download(model, filename="config.json")
|
67 |
+
if config_file is None:
|
68 |
+
raise Exception(f"Model {model} not found")
|
69 |
+
|
70 |
+
with open(config_file, "r") as f:
|
71 |
+
config_json = f.read()
|
72 |
+
|
73 |
+
config = json.loads(config_json)
|
74 |
+
model_type = config["model_type"]
|
75 |
+
|
76 |
+
features = FeaturesManager.get_supported_features_for_model_type(model_type)
|
77 |
+
tasks = list(features.keys())
|
78 |
+
|
79 |
+
frameworks = supported_frameworks(model)
|
80 |
+
selected_framework = frameworks[0] if len(frameworks) > 0 else None
|
81 |
+
return (
|
82 |
+
gr.update(visible=bool(model_type)), # Settings column
|
83 |
+
gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks
|
84 |
+
gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
|
85 |
+
gr.update(value=error_str(error)), # Error
|
86 |
+
)
|
87 |
+
except Exception as e:
|
88 |
+
error = e
|
89 |
+
model_type = None
|
90 |
+
|
91 |
+
|
92 |
+
def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None):
|
93 |
+
coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq)
|
94 |
+
|
95 |
+
mlmodel = export(
|
96 |
+
preprocessor,
|
97 |
+
model,
|
98 |
+
coreml_config,
|
99 |
+
quantize=precision,
|
100 |
+
compute_units=compute_units,
|
101 |
+
)
|
102 |
+
|
103 |
+
filename = output
|
104 |
+
if seq2seq == "encoder":
|
105 |
+
filename = filename.parent / ("encoder_" + filename.name)
|
106 |
+
elif seq2seq == "decoder":
|
107 |
+
filename = filename.parent / ("decoder_" + filename.name)
|
108 |
+
filename = filename.as_posix()
|
109 |
+
|
110 |
+
mlmodel.save(filename)
|
111 |
+
|
112 |
+
if tolerance is None:
|
113 |
+
tolerance = coreml_config.atol_for_validation
|
114 |
+
validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance)
|
115 |
+
|
116 |
+
|
117 |
+
def convert(model, task, compute_units, precision, tolerance, framework):
|
118 |
+
model = url_to_model_id(model)
|
119 |
+
compute_units = compute_units_mapping[compute_units]
|
120 |
+
precision = precision_mapping[precision]
|
121 |
+
tolerance = tolerance_mapping[tolerance]
|
122 |
+
framework = framework_mapping[framework]
|
123 |
+
|
124 |
+
# TODO: support legacy format
|
125 |
+
output = Path("exported")/model/"coreml"/task
|
126 |
+
output.mkdir(parents=True, exist_ok=True)
|
127 |
+
output = output/f"{precision}_model.mlpackage"
|
128 |
+
|
129 |
+
try:
|
130 |
+
preprocessor = get_preprocessor(model)
|
131 |
+
model = FeaturesManager.get_model_from_feature(task, model, framework=framework)
|
132 |
+
_, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task)
|
133 |
+
|
134 |
+
if task in ["seq2seq-lm", "speech-seq2seq"]:
|
135 |
+
# Convert encoder / decoder
|
136 |
+
convert_model(
|
137 |
+
preprocessor,
|
138 |
+
model,
|
139 |
+
model_coreml_config,
|
140 |
+
compute_units,
|
141 |
+
precision,
|
142 |
+
tolerance,
|
143 |
+
output,
|
144 |
+
seq2seq="encoder"
|
145 |
+
)
|
146 |
+
convert_model(
|
147 |
+
preprocessor,
|
148 |
+
model,
|
149 |
+
model_coreml_config,
|
150 |
+
compute_units,
|
151 |
+
precision,
|
152 |
+
tolerance,
|
153 |
+
output,
|
154 |
+
seq2seq="decoder"
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
convert_model(
|
158 |
+
preprocessor,
|
159 |
+
model,
|
160 |
+
model_coreml_config,
|
161 |
+
compute_units,
|
162 |
+
precision,
|
163 |
+
tolerance,
|
164 |
+
output,
|
165 |
+
)
|
166 |
+
|
167 |
+
# TODO: push to hub, whatever
|
168 |
+
return "Done"
|
169 |
+
except Exception as e:
|
170 |
+
return error_str(e)
|
171 |
+
|
172 |
+
DESCRIPTION = """
|
173 |
+
## Convert a transformers model to Core ML
|
174 |
+
|
175 |
+
With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood.
|
176 |
+
|
177 |
+
Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters).
|
178 |
+
|
179 |
+
After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights.
|
180 |
+
"""
|
181 |
+
|
182 |
+
with gr.Blocks() as demo:
|
183 |
+
gr.Markdown(DESCRIPTION)
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column(scale=2):
|
186 |
+
gr.Markdown("## 1. Load model info")
|
187 |
+
input_model = gr.Textbox(
|
188 |
+
max_lines=1,
|
189 |
+
label="Model name or URL, such as apple/mobilevit-small",
|
190 |
+
placeholder="distilbert-base-uncased",
|
191 |
+
value="distilbert-base-uncased",
|
192 |
+
)
|
193 |
+
btn_get_tasks = gr.Button("Load")
|
194 |
+
with gr.Column(scale=3):
|
195 |
+
with gr.Column(visible=False) as group_settings:
|
196 |
+
gr.Markdown("## 2. Select Task")
|
197 |
+
radio_tasks = gr.Radio(label="Choose the task for the converted model.")
|
198 |
+
gr.Markdown("The `default` task is suitable for feature extraction.")
|
199 |
+
radio_framework = gr.Radio(
|
200 |
+
visible=False,
|
201 |
+
label="Framework",
|
202 |
+
choices=framework_labels,
|
203 |
+
value=framework_labels[0],
|
204 |
+
)
|
205 |
+
radio_compute = gr.Radio(
|
206 |
+
label="Compute Units",
|
207 |
+
choices=compute_units_labels,
|
208 |
+
value=compute_units_labels[0],
|
209 |
+
)
|
210 |
+
radio_precision = gr.Radio(
|
211 |
+
label="Precision",
|
212 |
+
choices=precision_labels,
|
213 |
+
value=precision_labels[0],
|
214 |
+
)
|
215 |
+
radio_tolerance = gr.Radio(
|
216 |
+
label="Absolute Tolerance for Validation",
|
217 |
+
choices=tolerance_labels,
|
218 |
+
value=tolerance_labels[0],
|
219 |
+
)
|
220 |
+
btn_convert = gr.Button("Convert")
|
221 |
+
gr.Markdown("Conversion will take a few minutes.")
|
222 |
+
|
223 |
+
|
224 |
+
error_output = gr.Markdown(label="Output")
|
225 |
+
|
226 |
+
btn_get_tasks.click(
|
227 |
+
fn=on_model_change,
|
228 |
+
inputs=input_model,
|
229 |
+
outputs=[group_settings, radio_tasks, radio_framework, error_output],
|
230 |
+
queue=False,
|
231 |
+
scroll_to_output=True
|
232 |
+
)
|
233 |
+
|
234 |
+
btn_convert.click(
|
235 |
+
fn=convert,
|
236 |
+
inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework],
|
237 |
+
outputs=error_output,
|
238 |
+
scroll_to_output=True
|
239 |
+
)
|
240 |
+
|
241 |
+
# gr.HTML("""
|
242 |
+
# <div style="border-top: 1px solid #303030;">
|
243 |
+
# <br>
|
244 |
+
# <p>Footer</p><br>
|
245 |
+
# <p><img src="https://visitor-badge.glitch.me/badge?page_id=pcuenq.transformers-to-coreml" alt="visitors"></p>
|
246 |
+
# </div>
|
247 |
+
# """)
|
248 |
+
|
249 |
+
demo.queue(concurrency_count=1, max_size=10)
|
250 |
+
demo.launch(debug=True, share=False)
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
huggingface_hub
|
2 |
+
transformers
|
3 |
+
coremltools
|
4 |
+
git+https://github.com/huggingface/exporters.git
|