File size: 23,527 Bytes
8cec513
6910e6a
a5bf838
8cec513
6910e6a
cc3cebf
1210dc8
71a43f8
131afdb
a5bf838
7b55fe6
cc3cebf
 
 
 
6910e6a
44454ae
7b55fe6
553a86b
 
1210dc8
8cec513
 
 
 
 
 
 
 
 
 
 
 
 
 
f6060a6
992e742
f6060a6
8cec513
2414673
8cec513
 
 
094fc2c
 
 
 
 
 
8cec513
 
 
 
 
 
 
 
 
2642cae
 
8cec513
 
5b67ea9
5a5d474
 
 
 
 
 
 
8cec513
 
 
 
 
 
0865613
 
 
 
 
 
 
782b6a4
 
0865613
8cec513
 
 
 
 
 
 
 
782b6a4
 
8cec513
96deb6b
 
 
 
 
 
 
5f79b82
 
 
 
e923e62
5f79b82
 
 
 
9ec2077
 
2d8def6
 
 
44454ae
62eaee7
44454ae
ff939d8
 
 
 
44454ae
 
 
 
eb41f76
6b3b271
44454ae
 
19a600a
 
 
 
 
 
 
 
 
 
 
 
eb41f76
6b3b271
19a600a
 
eb41f76
 
 
 
 
 
 
 
 
6910e6a
6b3b271
eb41f76
 
1115c50
f5a828a
 
 
 
 
 
1115c50
2ce5c0d
 
 
802f966
 
 
 
33e1170
802f966
 
 
7b55fe6
 
8cec513
0ce1a65
 
 
 
 
 
 
 
 
 
 
 
 
2ea70eb
 
 
 
 
0ce1a65
 
7de912e
cc3cebf
 
269c543
 
 
601b77b
269c543
cc3cebf
269c543
601b77b
269c543
cc3cebf
 
 
0f10080
 
 
 
131afdb
 
 
 
317fa25
 
 
 
 
131afdb
 
 
00568c1
 
 
 
 
 
 
 
 
 
 
 
 
 
2bb0b78
2ce5c0d
2bb0b78
59a31fe
 
 
3437149
 
 
 
 
3aad5f3
 
 
 
3c71c8d
553a86b
3c71c8d
 
 
 
7ee3c4c
 
 
 
 
2642cae
 
 
 
1d7da3b
48f4c05
 
52dd92a
 
 
 
 
 
 
 
 
48f4c05
52dd92a
 
 
 
 
 
 
 
dd00657
15d3a65
 
 
4cb7900
 
553a86b
52dd92a
15d3a65
 
 
8487b97
 
 
 
 
bde3c5a
 
 
 
 
 
 
 
 
 
 
 
 
15d3a65
 
 
2824423
553a86b
52dd92a
 
a5bf838
1c33eb8
b832a0a
 
 
1c33eb8
bfd27ba
babf0fd
 
14668fa
 
 
 
 
1edc30c
 
553a86b
1edc30c
 
1072f28
553a86b
1a82082
 
1210dc8
c01015f
553a86b
1210dc8
 
 
1edc30c
eea2731
553a86b
eea2731
 
2f586d1
 
 
 
eea2731
19cf0bd
cb9d3af
 
553a86b
cb9d3af
e79c8e6
 
 
 
 
af29d81
 
 
 
 
6b3b271
96bd6ae
6b3b271
96bd6ae
6b3b271
96bd6ae
 
00568c1
 
 
 
 
2bb0b78
 
 
 
 
 
00568c1
 
 
 
 
 
 
e30f1e3
 
 
 
 
 
 
 
 
62eaee7
e7d3e2d
 
590d603
 
e7d3e2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f79b82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383f88d
 
 
 
 
 
 
 
 
 
 
 
 
 
8c2e05a
 
 
 
 
383f88d
 
 
e7d3e2d
9923b72
 
 
 
 
 
 
 
 
44c9d01
 
 
 
 
 
1bc1186
 
 
a1da39c
 
 
 
 
 
 
ef24342
 
 
 
 
 
 
 
 
 
 
 
 
 
bdfefaf
 
 
 
 
6910e6a
 
 
 
 
 
 
 
 
 
e923e62
6910e6a
 
 
e923e62
 
 
 
