|
"""Module for validating config files""" |
|
|
|
import logging |
|
|
|
import torch |
|
|
|
|
|
def validate_config(cfg): |
|
if cfg.gradient_accumulation_steps and cfg.batch_size: |
|
raise ValueError( |
|
"please set only one of gradient_accumulation_steps or batch_size" |
|
) |
|
if cfg.batch_size: |
|
logging.warning( |
|
"%s\n%s", |
|
"batch_size is not recommended. Please use gradient_accumulation_steps instead.", |
|
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.", |
|
) |
|
if cfg.load_4bit: |
|
raise ValueError( |
|
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq" |
|
) |
|
|
|
if cfg.adapter == "qlora": |
|
if cfg.merge_lora: |
|
|
|
if cfg.load_in_8bit: |
|
raise ValueError("Can't merge qlora if loaded in 8bit") |
|
|
|
if cfg.gptq: |
|
raise ValueError("Can't merge qlora if gptq") |
|
|
|
if cfg.load_in_4bit: |
|
raise ValueError("Can't merge qlora if loaded in 4bit") |
|
|
|
else: |
|
if cfg.load_in_8bit: |
|
raise ValueError("Can't load qlora in 8bit") |
|
|
|
if cfg.gptq: |
|
raise ValueError("Can't load qlora if gptq") |
|
|
|
if not cfg.load_in_4bit: |
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora") |
|
|
|
if not cfg.load_in_8bit and cfg.adapter == "lora": |
|
logging.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") |
|
|
|
if cfg.trust_remote_code: |
|
logging.warning( |
|
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." |
|
) |
|
|
|
if cfg.push_dataset_to_hub and cfg.hf_use_auth_token is not True: |
|
raise ValueError( |
|
"Require cfg.hf_use_auth_token to be True for push_dataset_to_hub" |
|
) |
|
|
|
if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: |
|
raise ValueError("FSDP is not supported for falcon models") |
|
|
|
if cfg.flash_optimum is True: |
|
if cfg.adapter: |
|
logging.warning( |
|
"BetterTransformers probably doesn't work with PEFT adapters" |
|
) |
|
if cfg.fp16 or cfg.bf16: |
|
raise ValueError("AMP is not supported with BetterTransformer") |
|
if cfg.float16 is not True: |
|
logging.warning( |
|
"You should probably set float16 to true to load the model in float16 for BetterTransformers" |
|
) |
|
if int(torch.__version__.split(".")[0]) < 2: |
|
logging.warning("torch>=2.0.0 required") |
|
raise ValueError( |
|
f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|