strip out hacky qlora-fsdp workarounds now that qlora-fsdp fixes are upstreamed (#1428)
Browse files
examples/llama-2/qlora-fsdp.yml
CHANGED
@@ -36,7 +36,7 @@ wandb_log_model:
|
|
36 |
gradient_accumulation_steps: 4
|
37 |
micro_batch_size: 4
|
38 |
num_epochs: 4
|
39 |
-
optimizer:
|
40 |
lr_scheduler: cosine
|
41 |
learning_rate: 0.00001
|
42 |
|
@@ -66,5 +66,11 @@ weight_decay: 0.0
|
|
66 |
fsdp:
|
67 |
- full_shard
|
68 |
fsdp_config:
|
|
|
|
|
|
|
|
|
|
|
69 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
|
|
70 |
special_tokens:
|
|
|
36 |
gradient_accumulation_steps: 4
|
37 |
micro_batch_size: 4
|
38 |
num_epochs: 4
|
39 |
+
optimizer: adamw_torch
|
40 |
lr_scheduler: cosine
|
41 |
learning_rate: 0.00001
|
42 |
|
|
|
66 |
fsdp:
|
67 |
- full_shard
|
68 |
fsdp_config:
|
69 |
+
fsdp_limit_all_gathers: true
|
70 |
+
fsdp_sync_module_states: true
|
71 |
+
fsdp_offload_params: true
|
72 |
+
fsdp_use_orig_params: false
|
73 |
+
fsdp_cpu_ram_efficient_loading: true
|
74 |
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
|
75 |
+
fsdp_state_dict_type: SHARDED_STATE_DICT
|
76 |
special_tokens:
|
requirements.txt
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.9.0
|
4 |
-
transformers @ git+https://github.com/huggingface/transformers.git@
|
5 |
tokenizers==0.15.0
|
6 |
-
bitsandbytes
|
7 |
-
accelerate==0.
|
8 |
deepspeed==0.13.1
|
9 |
pydantic==2.6.3
|
10 |
addict
|
@@ -40,4 +40,3 @@ gcsfs
|
|
40 |
# adlfs
|
41 |
|
42 |
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
43 |
-
fastcore>=1.5.29
|
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.9.0
|
4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
|
5 |
tokenizers==0.15.0
|
6 |
+
bitsandbytes==0.43.0
|
7 |
+
accelerate==0.28.0
|
8 |
deepspeed==0.13.1
|
9 |
pydantic==2.6.3
|
10 |
addict
|
|
|
40 |
# adlfs
|
41 |
|
42 |
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
|
src/axolotl/core/policies/__init__.py
DELETED
File without changes
|
src/axolotl/core/policies/auto_wrap.py
DELETED
@@ -1,55 +0,0 @@
|
|
1 |
-
"""module for building the auto wrap policy for FSDP"""
|
2 |
-
import functools
|
3 |
-
|
4 |
-
from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
|
5 |
-
from torch.distributed.fsdp.wrap import (
|
6 |
-
_or_policy,
|
7 |
-
lambda_auto_wrap_policy,
|
8 |
-
transformer_auto_wrap_policy,
|
9 |
-
)
|
10 |
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
11 |
-
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
12 |
-
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
13 |
-
|
14 |
-
SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
|
15 |
-
"llama",
|
16 |
-
"mistral",
|
17 |
-
"mixtral",
|
18 |
-
]
|
19 |
-
|
20 |
-
|
21 |
-
def get_wrapping_policy_factory(model_type):
|
22 |
-
if model_type == "llama":
|
23 |
-
layer_to_wrap = LlamaDecoderLayer
|
24 |
-
elif model_type == "mistral":
|
25 |
-
layer_to_wrap = MistralDecoderLayer
|
26 |
-
elif model_type == "mixtral":
|
27 |
-
layer_to_wrap = MixtralDecoderLayer
|
28 |
-
|
29 |
-
def get_wrapping_policy():
|
30 |
-
"""This checks for lora layers (has weight and requires_grad)"""
|
31 |
-
|
32 |
-
def lambda_policy_fn(module):
|
33 |
-
return (
|
34 |
-
len(list(module.named_children())) == 0
|
35 |
-
and getattr(module, "weight", None) is not None
|
36 |
-
and module.weight.requires_grad
|
37 |
-
)
|
38 |
-
|
39 |
-
lambda_policy = functools.partial(
|
40 |
-
lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
|
41 |
-
)
|
42 |
-
transformer_layer_name = layer_to_wrap
|
43 |
-
transformer_wrap_policy = functools.partial(
|
44 |
-
transformer_auto_wrap_policy,
|
45 |
-
transformer_layer_cls=(
|
46 |
-
PrefixEncoder,
|
47 |
-
PromptEncoder,
|
48 |
-
PromptEmbedding,
|
49 |
-
transformer_layer_name,
|
50 |
-
),
|
51 |
-
)
|
52 |
-
policies = [lambda_policy, transformer_wrap_policy]
|
53 |
-
return functools.partial(_or_policy, policies=policies)
|
54 |
-
|
55 |
-
return get_wrapping_policy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -8,7 +8,6 @@ import importlib
|
|
8 |
import importlib.util
|
9 |
import logging
|
10 |
import math
|
11 |
-
import os
|
12 |
import sys
|
13 |
from abc import abstractmethod
|
14 |
from collections import defaultdict
|
@@ -19,10 +18,7 @@ from typing import Dict, List, Literal, Optional, Type, Union
|
|
19 |
|
20 |
import torch
|
21 |
import transformers
|
22 |
-
from accelerate import FullyShardedDataParallelPlugin
|
23 |
-
from accelerate.utils import str_to_bool
|
24 |
from datasets import Dataset
|
25 |
-
from torch.distributed.fsdp import MixedPrecision
|
26 |
from torch.optim.lr_scheduler import OneCycleLR
|
27 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
28 |
from transformers import (
|
@@ -35,7 +31,6 @@ from transformers.trainer_utils import seed_worker
|
|
35 |
from transformers.utils import is_sagemaker_mp_enabled
|
36 |
from trl import DPOTrainer
|
37 |
|
38 |
-
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
39 |
from axolotl.loraplus import create_loraplus_optimizer
|
40 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
41 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
@@ -591,51 +586,14 @@ class AxolotlTrainer(Trainer):
|
|
591 |
|
592 |
@wraps(Trainer.create_accelerator_and_postprocess)
|
593 |
def create_accelerator_and_postprocess(self):
|
594 |
-
rank = int(os.environ.get("LOCAL_RANK", 0))
|
595 |
res = super().create_accelerator_and_postprocess()
|
596 |
|
597 |
-
if self.args.qlora is False:
|
598 |
-
return res
|
599 |
-
|
600 |
-
# the rest of this method override is specific to fsdp + qlora (for now)
|
601 |
-
sync_module_states = (
|
602 |
-
str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
|
603 |
-
)
|
604 |
-
|
605 |
-
mp_policy = None
|
606 |
-
amp = os.environ["ACCELERATE_MIXED_PRECISION"]
|
607 |
-
if amp == "fp16":
|
608 |
-
mp_policy = MixedPrecision(
|
609 |
-
param_dtype=torch.float32,
|
610 |
-
reduce_dtype=torch.float32,
|
611 |
-
buffer_dtype=torch.float32,
|
612 |
-
)
|
613 |
-
elif amp == "bf16":
|
614 |
-
mp_policy = MixedPrecision(
|
615 |
-
param_dtype=torch.float32,
|
616 |
-
reduce_dtype=torch.float32,
|
617 |
-
buffer_dtype=torch.float32,
|
618 |
-
)
|
619 |
-
|
620 |
-
# If somehow we figure out how we want to parameterize we want to autocast buffers...
|
621 |
-
# mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
|
622 |
-
# load_param_skip_names = ['inv_freq']
|
623 |
-
|
624 |
if self.is_fsdp_enabled:
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
limit_all_gathers=True,
|
631 |
-
param_init_fn=lambda module: module.to_empty(
|
632 |
-
device=torch.device("cuda"), recurse=False
|
633 |
-
)
|
634 |
-
if (rank != 0 and sync_module_states)
|
635 |
-
else None,
|
636 |
-
mixed_precision_policy=mp_policy,
|
637 |
-
)
|
638 |
-
self.accelerator.state.fsdp_plugin = fsdp_plugin
|
639 |
|
640 |
return res
|
641 |
|
|
|
8 |
import importlib.util
|
9 |
import logging
|
10 |
import math
|
|
|
11 |
import sys
|
12 |
from abc import abstractmethod
|
13 |
from collections import defaultdict
|
|
|
18 |
|
19 |
import torch
|
20 |
import transformers
|
|
|
|
|
21 |
from datasets import Dataset
|
|
|
22 |
from torch.optim.lr_scheduler import OneCycleLR
|
23 |
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
24 |
from transformers import (
|
|
|
31 |
from transformers.utils import is_sagemaker_mp_enabled
|
32 |
from trl import DPOTrainer
|
33 |
|
|
|
34 |
from axolotl.loraplus import create_loraplus_optimizer
|
35 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
36 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
|
586 |
|
587 |
@wraps(Trainer.create_accelerator_and_postprocess)
|
588 |
def create_accelerator_and_postprocess(self):
|
|
|
589 |
res = super().create_accelerator_and_postprocess()
|
590 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
if self.is_fsdp_enabled:
|
592 |
+
if (
|
593 |
+
"limit_all_gathers" in self.args.fsdp_config
|
594 |
+
and self.args.fsdp_config["limit_all_gathers"]
|
595 |
+
):
|
596 |
+
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
597 |
|
598 |
return res
|
599 |
|
src/axolotl/utils/models.py
CHANGED
@@ -5,16 +5,14 @@ import logging
|
|
5 |
import math
|
6 |
import os
|
7 |
import types
|
8 |
-
from typing import Any, Dict,
|
9 |
|
10 |
import addict
|
11 |
import bitsandbytes as bnb
|
12 |
-
import safetensors
|
13 |
import torch
|
14 |
import transformers
|
15 |
from accelerate import init_empty_weights
|
16 |
-
from bitsandbytes.nn import
|
17 |
-
from fastcore.parallel import parallel
|
18 |
from peft import (
|
19 |
LoftQConfig,
|
20 |
PeftConfig,
|
@@ -23,7 +21,7 @@ from peft import (
|
|
23 |
prepare_model_for_kbit_training,
|
24 |
)
|
25 |
from peft.tuners.lora import QuantLinear
|
26 |
-
from torch import
|
27 |
from transformers import ( # noqa: F401
|
28 |
AddedToken,
|
29 |
AutoConfig,
|
@@ -35,9 +33,7 @@ from transformers import ( # noqa: F401
|
|
35 |
PreTrainedTokenizerBase,
|
36 |
)
|
37 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
38 |
-
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
|
39 |
|
40 |
-
from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
41 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
42 |
from axolotl.monkeypatch.multipack import (
|
43 |
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
@@ -272,117 +268,6 @@ def load_tokenizer(cfg):
|
|
272 |
return tokenizer
|
273 |
|
274 |
|
275 |
-
def replace_linear(
|
276 |
-
model: nn.Module,
|
277 |
-
linear_replacement: Type[nn.Module],
|
278 |
-
quant_config: Union[dict, None] = None,
|
279 |
-
skip_modules=None,
|
280 |
-
**kwargs,
|
281 |
-
):
|
282 |
-
"""
|
283 |
-
Replace linear modules with a new Linear module.
|
284 |
-
Parameters:
|
285 |
-
model (`torch.nn.Module`):
|
286 |
-
Input model or `torch.nn.Module` as the function is run recursively.
|
287 |
-
linear_replacement (`torch.nn.Module`):
|
288 |
-
The linear module that replaces the old one. Only expects standard arguments.
|
289 |
-
If other arguments need to be passed, use a lambda.
|
290 |
-
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
|
291 |
-
List of modules names not to convert. Defaults to `lm_head`.
|
292 |
-
"""
|
293 |
-
if skip_modules is None:
|
294 |
-
skip_modules = ["lm_head"]
|
295 |
-
for name, module in model.named_children():
|
296 |
-
if len(list(module.children())) > 0:
|
297 |
-
replace_linear(
|
298 |
-
module, linear_replacement, quant_config, skip_modules, **kwargs
|
299 |
-
)
|
300 |
-
|
301 |
-
if isinstance(module, torch.nn.Linear) and name not in skip_modules:
|
302 |
-
if issubclass(linear_replacement, Linear4bit):
|
303 |
-
model._modules[ # pylint: disable=protected-access
|
304 |
-
name
|
305 |
-
] = linear_replacement(
|
306 |
-
module.in_features,
|
307 |
-
module.out_features,
|
308 |
-
module.bias is not None,
|
309 |
-
**kwargs,
|
310 |
-
)
|
311 |
-
else:
|
312 |
-
raise ValueError(
|
313 |
-
f"Unsupported linear replacement: {type(linear_replacement)}"
|
314 |
-
)
|
315 |
-
return model
|
316 |
-
|
317 |
-
|
318 |
-
def load_and_quantize(
|
319 |
-
module: nn.Module,
|
320 |
-
name: str,
|
321 |
-
value: Tensor,
|
322 |
-
device: torch.device = None,
|
323 |
-
dtype: torch.dtype = None,
|
324 |
-
skip_names: Optional[List[str]] = None,
|
325 |
-
is_meta_rank: bool = False,
|
326 |
-
low_memory: bool = True,
|
327 |
-
verbose: bool = False,
|
328 |
-
quant_method: str = "bnb",
|
329 |
-
):
|
330 |
-
"""
|
331 |
-
Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
|
332 |
-
|
333 |
-
Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
|
334 |
-
"""
|
335 |
-
|
336 |
-
if skip_names is None:
|
337 |
-
skip_names = []
|
338 |
-
|
339 |
-
def place_on_device(value):
|
340 |
-
if is_meta_rank:
|
341 |
-
device = "meta"
|
342 |
-
elif low_memory:
|
343 |
-
device = "cpu"
|
344 |
-
else:
|
345 |
-
device = "cuda"
|
346 |
-
return value.to(device=device, dtype=dtype)
|
347 |
-
|
348 |
-
if any(skip_name in name for skip_name in skip_names):
|
349 |
-
if verbose:
|
350 |
-
print(f"Skipping {name} because it is in skip_names")
|
351 |
-
return
|
352 |
-
|
353 |
-
module_key, _, value_key = name.rpartition(".")
|
354 |
-
try:
|
355 |
-
submodule = module.get_submodule(module_key)
|
356 |
-
except AttributeError as exc:
|
357 |
-
print(f"Module {module_key} not found:\n{exc}")
|
358 |
-
return
|
359 |
-
|
360 |
-
try:
|
361 |
-
if quant_method == "bnb":
|
362 |
-
param = submodule.get_parameter(value_key)
|
363 |
-
if isinstance(param, Params4bit):
|
364 |
-
# With `sync_module_states=True`, a meta device Params4bit needs to be the same
|
365 |
-
# shape as the quantized Params4bit with an initialized quant_state. However,
|
366 |
-
# FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
|
367 |
-
# workaround quantizes Params4bit to initialize quant_state on all ranks, then
|
368 |
-
# replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
|
369 |
-
value = type(param)(
|
370 |
-
value.to(device=device, dtype=dtype).data, **param.__dict__
|
371 |
-
).cuda(device)
|
372 |
-
if is_meta_rank:
|
373 |
-
value = type(param)(value.data.to("meta"), **value.__dict__)
|
374 |
-
elif low_memory:
|
375 |
-
value = type(param)(value.data.to("cpu"), **value.__dict__)
|
376 |
-
else:
|
377 |
-
value = type(param)(place_on_device(value).data)
|
378 |
-
|
379 |
-
except AttributeError:
|
380 |
-
# it's a buffer
|
381 |
-
value = place_on_device(value)
|
382 |
-
|
383 |
-
setattr(submodule, value_key, value)
|
384 |
-
|
385 |
-
|
386 |
def load_model(
|
387 |
cfg: DictDefault,
|
388 |
tokenizer: PreTrainedTokenizerBase,
|
@@ -568,6 +453,7 @@ def load_model(
|
|
568 |
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
569 |
"bnb_4bit_use_double_quant": True,
|
570 |
"bnb_4bit_quant_type": "nf4",
|
|
|
571 |
}
|
572 |
|
573 |
if cfg.bnb_config_kwargs:
|
@@ -617,78 +503,10 @@ def load_model(
|
|
617 |
model_kwargs["attn_implementation"] = "eager"
|
618 |
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
619 |
|
620 |
-
qlora_fsdp =
|
621 |
-
cfg.fsdp
|
622 |
-
and cfg.adapter == "qlora"
|
623 |
-
and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
|
624 |
-
)
|
625 |
|
626 |
try:
|
627 |
-
if
|
628 |
-
if cfg.bf16 or cfg.bfloat16:
|
629 |
-
torch_dtype, compute_dtype = torch.float32, torch.bfloat16
|
630 |
-
elif cfg.fp16 or cfg.float16:
|
631 |
-
torch_dtype, compute_dtype = torch.float32, torch.float16
|
632 |
-
else:
|
633 |
-
torch_dtype, compute_dtype = torch.float32, torch.float16
|
634 |
-
|
635 |
-
with init_empty_weights():
|
636 |
-
LOG.info("Loading model with empty weights.")
|
637 |
-
model = AutoModelForCausalLM.from_config(model_config)
|
638 |
-
model.model = replace_linear(
|
639 |
-
model.model,
|
640 |
-
Linear4bit,
|
641 |
-
compute_dtype=compute_dtype,
|
642 |
-
quant_type="nf4",
|
643 |
-
quant_storage=torch_dtype,
|
644 |
-
)
|
645 |
-
|
646 |
-
model.is_loaded_in_4bit = True
|
647 |
-
|
648 |
-
# Grab the safetensors files that hold the weights
|
649 |
-
try:
|
650 |
-
idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
|
651 |
-
files, _ = hub.get_checkpoint_shard_files(base_model, idx)
|
652 |
-
except OSError:
|
653 |
-
try:
|
654 |
-
# This means the model doesn't have a model.safetensors.index.json because it is not sharded
|
655 |
-
files = []
|
656 |
-
files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
|
657 |
-
except OSError as exc:
|
658 |
-
# This means the model probably doesn't have a safetensors file
|
659 |
-
raise exc
|
660 |
-
|
661 |
-
# Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
|
662 |
-
# and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
|
663 |
-
def load_and_quantize_parallel(name_param, model, **kwargs):
|
664 |
-
name, param = name_param
|
665 |
-
load_and_quantize(model, name, param, **kwargs)
|
666 |
-
|
667 |
-
param_count = sum((p.numel() for n, p in model.named_parameters()))
|
668 |
-
for filename in files:
|
669 |
-
weights = safetensors.torch.load_file(filename)
|
670 |
-
quant_method = "bnb"
|
671 |
-
devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
|
672 |
-
left = int(os.cpu_count() / torch.cuda.device_count())
|
673 |
-
right = int(
|
674 |
-
8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
|
675 |
-
)
|
676 |
-
n_workers = min(left, right)
|
677 |
-
parallel(
|
678 |
-
load_and_quantize_parallel,
|
679 |
-
weights.items(),
|
680 |
-
n_workers=n_workers,
|
681 |
-
threadpool=True,
|
682 |
-
model=model,
|
683 |
-
dtype=torch_dtype,
|
684 |
-
device=cfg.local_rank,
|
685 |
-
skip_names=[],
|
686 |
-
is_meta_rank=(cfg.local_rank != 0),
|
687 |
-
verbose=False,
|
688 |
-
quant_method=quant_method,
|
689 |
-
)
|
690 |
-
|
691 |
-
elif (
|
692 |
model_config.model_type == "llama"
|
693 |
and not cfg.trust_remote_code
|
694 |
and not cfg.gptq
|
@@ -715,32 +533,6 @@ def load_model(
|
|
715 |
if cfg.flash_attn_fuse_qkv:
|
716 |
LOG.info("patching with fused QKV")
|
717 |
replace_llama_qkv_with_fused(model)
|
718 |
-
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
719 |
-
# This is a WIP, still an issue with the backward pass
|
720 |
-
# RuntimeError: grad can be implicitly created only for scalar outputs
|
721 |
-
# TODO: try config.sequence_parallel = False
|
722 |
-
# # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
|
723 |
-
# # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
|
724 |
-
# # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
|
725 |
-
# from flash_attn.utils.pretrained import state_dict_from_pretrained
|
726 |
-
# from flash_attn.models.gpt import GPTLMHeadModel
|
727 |
-
# from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
|
728 |
-
# from transformers import GPTNeoXConfig
|
729 |
-
# config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
|
730 |
-
# config.use_flash_attn = True
|
731 |
-
# config.fused_bias_fc = True
|
732 |
-
# config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
|
733 |
-
# config.activation_function = "gelu_fast"
|
734 |
-
# config.fused_dropout_add_ln = True
|
735 |
-
# # config.residual_in_fp32 = True
|
736 |
-
#
|
737 |
-
# model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
|
738 |
-
# base_model,
|
739 |
-
# config,
|
740 |
-
# dtype=torch_dtype,
|
741 |
-
# device=cfg.device,
|
742 |
-
# )
|
743 |
-
# model.train() # sets to train instead of eval mode
|
744 |
elif model_type == "MambaLMHeadModel":
|
745 |
# FIXME this is janky at best and hacked together to make it work
|
746 |
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
|
|
5 |
import math
|
6 |
import os
|
7 |
import types
|
8 |
+
from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
|
9 |
|
10 |
import addict
|
11 |
import bitsandbytes as bnb
|
|
|
12 |
import torch
|
13 |
import transformers
|
14 |
from accelerate import init_empty_weights
|
15 |
+
from bitsandbytes.nn import Params4bit
|
|
|
16 |
from peft import (
|
17 |
LoftQConfig,
|
18 |
PeftConfig,
|
|
|
21 |
prepare_model_for_kbit_training,
|
22 |
)
|
23 |
from peft.tuners.lora import QuantLinear
|
24 |
+
from torch import nn
|
25 |
from transformers import ( # noqa: F401
|
26 |
AddedToken,
|
27 |
AutoConfig,
|
|
|
33 |
PreTrainedTokenizerBase,
|
34 |
)
|
35 |
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
|
|
36 |
|
|
|
37 |
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
38 |
from axolotl.monkeypatch.multipack import (
|
39 |
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
|
|
268 |
return tokenizer
|
269 |
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
def load_model(
|
272 |
cfg: DictDefault,
|
273 |
tokenizer: PreTrainedTokenizerBase,
|
|
|
453 |
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
454 |
"bnb_4bit_use_double_quant": True,
|
455 |
"bnb_4bit_quant_type": "nf4",
|
456 |
+
"bnb_4bit_quant_storage": torch.bfloat16,
|
457 |
}
|
458 |
|
459 |
if cfg.bnb_config_kwargs:
|
|
|
503 |
model_kwargs["attn_implementation"] = "eager"
|
504 |
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
505 |
|
506 |
+
qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
|
|
|
|
|
|
|
|
|
507 |
|
508 |
try:
|
509 |
+
if (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
model_config.model_type == "llama"
|
511 |
and not cfg.trust_remote_code
|
512 |
and not cfg.gptq
|
|
|
533 |
if cfg.flash_attn_fuse_qkv:
|
534 |
LOG.info("patching with fused QKV")
|
535 |
replace_llama_qkv_with_fused(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
536 |
elif model_type == "MambaLMHeadModel":
|
537 |
# FIXME this is janky at best and hacked together to make it work
|
538 |
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
src/axolotl/utils/trainer.py
CHANGED
@@ -304,6 +304,10 @@ def setup_fsdp_envs(cfg):
|
|
304 |
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
305 |
if cfg.fsdp_config.fsdp_sync_module_states:
|
306 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
|
|
|
|
|
|
|
|
307 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
308 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
309 |
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
|
|
304 |
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
|
305 |
if cfg.fsdp_config.fsdp_sync_module_states:
|
306 |
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
307 |
+
if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
|
308 |
+
os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
|
309 |
+
if cfg.fsdp_config.fsdp_use_orig_params:
|
310 |
+
os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
|
311 |
if cfg.fsdp_config.fsdp_state_dict_type:
|
312 |
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
313 |
if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
|
tests/e2e/test_mixtral.py
CHANGED
@@ -77,7 +77,7 @@ class TestMixtral(unittest.TestCase):
|
|
77 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
78 |
assert (
|
79 |
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
80 |
-
== torch.
|
81 |
)
|
82 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
83 |
|
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
|
|
131 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
132 |
assert (
|
133 |
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
134 |
-
== torch.
|
135 |
)
|
136 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
137 |
|
|
|
77 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
78 |
assert (
|
79 |
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
80 |
+
== torch.float32
|
81 |
)
|
82 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
83 |
|
|
|
131 |
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
132 |
assert (
|
133 |
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
134 |
+
== torch.float32
|
135 |
)
|
136 |
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
137 |
|