File size: 7,961 Bytes
9105935 05fffb5 ce24f5e 9105935 949a27b 8d959a7 ce24f5e 247825b ce24f5e 247825b ce24f5e f2a2029 ce24f5e 5159d00 a6028d3 ce24f5e 6045345 ce24f5e a459383 77fca25 05fffb5 a6028d3 f2a2029 247825b 47ad389 247825b 9105935 87d7825 87e073d 9105935 6045345 d653859 9105935 247825b d653859 9105935 d653859 247825b d653859 949a27b f2a2029 87d7825 f2a2029 ce24f5e f2a2029 2393801 ce24f5e 949a27b f2a2029 ce24f5e a6028d3 f2a2029 ce24f5e f2a2029 ce24f5e f2a2029 94f5e41 ce24f5e a6028d3 ce24f5e 12de7b7 ce24f5e 2255bb7 87d7825 a6028d3 949a27b 77fca25 949a27b 9105935 6045345 87d7825 120e7df 6045345 949a27b 5159d00 949a27b 5159d00 949a27b f2a2029 2df63ef 8d959a7 a459383 8d959a7 902dd0a 2255bb7 6045345 2255bb7 902dd0a f2a2029 d1aed4c 8d959a7 a459383 0a472e1 ce24f5e 9105935 71a1f7f 915c56c ce24f5e a6028d3 ce24f5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
import importlib
import logging
import os
import pathlib
import random
import signal
import sys
from pathlib import Path
from typing import Optional
import fire
import torch
import transformers
import yaml
from attrdict import AttrDefault
# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def choose_device(cfg):
def get_device():
if torch.cuda.is_available():
return "cuda"
else:
try:
if torch.backends.mps.is_available():
return "mps"
except:
return "cpu"
cfg.device = get_device()
if cfg.device == "cuda":
cfg.device_map = {"": cfg.local_rank}
else:
cfg.device_map = {"": cfg.device}
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to finish): ")
instruction = ""
for line in sys.stdin:
instruction += line
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
tokenizer.add_special_tokens({"unk_token": "<unk>"})
tokenizer.add_special_tokens({"bos_token": "<s>"})
tokenizer.add_special_tokens({"eos_token": "</s>"})
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
while True:
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
prompt = prompter_module().build_prompt(instruction=instruction)
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
# gc = GenerationConfig() # TODO swap out and use this
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
do_sample=True,
use_cache=True,
repetition_penalty=1.1,
max_new_tokens=100,
temperature=0.9,
top_p=0.95,
top_k=40,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def choose_config(path: Path):
yaml_files = [file for file in path.glob("*.yml")]
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = yaml_files[choice - 1]
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def train(
config: Path = Path("configs/"),
prepare_ds_only: bool = False,
**kwargs,
):
if Path(config).is_dir():
config = choose_config(config)
# load the config from the yaml file
with open(config, "r") as f:
cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = dict(cfg).keys()
for k in kwargs:
if k in cfg_keys:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
# setup some derived config / hyperparams
cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
choose_device(cfg)
cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
if cfg.ddp:
cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
cfg.gradient_accumulation_steps = (
cfg.gradient_accumulation_steps // cfg.world_size
)
setup_wandb_env_vars(cfg)
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False
if cfg.bf16:
cfg.fp16 = True
cfg.bf16 = False
# Load the model and tokenizer
logging.info("loading model, tokenizer, and peft_config...")
model, tokenizer, peft_config = load_model(
cfg.base_model,
cfg.base_model_config,
cfg.model_type,
cfg.tokenizer_type,
cfg,
adapter=cfg.adapter,
inference=("inference" in kwargs),
)
if "inference" in kwargs:
logging.info("calling do_inference function")
do_inference(cfg, model, tokenizer)
return
if "shard" in kwargs:
model.save_pretrained(cfg.output_dir)
return
train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
if prepare_ds_only:
logging.info("Finished preparing dataset. Exiting...")
return
if cfg.debug:
logging.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
tokenizer,
)
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
model.config.use_cache = False
if torch.__version__ >= "2" and sys.platform != "win32":
logging.info("Compiling torch model")
model = torch.compile(model)
# go ahead and presave, so we have the adapter config available to inspect
if peft_config:
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
peft_config.save_pretrained(cfg.output_dir)
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
if cfg.local_rank == 0:
signal.signal(
signal.SIGINT,
lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
)
logging.info("Starting trainer...")
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
resume_from_checkpoint = sorted_paths[-1]
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
logging.info(
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
)
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
trainer.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__":
fire.Fire(train)
|