"""Prepare and train a model on a dataset. Can also infer from a model or merge lora""" import logging import os import signal import sys from dataclasses import dataclass from pathlib import Path from typing import Optional import torch import transformers.modelcard from datasets import Dataset from optimum.bettertransformer import BetterTransformer from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.trainer import setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) configure_logging() LOG = logging.getLogger("axolotl.train") @dataclass class TrainDatasetMeta: """ dataclass to capture the dataset specific options for training """ train_dataset: Dataset eval_dataset: Optional[Dataset] = None total_num_steps: Optional[int] = None def train( *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta, ): # load the tokenizer first LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) train_dataset = dataset_meta.train_dataset eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps # Load the model and tokenizer LOG.info("loading model and (optionally) peft_config...") model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) safe_serialization = cfg.save_safetensors is True if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") ] if len(possible_checkpoints) > 0: sorted_paths = sorted( possible_checkpoints, key=lambda path: int(path.split("-")[-1]), ) cfg.resume_from_checkpoint = sorted_paths[-1] LOG.info( f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" ) resume_from_checkpoint = cfg.resume_from_checkpoint trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps ) model.config.use_cache = False # go ahead and presave, so we have the adapter config available to inspect if peft_config: LOG.info(f"Pre-saving adapter config to {cfg.output_dir}") peft_config.save_pretrained(cfg.output_dir) # additionally presave the tokenizer and model configs if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) tokenizer.save_pretrained(str(Path(cfg.output_dir))) model.config.save_pretrained(str(Path(cfg.output_dir))) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: def terminate_handler(_, __, model): if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) sys.exit(0) signal.signal( signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) badge_markdown = """[Built with Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl)""" transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}" LOG.info("Starting trainer...") if cfg.group_by_length: LOG.info("hang tight... sorting dataset for group_by_length") if cfg.flash_optimum: with torch.backends.cuda.sdp_kernel( enable_flash=True, enable_math=True, enable_mem_efficient=True ): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") if trainer.is_fsdp_enabled: trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") if cfg.relora_steps: if cfg.adapter == "lora" and not (cfg.load_in_4bit or cfg.load_in_8bit): model = model.merge_and_unload() else: # final model weights have already been saved by `ReLoRACallback.on_train_end` return model, tokenizer # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.fsdp: trainer.save_model(cfg.output_dir) elif cfg.local_rank == 0: if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) if not cfg.hub_model_id: trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) return model, tokenizer