File size: 16,903 Bytes
f4e5d86
 
 
6045345
bdbca8f
6045345
 
37293dc
6045345
ffd1043
6045345
 
1edc30c
1210dc8
1edc30c
39a208c
 
 
 
 
 
2bc1a5b
8d43785
e44c9e0
 
 
f4e5d86
2bc1a5b
 
 
6045345
 
 
 
f4e5d86
6045345
37293dc
 
6045345
32e6fe9
39a208c
32e6fe9
 
 
 
 
39a208c
933e970
32e6fe9
 
 
39a208c
ca1bb92
32e6fe9
 
 
 
 
 
 
37293dc
 
 
 
32e6fe9
 
 
 
 
 
 
f4e5d86
 
32e6fe9
 
 
 
 
 
6045345
 
 
 
32e6fe9
6045345
 
 
 
55b8542
f4e5d86
 
 
6045345
 
 
55b8542
2bc1a5b
 
6045345
55b8542
6045345
 
 
 
 
55b8542
6cb2310
2bc1a5b
 
 
8746b70
 
55b8542
6cb2310
 
 
 
 
 
55b8542
e44c9e0
55b8542
e44c9e0
55b8542
 
 
 
e44c9e0
 
6045345
a03a7d7
 
 
 
 
 
 
 
1210dc8
aef00b6
1edc30c
aef00b6
 
 
6045345
dd00657
6045345
 
 
 
 
f4e5d86
 
 
6045345
aefb2fc
 
 
 
 
 
 
 
3b4d055
1987e5c
3b4d055
 
 
 
e65aeed
3b4d055
 
 
6045345
55b8542
6045345
 
 
d653859
 
 
2bc1a5b
 
 
 
 
 
d653859
 
 
 
6045345
d653859
 
 
 
 
 
 
f4e5d86
d653859
32e6fe9
6045345
 
 
641f801
6045345
 
 
 
 
 
55b8542
2520ecd
9190ada
 
2520ecd
9190ada
e8aacfb
9190ada
e65aeed
3b4d055
9190ada
1d5ab84
 
 
 
 
 
 
56f9ca5
1d5ab84
56f9ca5
1d5ab84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94f5e41
6045345
 
6dfdd2d
e8aacfb
6045345
 
f4e5d86
3b4d055
6045345
94f5e41
e2e68c3
 
f4e5d86
e2e68c3
4ac9e25
 
7f09106
4ac9e25
7f09106
 
 
 
 
 
 
94f5e41
 
e2e68c3
6dfdd2d
e8aacfb
94f5e41
 
f4e5d86
3b4d055
94f5e41
f4e5d86
6045345
 
 
f4e5d86
6045345
 
6dfdd2d
6045345
 
f4e5d86
3b4d055
6045345
 
bdbca8f
 
aa3c3f9
f4e5d86
 
 
7b5e762
df9528f
 
6045345
 
 
94f5e41
6045345
 
dd00657
6045345
 
f4e5d86
 
 
6045345
f4e5d86
 
 
 
6045345
ce34d64
 
 
e65aeed
ce34d64
42410c7
 
 
cfcc549
 
ad2b48c
247825b
 
 
 
 
 
bdbca8f
ad2b48c
1edc30c
 
 
6045345
32e6fe9
6045345
 
 
8bd7a49
6045345
 
 
7b5e762
6045345
2255bb7
 
6045345
 
 
 
2255bb7
8bd7a49
37293dc
2255bb7
 
 
 
 
 
 
1d5ab84
1b3e401
2255bb7
 
 
 
 
 
 
 
 
 
 
 
 
 
ffd1043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6045345
8bd7a49
6045345
37293dc
6045345
4c90633
9196237
 
 
 
 
 
 
 
 
 
 
6045345
2255bb7
 
 
ffd1043
2255bb7
 
2c73c81
2255bb7
 
 
6045345
2255bb7
 
 
 
 
1d5ab84
2255bb7
 
 
6045345
2255bb7
6045345
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
"""Module for models and model loading"""


import logging
import math
import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Tuple  # noqa: F401

