ndkhanh95 commited on
Commit
62a8d6e
·
verified ·
1 Parent(s): 4417b86

Upload Fine_tune_PaliGemma_NoJax.ipynb

Browse files
Files changed (1) hide show
  1. Fine_tune_PaliGemma_NoJax.ipynb +351 -0
Fine_tune_PaliGemma_NoJax.ipynb ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "private_outputs": true,
7
+ "provenance": [],
8
+ "machine_shape": "hm",
9
+ "gpuType": "A100",
10
+ "include_colab_link": true
11
+ },
12
+ "kernelspec": {
13
+ "name": "python3",
14
+ "display_name": "Python 3"
15
+ },
16
+ "language_info": {
17
+ "name": "python"
18
+ },
19
+ "accelerator": "GPU"
20
+ },
21
+ "cells": [
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "view-in-github",
26
+ "colab_type": "text"
27
+ },
28
+ "source": [
29
+ "<a href=\"https://colab.research.google.com/github/merveenoyan/smol-vision/blob/main/Fine_tune_PaliGemma.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "source": [
35
+ "## PaliGemma Fine-tuning\n",
36
+ "\n",
37
+ "In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries."
38
+ ],
39
+ "metadata": {
40
+ "id": "m8t6tkjuuONX"
41
+ }
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "FrKEBkmJtMan"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate"
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "markdown",
56
+ "source": [
57
+ "We will authenticate to access the model using `notebook_login()`."
58
+ ],
59
+ "metadata": {
60
+ "id": "q_85okyYt1eo"
61
+ }
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "source": [
66
+ "from huggingface_hub import notebook_login\n",
67
+ "notebook_login()"
68
+ ],
69
+ "metadata": {
70
+ "id": "NzJZSHD8tZZy"
71
+ },
72
+ "execution_count": null,
73
+ "outputs": []
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "source": [
78
+ "Let's load the dataset."
79
+ ],
80
+ "metadata": {
81
+ "id": "9_jUBDTEuw1j"
82
+ }
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "source": [
87
+ "from datasets import load_dataset\n",
88
+ "ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train[:10%]\")\n"
89
+ ],
90
+ "metadata": {
91
+ "id": "az5kdSbNpjgH"
92
+ },
93
+ "execution_count": null,
94
+ "outputs": []
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "source": [
99
+ "cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"]\n",
100
+ "ds = ds.remove_columns(cols_remove)"
101
+ ],
102
+ "metadata": {
103
+ "id": "GEsDnBNmppIJ"
104
+ },
105
+ "execution_count": null,
106
+ "outputs": []
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "source": [
111
+ "split_ds = ds.train_test_split(test_size=0.05) # we'll use a very small split for demo\n",
112
+ "train_ds = split_ds[\"test\"]"
113
+ ],
114
+ "metadata": {
115
+ "id": "wN1c9Aqhqt47"
116
+ },
117
+ "execution_count": null,
118
+ "outputs": []
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "source": [
123
+ "train_ds"
124
+ ],
125
+ "metadata": {
126
+ "id": "TNJW2ty4yy4L"
127
+ },
128
+ "execution_count": null,
129
+ "outputs": []
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "source": [
134
+ "Load the processor to preprocess the dataset."
135
+ ],
136
+ "metadata": {
137
+ "id": "OsquATWQu2lJ"
138
+ }
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "source": [
143
+ "from transformers import PaliGemmaProcessor\n",
144
+ "model_id = \"google/paligemma-3b-pt-224\"\n",
145
+ "processor = PaliGemmaProcessor.from_pretrained(model_id)"
146
+ ],
147
+ "metadata": {
148
+ "id": "Zya_PWM3uBWs"
149
+ },
150
+ "execution_count": null,
151
+ "outputs": []
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "source": [
156
+ "We will preprocess our examples. We need to prepare a prompt template and pass the text input inside, pass it with batches of images to processor. Then we will set the pad tokens and image tokens to -100 to let the model ignore them. We will pass our preprocessed input as labels to make the model learn how to generate responses."
157
+ ],
158
+ "metadata": {
159
+ "id": "QZROnV-pu7rt"
160
+ }
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "source": [
165
+ "import torch\n",
166
+ "device = \"cuda\"\n",
167
+ "\n",
168
+ "image_token = processor.tokenizer.convert_tokens_to_ids(\"<image>\")\n",
169
+ "def collate_fn(examples):\n",
170
+ " texts = [\"answer \" + example[\"question\"] for example in examples]\n",
171
+ " labels= [example['multiple_choice_answer'] for example in examples]\n",
172
+ " images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
173
+ " tokens = processor(text=texts, images=images, suffix=labels,\n",
174
+ " return_tensors=\"pt\", padding=\"longest\",\n",
175
+ " tokenize_newline_separately=False)\n",
176
+ "\n",
177
+ " tokens = tokens.to(torch.bfloat16).to(device)\n",
178
+ " return tokens\n"
179
+ ],
180
+ "metadata": {
181
+ "id": "hdw3uBcNuGmw"
182
+ },
183
+ "execution_count": null,
184
+ "outputs": []
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "source": [
189
+ "Our dataset is a very general one and similar to many datasets that PaliGemma was trained with. In this case, we do not need to fine-tune the image encoder, the multimodal projector but we will only fine-tune the text decoder."
190
+ ],
191
+ "metadata": {
192
+ "id": "Hi_Y1blXwA04"
193
+ }
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "source": [
198
+ "from transformers import PaliGemmaForConditionalGeneration\n",
199
+ "import torch\n",
200
+ "\n",
201
+ "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
202
+ "\n",
203
+ "for param in model.vision_tower.parameters():\n",
204
+ " param.requires_grad = False\n",
205
+ "\n",
206
+ "for param in model.multi_modal_projector.parameters():\n",
207
+ " param.requires_grad = False\n"
208
+ ],
209
+ "metadata": {
210
+ "id": "iZRvrfUquH1y"
211
+ },
212
+ "execution_count": null,
213
+ "outputs": []
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "source": [
218
+ "Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized."
219
+ ],
220
+ "metadata": {
221
+ "id": "uCiVI-xUwSJm"
222
+ }
223
+ },
224
+ {
225
+ "cell_type": "code",
226
+ "source": [
227
+ "from transformers import BitsAndBytesConfig\n",
228
+ "from peft import get_peft_model, LoraConfig\n",
229
+ "\n",
230
+ "bnb_config = BitsAndBytesConfig(\n",
231
+ " load_in_4bit=True,\n",
232
+ " bnb_4bit_quant_type=\"nf4\",\n",
233
+ " bnb_4bit_compute_type=torch.bfloat16\n",
234
+ ")\n",
235
+ "\n",
236
+ "lora_config = LoraConfig(\n",
237
+ " r=8,\n",
238
+ " target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
239
+ " task_type=\"CAUSAL_LM\",\n",
240
+ ")\n",
241
+ "model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
242
+ "model = get_peft_model(model, lora_config)\n",
243
+ "model.print_trainable_parameters()\n",
244
+ "#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
245
+ ],
246
+ "metadata": {
247
+ "id": "9AYeuyzNuJ9X"
248
+ },
249
+ "execution_count": null,
250
+ "outputs": []
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "source": [
255
+ "We will now initialize the `TrainingArguments`."
256
+ ],
257
+ "metadata": {
258
+ "id": "logv0oLqwbIe"
259
+ }
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "source": [
264
+ "from transformers import TrainingArguments\n",
265
+ "args=TrainingArguments(\n",
266
+ " num_train_epochs=2,\n",
267
+ " remove_unused_columns=False,\n",
268
+ " per_device_train_batch_size=4,\n",
269
+ " gradient_accumulation_steps=4,\n",
270
+ " warmup_steps=2,\n",
271
+ " learning_rate=2e-5,\n",
272
+ " weight_decay=1e-6,\n",
273
+ " adam_beta2=0.999,\n",
274
+ " logging_steps=100,\n",
275
+ " optim=\"adamw_hf\",\n",
276
+ " save_strategy=\"steps\",\n",
277
+ " save_steps=1000,\n",
278
+ " push_to_hub=True,\n",
279
+ " save_total_limit=1,\n",
280
+ " output_dir=\"paligemma_vqav2\",\n",
281
+ " bf16=True,\n",
282
+ " report_to=[\"tensorboard\"],\n",
283
+ " dataloader_pin_memory=False\n",
284
+ " )\n"
285
+ ],
286
+ "metadata": {
287
+ "id": "Il7zKQO9uMPT"
288
+ },
289
+ "execution_count": null,
290
+ "outputs": []
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "source": [
295
+ "We can now start training."
296
+ ],
297
+ "metadata": {
298
+ "id": "8pR0EaGlwrDp"
299
+ }
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "source": [
304
+ "from transformers import Trainer\n",
305
+ "\n",
306
+ "trainer = Trainer(\n",
307
+ " model=model,\n",
308
+ " train_dataset=train_ds ,\n",
309
+ " data_collator=collate_fn,\n",
310
+ " args=args\n",
311
+ " )\n"
312
+ ],
313
+ "metadata": {
314
+ "id": "CguCGDv1uNkF"
315
+ },
316
+ "execution_count": null,
317
+ "outputs": []
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "source": [
322
+ "trainer.train()"
323
+ ],
324
+ "metadata": {
325
+ "id": "9KFPQLrnF2Ha"
326
+ },
327
+ "execution_count": null,
328
+ "outputs": []
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "source": [
333
+ "trainer.push_to_hub()"
334
+ ],
335
+ "metadata": {
336
+ "id": "O9fMDEjXSSzF"
337
+ },
338
+ "execution_count": null,
339
+ "outputs": []
340
+ },
341
+ {
342
+ "cell_type": "markdown",
343
+ "source": [
344
+ "You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing)."
345
+ ],
346
+ "metadata": {
347
+ "id": "JohfxEJQjLBd"
348
+ }
349
+ }
350
+ ]
351
+ }