File size: 9,549 Bytes
392dfd9 9105935 05fffb5 ce24f5e 949a27b 8d959a7 ce24f5e cc67862 ce24f5e 5159d00 1f5d83e 8bd7a49 5159d00 6045345 32e6fe9 6045345 ce24f5e 392dfd9 a459383 77fca25 05fffb5 a6028d3 f2a2029 1d5ab84 f2a2029 392dfd9 f2a2029 247825b 47ad389 247825b 9105935 87d7825 87e073d 9105935 6045345 d653859 9105935 247825b d653859 b46bc02 d653859 392dfd9 d653859 247825b d653859 56f9ca5 d653859 949a27b f2a2029 87d7825 f2a2029 cc67862 ce24f5e f2a2029 2393801 ce24f5e 949a27b f2a2029 ce24f5e 392dfd9 ce24f5e 93acb64 392dfd9 1d5ab84 f2a2029 ce24f5e f2a2029 94f5e41 ce24f5e a6028d3 ce24f5e 12de7b7 ce24f5e 1f5d83e 32e6fe9 392dfd9 32e6fe9 392dfd9 32e6fe9 21f17cc 32e6fe9 ce24f5e 32e6fe9 87d7825 32e6fe9 87d7825 a6028d3 949a27b 1d5ab84 ae1719d 1d5ab84 a5bf838 1d5ab84 3457810 1d5ab84 949a27b 77fca25 949a27b 9105935 2df63ef 8d959a7 a459383 8d959a7 902dd0a 2255bb7 6045345 2255bb7 902dd0a f2a2029 d1aed4c 8d959a7 a459383 1d5ab84 0a472e1 2bc1a5b 0a472e1 2bc1a5b 0a472e1 2bc1a5b 0a472e1 ce24f5e 2bc1a5b 915c56c 392dfd9 915c56c bdbca8f 392dfd9 915c56c ce24f5e a6028d3 ce24f5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 |
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import logging
import os
import random
import signal
import sys
from pathlib import Path
from typing import Optional, List, Dict, Any, Union
import fire
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.validation import validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
if torch.cuda.is_available():
return f"cuda:{cfg.local_rank}"
else:
try:
if torch.backends.mps.is_available():
return "mps"
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this # pylint: disable=fixme
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = [file for file in path.glob("*.yml")]
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.load(file, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or cfg.strict is False:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
validate_config(cfg)
# load the tokenizer first
logging.info("loading tokenizer...")
tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
if check_not_in(
["inference", "shard", "merge_lora"], kwargs
): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if cfg.debug or "debug" in kwargs:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[random.randrange(0, len(train_dataset) - 1) for i in range(5)]
),
tokenizer,
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
# Load the model and tokenizer
logging.info("loading model and peft_config...")
model, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
tokenizer,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
)
if "merge_lora" in kwargs and cfg.adapter is not None:
logging.info("running merge of LoRA with base model")
model = model.merge_and_unload()
model.to(dtype=torch.float16)
if cfg.local_rank == 0:
logging.info("saving merged model")
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
return
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
logging.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
)
logging.info("Starting trainer...")
if cfg.group_by_length:
logging.info("hang tight... sorting dataset for group_by_length")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints, key=lambda path: int(path.split("-")[-1])
)
resume_from_checkpoint = sorted_paths[-1]
logging.info(
f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# pylint: disable=fixme
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.local_rank == 0:
model.save_pretrained(cfg.output_dir)
# pylint: disable=fixme
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__":
fire.Fire(train)
|