import os | |
from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl | |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR | |
class SavePeftModelCallback(TrainerCallback): | |
def on_save( | |
self, | |
args: TrainingArguments, | |
state: TrainerState, | |
control: TrainerControl, | |
**kwargs, | |
): | |
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") | |
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") | |
kwargs["model"].save_pretrained(peft_model_path) | |
return control | |