File size: 7,831 Bytes
9105935
05fffb5
ce24f5e
9105935
949a27b
8d959a7
ce24f5e
 
247825b
ce24f5e
 
 
247825b
ce24f5e
f2a2029
ce24f5e
 
5159d00
 
a6028d3
 
ce24f5e
 
6045345
 
 
 
ce24f5e
a459383
77fca25
05fffb5
a6028d3
f2a2029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247825b
47ad389
247825b
 
 
 
 
 
 
9105935
87d7825
 
 
87e073d
9105935
6045345
d653859
9105935
247825b
d653859
 
9105935
d653859
 
 
 
 
 
247825b
d653859
 
 
 
 
 
 
 
 
 
 
 
 
949a27b
 
f2a2029
 
 
 
87d7825
 
 
f2a2029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce24f5e
f2a2029
2393801
ce24f5e
 
949a27b
f2a2029
 
ce24f5e
a6028d3
f2a2029
ce24f5e
 
f2a2029
 
 
 
 
 
 
 
ce24f5e
 
 
 
 
f2a2029
94f5e41
ce24f5e
 
a6028d3
 
 
ce24f5e
12de7b7
 
 
 
 
 
ce24f5e
 
2255bb7
 
87d7825
 
 
 
 
 
 
a6028d3
949a27b
 
77fca25
949a27b
 
 
9105935
 
 
 
6045345
 
87d7825
120e7df
6045345
 
 
949a27b
 
5159d00
949a27b
5159d00
949a27b
 
f2a2029
2df63ef
8d959a7
 
 
 
a459383
8d959a7
 
902dd0a
2255bb7
6045345
2255bb7
902dd0a
f2a2029
d1aed4c
 
 
 
 
8d959a7
a459383
0a472e1
 
 
 
 
 
 
 
ce24f5e
9105935
 
 
 
 
ce24f5e
a6028d3
ce24f5e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import importlib
import logging
import os
import pathlib
import random
import signal
import sys
from pathlib import Path
from typing import Optional

import fire
import torch
import transformers
import yaml
from attrdict import AttrDefault

# add src to the pythonpath so we don't need to pip install this
from axolotl.utils.tokenization import check_dataset_labels

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)

from axolotl.utils.data import load_prepare_datasets
from axolotl.utils.models import load_model
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.wandb import setup_wandb_env_vars

logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"


def choose_device(cfg):
    def get_device():
        if torch.cuda.is_available():
            return "cuda"
        else:
            try:
                if torch.backends.mps.is_available():
                    return "mps"
            except:
                return "cpu"

    cfg.device = get_device()
    if cfg.device == "cuda":
        cfg.device_map = {"": cfg.local_rank}
    else:
        cfg.device_map = {"": cfg.device}


def get_multi_line_input() -> Optional[str]:
    print("Give me an instruction (Ctrl + D to finish): ")
    instruction = ""
    for line in sys.stdin:
        instruction += line
    # instruction = pathlib.Path("/proc/self/fd/0").read_text()
    return instruction


def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
    tokenizer.add_special_tokens({"unk_token": "<unk>"})
    tokenizer.add_special_tokens({"bos_token": "<s>"})
    tokenizer.add_special_tokens({"eos_token": "</s>"})

    prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)

    while True:
        # support for multiline inputs
        instruction = get_multi_line_input()
        if not instruction:
            return
        prompt = prompter_module().build_prompt(instruction=instruction)
        batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

        model.eval()
        with torch.no_grad():
            # gc = GenerationConfig()  # TODO swap out and use this
            generated = model.generate(
                inputs=batch["input_ids"].to(cfg.device),
                do_sample=True,
                use_cache=True,
                repetition_penalty=1.1,
                max_new_tokens=100,
                temperature=0.9,
                top_p=0.95,
                top_k=40,
                return_dict_in_generate=True,
                output_attentions=False,
                output_hidden_states=False,
                output_scores=False,
            )
        print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))


