pcuenq HF staff commited on
Commit
e05f54a
·
1 Parent(s): 7d18807

Initial conversion

Browse files

Assumes macOS. Does not push anywhere.

Files changed (3) hide show
  1. .gitignore +6 -0
  2. app.py +250 -0
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ exported/
2
+ .ipynb_checkpoints/
3
+ .vscode/
4
+ __pycache__/
5
+ Untitled.ipynb
6
+ test.py
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ from pathlib import Path
4
+
5
+ from huggingface_hub import hf_hub_download, HfApi
6
+ from coremltools import ComputeUnit
7
+ from transformers.onnx.utils import get_preprocessor
8
+
9
+ from exporters.coreml import export
10
+ from exporters.coreml.features import FeaturesManager
11
+ from exporters.coreml.validate import validate_model_outputs
12
+
13
+ compute_units_mapping = {
14
+ "All": ComputeUnit.ALL,
15
+ "CPU": ComputeUnit.CPU_ONLY,
16
+ "CPU + GPU": ComputeUnit.CPU_AND_GPU,
17
+ "CPU + NE": ComputeUnit.CPU_AND_NE,
18
+ }
19
+ compute_units_labels = list(compute_units_mapping.keys())
20
+
21
+ framework_mapping = {
22
+ "PyTorch": "pt",
23
+ "TensorFlow": "tf",
24
+ }
25
+ framework_labels = list(framework_mapping.keys())
26
+
27
+ precision_mapping = {
28
+ "Float32": "float32",
29
+ "Float16 quantization": "float16",
30
+ }
31
+ precision_labels = list(precision_mapping.keys())
32
+
33
+ tolerance_mapping = {
34
+ "Model default": None,
35
+ "1e-2": 1e-2,
36
+ "1e-3": 1e-3,
37
+ "1e-4": 1e-4,
38
+ }
39
+ tolerance_labels = list(tolerance_mapping.keys())
40
+
41
+ def error_str(error, title="Error"):
42
+ return f"""#### {title}
43
+ {error}""" if error else ""
44
+
45
+ def url_to_model_id(model_id_str):
46
+ if not model_id_str.startswith("https://huggingface.co/"): return model_id_str
47
+ return model_id_str.split("/")[-2] + "/" + model_id_str.split("/")[-1]
48
+
49
+ def supported_frameworks(model_id):
50
+ """
51
+ Return a list of supported frameworks (`PyTorch` or `TensorFlow`) for a given model_id.
52
+ Only PyTorch and Tensorflow are supported.
53
+ """
54
+ api = HfApi()
55
+ model_info = api.model_info(model_id)
56
+ tags = model_info.tags
57
+ frameworks = [tag for tag in tags if tag in ["pytorch", "tf"]]
58
+ return sorted(["PyTorch" if f == "pytorch" else "TensorFlow" for f in frameworks])
59
+
60
+ def on_model_change(model):
61
+ model = url_to_model_id(model)
62
+ tasks = None
63
+ error = None
64
+
65
+ try:
66
+ config_file = hf_hub_download(model, filename="config.json")
67
+ if config_file is None:
68
+ raise Exception(f"Model {model} not found")
69
+
70
+ with open(config_file, "r") as f:
71
+ config_json = f.read()
72
+
73
+ config = json.loads(config_json)
74
+ model_type = config["model_type"]
75
+
76
+ features = FeaturesManager.get_supported_features_for_model_type(model_type)
77
+ tasks = list(features.keys())
78
+
79
+ frameworks = supported_frameworks(model)
80
+ selected_framework = frameworks[0] if len(frameworks) > 0 else None
81
+ return (
82
+ gr.update(visible=bool(model_type)), # Settings column
83
+ gr.update(choices=tasks, value=tasks[0] if tasks else None), # Tasks
84
+ gr.update(visible=len(frameworks)>1, choices=frameworks, value=selected_framework), # Frameworks
85
+ gr.update(value=error_str(error)), # Error
86
+ )
87
+ except Exception as e:
88
+ error = e
89
+ model_type = None
90
+
91
+
92
+ def convert_model(preprocessor, model, model_coreml_config, compute_units, precision, tolerance, output, use_past=False, seq2seq=None):
93
+ coreml_config = model_coreml_config(model.config, use_past=use_past, seq2seq=seq2seq)
94
+
95
+ mlmodel = export(
96
+ preprocessor,
97
+ model,
98
+ coreml_config,
99
+ quantize=precision,
100
+ compute_units=compute_units,
101
+ )
102
+
103
+ filename = output
104
+ if seq2seq == "encoder":
105
+ filename = filename.parent / ("encoder_" + filename.name)
106
+ elif seq2seq == "decoder":
107
+ filename = filename.parent / ("decoder_" + filename.name)
108
+ filename = filename.as_posix()
109
+
110
+ mlmodel.save(filename)
111
+
112
+ if tolerance is None:
113
+ tolerance = coreml_config.atol_for_validation
114
+ validate_model_outputs(coreml_config, preprocessor, model, mlmodel, tolerance)
115
+
116
+
117
+ def convert(model, task, compute_units, precision, tolerance, framework):
118
+ model = url_to_model_id(model)
119
+ compute_units = compute_units_mapping[compute_units]
120
+ precision = precision_mapping[precision]
121
+ tolerance = tolerance_mapping[tolerance]
122
+ framework = framework_mapping[framework]
123
+
124
+ # TODO: support legacy format
125
+ output = Path("exported")/model/"coreml"/task
126
+ output.mkdir(parents=True, exist_ok=True)
127
+ output = output/f"{precision}_model.mlpackage"
128
+
129
+ try:
130
+ preprocessor = get_preprocessor(model)
131
+ model = FeaturesManager.get_model_from_feature(task, model, framework=framework)
132
+ _, model_coreml_config = FeaturesManager.check_supported_model_or_raise(model, feature=task)
133
+
134
+ if task in ["seq2seq-lm", "speech-seq2seq"]:
135
+ # Convert encoder / decoder
136
+ convert_model(
137
+ preprocessor,
138
+ model,
139
+ model_coreml_config,
140
+ compute_units,
141
+ precision,
142
+ tolerance,
143
+ output,
144
+ seq2seq="encoder"
145
+ )
146
+ convert_model(
147
+ preprocessor,
148
+ model,
149
+ model_coreml_config,
150
+ compute_units,
151
+ precision,
152
+ tolerance,
153
+ output,
154
+ seq2seq="decoder"
155
+ )
156
+ else:
157
+ convert_model(
158
+ preprocessor,
159
+ model,
160
+ model_coreml_config,
161
+ compute_units,
162
+ precision,
163
+ tolerance,
164
+ output,
165
+ )
166
+
167
+ # TODO: push to hub, whatever
168
+ return "Done"
169
+ except Exception as e:
170
+ return error_str(e)
171
+
172
+ DESCRIPTION = """
173
+ ## Convert a transformers model to Core ML
174
+
175
+ With this Space you can try to convert a transformers model to Core ML. It uses the 🤗 Hugging Face [Exporters repo](https://huggingface.co/exporters) under the hood.
176
+
177
+ Note that not all models are supported. If you get an error on a model you'd like to convert, please open an issue on the [repo](https://github.com/huggingface/exporters).
178
+
179
+ After conversion, you can choose to submit a PR to the original repo, or create your own repo with just the converted Core ML weights.
180
+ """
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.Markdown(DESCRIPTION)
184
+ with gr.Row():
185
+ with gr.Column(scale=2):
186
+ gr.Markdown("## 1. Load model info")
187
+ input_model = gr.Textbox(
188
+ max_lines=1,
189
+ label="Model name or URL, such as apple/mobilevit-small",
190
+ placeholder="distilbert-base-uncased",
191
+ value="distilbert-base-uncased",
192
+ )
193
+ btn_get_tasks = gr.Button("Load")
194
+ with gr.Column(scale=3):
195
+ with gr.Column(visible=False) as group_settings:
196
+ gr.Markdown("## 2. Select Task")
197
+ radio_tasks = gr.Radio(label="Choose the task for the converted model.")
198
+ gr.Markdown("The `default` task is suitable for feature extraction.")
199
+ radio_framework = gr.Radio(
200
+ visible=False,
201
+ label="Framework",
202
+ choices=framework_labels,
203
+ value=framework_labels[0],
204
+ )
205
+ radio_compute = gr.Radio(
206
+ label="Compute Units",
207
+ choices=compute_units_labels,
208
+ value=compute_units_labels[0],
209
+ )
210
+ radio_precision = gr.Radio(
211
+ label="Precision",
212
+ choices=precision_labels,
213
+ value=precision_labels[0],
214
+ )
215
+ radio_tolerance = gr.Radio(
216
+ label="Absolute Tolerance for Validation",
217
+ choices=tolerance_labels,
218
+ value=tolerance_labels[0],
219
+ )
220
+ btn_convert = gr.Button("Convert")
221
+ gr.Markdown("Conversion will take a few minutes.")
222
+
223
+
224
+ error_output = gr.Markdown(label="Output")
225
+
226
+ btn_get_tasks.click(
227
+ fn=on_model_change,
228
+ inputs=input_model,
229
+ outputs=[group_settings, radio_tasks, radio_framework, error_output],
230
+ queue=False,
231
+ scroll_to_output=True
232
+ )
233
+
234
+ btn_convert.click(
235
+ fn=convert,
236
+ inputs=[input_model, radio_tasks, radio_compute, radio_precision, radio_tolerance, radio_framework],
237
+ outputs=error_output,
238
+ scroll_to_output=True
239
+ )
240
+
241
+ # gr.HTML("""
242
+ # <div style="border-top: 1px solid #303030;">
243
+ # <br>
244
+ # <p>Footer</p><br>
245
+ # <p><img src="https://visitor-badge.glitch.me/badge?page_id=pcuenq.transformers-to-coreml" alt="visitors"></p>
246
+ # </div>
247
+ # """)
248
+
249
+ demo.queue(concurrency_count=1, max_size=10)
250
+ demo.launch(debug=True, share=False)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ huggingface_hub
2
+ transformers
3
+ coremltools
4
+ git+https://github.com/huggingface/exporters.git