import bitsandbytes as bnb
import torch
import transformers
from optimum.bettertransformer import BetterTransformer
from transformers import PreTrainedModel  # noqa: F401
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    LlamaConfig,
)

try:
    from transformers import (  # pylint: disable=unused-import  # noqa: F401
        LlamaForCausalLM,
    )
except ImportError:
    logging.warning(
        "This version of transformers does not support Llama. Consider upgrading."
    )

from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN

if TYPE_CHECKING:
    from peft import PeftConfig  # noqa: F401

    from axolotl.utils.dict import DictDefault  # noqa: F401


def load_tokenizer(
    tokenizer_config,
    tokenizer_type,
    cfg,
):
    if tokenizer_type:
        tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
            tokenizer_config,
            trust_remote_code=cfg.trust_remote_code or False,
        )
    else:
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_config,
            trust_remote_code=cfg.trust_remote_code or False,
        )

    logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
    logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
    logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
    logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")

    if tokenizer.__class__.__name__ in [
        "LlamaTokenizer",
        "LlamaTokenizerFast",
    ]:
        tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN

    if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

    if cfg.special_tokens:
        for k, val in cfg.special_tokens.items():
            tokenizer.add_special_tokens({k: val})
    if cfg.tokens:
        tokenizer.add_tokens(list(cfg.tokens))

    return tokenizer


