|
"""Callbacks for Trainer class""" |
|
|
|
import os |
|
|
|
from transformers import ( |
|
TrainerCallback, |
|
TrainingArguments, |
|
TrainerState, |
|
TrainerControl, |
|
) |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
|
|
|
|
class SavePeftModelCallback(TrainerCallback): |
|
"""Callback to save the PEFT adapter""" |
|
|
|
def on_save( |
|
self, |
|
args: TrainingArguments, |
|
state: TrainerState, |
|
control: TrainerControl, |
|
**kwargs, |
|
): |
|
checkpoint_folder = os.path.join( |
|
args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" |
|
) |
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") |
|
kwargs["model"].save_pretrained(peft_model_path) |
|
|
|
return control |
|
|