6910e6a
e923e62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6910e6a
cda52dc
 
 
 
 
e923e62
 
 
5a5d474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d7da3b
 
 
ab5cd28
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
"""Module for working with config dicts"""
import json
import logging
import os
from pathlib import Path
from typing import Optional

import torch
from transformers.utils import is_torch_bf16_gpu_available

from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config.models.input.v0_4_1 import (
    AxolotlConfigWCapabilities,
    AxolotlInputConfig,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model_config

LOG = logging.getLogger("axolotl")


def choose_device(cfg):
    def get_device():
        try:
            if torch.cuda.is_available():
                return f"cuda:{cfg.local_rank}"

            if torch.backends.mps.is_available():
                return "mps"

            raise SystemError("No CUDA/mps device found")
        except Exception:  # pylint: disable=broad-exception-caught
            return "cpu"

    cfg.device = get_device()
    if cfg.world_size == 1:
        cfg.device_map = cfg.device_map or "auto"
    else:
        if cfg.device.startswith("cuda"):
            cfg.device_map = {"": torch.cuda.current_device()}
        else:
            cfg.device_map = {"": cfg.device}

    # in `accelerate launch`, we need to not pass through any device map and let
    # accelerate figure out which parts of the model to put on which gpu
    accelerate_vars = [var for var in os.environ if var.startswith("ACCELERATE_USE_")]
    if accelerate_vars:
        cfg.device_map = None


def normalize_config(cfg):
    # setup some derived config / hyperparams
    cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
        cfg.batch_size // cfg.micro_batch_size
    )
    cfg.batch_size = (
        cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
    )
    if cfg.eval_batch_size is None:
        cfg.eval_batch_size = cfg.micro_batch_size
    cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
    cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
    cfg.eval_table_size = cfg.eval_table_size or 0
    cfg.eval_max_new_tokens = cfg.eval_max_new_tokens or 128
    cfg.eval_causal_lm_metrics = cfg.eval_causal_lm_metrics or [
        "sacrebleu",
        "comet",
        "ter",
        "chrf",
    ]
    choose_device(cfg)
    cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
    if cfg.ddp:
        cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
        cfg.batch_size = cfg.batch_size * cfg.world_size

    if cfg.bf16 == "auto":
        if is_torch_bf16_gpu_available():
            LOG.debug("bf16 support detected, enabling for this configuration.")
            cfg.bf16 = True
        else:
            LOG.debug("bf16 support not detected, disabling for this configuration.")
            cfg.bf16 = False
            if cfg.fp16 is None:
                cfg.fp16 = True

    if cfg.device == "mps":
        cfg.load_in_8bit = False
        cfg.tf32 = False
        if cfg.bf16:
            cfg.fp16 = True
        cfg.bf16 = False
    else:
        torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
        if cfg.bf16:
            cfg.fp16 = False

    if cfg.bf16 or cfg.bfloat16:
        cfg.torch_dtype = torch.bfloat16
    elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
        cfg.torch_dtype = torch.float16
    else:
        cfg.torch_dtype = torch.float32

    if cfg.saves_per_epoch:
        save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
        if save_steps < 1.0:  # prevent saves on every step
            cfg.save_steps = save_steps
    if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
        eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
        if eval_steps < 1.0:  # prevent evals on every step
            cfg.eval_steps = eval_steps

    cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()

    if not cfg.base_model_config:
        cfg.base_model_config = cfg.base_model

    model_config = load_model_config(cfg)
    cfg.model_config_type = model_config.model_type

    cfg.tokenizer_config = (
        cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
    )

    # figure out if the model is llama
    cfg.is_llama_derived_model = (
        (hasattr(model_config, "model_type") and model_config.model_type == "llama")
        or cfg.is_llama_derived_model
        or "llama" in cfg.base_model.lower()
        or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
    )

    # figure out if the model is falcon
    cfg.is_falcon_derived_model = (
        (
            hasattr(model_config, "model_type")
            and model_config.model_type
            in [
                "falcon",
                "RefinedWebModel",
                "RefinedWeb",
            ]
        )
        or cfg.is_falcon_derived_model
        or "falcon" in cfg.base_model.lower()
        or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
    )

    cfg.is_mistral_derived_model = (
        (
            hasattr(model_config, "model_type")
            and model_config.model_type
            in [
                "mistral",
            ]
        )
        or cfg.is_mistral_derived_model
        or "mistral" in cfg.base_model.lower().split("/")[-1]
        or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
    )

    cfg.is_qwen_derived_model = (
        hasattr(model_config, "model_type")
        and model_config.model_type
        in [
            "qwen",
        ]
    ) or cfg.is_qwen_derived_model

    if isinstance(cfg.pretraining_dataset, dict):
        cfg.pretraining_dataset = [cfg.pretraining_dataset]

    if (
        cfg.gradient_checkpointing
        and cfg.unfrozen_parameters is None
        and cfg.gradient_checkpointing_kwargs is None
        and cfg.rl is None
    ):
        cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}

    log_gpu_memory_usage(LOG, "baseline", cfg.device)


