DocWolle commited on
Commit
53751ec
·
verified ·
1 Parent(s): 321cff1

Upload notebook for model generation

Browse files
Generate_tflite_for_whisper_base_with_transcribe_and_translate_signatures.ipynb ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "c5g9NTF_Ixad"
7
+ },
8
+ "source": [
9
+ "##Install Tranformers and datasets"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {
16
+ "id": "w4VPaSlnHUvT"
17
+ },
18
+ "outputs": [],
19
+ "source": [
20
+ "!pip install transformers==4.33.0\n",
21
+ "!pip install tensorflow==2.14.0"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {
28
+ "id": "ClniiYCWHK4b"
29
+ },
30
+ "outputs": [],
31
+ "source": [
32
+ "! pip install datasets"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {
38
+ "id": "pljpioLsJOtb"
39
+ },
40
+ "source": [
41
+ "##Load pre trained TF Whisper Base model"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {
48
+ "id": "BJNOxn5vHaGi"
49
+ },
50
+ "outputs": [],
51
+ "source": [
52
+ "import tensorflow as tf\n",
53
+ "from transformers import TFWhisperModel, WhisperFeatureExtractor\n",
54
+ "from datasets import load_dataset\n",
55
+ "\n",
56
+ "model = TFWhisperModel.from_pretrained(\"openai/whisper-base\")\n",
57
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(\"openai/whisper-base\")\n",
58
+ "\n",
59
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
60
+ "inputs = feature_extractor(\n",
61
+ " ds[0][\"audio\"][\"array\"], sampling_rate=ds[0][\"audio\"][\"sampling_rate\"], return_tensors=\"tf\"\n",
62
+ ")\n",
63
+ "input_features = inputs.input_features\n",
64
+ "print(input_features)\n",
65
+ "decoder_input_ids = tf.convert_to_tensor([[1, 1]]) * model.config.decoder_start_token_id\n",
66
+ "last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state\n",
67
+ "list(last_hidden_state.shape)"
68
+ ]
69
+ },
70
+ {
71
+ "cell_type": "markdown",
72
+ "metadata": {
73
+ "id": "W9XP25uhJl44"
74
+ },
75
+ "source": [
76
+ "##Generate Saved model"
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
83
+ "id": "vpYwMmgyHf0B"
84
+ },
85
+ "outputs": [],
86
+ "source": [
87
+ "model.save('/content/tf_whisper_saved')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "markdown",
92
+ "metadata": {
93
+ "id": "TY_79jFEJYyJ"
94
+ },
95
+ "source": [
96
+ "##Convert saved model to TFLite model"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": null,
102
+ "metadata": {
103
+ "id": "owez2zvzHl-p"
104
+ },
105
+ "outputs": [],
106
+ "source": [
107
+ "import tensorflow as tf\n",
108
+ "\n",
109
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
110
+ "tflite_model_path = 'whisper.tflite'\n",
111
+ "\n",
112
+ "# Convert the model\n",
113
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
114
+ "converter.target_spec.supported_ops = [\n",
115
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
116
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
117
+ "]\n",
118
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
119
+ "tflite_model = converter.convert()\n",
120
+ "\n",
121
+ "# Save the model\n",
122
+ "with open(tflite_model_path, 'wb') as f:\n",
123
+ " f.write(tflite_model)"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "metadata": {
130
+ "id": "tFkzUrjIbNcH"
131
+ },
132
+ "outputs": [],
133
+ "source": [
134
+ "%ls -la"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "markdown",
139
+ "metadata": {
140
+ "id": "fpEnWZt7iQJK"
141
+ },
142
+ "source": [
143
+ "##Evaluate TF model"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {
150
+ "id": "-RuFFohHg2ho"
151
+ },
152
+ "outputs": [],
153
+ "source": [
154
+ "import tensorflow as tf\n",
155
+ "from transformers import WhisperProcessor, TFWhisperForConditionalGeneration\n",
156
+ "from datasets import load_dataset\n",
157
+ "\n",
158
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-base\")\n",
159
+ "model = TFWhisperForConditionalGeneration.from_pretrained(\"openai/whisper-base\")\n",
160
+ "\n",
161
+ "ds = load_dataset(\"google/fleurs\", \"fr_fr\", split=\"test\")\n",
162
+ "\n",
163
+ "inputs = processor(ds[0][\"audio\"][\"array\"], return_tensors=\"tf\")\n",
164
+ "input_features = inputs.input_features\n",
165
+ "\n",
166
+ "generated_ids = model.generate(input_features)\n",
167
+ "\n",
168
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
169
+ "transcription"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "metadata": {
175
+ "id": "U-eKuy_cG4u0"
176
+ },
177
+ "source": [
178
+ "## Evaluate TF Lite model (naive)\n",
179
+ "\n",
180
+ "We can load the model as defined above... but the model is useless on its own. Generation is much more complex that a model forward pass."
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "metadata": {
187
+ "id": "wnfHirgyG0W4"
188
+ },
189
+ "outputs": [],
190
+ "source": [
191
+ "tflite_model_path = 'whisper.tflite'\n",
192
+ "interpreter = tf.lite.Interpreter(tflite_model_path)"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {
198
+ "id": "a8VJQuHJKzl4"
199
+ },
200
+ "source": [
201
+ "## Create generation-enabled TF Lite model\n",
202
+ "\n",
203
+ "The solution consists in defining a model whose serving function is the generation call. Here's an example of how to do it:"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "markdown",
208
+ "metadata": {
209
+ "id": "JmIgqWVgVBZN"
210
+ },
211
+ "source": [
212
+ "Now with monkey-patch for fixing NaN errors with -inf values"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "code",
217
+ "execution_count": null,
218
+ "metadata": {
219
+ "id": "e5P8s66yU7Kv"
220
+ },
221
+ "outputs": [],
222
+ "source": [
223
+ "import tensorflow as tf\n",
224
+ "import numpy as np\n",
225
+ "from transformers import TFForceTokensLogitsProcessor, TFLogitsProcessor\n",
226
+ "from typing import List, Optional, Union, Any\n",
227
+ "\n",
228
+ "# Patching methods of class TFForceTokensLogitsProcessor(TFLogitsProcessor):\n",
229
+ "\n",
230
+ "def my__init__(self, force_token_map: List[List[int]]):\n",
231
+ " force_token_map = dict(force_token_map)\n",
232
+ " # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the\n",
233
+ " # index of the array corresponds to the index of the token to be forced, for XLA compatibility.\n",
234
+ " # Indexes without forced tokens will have an negative value.\n",
235
+ " force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1\n",
236
+ " for index, token in force_token_map.items():\n",
237
+ " if token is not None:\n",
238
+ " force_token_array[index] = token\n",
239
+ " self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)\n",
240
+ "\n",
241
+ "def my__call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:\n",
242
+ " def _force_token(generation_idx):\n",
243
+ " batch_size = scores.shape[0]\n",
244
+ " current_token = self.force_token_array[generation_idx]\n",
245
+ "\n",
246
+ " # Original code below generates NaN values when the model is exported to tflite\n",
247
+ " # it just needs to be a negative number so that the forced token's value of 0 is the largest\n",
248
+ " # so it will get chosen\n",
249
+ " #new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(\"inf\")\n",
250
+ " new_scores = tf.ones_like(scores, dtype=scores.dtype) * -float(1)\n",
251
+ " indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)\n",
252
+ " updates = tf.zeros((batch_size,), dtype=scores.dtype)\n",
253
+ " new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)\n",
254
+ " return new_scores\n",
255
+ "\n",
256
+ " scores = tf.cond(\n",
257
+ " tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),\n",
258
+ " # If the current length is geq than the length of force_token_array, the processor does nothing.\n",
259
+ " lambda: tf.identity(scores),\n",
260
+ " # Otherwise, it may force a certain token.\n",
261
+ " lambda: tf.cond(\n",
262
+ " tf.greater_equal(self.force_token_array[cur_len], 0),\n",
263
+ " # Only valid (positive) tokens are forced\n",
264
+ " lambda: _force_token(cur_len),\n",
265
+ " # Otherwise, the processor does nothing.\n",
266
+ " lambda: scores,\n",
267
+ " ),\n",
268
+ " )\n",
269
+ " return scores\n",
270
+ "\n",
271
+ "TFForceTokensLogitsProcessor.__init__ = my__init__\n",
272
+ "TFForceTokensLogitsProcessor.__call__ = my__call__"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": null,
278
+ "metadata": {
279
+ "id": "rIkUCdiyU7ZT"
280
+ },
281
+ "outputs": [],
282
+ "source": [
283
+ "import tensorflow as tf\n",
284
+ "\n",
285
+ "class GenerateModel(tf.Module):\n",
286
+ " def __init__(self, model):\n",
287
+ " super(GenerateModel, self).__init__()\n",
288
+ " self.model = model\n",
289
+ "\n",
290
+ " @tf.function(\n",
291
+ " input_signature=[\n",
292
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
293
+ " ],\n",
294
+ " )\n",
295
+ " def transcribe(self, input_features):\n",
296
+ " outputs = self.model.generate(\n",
297
+ " input_features,\n",
298
+ " max_new_tokens=450, # change as needed\n",
299
+ " return_dict_in_generate=True,\n",
300
+ " forced_decoder_ids=[[2, 50359], [3, 50363]], # forced to transcribe any language with no timestamps\n",
301
+ " )\n",
302
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
303
+ "\n",
304
+ " @tf.function(\n",
305
+ " input_signature=[\n",
306
+ " tf.TensorSpec((1, 80, 3000), tf.float32, name=\"input_features\"),\n",
307
+ " ],\n",
308
+ " )\n",
309
+ " def translate(self, input_features):\n",
310
+ " outputs = self.model.generate(\n",
311
+ " input_features,\n",
312
+ " max_new_tokens=450, # change as needed\n",
313
+ " return_dict_in_generate=True,\n",
314
+ " forced_decoder_ids=[[2, 50358], [3, 50363]], # different forced_decoder_ids\n",
315
+ " )\n",
316
+ " return {\"sequences\": outputs[\"sequences\"]}\n",
317
+ "\n",
318
+ "# Assuming `model` is already defined and loaded\n",
319
+ "saved_model_dir = '/content/tf_whisper_saved'\n",
320
+ "tflite_model_path = 'whisper.tflite'\n",
321
+ "\n",
322
+ "generate_model = GenerateModel(model=model)\n",
323
+ "tf.saved_model.save(generate_model, saved_model_dir, signatures={\n",
324
+ " \"serving_default\": generate_model.transcribe,\n",
325
+ " \"serving_transcribe\": generate_model.transcribe,\n",
326
+ " \"serving_translate\": generate_model.translate\n",
327
+ "\n",
328
+ "})\n",
329
+ "\n",
330
+ "# Convert the model\n",
331
+ "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
332
+ "converter.target_spec.supported_ops = [\n",
333
+ " tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.\n",
334
+ " tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.\n",
335
+ "]\n",
336
+ "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n",
337
+ "tflite_model = converter.convert()\n",
338
+ "\n",
339
+ "# Save the model\n",
340
+ "with open(tflite_model_path, 'wb') as f:\n",
341
+ " f.write(tflite_model)"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {
348
+ "id": "u9MustgMU7oI"
349
+ },
350
+ "outputs": [],
351
+ "source": [
352
+ "# loaded model... now with generate!\n",
353
+ "tflite_model_path = 'whisper.tflite'\n",
354
+ "interpreter = tf.lite.Interpreter(tflite_model_path)\n",
355
+ "\n",
356
+ "tflite_generate = interpreter.get_signature_runner('serving_default')\n",
357
+ "generated_ids = tflite_generate(input_features=input_features)[\"sequences\"]\n",
358
+ "transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]\n",
359
+ "transcription\n",
360
+ "\n",
361
+ "\n"
362
+ ]
363
+ }
364
+ ],
365
+ "metadata": {
366
+ "colab": {
367
+ "machine_shape": "hm",
368
+ "provenance": []
369
+ },
370
+ "kernelspec": {
371
+ "display_name": "Python 3",
372
+ "name": "python3"
373
+ },
374
+ "language_info": {
375
+ "name": "python"
376
+ }
377
+ },
378
+ "nbformat": 4,
379
+ "nbformat_minor": 0
380
+ }