def choose_config(path: Path):
    yaml_files = [file for file in path.glob("*.yml")]

    if not yaml_files:
        raise ValueError(
            "No YAML config files found in the specified directory. Are you using a .yml extension?"
        )

    print("Choose a YAML file:")
    for idx, file in enumerate(yaml_files):
        print(f"{idx + 1}. {file}")

    chosen_file = None
    while chosen_file is None:
        try:
            choice = int(input("Enter the number of your choice: "))
            if 1 <= choice <= len(yaml_files):
                chosen_file = yaml_files[choice - 1]
            else:
                print("Invalid choice. Please choose a number from the list.")
        except ValueError:
            print("Invalid input. Please enter a number.")

    return chosen_file


def train(
    config: Path = Path("configs/"),
    prepare_ds_only: bool = False,
    **kwargs,
):
    if Path(config).is_dir():
        config = choose_config(config)

    # load the config from the yaml file
    with open(config, "r") as f:
        cfg: AttrDefault = AttrDefault(lambda: None, yaml.load(f, Loader=yaml.Loader))
    # if there are any options passed in the cli, if it is something that seems valid from the yaml,
    # then overwrite the value
    cfg_keys = dict(cfg).keys()
    for k in kwargs:
        if k in cfg_keys:
            # handle booleans
            if isinstance(cfg[k], bool):
                cfg[k] = bool(kwargs[k])
            else:
                cfg[k] = kwargs[k]

    # setup some derived config / hyperparams
    cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
    cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
    cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
    choose_device(cfg)
    cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
    if cfg.ddp:
        cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
        cfg.gradient_accumulation_steps = (
            cfg.gradient_accumulation_steps // cfg.world_size
        )
    setup_wandb_env_vars(cfg)
    if cfg.device == "mps":
        cfg.load_in_8bit = False
        cfg.tf32 = False
        if cfg.bf16:
            cfg.fp16 = True
        cfg.bf16 = False

    # Load the model and tokenizer
    logging.info("loading model, tokenizer, and peft_config...")
    model, tokenizer, peft_config = load_model(
        cfg.base_model,
        cfg.base_model_config,
        cfg.model_type,
        cfg.tokenizer_type,
        cfg,
        adapter=cfg.adapter,
        inference=("inference" in kwargs),
    )

    if "inference" in kwargs:
        logging.info("calling do_inference function")
        do_inference(cfg, model, tokenizer)
        return

    if "shard" in kwargs:
        model.save_pretrained(cfg.output_dir)
        return

    train_dataset, eval_dataset = load_prepare_datasets(
        tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
    )

    if prepare_ds_only:
        logging.info("Finished preparing dataset. Exiting...")
        return

    if cfg.debug:
        logging.info("check_dataset_labels...")
        check_dataset_labels(
            train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
            tokenizer,
        )

    trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)

    model.config.use_cache = False

    if torch.__version__ >= "2" and sys.platform != "win32":
        logging.info("Compiling torch model")
        model = torch.compile(model)

    # go ahead and presave, so we have the adapter config available to inspect
    if peft_config:
        logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
        peft_config.save_pretrained(cfg.output_dir)

    # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
    if cfg.local_rank == 0:
        signal.signal(
            signal.SIGINT,
            lambda signal, frame: (model.save_pretrained(cfg.output_dir), exit(0)),
        )

    logging.info("Starting trainer...")
    resume_from_checkpoint = cfg.resume_from_checkpoint
    if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
        possible_checkpoints = [str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")]
        if len(possible_checkpoints) > 0:
            sorted_paths = sorted(possible_checkpoints, key=lambda path: int(path.split('-')[-1]))
            resume_from_checkpoint = sorted_paths[-1]
            logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)

    logging.info(
        f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
    )
    # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
    trainer.save_model(cfg.output_dir)


if __name__ == "__main__":
    fire.Fire(train)