Maverick17
commited on
Update README.md
Browse filesAdded finetuning script description
README.md
CHANGED
@@ -20,6 +20,163 @@ should probably proofread and complete it, then remove this comment. -->
|
|
20 |
|
21 |
This model is a fine-tuned version of [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) on https://huggingface.co/datasets/Agent-Eval-Refine/GUI-Dense-Descriptions dataset
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
## Intended usage
|
24 |
|
25 |
```python
|
|
|
20 |
|
21 |
This model is a fine-tuned version of [HuggingFaceM4/Idefics3-8B-Llama3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) on https://huggingface.co/datasets/Agent-Eval-Refine/GUI-Dense-Descriptions dataset
|
22 |
|
23 |
+
## Finetuning script
|
24 |
+
|
25 |
+
```python
|
26 |
+
# !pip install git+https://github.com/andimarafioti/transformers.git@e1b7c0a05ab65e4ddb62a407fe12f8ec13a916f0"
|
27 |
+
# !pip install accelerate datasets peft bitsandbytes
|
28 |
+
# !pip install flash-attn --no-build-isolation
|
29 |
+
|
30 |
+
import pandas as pd
|
31 |
+
import torch
|
32 |
+
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
|
33 |
+
from transformers import (
|
34 |
+
AutoProcessor,
|
35 |
+
BitsAndBytesConfig,
|
36 |
+
Idefics3ForConditionalGeneration,
|
37 |
+
)
|
38 |
+
import os
|
39 |
+
from PIL import Image
|
40 |
+
from datasets import load_dataset
|
41 |
+
from transformers import TrainingArguments, Trainer
|
42 |
+
from huggingface_hub import notebook_login
|
43 |
+
|
44 |
+
notebook_login()
|
45 |
+
|
46 |
+
gui_dense_desc_dataset = load_dataset("Agent-Eval-Refine/GUI-Dense-Descriptions")
|
47 |
+
train_ds = gui_dense_desc_dataset["train"]
|
48 |
+
|
49 |
+
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
50 |
+
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
51 |
+
|
52 |
+
USE_LORA = False
|
53 |
+
USE_QLORA = True
|
54 |
+
model_id = "HuggingFaceM4/Idefics3-8B-Llama3"
|
55 |
+
|
56 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
57 |
+
|
58 |
+
if USE_QLORA or USE_LORA:
|
59 |
+
lora_config = LoraConfig(
|
60 |
+
r=8,
|
61 |
+
lora_alpha=8,
|
62 |
+
lora_dropout=0.1,
|
63 |
+
target_modules=[
|
64 |
+
"down_proj",
|
65 |
+
"o_proj",
|
66 |
+
"k_proj",
|
67 |
+
"q_proj",
|
68 |
+
"gate_proj",
|
69 |
+
"up_proj",
|
70 |
+
"v_proj",
|
71 |
+
],
|
72 |
+
use_dora=False if USE_QLORA else True,
|
73 |
+
init_lora_weights="gaussian",
|
74 |
+
)
|
75 |
+
lora_config.inference_mode = False
|
76 |
+
if USE_QLORA:
|
77 |
+
bnb_config = BitsAndBytesConfig(
|
78 |
+
load_in_4bit=True,
|
79 |
+
bnb_4bit_use_double_quant=True,
|
80 |
+
bnb_4bit_quant_type="nf4",
|
81 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
82 |
+
)
|
83 |
+
|
84 |
+
model = Idefics3ForConditionalGeneration.from_pretrained(
|
85 |
+
model_id,
|
86 |
+
quantization_config=bnb_config if USE_QLORA else None,
|
87 |
+
_attn_implementation="flash_attention_2",
|
88 |
+
device_map="auto",
|
89 |
+
torch_dtype=torch.bfloat16,
|
90 |
+
)
|
91 |
+
model.add_adapter(lora_config)
|
92 |
+
model.enable_adapters()
|
93 |
+
model = prepare_model_for_kbit_training(model)
|
94 |
+
model = get_peft_model(model, lora_config)
|
95 |
+
print(model.get_nb_trainable_parameters())
|
96 |
+
else:
|
97 |
+
model = Idefics3ForConditionalGeneration.from_pretrained(
|
98 |
+
model_id,
|
99 |
+
torch_dtype=torch.bfloat16,
|
100 |
+
_attn_implementation="flash_attention_2",
|
101 |
+
device_map="auto",
|
102 |
+
)
|
103 |
+
|
104 |
+
# if you'd like to only fine-tune LLM
|
105 |
+
for param in model.model.vision_model.parameters():
|
106 |
+
param.requires_grad = False
|
107 |
+
|
108 |
+
image_token_id = processor.tokenizer.additional_special_tokens_ids[
|
109 |
+
processor.tokenizer.additional_special_tokens.index("<image>")
|
110 |
+
]
|
111 |
+
|
112 |
+
|
113 |
+
def collate_fn(examples):
|
114 |
+
texts = []
|
115 |
+
images = []
|
116 |
+
for example in examples:
|
117 |
+
image = example["image"]
|
118 |
+
image_description = example["text"]
|
119 |
+
messages = [
|
120 |
+
{
|
121 |
+
"role": "user",
|
122 |
+
"content": [
|
123 |
+
{"type": "image"},
|
124 |
+
{
|
125 |
+
"type": "text",
|
126 |
+
"text": "Provide a detailed description of the image.",
|
127 |
+
},
|
128 |
+
],
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"role": "assistant",
|
132 |
+
"content": [{"type": "text", "text": image_description}],
|
133 |
+
},
|
134 |
+
]
|
135 |
+
text = processor.apply_chat_template(messages, add_generation_prompt=False)
|
136 |
+
texts.append(text.strip())
|
137 |
+
images.append([image])
|
138 |
+
|
139 |
+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
|
140 |
+
labels = batch["input_ids"].clone()
|
141 |
+
labels[labels == processor.tokenizer.pad_token_id] = -100
|
142 |
+
labels[labels == image_token_id] = -100
|
143 |
+
batch["labels"] = labels
|
144 |
+
|
145 |
+
return batch
|
146 |
+
|
147 |
+
training_args = TrainingArguments(
|
148 |
+
num_train_epochs=1,
|
149 |
+
per_device_train_batch_size=2,
|
150 |
+
gradient_accumulation_steps=8,
|
151 |
+
warmup_steps=50,
|
152 |
+
learning_rate=1e-4,
|
153 |
+
weight_decay=0.01,
|
154 |
+
logging_steps=5,
|
155 |
+
save_strategy="steps",
|
156 |
+
save_steps=250,
|
157 |
+
save_total_limit=1,
|
158 |
+
optim="adamw_torch",
|
159 |
+
bf16=True,
|
160 |
+
output_dir="./idefics3-llama-gui-dense-descriptions",
|
161 |
+
hub_model_id="idefics3-llama-gui-dense-descriptions",
|
162 |
+
remove_unused_columns=False,
|
163 |
+
)
|
164 |
+
|
165 |
+
trainer = Trainer(
|
166 |
+
model=model,
|
167 |
+
args=training_args,
|
168 |
+
data_collator=collate_fn,
|
169 |
+
train_dataset=train_ds,
|
170 |
+
)
|
171 |
+
|
172 |
+
trainer.train()
|
173 |
+
|
174 |
+
trainer.push_to_hub()
|
175 |
+
|
176 |
+
```
|
177 |
+
|
178 |
+
Training took approx. 40 min. on 2xH100 (80 Gb each) devices.
|
179 |
+
|
180 |
## Intended usage
|
181 |
|
182 |
```python
|