{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"source": [
"## PaliGemma Fine-tuning\n",
"\n",
"In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries."
],
"metadata": {
"id": "m8t6tkjuuONX"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FrKEBkmJtMan"
},
"outputs": [],
"source": [
"!pip install -q -U git+https://github.com/huggingface/transformers.git datasets accelerate"
]
},
{
"cell_type": "markdown",
"source": [
"We will authenticate to access the model using `notebook_login()`."
],
"metadata": {
"id": "q_85okyYt1eo"
}
},
{
"cell_type": "code",
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
],
"metadata": {
"id": "NzJZSHD8tZZy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Let's load the dataset."
],
"metadata": {
"id": "9_jUBDTEuw1j"
}
},
{
"cell_type": "code",
"source": [
"from datasets import load_dataset\n",
"ds = load_dataset('HuggingFaceM4/VQAv2', split=\"train[:10%]\")\n"
],
"metadata": {
"id": "az5kdSbNpjgH"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"cols_remove = [\"question_type\", \"answers\", \"answer_type\", \"image_id\", \"question_id\"]\n",
"ds = ds.remove_columns(cols_remove)"
],
"metadata": {
"id": "GEsDnBNmppIJ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"split_ds = ds.train_test_split(test_size=0.05) # we'll use a very small split for demo\n",
"train_ds = split_ds[\"test\"]"
],
"metadata": {
"id": "wN1c9Aqhqt47"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_ds"
],
"metadata": {
"id": "TNJW2ty4yy4L"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Load the processor to preprocess the dataset."
],
"metadata": {
"id": "OsquATWQu2lJ"
}
},
{
"cell_type": "code",
"source": [
"from transformers import PaliGemmaProcessor\n",
"model_id = \"google/paligemma-3b-pt-224\"\n",
"processor = PaliGemmaProcessor.from_pretrained(model_id)"
],
"metadata": {
"id": "Zya_PWM3uBWs"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We will preprocess our examples. We need to prepare a prompt template and pass the text input inside, pass it with batches of images to processor. Then we will set the pad tokens and image tokens to -100 to let the model ignore them. We will pass our preprocessed input as labels to make the model learn how to generate responses."
],
"metadata": {
"id": "QZROnV-pu7rt"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"device = \"cuda\"\n",
"\n",
"image_token = processor.tokenizer.convert_tokens_to_ids(\"\")\n",
"def collate_fn(examples):\n",
" texts = [\"answer \" + example[\"question\"] for example in examples]\n",
" labels= [example['multiple_choice_answer'] for example in examples]\n",
" images = [example[\"image\"].convert(\"RGB\") for example in examples]\n",
" tokens = processor(text=texts, images=images, suffix=labels,\n",
" return_tensors=\"pt\", padding=\"longest\",\n",
" tokenize_newline_separately=False)\n",
"\n",
" tokens = tokens.to(torch.bfloat16).to(device)\n",
" return tokens\n"
],
"metadata": {
"id": "hdw3uBcNuGmw"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Our dataset is a very general one and similar to many datasets that PaliGemma was trained with. In this case, we do not need to fine-tune the image encoder, the multimodal projector but we will only fine-tune the text decoder."
],
"metadata": {
"id": "Hi_Y1blXwA04"
}
},
{
"cell_type": "code",
"source": [
"from transformers import PaliGemmaForConditionalGeneration\n",
"import torch\n",
"\n",
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)\n",
"\n",
"for param in model.vision_tower.parameters():\n",
" param.requires_grad = False\n",
"\n",
"for param in model.multi_modal_projector.parameters():\n",
" param.requires_grad = False\n"
],
"metadata": {
"id": "iZRvrfUquH1y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Alternatively, if you want to do LoRA and QLoRA fine-tuning, you can run below cells to load the adapter either in full precision or quantized."
],
"metadata": {
"id": "uCiVI-xUwSJm"
}
},
{
"cell_type": "code",
"source": [
"from transformers import BitsAndBytesConfig\n",
"from peft import get_peft_model, LoraConfig\n",
"\n",
"bnb_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_compute_type=torch.bfloat16\n",
")\n",
"\n",
"lora_config = LoraConfig(\n",
" r=8,\n",
" target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
" task_type=\"CAUSAL_LM\",\n",
")\n",
"model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={\"\":0})\n",
"model = get_peft_model(model, lora_config)\n",
"model.print_trainable_parameters()\n",
"#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344\n"
],
"metadata": {
"id": "9AYeuyzNuJ9X"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We will now initialize the `TrainingArguments`."
],
"metadata": {
"id": "logv0oLqwbIe"
}
},
{
"cell_type": "code",
"source": [
"from transformers import TrainingArguments\n",
"args=TrainingArguments(\n",
" num_train_epochs=2,\n",
" remove_unused_columns=False,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=2,\n",
" learning_rate=2e-5,\n",
" weight_decay=1e-6,\n",
" adam_beta2=0.999,\n",
" logging_steps=100,\n",
" optim=\"adamw_hf\",\n",
" save_strategy=\"steps\",\n",
" save_steps=1000,\n",
" push_to_hub=True,\n",
" save_total_limit=1,\n",
" output_dir=\"paligemma_vqav2\",\n",
" bf16=True,\n",
" report_to=[\"tensorboard\"],\n",
" dataloader_pin_memory=False\n",
" )\n"
],
"metadata": {
"id": "Il7zKQO9uMPT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"We can now start training."
],
"metadata": {
"id": "8pR0EaGlwrDp"
}
},
{
"cell_type": "code",
"source": [
"from transformers import Trainer\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" train_dataset=train_ds ,\n",
" data_collator=collate_fn,\n",
" args=args\n",
" )\n"
],
"metadata": {
"id": "CguCGDv1uNkF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer.train()"
],
"metadata": {
"id": "9KFPQLrnF2Ha"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"trainer.push_to_hub()"
],
"metadata": {
"id": "O9fMDEjXSSzF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing)."
],
"metadata": {
"id": "JohfxEJQjLBd"
}
}
]
}