def normalize_cfg_datasets(cfg):
    """
    helpers for mapping chat_template to various dataset configurations as necessary
    """

    if cfg.chat_template and cfg.chat_template == "chatml":
        if cfg.datasets:
            for idx, ds_cfg in enumerate(cfg.datasets):
                if ds_cfg.type == "sharegpt" and not ds_cfg.conversation:
                    LOG.info(
                        f"updating dataset {ds_cfg.path} with `conversation: chatml` to match your chat_template"
                    )
                    cfg.datasets[idx].conversation = "chatml"
                if ds_cfg.type == "orpo.chat_template" and not ds_cfg.chat_template:
                    LOG.info(
                        f"updating dataset {ds_cfg.path} with `chat_template: chatml` to match your chat_template"
                    )
                    cfg.datasets[idx].chat_template = "chatml"


def validate_config(cfg: DictDefault, capabilities: Optional[dict] = None):
    if capabilities:
        return DictDefault(
            dict(
                AxolotlConfigWCapabilities(
                    **cfg.to_dict(), capabilities=capabilities
                ).model_dump(exclude_none=True)
            )
        )
    return DictDefault(
        dict(AxolotlInputConfig(**cfg.to_dict()).model_dump(exclude_none=True))
    )


def legacy_validate_config(cfg):
    """
    This is a "pre-validation" step that handles the yaml configuration before we have any
    information about the model architecture
    """
    if is_torch_bf16_gpu_available():
        if not cfg.bf16 and not cfg.bfloat16:
            LOG.info("bf16 support detected, but not enabled for this configuration.")
    else:
        if (
            not cfg.merge_lora
            and not cfg.is_preprocess
            and (cfg.bf16 is True or cfg.bfloat16 is True)
        ):
            raise ValueError(
                "bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
            )
    if (
        # pylint: disable=too-many-boolean-expressions
        not (cfg.bf16 or cfg.bfloat16)
        and (cfg.fp16 or cfg.float16)
        and not cfg.adapter
        and not cfg.flash_attention
        and cfg.sample_packing
    ):
        LOG.warning(
            "Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
        )
        # ValueError: Attempting to unscale FP16 gradients.
        # OR
        # RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
    if cfg.max_packed_sequence_len:
        raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")

    if cfg.sample_packing and cfg.rl:
        raise ValueError("`sample_packing: true` does not work with RLHF training")

    if cfg.sample_packing and not cfg.pad_to_sequence_len:
        LOG.warning(
            "`pad_to_sequence_len: true` is recommended when using sample_packing"
        )

    if cfg.gradient_accumulation_steps and cfg.batch_size:
        raise ValueError(
            "please set only one of gradient_accumulation_steps or batch_size"
        )
    if cfg.batch_size:
        LOG.warning(
            "%s\n%s",
            "batch_size is not recommended. Please use gradient_accumulation_steps instead.",
            "To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
        )
    if (
        cfg.eval_batch_size
        and cfg.micro_batch_size
        and cfg.eval_batch_size != cfg.micro_batch_size
    ):
        LOG.warning(
            "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
        )

    if cfg.adapter == "qlora":
        if cfg.merge_lora:
            # can't merge qlora if loaded in 8bit or 4bit
            if cfg.load_in_8bit:
                raise ValueError("Can't merge qlora if loaded in 8bit")

            if cfg.gptq:
                raise ValueError("Can't merge qlora if gptq")

            if cfg.load_in_4bit:
                raise ValueError("Can't merge qlora if loaded in 4bit")

        else:
            if cfg.load_in_8bit:
                raise ValueError("Can't load qlora in 8bit")

            if cfg.gptq:
                raise ValueError("Can't load qlora if gptq")

            if not cfg.load_in_4bit:
                raise ValueError("Require cfg.load_in_4bit to be True for qlora")

        if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
            raise ValueError("Fused modules are not supported with QLoRA")

    loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
    if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
        LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")

    if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
        raise ValueError("Fused modules are not supported with LoRA")

    if cfg.adapter and cfg.peft_layers_to_transform and cfg.unfrozen_parameters:
        raise ValueError(
            "`unfrozen_parameters` used with `peft_layers_to_transform` can have unexpected behavior."
        )

    if cfg.relora_steps:
        if cfg.adapter not in ("lora", "qlora"):
            raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")

        if cfg.fsdp:
            raise ValueError("fsdp not supported with ReLoRA")

        if cfg.deepspeed:
            raise ValueError("deepspeed not supported with ReLoRA")

        if cfg.lr_scheduler == "one_cycle":
            raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")

        if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
            raise ValueError("Fused modules are not supported with ReLoRA")

    if cfg.trust_remote_code:
        LOG.warning(
            "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
        )

    if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
        raise ValueError(
            "Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
        )

    if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
        raise ValueError("FSDP is not supported for falcon models")

    if (
        cfg.base_model and "mpt" in cfg.base_model.lower()
    ) and cfg.gradient_checkpointing:
        raise ValueError("gradient_checkpointing is not supported for MPT models")

    if cfg.flash_optimum is True:
        if cfg.adapter:
            LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
        if cfg.fp16 or cfg.bf16:
            raise ValueError("AMP is not supported with BetterTransformer")
        if cfg.float16 is not True and cfg.bfloat16 is not True:
            LOG.warning(
                "You should probably set bfloat16 or float16 to true to "
                "load the model in float16 for BetterTransformers"
            )
        if int(torch.__version__.split(".", maxsplit=1)[0]) < 2:
            LOG.warning("torch>=2.0.0 required")
            raise ValueError(
                f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
            )

    if cfg.pretraining_dataset and cfg.group_by_length:
        LOG.warning(
            "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
        )
    if cfg.pretraining_dataset and not cfg.max_steps:
        raise ValueError(
            "max_steps must be set when using iterable pretraining_dataset, Trainer can't infer length and schedule optimizer/learning rate without it!"
        )

    if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and (
        not cfg.optimizer or "adamw" not in cfg.optimizer
    ):
        LOG.warning("adamw hyperparameters found, but no adamw optimizer set")

    if cfg.push_to_hub_model_id:
        raise ValueError(
            "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
        )

    if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
        LOG.warning(
            "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
        )

    if cfg.gptq and cfg.revision_of_model:
        raise ValueError(
            "revision_of_model is not supported for GPTQ models. "
            + "Please download the model from HuggingFace Hub manually for correct branch, "
            + "point to its path, and remove revision_of_model from the config."
        )

    # if cfg.sample_packing and cfg.sdp_attention:
    #     # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
    #     raise ValueError(
    #         "sample_packing not compatible with sdp_attention. Use flash_attention"
    #     )

    if cfg.sample_packing and cfg.xformers_attention:
        raise ValueError(
            "sample_packing not compatible with xformers_attention. Use flash_attention"
        )

    if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
        # https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
        LOG.warning(
            "sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
            "This may work on H100s."
        )

    if cfg.early_stopping_patience:
        if not cfg.save_steps or not cfg.eval_steps:
            raise ValueError(
                "`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
            )
        if cfg.save_steps % cfg.eval_steps != 0:
            raise ValueError(
                "`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
            )

    if cfg.datasets:
        for idx, ds_cfg in enumerate(cfg.datasets):
            if not ds_cfg.type:
                continue
            if ds_cfg.type == "sharegpt:chat":
                LOG.warning(
                    PendingDeprecationWarning(
                        "`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
                    )
                )
                cfg.datasets[idx].type = "sharegpt"
            if "sharegpt_simple" in ds_cfg.type:
                LOG.warning(
                    PendingDeprecationWarning(
                        "`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
                    )
                )
                cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
                    "sharegpt_simple", "sharegpt"
                )

    if cfg.saves_per_epoch and cfg.save_steps:
        raise ValueError(
            "save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
        )
    if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
        raise ValueError(
            "save_strategy must be empty or set to `steps` when used with saves_per_epoch."
        )
    if cfg.evals_per_epoch and cfg.eval_steps:
        raise ValueError(
            "eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
        )
    if (
        cfg.evals_per_epoch
        and cfg.evaluation_strategy
        and cfg.evaluation_strategy != "steps"
    ):
        raise ValueError(
            "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
        )
    if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
        raise ValueError(
            "save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
        )

    if (
        cfg.evaluation_strategy
        and cfg.eval_steps
        and cfg.evaluation_strategy != "steps"
    ):
        raise ValueError(
            "evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
        )

    if (
        cfg.val_set_size == 0
        and (cfg.eval_steps or cfg.evaluation_strategy)
        and not cfg.test_datasets
    ):
        raise ValueError(
            "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
        )

    if (
        cfg.sample_packing
        and cfg.eval_table_size
        and cfg.eval_sample_packing is not False
    ):
        raise ValueError(
            "eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
        )

    if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
        raise ValueError(
            "load_in_8bit and load_in_4bit are not supported without setting an adapter."
            "If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
        )

    if cfg.rope_scaling:
        LOG.warning("`rope_scaling` should now be be a key under `model_config`")

    if cfg.wandb_run_id and not cfg.wandb_name:
        cfg.wandb_name = cfg.wandb_run_id

        LOG.warning(
            "wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
        )

    if cfg.noisy_embedding_alpha is not None:
        # Deprecated, use neftune_noise_alpha
        LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
        if cfg.neftune_noise_alpha is None:
            cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
        else:
            # User is providing both; bail and have them sort out their settings
            raise ValueError(
                "noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
            )

    if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
        raise ValueError("neftune_noise_alpha must be > 0.0")

    if cfg.max_memory is not None and cfg.gpu_memory_limit is not None:
        raise ValueError(
            "max_memory and gpu_memory_limit are mutually exclusive and cannot be used together."
        )

    if (
        cfg.unfrozen_parameters
        and cfg.gradient_checkpointing_kwargs
        and cfg.gradient_checkpointing_kwargs.use_reentrant is True
    ):
        # https://github.com/huggingface/transformers/issues/21381
        raise ValueError(
            "`use_reentrant` must be false when used with partially frozen model."
        )

    if cfg.deepspeed and Path(cfg.deepspeed).is_file():
        with open(cfg.deepspeed, encoding="utf-8") as file:
            contents = file.read()
            deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
            if cfg.flash_attention:
                if (
                    deepspeed_cfg.zero_optimization
                    and deepspeed_cfg.zero_optimization.stage == 3
                ):
                    if not (
                        (
                            deepspeed_cfg.bf16
                            and deepspeed_cfg.bf16.enabled  # pylint: disable=no-member
                            is True
                        )
                        or (
                            deepspeed_cfg.fp16
                            and deepspeed_cfg.fp16.enabled  # pylint: disable=no-member
                            is True
                        )
                    ):
                        raise ValueError(
                            "bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
                        )
            if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
                LOG.warning(
                    f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
                )

    if cfg.test_datasets and cfg.val_set_size:
        raise ValueError(
            "non-zero val_set_size should not be used with test_datasets configuration"
        )

    if cfg.fsdp and "bnb" in cfg.optimizer:
        raise ValueError(f"FSDP not compatible with {cfg.optimizer}")

    if cfg.do_causal_lm_eval and cfg.eval_sample_packing:
        raise ValueError(
            "do_causal_lm_eval is enabled, eval_sample_packing must be set to False"
        )

    if cfg.eval_causal_lm_metrics:
        supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
        if not isinstance(cfg.eval_causal_lm_metrics, list):
            raise ValueError("eval_causal_lm_metrics must be a list")
        # only ["sacrebleu", "comet", "ter", "chrf"] supported
        if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
            raise ValueError(
                f"eval_causal_lm_metrics must be one of {supported_metrics}"
            )

    # TODO
    # MPT 7b
    # https://github.com/facebookresearch/bitsandbytes/issues/25
    # no 8bit adaAmw w bf16

    # GPT-NeoX
    # evals broken when extending context len
    # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward                        attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
    # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
    # attention_mask = causal_mask + attention_mask
    # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3