|
"""Module for models and model loading""" |
|
|
|
|
|
|
|
import logging |
|
import math |
|
import os |
|
import types |
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import addict |
|
import bitsandbytes as bnb |
|
import torch |
|
import transformers |
|
import transformers.modeling_utils |
|
from accelerate import init_empty_weights |
|
from bitsandbytes.nn import Params4bit |
|
from peft import ( |
|
LoftQConfig, |
|
PeftConfig, |
|
PeftModel, |
|
PeftModelForCausalLM, |
|
prepare_model_for_kbit_training, |
|
) |
|
from peft.tuners.lora import QuantLinear |
|
from torch import nn |
|
from transformers import ( |
|
AddedToken, |
|
AutoConfig, |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
BitsAndBytesConfig, |
|
GPTQConfig, |
|
PreTrainedModel, |
|
PreTrainedTokenizerBase, |
|
) |
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
|
|
|
from axolotl.models.mamba import fix_mamba_attn_for_loss |
|
from axolotl.monkeypatch.multipack import ( |
|
SUPPORTED_MULTIPACK_MODEL_TYPES, |
|
patch_for_multipack, |
|
) |
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN |
|
from axolotl.utils.bench import log_gpu_memory_usage |
|
from axolotl.utils.chat_templates import chat_templates |
|
from axolotl.utils.dict import DictDefault |
|
from axolotl.utils.distributed import zero_only |
|
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper |
|
from axolotl.utils.lora_embeddings import get_linear_embedding_layers |
|
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant |
|
|
|
LOG = logging.getLogger("axolotl") |
|
|
|
|
|
|
|
def get_module_class_from_name(module, name): |
|
""" |
|
Gets a class from a module by its name. |
|
|
|
Args: |
|
module (`torch.nn.Module`): The module to get the class from. |
|
name (`str`): The name of the class. |
|
""" |
|
modules_children = list(module.children()) |
|
if module.__class__.__name__ == name: |
|
return module.__class__ |
|
|
|
if len(modules_children) == 0: |
|
return None |
|
|
|
for child_module in modules_children: |
|
module_class = get_module_class_from_name(child_module, name) |
|
if module_class is not None: |
|
return module_class |
|
|
|
return None |
|
|
|
|
|
def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDefault]): |
|
quant_config_exists = ( |
|
hasattr(model_config, "quantization_config") |
|
and model_config.quantization_config |
|
) |
|
quant_config_method_is_gptq = ( |
|
quant_config_exists |
|
and "quant_method" in model_config.quantization_config |
|
and model_config.quantization_config["quant_method"] == "gptq" |
|
) |
|
|
|
if cfg.gptq and not quant_config_method_is_gptq: |
|
raise ValueError( |
|
"model_config.quantization_config is not set or quant_method is not set to gptq. " |
|
"Please make sure to point to a GPTQ model." |
|
) |
|
|
|
if not cfg.gptq and quant_config_exists: |
|
raise ValueError( |
|
"model_config.quantization_config is set but `gptq` flag is not. " |
|
"Please use the `gptq` flag to train quantized model or point to a non-quantized model." |
|
) |
|
|
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) |
|
if ( |
|
cfg.adapter |
|
and cfg.tokens |
|
and ( |
|
not cfg.lora_modules_to_save |
|
or not all(x in cfg.lora_modules_to_save for x in lora_modules_to_save) |
|
) |
|
): |
|
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) |
|
raise ValueError( |
|
f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`." |
|
) |
|
|
|
|
|
def load_model_config(cfg): |
|
model_config_name = cfg.base_model_config or cfg.base_model |
|
if not model_config_name and cfg.tokenizer_config: |
|
model_config_name = cfg.tokenizer_config |
|
trust_remote_code = cfg.trust_remote_code is True |
|
config_kwargs = {} |
|
if cfg.revision_of_model: |
|
config_kwargs["revision"] = cfg.revision_of_model |
|
|
|
try: |
|
model_config = AutoConfig.from_pretrained( |
|
model_config_name, |
|
trust_remote_code=trust_remote_code, |
|
**config_kwargs, |
|
) |
|
except ValueError as err: |
|
if "mamba" in model_config_name: |
|
return addict.Dict( |
|
{ |
|
"model_type": "mamba", |
|
} |
|
) |
|
raise err |
|
|
|
if cfg.overrides_of_model_config: |
|
for key, val in cfg.overrides_of_model_config.items(): |
|
setattr(model_config, key, val) |
|
|
|
check_model_config(cfg, model_config) |
|
|
|
return model_config |
|
|
|
|
|
def load_tokenizer(cfg): |
|
model_config = load_model_config(cfg) |
|
tokenizer_kwargs = {} |
|
use_fast = True |
|
|
|
if cfg.tokenizer_use_fast is not None: |
|
use_fast = cfg.tokenizer_use_fast |
|
if cfg.tokenizer_legacy is not None: |
|
|
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy |
|
|
|
tokenizer_cls = AutoTokenizer |
|
if cfg.tokenizer_type: |
|
tokenizer_cls = getattr(transformers, cfg.tokenizer_type) |
|
|
|
tokenizer = tokenizer_cls.from_pretrained( |
|
cfg.tokenizer_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
use_fast=use_fast, |
|
**tokenizer_kwargs, |
|
) |
|
|
|
if ( |
|
tokenizer.__class__.__name__ |
|
in [ |
|
"LlamaTokenizer", |
|
"LlamaTokenizerFast", |
|
"CodeLlamaTokenizer", |
|
"CodeLlamaTokenizerFast", |
|
] |
|
and hasattr(tokenizer, "pad_token") |
|
and not tokenizer.pad_token |
|
): |
|
|
|
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN |
|
|
|
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": |
|
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing: |
|
tokenizer.padding_side = "left" |
|
|
|
|
|
if cfg.is_qwen_derived_model: |
|
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"] |
|
for attr_name in token_ids: |
|
if getattr(tokenizer, attr_name) is None: |
|
setattr(tokenizer, attr_name, tokenizer.eod_id) |
|
|
|
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"] |
|
for attr_name in token_names: |
|
if getattr(tokenizer, attr_name) is None: |
|
setattr(tokenizer, attr_name, "<|endoftext|>") |
|
|
|
additional_special_tokens = None |
|
if cfg.special_tokens: |
|
special_tokens = cfg.special_tokens.to_dict() |
|
additional_special_tokens = special_tokens.pop( |
|
"additional_special_tokens", None |
|
) |
|
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type) |
|
for k, val in special_tokens.items(): |
|
|
|
|
|
|
|
if ( |
|
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val) |
|
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2) |
|
and cfg.adapter |
|
and ( |
|
not cfg.lora_modules_to_save |
|
or not all( |
|
x in cfg.lora_modules_to_save for x in lora_modules_to_save |
|
) |
|
) |
|
): |
|
lora_modules_to_save = ", ".join( |
|
[f"`{x}`" for x in lora_modules_to_save] |
|
) |
|
raise ValueError( |
|
f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens." |
|
) |
|
|
|
tokenizer.add_special_tokens( |
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)} |
|
) |
|
|
|
|
|
|
|
|
|
bos_or_eos_in_special_tokens = ( |
|
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens |
|
) |
|
if ( |
|
tokenizer.__class__.__name__ |
|
in ( |
|
"LlamaTokenizerFast", |
|
"CodeLlamaTokenizerFast", |
|
) |
|
and bos_or_eos_in_special_tokens |
|
): |
|
tokenizer.update_post_processor() |
|
|
|
if cfg.tokens: |
|
tokenizer.add_tokens( |
|
[ |
|
AddedToken(token, rstrip=False, lstrip=False, normalized=False) |
|
for token in cfg.tokens |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if additional_special_tokens is not None: |
|
tokenizer.add_special_tokens( |
|
{"additional_special_tokens": additional_special_tokens} |
|
) |
|
|
|
with zero_only(): |
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") |
|
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") |
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") |
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") |
|
|
|
if cfg.chat_template: |
|
chat_template_string = chat_templates(cfg.chat_template) |
|
if cfg.default_system_message and cfg.chat_template == "chatml": |
|
chat_template_string = chat_template_string.replace( |
|
"You are a helpful assistant.", cfg.default_system_message |
|
) |
|
|
|
tokenizer.chat_template = chat_template_string |
|
else: |
|
LOG.info( |
|
"No Chat template selected. Consider adding a chat template for easier inference." |
|
) |
|
return tokenizer |
|
|
|
|
|
def load_model( |
|
cfg: DictDefault, |
|
tokenizer: PreTrainedTokenizerBase, |
|
inference: bool = False, |
|
reference_model: bool = False, |
|
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: |
|
""" |
|
Load a model for a given configuration and tokenizer. |
|
""" |
|
base_model = cfg.base_model |
|
model_type = cfg.type_of_model |
|
model_config = load_model_config(cfg) |
|
|
|
|
|
load_in_8bit = cfg.load_in_8bit |
|
|
|
if cfg.gradient_checkpointing == "unsloth": |
|
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper |
|
|
|
if hasattr(model_config, "model_type") and model_config.model_type == "btlm": |
|
if cfg.flash_attention: |
|
from axolotl.monkeypatch.btlm_attn_hijack_flash import ( |
|
replace_btlm_attn_with_flash_attn, |
|
) |
|
|
|
replace_btlm_attn_with_flash_attn(cfg.base_model) |
|
|
|
if ( |
|
hasattr(model_config, "model_type") |
|
and model_config.model_type == "stablelm_epoch" |
|
): |
|
if cfg.flash_attention and cfg.sample_packing: |
|
from axolotl.monkeypatch.stablelm_attn_hijack_flash import ( |
|
replace_stablelm_attn_with_flash_attn, |
|
) |
|
|
|
replace_stablelm_attn_with_flash_attn(cfg.base_model) |
|
|
|
if cfg.sample_packing and cfg.s2_attention: |
|
raise ValueError( |
|
"Received `sample_packing=true` and `s2_attention=true`; however, \ |
|
shifted-sparse attention does not currently support sample packing." |
|
) |
|
|
|
if ( |
|
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES |
|
and cfg.flash_attention |
|
and cfg.sample_packing |
|
): |
|
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model) |
|
elif cfg.is_llama_derived_model: |
|
|
|
|
|
if cfg.flash_attention: |
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( |
|
replace_llama_attn_with_flash_attn, |
|
) |
|
|
|
if cfg.sample_packing: |
|
if cfg.device not in ["mps", "cpu"] and not inference: |
|
LOG.info("patching with flash attention for sample packing") |
|
replace_llama_attn_with_flash_attn( |
|
packed=True, |
|
cross_entropy=cfg.flash_attn_cross_entropy, |
|
rms_norm=cfg.flash_attn_rms_norm, |
|
) |
|
elif cfg.s2_attention: |
|
LOG.info("patching w/ flash-enabled, shifted-sparse attention") |
|
replace_llama_attn_with_flash_attn( |
|
packed=False, |
|
cross_entropy=cfg.flash_attn_cross_entropy, |
|
rms_norm=cfg.flash_attn_rms_norm, |
|
use_shifted_sparse_attn=True, |
|
) |
|
elif cfg.xformers_attention: |
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( |
|
hijack_llama_attention, |
|
) |
|
|
|
LOG.info("patching with xformers attention") |
|
hijack_llama_attention() |
|
elif cfg.sample_packing: |
|
from axolotl.monkeypatch.llama_patch_multipack import ( |
|
hijack_llama_prepare_4d_mask, |
|
) |
|
|
|
LOG.info("patching llama _prepare_4d_causal_attention_mask*") |
|
hijack_llama_prepare_4d_mask() |
|
elif cfg.s2_attention: |
|
raise NotImplementedError( |
|
"Shifted-sparse attention not currently implemented without flash attention." |
|
) |
|
|
|
if cfg.unsloth_cross_entropy_loss: |
|
from axolotl.monkeypatch.unsloth_ import integrate_cross_entropy_loss_patch |
|
|
|
integrate_cross_entropy_loss_patch() |
|
|
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: |
|
from axolotl.monkeypatch.unsloth_ import patch_self_attn_lora |
|
|
|
patch_self_attn_lora() |
|
|
|
|
|
if ( |
|
cfg.model_config_type == "mistral" |
|
and cfg.flash_attention |
|
and cfg.sample_packing |
|
): |
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import ( |
|
replace_mistral_attn_with_flash_attn, |
|
) |
|
|
|
LOG.info("patching mistral with flash attention") |
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) |
|
|
|
if cfg.is_llama_derived_model and cfg.sample_packing and not inference: |
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask |
|
|
|
LOG.info("patching _expand_mask") |
|
hijack_expand_mask() |
|
|
|
model_kwargs: Dict[str, Any] = {} |
|
|
|
if cfg.model_kwargs: |
|
for key, val in cfg.model_kwargs.items(): |
|
model_kwargs[key] = val |
|
|
|
max_memory = cfg.max_memory |
|
device_map = cfg.device_map |
|
|
|
if cfg.gpu_memory_limit: |
|
gpu_memory_limit = ( |
|
str(cfg.gpu_memory_limit) + "GiB" |
|
if isinstance(cfg.gpu_memory_limit, int) |
|
else cfg.gpu_memory_limit |
|
) |
|
|
|
max_memory = {} |
|
for i in range(torch.cuda.device_count()): |
|
max_memory[i] = gpu_memory_limit |
|
max_memory["cpu"] = "256GiB" |
|
|
|
if max_memory is not None: |
|
|
|
from accelerate import infer_auto_device_map |
|
|
|
with init_empty_weights(): |
|
model_canvas = AutoModelForCausalLM.from_config( |
|
model_config, trust_remote_code=cfg.trust_remote_code or False |
|
) |
|
model_canvas.tie_weights() |
|
device_map = infer_auto_device_map( |
|
model_canvas, |
|
max_memory=max_memory, |
|
dtype=cfg.torch_dtype, |
|
) |
|
|
|
max_memory = None |
|
|
|
model_kwargs["device_map"] = device_map |
|
model_kwargs["torch_dtype"] = cfg.torch_dtype |
|
|
|
if torch.backends.mps.is_available(): |
|
model_kwargs["device_map"] = "mps:0" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_deepspeed_zero3_enabled(): |
|
del model_kwargs["device_map"] |
|
|
|
if cfg.revision_of_model: |
|
model_kwargs["revision"] = cfg.revision_of_model |
|
|
|
if cfg.gptq: |
|
if not hasattr(model_config, "quantization_config"): |
|
LOG.warning("model config does not contain quantization_config information") |
|
else: |
|
if cfg.gptq_disable_exllama is not None: |
|
model_config.quantization_config[ |
|
"disable_exllama" |
|
] = cfg.gptq_disable_exllama |
|
model_kwargs["quantization_config"] = GPTQConfig( |
|
**model_config.quantization_config |
|
) |
|
if cfg.adapter == "qlora" and cfg.load_in_4bit: |
|
bnb_config = { |
|
"load_in_4bit": True, |
|
"llm_int8_threshold": 6.0, |
|
"llm_int8_has_fp16_weight": False, |
|
"bnb_4bit_compute_dtype": cfg.torch_dtype, |
|
"bnb_4bit_use_double_quant": True, |
|
"bnb_4bit_quant_type": "nf4", |
|
"bnb_4bit_quant_storage": torch.bfloat16, |
|
} |
|
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed: |
|
|
|
|
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32 |
|
|
|
if cfg.bnb_config_kwargs: |
|
bnb_config.update(cfg.bnb_config_kwargs) |
|
|
|
model_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
**bnb_config, |
|
) |
|
elif cfg.adapter == "lora" and cfg.load_in_8bit: |
|
bnb_config = { |
|
"load_in_8bit": True, |
|
} |
|
|
|
if cfg.model_config_type == "jamba": |
|
bnb_config["llm_int8_skip_modules"] = ["mamba"] |
|
model_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
**bnb_config, |
|
) |
|
|
|
if cfg.load_in_8bit and cfg.adapter is not None: |
|
model_kwargs["load_in_8bit"] = True |
|
if cfg.load_in_4bit and cfg.adapter is not None: |
|
model_kwargs["load_in_4bit"] = True |
|
|
|
|
|
if "quantization_config" in model_kwargs or cfg.gptq: |
|
if "load_in_8bit" in model_kwargs: |
|
del model_kwargs["load_in_8bit"] |
|
if "load_in_4bit" in model_kwargs: |
|
del model_kwargs["load_in_4bit"] |
|
|
|
|
|
if cfg.flash_attention: |
|
if not cfg.sample_packing: |
|
if cfg.s2_attention: |
|
pass |
|
|
|
model_kwargs["attn_implementation"] = "flash_attention_2" |
|
model_config._attn_implementation = ( |
|
"flash_attention_2" |
|
) |
|
else: |
|
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES: |
|
model_kwargs["attn_implementation"] = "flash_attention_2" |
|
model_config._attn_implementation = ( |
|
"flash_attention_2" |
|
) |
|
else: |
|
model_kwargs["attn_implementation"] = "eager" |
|
model_config._attn_implementation = ( |
|
"eager" |
|
) |
|
elif cfg.sdp_attention: |
|
model_kwargs["attn_implementation"] = "sdpa" |
|
model_config._attn_implementation = "sdpa" |
|
elif cfg.eager_attention: |
|
model_kwargs["attn_implementation"] = "eager" |
|
model_config._attn_implementation = "eager" |
|
|
|
if cfg.low_cpu_mem_usage: |
|
model_kwargs["low_cpu_mem_usage"] = True |
|
|
|
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora" |
|
|
|
try: |
|
skip_move_to_device = False |
|
if ( |
|
cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading |
|
) and not qlora_fsdp: |
|
model = load_sharded_model( |
|
base_model, |
|
model_config, |
|
cfg, |
|
torch_dtype=cfg.torch_dtype, |
|
) |
|
skip_move_to_device = True |
|
elif ( |
|
qlora_fsdp |
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading |
|
and cfg.model_config_type == "dbrx" |
|
): |
|
quant_storage = cfg.torch_dtype |
|
model = load_sharded_model_quant( |
|
base_model, |
|
model_config, |
|
cfg, |
|
quant_storage=quant_storage, |
|
) |
|
skip_move_to_device = True |
|
elif ( |
|
model_config.model_type == "llama" |
|
and not cfg.trust_remote_code |
|
and not cfg.gptq |
|
): |
|
from transformers import LlamaForCausalLM |
|
|
|
model = LlamaForCausalLM.from_pretrained( |
|
base_model, |
|
config=model_config, |
|
**model_kwargs, |
|
) |
|
|
|
if cfg.flash_attention and not inference: |
|
from axolotl.monkeypatch.llama_attn_hijack_flash import ( |
|
is_xformers_swiglu_available, |
|
replace_llama_mlp_with_swiglu, |
|
replace_llama_qkv_with_fused, |
|
) |
|
|
|
if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available(): |
|
LOG.info("patching with SwiGLU") |
|
replace_llama_mlp_with_swiglu(model) |
|
|
|
if cfg.flash_attn_fuse_qkv: |
|
LOG.info("patching with fused QKV") |
|
replace_llama_qkv_with_fused(model) |
|
elif model_type == "MambaLMHeadModel": |
|
|
|
MambaLMHeadModel = fix_mamba_attn_for_loss() |
|
|
|
model_kwargs["dtype"] = model_kwargs["torch_dtype"] |
|
model_kwargs["device"] = torch.cuda.current_device() |
|
del model_kwargs["torch_dtype"] |
|
del model_kwargs["device_map"] |
|
|
|
model = MambaLMHeadModel.from_pretrained( |
|
base_model, |
|
**model_kwargs, |
|
) |
|
elif model_type and not cfg.trust_remote_code: |
|
if cfg.gptq: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
config=model_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
model = getattr(transformers, model_type).from_pretrained( |
|
base_model, |
|
config=model_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
|
|
|
|
if ( |
|
hasattr(model_config, "max_seq_len") |
|
and model_config.max_seq_len |
|
and cfg.sequence_len > model_config.max_seq_len |
|
): |
|
model_config.max_seq_len = cfg.sequence_len |
|
LOG.warning(f"increasing context length to {cfg.sequence_len}") |
|
elif ( |
|
hasattr(model_config, "max_sequence_length") |
|
and model_config.max_sequence_length |
|
and cfg.sequence_len > model_config.max_sequence_length |
|
): |
|
model_config.max_sequence_length = cfg.sequence_len |
|
LOG.warning(f"increasing context length to {cfg.sequence_len}") |
|
if cfg.gptq: |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
config=model_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
else: |
|
if qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading: |
|
skip_move_to_device = True |
|
if "device_map" in model_kwargs: |
|
del model_kwargs["device_map"] |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
config=model_config, |
|
trust_remote_code=cfg.trust_remote_code or False, |
|
**model_kwargs, |
|
) |
|
except Exception as err: |
|
LOG.exception(err) |
|
raise err |
|
|
|
if isinstance(model, (PeftModel, PeftModelForCausalLM)) and not qlora_fsdp: |
|
model = model.merge_and_unload() |
|
|
|
embeddings_len = ( |
|
math.ceil(len(tokenizer) / 32) * 32 |
|
if cfg.resize_token_embeddings_to_32x |
|
else len(tokenizer) |
|
) |
|
if ( |
|
hasattr(model, "get_input_embeddings") |
|
and model.get_input_embeddings().num_embeddings < embeddings_len |
|
): |
|
model.resize_token_embeddings(embeddings_len) |
|
else: |
|
model.tie_weights() |
|
|
|
if ( |
|
hasattr(model, "config") |
|
and hasattr(model.config, "max_position_embeddings") |
|
and model.config.max_position_embeddings |
|
and cfg.sequence_len > model.config.max_position_embeddings |
|
): |
|
LOG.warning( |
|
f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}" |
|
) |
|
model.config.max_position_embeddings = cfg.sequence_len |
|
|
|
if ( |
|
hasattr(model, "config") |
|
and hasattr(model.config, "bos_token_id") |
|
and model.config.bos_token_id |
|
and model.config.bos_token_id != tokenizer.bos_token_id |
|
): |
|
model.config.bos_token_id = tokenizer.bos_token_id |
|
|
|
if ( |
|
hasattr(model, "config") |
|
and hasattr(model.config, "eos_token_id") |
|
and model.config.eos_token_id |
|
and model.config.eos_token_id != tokenizer.eos_token_id |
|
): |
|
model.config.eos_token_id = tokenizer.eos_token_id |
|
|
|
if hasattr(model, "device") and model.device.type in ("cuda", "mps"): |
|
log_gpu_memory_usage(LOG, "after model load", model.device) |
|
|
|
|
|
embedding_modules = get_linear_embedding_layers(cfg.model_config_type) |
|
if not cfg.fsdp: |
|
|
|
for name, module in model.named_modules(): |
|
if "norm" in name or name.endswith(".gate"): |
|
module.to(torch.float32) |
|
if model_config.model_type == "btlm": |
|
|
|
continue |
|
if any(m in name for m in embedding_modules): |
|
if hasattr(module, "weight"): |
|
module.to(torch.float32) |
|
|
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp |
|
skip_prepare_model_for_kbit_training = False |
|
|
|
if is_deepspeed_zero3_enabled(): |
|
from deepspeed.utils import ( |
|
set_z3_leaf_modules, |
|
) |
|
|
|
if cfg.model_config_type == "mixtral": |
|
moe_block = get_module_class_from_name(model, "MixtralSparseMoeBlock") |
|
set_z3_leaf_modules(model, [moe_block]) |
|
elif cfg.model_config_type == "dbrx": |
|
moe_block = get_module_class_from_name(model, "DbrxFFN") |
|
set_z3_leaf_modules(model, [moe_block]) |
|
|
|
if cfg.model_config_type == "qwen" and cfg.adapter == "lora": |
|
|
|
skip_prepare_model_for_kbit_training = True |
|
|
|
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits |
|
if cfg.adapter == "lora" and loftq_bits: |
|
skip_prepare_model_for_kbit_training = True |
|
|
|
if qlora_fsdp or (cfg.fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading): |
|
|
|
skip_prepare_model_for_kbit_training = True |
|
|
|
if cfg.adapter in ["lora", "qlora"]: |
|
if cfg.gradient_checkpointing: |
|
model.gradient_checkpointing_enable( |
|
gradient_checkpointing_kwargs=cfg.gradient_checkpointing_kwargs |
|
) |
|
if ( |
|
cfg.load_in_8bit or cfg.load_in_4bit |
|
) and not skip_prepare_model_for_kbit_training: |
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training") |
|
model = prepare_model_for_kbit_training( |
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing |
|
) |
|
needs_fa2_dtype = True |
|
|
|
|
|
|
|
if (needs_fa2_dtype or cfg.flash_attention) and not qlora_fsdp: |
|
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype) |
|
for name, module in model.named_modules(): |
|
if "norm" in name: |
|
module.to(cfg.torch_dtype) |
|
if any(m in name for m in embedding_modules): |
|
if hasattr(module, "weight"): |
|
module.to(cfg.torch_dtype) |
|
|
|
lora_config = None |
|
if not reference_model or cfg.lora_model_dir: |
|
|
|
|
|
if ( |
|
cfg.adapter |
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"] |
|
and not cfg.merge_lora |
|
): |
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True) |
|
else: |
|
model, lora_config = load_adapter(model, cfg, cfg.adapter) |
|
|
|
if ( |
|
cfg.ddp |
|
and not load_in_8bit |
|
and not (cfg.rl and cfg.load_in_4bit) |
|
and not skip_move_to_device |
|
): |
|
|
|
model.to(f"cuda:{cfg.local_rank}") |
|
|
|
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: |
|
setattr(model, "is_parallelizable", True) |
|
setattr(model, "model_parallel", True) |
|
|
|
requires_grad = [] |
|
for name, param in model.named_parameters(recurse=True): |
|
if param.requires_grad: |
|
requires_grad.append(f"{name}: {param.requires_grad}") |
|
if len(requires_grad) == 0: |
|
LOG.warning("there are no parameters that require gradient updates") |
|
if hasattr(model, "config"): |
|
model.config.use_cache = False |
|
|
|
if cfg.flash_optimum: |
|
from optimum.bettertransformer import BetterTransformer |
|
|
|
model = BetterTransformer.transform(model) |
|
|
|
if cfg.adapter is not None: |
|
log_gpu_memory_usage(LOG, "after adapters", model.device) |
|
|
|
if cfg.unsloth_lora_mlp: |
|
from axolotl.monkeypatch.unsloth_ import integrate_lora_mlp_patch |
|
|
|
integrate_lora_mlp_patch(model) |
|
if cfg.unsloth_lora_qkv or cfg.unsloth_lora_o: |
|
from axolotl.monkeypatch.unsloth_ import integrate_lora_patch |
|
|
|
integrate_lora_patch(model, cfg) |
|
|
|
|
|
return model, lora_config |
|
|
|
|
|
def load_adapter(model, cfg, adapter, inference=False): |
|
|
|
|
|
if adapter is None: |
|
return model, None |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
if adapter in ["lora", "qlora"]: |
|
return load_lora(model, cfg, inference=inference) |
|
if adapter == "llama-adapter": |
|
return load_llama_adapter(model, cfg) |
|
|
|
raise NotImplementedError(f"{adapter} peft adapter not available") |
|
|
|
|
|
def load_llama_adapter(model, cfg): |
|
|
|
from peft import AdaptionPromptConfig, get_peft_model |
|
|
|
peft_config = AdaptionPromptConfig( |
|
adapter_layers=cfg.peft_adapter.layers, |
|
adapter_len=cfg.peft_adapter.len, |
|
task_type="CAUSAL_LM", |
|
) |
|
|
|
if cfg.lora_model_dir: |
|
LOG.debug("Loading pretrained PEFT - llama_adapter") |
|
model = PeftModel.from_pretrained( |
|
model, |
|
cfg.lora_model_dir, |
|
torch_dtype=torch.float16, |
|
) |
|
else: |
|
model = get_peft_model(model, peft_config) |
|
|
|
model.print_trainable_parameters() |
|
|
|
return model, peft_config |
|
|
|
|
|
def find_all_linear_names(model): |
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) |
|
lora_module_names = set() |
|
for name, module in model.named_modules(): |
|
if ( |
|
isinstance(module, cls) |
|
or "Linear" in module.__class__.__name__ |
|
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",) |
|
): |
|
names = name.split(".") |
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
|
embedding_modules = get_linear_embedding_layers(model.config.model_type) |
|
output_embedding = embedding_modules[1] |
|
if output_embedding in lora_module_names: |
|
lora_module_names.remove(output_embedding) |
|
|
|
return list(lora_module_names) |
|
|
|
|
|
def setup_quantized_meta_for_peft(model: nn.Module): |
|
"""Replaces `quant_state.to` with a dummy function to prevent PEFT from moving `quant_state` to meta device""" |
|
|
|
def temp_to_method(self, *args, **kwargs): |
|
return self |
|
|
|
for param in model.parameters(): |
|
if isinstance(param, Params4bit): |
|
param.quant_state._orig_to = ( |
|
param.quant_state.to |
|
) |
|
param.quant_state.to = types.MethodType(temp_to_method, param.quant_state) |
|
|
|
|
|
def setup_quantized_peft_meta_for_training(model: nn.Module): |
|
"""Replaces dummy `quant_state.to` method with the original function to allow training to continue""" |
|
for param in model.parameters(): |
|
if isinstance(param, Params4bit) and hasattr(param.quant_state, "_orig_to"): |
|
param.quant_state.to = ( |
|
param.quant_state._orig_to |
|
) |
|
param.quant_state._orig_to = None |
|
|
|
|
|
def load_lora(model, cfg, inference=False, config_only=False): |
|
|
|
|
|
from peft import LoraConfig, get_peft_model |
|
|
|
lora_target_modules = list(cfg.lora_target_modules or []) |
|
|
|
if cfg.lora_target_linear: |
|
linear_names = find_all_linear_names(model) |
|
LOG.info(f"found linear modules: {repr(linear_names)}") |
|
lora_target_modules = list(set(lora_target_modules + linear_names)) |
|
|
|
lora_config_kwargs = {} |
|
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits |
|
if loftq_bits: |
|
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits) |
|
lora_config_kwargs["init_lora_weights"] = "loftq" |
|
if cfg.peft_use_dora: |
|
lora_config_kwargs["use_dora"] = cfg.peft_use_dora |
|
if cfg.peft_use_rslora: |
|
lora_config_kwargs["use_rslora"] = cfg.peft_use_rslora |
|
if cfg.peft_layer_replication: |
|
lora_config_kwargs["layer_replication"] = cfg.peft_layer_replication |
|
|
|
lora_config = LoraConfig( |
|
r=cfg.lora_r, |
|
lora_alpha=cfg.lora_alpha, |
|
target_modules=lora_target_modules, |
|
layers_to_transform=cfg.peft_layers_to_transform, |
|
lora_dropout=cfg.lora_dropout, |
|
fan_in_fan_out=cfg.lora_fan_in_fan_out, |
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
**lora_config_kwargs, |
|
) |
|
|
|
if config_only: |
|
return None, lora_config |
|
|
|
rank = int(os.environ.get("LOCAL_RANK", 0)) |
|
|
|
if ( |
|
cfg.fsdp |
|
and cfg.adapter |
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading |
|
and rank != 0 |
|
): |
|
setup_quantized_meta_for_peft(model) |
|
|
|
if cfg.lora_model_dir: |
|
LOG.debug("Loading pretrained PEFT - LoRA") |
|
model_kwargs: Any = {} |
|
if cfg.lora_on_cpu: |
|
model_kwargs["max_memory"] = {"cpu": "256GiB"} |
|
model_kwargs["device_map"] = {"": "cpu"} |
|
model = PeftModel.from_pretrained( |
|
model, |
|
cfg.lora_model_dir, |
|
is_trainable=(not inference), |
|
**model_kwargs, |
|
) |
|
else: |
|
model = get_peft_model(model, lora_config) |
|
|
|
if rank == 0: |
|
try: |
|
model.print_trainable_parameters() |
|
except AttributeError as exc: |
|
LOG.warning( |
|
"Exception caught during model.print_trainable_parameters(): %s", exc |
|
) |
|
elif ( |
|
cfg.fsdp |
|
and cfg.adapter |
|
and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading |
|
and rank != 0 |
|
): |
|
setup_quantized_peft_meta_for_training(model) |
|
|
|
return model, lora_config |
|
|
|
|
|
def ensure_dtype(model, dtype=torch.bfloat16): |
|
for name, module in model.named_modules(): |
|
try: |
|
if module.weight.dtype != dtype: |
|
print(f"Converting module {name}: {module.weight.dtype} -> {dtype}") |
|
module.to(dtype) |
|
except AttributeError: |
|
pass |
|
|