def load_model(
    base_model,
    base_model_config,
    model_type,
    tokenizer,
    cfg,
    adapter="lora",
    inference=False,
):
    # type: (str, str, str, AutoTokenizer, DictDefault, Optional[str], bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
    """
    Load a model from a base model and a model type.
    """

    # TODO refactor as a kwarg
    load_in_8bit = cfg.load_in_8bit
    cfg.is_llama_derived_model = "llama" in base_model or (
        cfg.model_type and "llama" in cfg.model_type.lower()
    )

    if cfg.is_llama_derived_model and cfg.flash_attention:
        if cfg.device not in ["mps", "cpu"] and inference is False:
            from axolotl.flash_attn import replace_llama_attn_with_flash_attn

            logging.info("patching with flash attention")
            replace_llama_attn_with_flash_attn()
    elif cfg.is_llama_derived_model and cfg.xformers_attention:
        from axolotl.monkeypatch.llama_attn_hijack_xformers import (
            hijack_llama_attention,
        )

        logging.info("patching with xformers attention")
        hijack_llama_attention()
    elif cfg.is_llama_derived_model and cfg.sdp_attention:
        from axolotl.monkeypatch.llama_attn_hijack_xformers import (
            hijack_llama_sdp_attention,
        )

        logging.info("patching with sdp attention")
        hijack_llama_sdp_attention()
    elif cfg.is_llama_derived_model and cfg.landmark_attention:
        from axolotl.monkeypatch.llama_landmark_attn import (  # pylint: disable=redefined-outer-name # noqa: F811
            MEM_TOKEN,
            LlamaForCausalLM,
        )

        logging.info("patching with landmark attention")

        # TODO: Check if this would overwrite previous additional_special_tokens
        tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})

    if cfg.is_llama_derived_model and cfg.xpos_rope:
        from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
            replace_llama_rope_with_xpos_rope,
        )

        logging.info("patching with xpos rope")
        replace_llama_rope_with_xpos_rope()

    if cfg.bf16 or cfg.bfloat16:
        torch_dtype = torch.bfloat16
    elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32
    try:
        if cfg.gptq:
            from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
                replace_peft_model_with_int4_lora_model,
            )

            replace_peft_model_with_int4_lora_model()
    except Exception as err:
        logging.exception(err)
        raise err

    try:
        from peft import prepare_model_for_kbit_training
    except ImportError:
        # For backward compatibility
        from peft import (
            prepare_model_for_int8_training as prepare_model_for_kbit_training,
        )

    model_kwargs = {}
    if cfg.adapter == "qlora" and cfg.load_in_4bit:
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
    try:
        if cfg.gptq and cfg.is_llama_derived_model:
            from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
            from huggingface_hub import snapshot_download

            try:
                snapshot_download_kwargs = {}
                if cfg.base_model_ignore_patterns:
                    snapshot_download_kwargs[
                        "ignore_patterns"
                    ] = cfg.base_model_ignore_patterns
                cache_model_path = Path(
                    snapshot_download(base_model, **snapshot_download_kwargs)
                )
                files = (
                    list(cache_model_path.glob("*.pt"))
                    + list(cache_model_path.glob("*.safetensors"))
                    + list(cache_model_path.glob("*.bin"))
                )
                if len(files) > 0:
                    model_path = str(files[0])
                else:
                    logging.warning(
                        "unable to find a cached model file, this will likely fail..."
                    )
                    model_path = str(cache_model_path)
            except Exception:  # pylint: disable=broad-exception-caught
                model_path = cfg.base_model
            model, _ = load_llama_model_4bit_low_ram(
                base_model_config if base_model_config else base_model,
                model_path,
                device_map=cfg.device_map,
                half=cfg.fp16,
                groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
                is_v1_model=cfg.gptq_model_v1
                if cfg.gptq_model_v1 is not None
                else True,
            )
            load_in_8bit = False
        elif cfg.is_llama_derived_model and "LlamaForCausalLM" in globals():
            config = LlamaConfig.from_pretrained(base_model_config)
            model = LlamaForCausalLM.from_pretrained(
                base_model,
                config=config,
                load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
                load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
                torch_dtype=torch_dtype,
                device_map="auto" if cfg.world_size == 1 else cfg.device_map,
                **model_kwargs,
            )
        # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
        #     This is a WIP, still an issue with the backward pass
        #     RuntimeError: grad can be implicitly created only for scalar outputs
        #     TODO: try config.sequence_parallel = False
        #     # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
        #     # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
        #     # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
        #     from flash_attn.utils.pretrained import state_dict_from_pretrained
        #     from flash_attn.models.gpt import GPTLMHeadModel
        #     from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
        #     from transformers import GPTNeoXConfig
        #     config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
        #     config.use_flash_attn = True
        #     config.fused_bias_fc = True
        #     config.fused_mlp = True  # GPT-NeoX-20B uses "gelu_fast"
        #     config.activation_function = "gelu_fast"
        #     config.fused_dropout_add_ln = True
        #     # config.residual_in_fp32 = True
        #
        #     model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
        #         base_model,
        #         config,
        #         dtype=torch_dtype,
        #         device=cfg.device,
        #     )
        #     model.train() # sets to train instead of eval mode
        elif model_type:
            model = getattr(transformers, model_type).from_pretrained(
                base_model,
                load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
                load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
                torch_dtype=torch_dtype,
                device_map=cfg.device_map,
                trust_remote_code=cfg.trust_remote_code or False,
                **model_kwargs,
            )
        else:
            config = AutoConfig.from_pretrained(
                base_model,
                trust_remote_code=cfg.trust_remote_code or False,
            )
            # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
            # when training starts
            if hasattr(config, "max_seq_len") and cfg.sequence_len > config.max_seq_len:
                config.max_seq_len = cfg.sequence_len
                logging.warning(f"increasing context length to {cfg.sequence_len}")
            elif (
                hasattr(config, "max_sequence_length")
                and cfg.sequence_len > config.max_sequence_length
            ):
                config.max_sequence_length = cfg.sequence_len
                logging.warning(f"increasing context length to {cfg.sequence_len}")
            model = AutoModelForCausalLM.from_pretrained(
                base_model,
                config=config,
                load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
                load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
                torch_dtype=torch_dtype,
                device_map=cfg.device_map,
                trust_remote_code=cfg.trust_remote_code or False,
                **model_kwargs,
            )
    except Exception as err:  # pylint: disable=broad-exception-caught
        logging.error(
            "Exception raised attempting to load model, retrying with AutoModelForCausalLM"
        )
        logging.exception(err)
        model = AutoModelForCausalLM.from_pretrained(
            base_model,
            load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
            torch_dtype=torch_dtype,
            device_map=cfg.device_map,
            trust_remote_code=cfg.trust_remote_code or False,
            **model_kwargs,
        )

    embeddings_len = math.ceil(len(tokenizer) / 32) * 32
    model.resize_token_embeddings(embeddings_len)

    if not cfg.gptq and (
        (cfg.adapter == "lora" and load_in_8bit)
        or (cfg.adapter == "qlora" and cfg.load_in_4bit)
    ):
        logging.info("converting PEFT model w/ prepare_model_for_kbit_training")
        model = prepare_model_for_kbit_training(model)

    model, lora_config = load_adapter(model, cfg, adapter)

    if cfg.ddp and not load_in_8bit:
        model.to(f"cuda:{cfg.local_rank}")

    if cfg.gptq:
        # Scales to half
        logging.info("Fitting 4bit scales and zeros to half")
        for _, module in model.named_modules():
            if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
                type(module)
            ):
                if hasattr(module, "is_v1_model") and module.is_v1_model:
                    module.zeros = module.zeros.half()
                module.scales = module.scales.half()
                module.bias = module.bias.half()

    if (
        torch.cuda.device_count() > 1
        and int(os.getenv("WORLD_SIZE", "1")) > 1
        and (cfg.gptq or cfg.load_in_4bit)
    ):
        # llama is PROBABLY model parallelizable, but the default isn't that it is
        # so let's only set it for the 4bit, see
        # https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
        setattr(model, "is_parallelizable", True)
        setattr(model, "model_parallel", True)

    requires_grad = []
    for name, param in model.named_parameters(recurse=True):
        if param.requires_grad:
            requires_grad.append(f"{name}: {param.requires_grad}")
    if len(requires_grad) == 0:
        logging.warning("there are no parameters that require gradient updates")
    model.config.use_cache = False

    if cfg.flash_optimum:
        model = BetterTransformer.transform(model)

    # TODO resume_from_checkpoint handling
    return model, lora_config


