File size: 2,843 Bytes
c2dbf2c a5bf838 71a43f8 a5bf838 1d7da3b 3aad5f3 3c71c8d dd00657 52dd92a dd00657 1d7da3b 48f4c05 52dd92a 48f4c05 52dd92a dd00657 fe0e69f 52dd92a 2824423 52dd92a a5bf838 1c33eb8 b832a0a 1c33eb8 bfd27ba babf0fd 1edc30c 71a43f8 1edc30c 1d7da3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
"""Module for validating config files"""
import logging
import torch
def validate_config(cfg):
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
)
if cfg.batch_size:
logging.warning(
"%s\n%s",
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
)
if cfg.load_4bit:
raise ValueError(
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
)
if cfg.adapter == "qlora":
if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if cfg.gptq:
raise ValueError("Can't merge qlora if gptq")
if cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if cfg.gptq:
raise ValueError("Can't load qlora if gptq")
if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if not cfg.load_in_8bit and cfg.adapter == "lora":
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.trust_remote_code:
logging.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
)
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True:
raise ValueError(
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub"
)
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp:
raise ValueError("FSDP is not supported for falcon models")
if cfg.flash_optimum is True:
if cfg.adapter:
logging.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True:
logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers")
if torch.__version__.split(".")[0] < 2:
logging.warning("torch>=2.0.0 required")
raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}")
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25
# no 8bit adamw w bf16
|