|
|
|
|
|
|
|
import os |
|
import torch |
|
|
|
from huggingface_hub import login as hf_login |
|
from datasets import load_dataset |
|
from peft import LoraConfig |
|
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration, TrainingArguments, Trainer |
|
from datasets.utils.logging import disable_progress_bar |
|
disable_progress_bar() |
|
|
|
HF_TOKEN = "" |
|
|
|
if os.environ.get('HF_TOKEN') is not None: |
|
HF_TOKEN = os.environ.get('HF_TOKEN') |
|
print(f"Hugging Face token found in environment variable") |
|
|
|
hf_login( |
|
token=HF_TOKEN, |
|
add_to_git_credential=True |
|
) |
|
dataset_id = "eltorio/ROCO-radiology" |
|
prompt= "You are an expert radiologist certified with over 15 years of experience in diagnostic imaging, describe this image" |
|
source_model_id = "HuggingFaceM4/Idefics3-8B-Llama3" |
|
destination_model_id = "eltorio/ROCO-idefics3-8B" |
|
output_dir = "IDEFICS3_ROCO" |
|
cache_dir = "/workspace/data" |
|
train_dataset = load_dataset(dataset_id, split="train", cache_dir=cache_dir) |
|
|
|
DEVICE = "cuda:0" |
|
USE_LORA = False |
|
USE_QLORA = True |
|
|
|
processor = AutoProcessor.from_pretrained( |
|
source_model_id, |
|
do_image_splitting=False |
|
) |
|
|
|
if USE_QLORA or USE_LORA: |
|
lora_config = LoraConfig( |
|
r=8, |
|
lora_alpha=8, |
|
lora_dropout=0.1, |
|
target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$', |
|
use_dora=False if USE_QLORA else True, |
|
init_lora_weights="gaussian" |
|
) |
|
if USE_QLORA: |
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
model = Idefics3ForConditionalGeneration.from_pretrained( |
|
source_model_id, |
|
torch_dtype=torch.float16, |
|
quantization_config=bnb_config if USE_QLORA else None, |
|
) |
|
model.add_adapter(lora_config) |
|
model.enable_adapters() |
|
else: |
|
model = Idefics3ForConditionalGeneration.from_pretrained( |
|
source_model_id, |
|
torch_dtype=torch.float16, |
|
_attn_implementation="flash_attention_2", |
|
).to(DEVICE) |
|
|
|
class MyDataCollator: |
|
def __init__(self, processor): |
|
self.processor = processor |
|
self.image_token_id = processor.tokenizer.additional_special_tokens_ids[ |
|
processor.tokenizer.additional_special_tokens.index("<image>") |
|
] |
|
|
|
def __call__(self, samples): |
|
texts = [] |
|
images = [] |
|
for sample in samples: |
|
image = sample["image"] |
|
answer = sample["caption"] |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": [ |
|
{"type": "text", "text": prompt} |
|
] |
|
|
|
}, |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
] |
|
}, |
|
{ |
|
"role": "assistant", |
|
"content": [ |
|
{"type": "text", "text": answer} |
|
] |
|
} |
|
] |
|
text = processor.apply_chat_template(messages, add_generation_prompt=False) |
|
texts.append(text.strip()) |
|
images.append([image.convert('RGB')]) |
|
|
|
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) |
|
|
|
labels = batch["input_ids"].clone() |
|
labels[labels == processor.tokenizer.pad_token_id] = self.image_token_id |
|
batch["labels"] = labels |
|
|
|
return batch |
|
|
|
data_collator = MyDataCollator(processor) |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir = output_dir, |
|
overwrite_output_dir = False, |
|
auto_find_batch_size = True, |
|
learning_rate = 2e-4, |
|
fp16 = True, |
|
per_device_train_batch_size = 2, |
|
per_device_eval_batch_size = 2, |
|
gradient_accumulation_steps = 8, |
|
dataloader_pin_memory = False, |
|
save_total_limit = 3, |
|
evaluation_strategy = None, |
|
save_strategy = "steps", |
|
eval_steps = 100, |
|
save_steps = 10, |
|
resume_from_checkpoint = True, |
|
logging_steps = 5, |
|
remove_unused_columns = False, |
|
push_to_hub = True, |
|
label_names = ["labels"], |
|
load_best_model_at_end = False, |
|
report_to = "none", |
|
optim = "paged_adamw_8bit", |
|
) |
|
|
|
trainer = Trainer( |
|
model = model, |
|
args = training_args, |
|
data_collator = data_collator, |
|
train_dataset = train_dataset, |
|
) |
|
|
|
trainer.train() |