def load_adapter(model, cfg, adapter):
    # type: (PreTrainedModel, DictDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

    if adapter is None:
        return model, None
    if adapter in ["lora", "qlora"]:
        return load_lora(model, cfg)
    if adapter == "llama-adapter":
        return load_llama_adapter(model, cfg)

    raise NotImplementedError(f"{adapter} peft adapter not available")


def load_llama_adapter(model, cfg):
    # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
    from peft import AdaptionPromptConfig, PeftModel, get_peft_model

    peft_config = AdaptionPromptConfig(
        adapter_layers=cfg.peft_adapter.layers,  # layers (L)
        adapter_len=cfg.peft_adapter.len,  # prompt length (K)
        task_type="CAUSAL_LM",
    )

    if cfg.lora_model_dir:
        logging.info("Loading pretained LORA")
        model = PeftModel.from_pretrained(
            model,
            cfg.lora_model_dir,
            device_map=cfg.device_map,
            torch_dtype=torch.float16,
        )
    else:
        model = get_peft_model(model, peft_config)

    model.print_trainable_parameters()

    return model, peft_config


def find_all_linear_names(bits, model):
    cls = (
        bnb.nn.Linear4bit
        if bits == 4
        else (bnb.nn.Linear8bitLt if bits == 8 else torch.nn.Linear)
    )
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")

    return list(lora_module_names)


def load_lora(model, cfg):
    # type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

    from peft import LoraConfig, PeftModel, get_peft_model

    lora_target_modules = list(cfg.lora_target_modules or [])

    if cfg.lora_target_linear:
        bits = None
        if cfg.load_in_4bit:
            bits = 4
        elif cfg.load_in_8bit:
            bits = 8

        linear_names = find_all_linear_names(bits, model)
        logging.info(f"found linear modules: {repr(linear_names)}")
        lora_target_modules = list(set(lora_target_modules + linear_names))

    lora_config = LoraConfig(
        r=cfg.lora_r,
        lora_alpha=cfg.lora_alpha,
        target_modules=lora_target_modules,
        lora_dropout=cfg.lora_dropout,
        fan_in_fan_out=cfg.lora_fan_in_fan_out,
        modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
        bias="none",
        task_type="CAUSAL_LM",
    )

    if cfg.lora_model_dir:
        model = PeftModel.from_pretrained(
            model,
            cfg.lora_model_dir,
            device_map=cfg.device_map,
            # torch_dtype=torch.float16,
        )
    else:
        model = get_peft_model(model, lora_config)

    model.print_trainable_parameters()

    return model, lora_config