winglian commited on
Commit
2a1589f
·
unverified ·
1 Parent(s): 7d55607

strip out hacky qlora-fsdp workarounds now that qlora-fsdp fixes are upstreamed (#1428)

Browse files
examples/llama-2/qlora-fsdp.yml CHANGED
@@ -36,7 +36,7 @@ wandb_log_model:
36
  gradient_accumulation_steps: 4
37
  micro_batch_size: 4
38
  num_epochs: 4
39
- optimizer: paged_adamw_8bit
40
  lr_scheduler: cosine
41
  learning_rate: 0.00001
42
 
@@ -66,5 +66,11 @@ weight_decay: 0.0
66
  fsdp:
67
  - full_shard
68
  fsdp_config:
 
 
 
 
 
69
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
 
70
  special_tokens:
 
36
  gradient_accumulation_steps: 4
37
  micro_batch_size: 4
38
  num_epochs: 4
39
+ optimizer: adamw_torch
40
  lr_scheduler: cosine
41
  learning_rate: 0.00001
42
 
 
66
  fsdp:
67
  - full_shard
68
  fsdp_config:
69
+ fsdp_limit_all_gathers: true
70
+ fsdp_sync_module_states: true
71
+ fsdp_offload_params: true
72
+ fsdp_use_orig_params: false
73
+ fsdp_cpu_ram_efficient_loading: true
74
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
75
+ fsdp_state_dict_type: SHARDED_STATE_DICT
76
  special_tokens:
requirements.txt CHANGED
@@ -1,10 +1,10 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.9.0
4
- transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
5
  tokenizers==0.15.0
6
- bitsandbytes>=0.43.0
7
- accelerate==0.26.1
8
  deepspeed==0.13.1
9
  pydantic==2.6.3
10
  addict
@@ -40,4 +40,3 @@ gcsfs
40
  # adlfs
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
43
- fastcore>=1.5.29
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft==0.9.0
4
+ transformers @ git+https://github.com/huggingface/transformers.git@73a73b415e36f41481369f6129cb4b62bb127a78
5
  tokenizers==0.15.0
6
+ bitsandbytes==0.43.0
7
+ accelerate==0.28.0
8
  deepspeed==0.13.1
9
  pydantic==2.6.3
10
  addict
 
40
  # adlfs
41
 
42
  trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
 
src/axolotl/core/policies/__init__.py DELETED
File without changes
src/axolotl/core/policies/auto_wrap.py DELETED
@@ -1,55 +0,0 @@
1
- """module for building the auto wrap policy for FSDP"""
2
- import functools
3
-
4
- from peft import PrefixEncoder, PromptEmbedding, PromptEncoder
5
- from torch.distributed.fsdp.wrap import (
6
- _or_policy,
7
- lambda_auto_wrap_policy,
8
- transformer_auto_wrap_policy,
9
- )
10
- from transformers.models.llama.modeling_llama import LlamaDecoderLayer
11
- from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
12
- from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
13
-
14
- SUPPORTED_AUTO_WRAP_MODEL_TYPES = [
15
- "llama",
16
- "mistral",
17
- "mixtral",
18
- ]
19
-
20
-
21
- def get_wrapping_policy_factory(model_type):
22
- if model_type == "llama":
23
- layer_to_wrap = LlamaDecoderLayer
24
- elif model_type == "mistral":
25
- layer_to_wrap = MistralDecoderLayer
26
- elif model_type == "mixtral":
27
- layer_to_wrap = MixtralDecoderLayer
28
-
29
- def get_wrapping_policy():
30
- """This checks for lora layers (has weight and requires_grad)"""
31
-
32
- def lambda_policy_fn(module):
33
- return (
34
- len(list(module.named_children())) == 0
35
- and getattr(module, "weight", None) is not None
36
- and module.weight.requires_grad
37
- )
38
-
39
- lambda_policy = functools.partial(
40
- lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn
41
- )
42
- transformer_layer_name = layer_to_wrap
43
- transformer_wrap_policy = functools.partial(
44
- transformer_auto_wrap_policy,
45
- transformer_layer_cls=(
46
- PrefixEncoder,
47
- PromptEncoder,
48
- PromptEmbedding,
49
- transformer_layer_name,
50
- ),
51
- )
52
- policies = [lambda_policy, transformer_wrap_policy]
53
- return functools.partial(_or_policy, policies=policies)
54
-
55
- return get_wrapping_policy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/core/trainer_builder.py CHANGED
@@ -8,7 +8,6 @@ import importlib
8
  import importlib.util
9
  import logging
10
  import math
11
- import os
12
  import sys
13
  from abc import abstractmethod
14
  from collections import defaultdict
@@ -19,10 +18,7 @@ from typing import Dict, List, Literal, Optional, Type, Union
19
 
20
  import torch
21
  import transformers
22
- from accelerate import FullyShardedDataParallelPlugin
23
- from accelerate.utils import str_to_bool
24
  from datasets import Dataset
25
- from torch.distributed.fsdp import MixedPrecision
26
  from torch.optim.lr_scheduler import OneCycleLR
27
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
28
  from transformers import (
@@ -35,7 +31,6 @@ from transformers.trainer_utils import seed_worker
35
  from transformers.utils import is_sagemaker_mp_enabled
36
  from trl import DPOTrainer
37
 
38
- from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
39
  from axolotl.loraplus import create_loraplus_optimizer
40
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
41
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -591,51 +586,14 @@ class AxolotlTrainer(Trainer):
591
 
592
  @wraps(Trainer.create_accelerator_and_postprocess)
593
  def create_accelerator_and_postprocess(self):
594
- rank = int(os.environ.get("LOCAL_RANK", 0))
595
  res = super().create_accelerator_and_postprocess()
596
 
597
- if self.args.qlora is False:
598
- return res
599
-
600
- # the rest of this method override is specific to fsdp + qlora (for now)
601
- sync_module_states = (
602
- str_to_bool(os.environ.get("FSDP_SYNC_MODULE_STATES", "True")) == 1
603
- )
604
-
605
- mp_policy = None
606
- amp = os.environ["ACCELERATE_MIXED_PRECISION"]
607
- if amp == "fp16":
608
- mp_policy = MixedPrecision(
609
- param_dtype=torch.float32,
610
- reduce_dtype=torch.float32,
611
- buffer_dtype=torch.float32,
612
- )
613
- elif amp == "bf16":
614
- mp_policy = MixedPrecision(
615
- param_dtype=torch.float32,
616
- reduce_dtype=torch.float32,
617
- buffer_dtype=torch.float32,
618
- )
619
-
620
- # If somehow we figure out how we want to parameterize we want to autocast buffers...
621
- # mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32)
622
- # load_param_skip_names = ['inv_freq']
623
-
624
  if self.is_fsdp_enabled:
625
- wrapping_policy = get_wrapping_policy_factory(self.args.model_type)
626
- fsdp_plugin = FullyShardedDataParallelPlugin(
627
- auto_wrap_policy=wrapping_policy(),
628
- cpu_offload=False,
629
- use_orig_params=False,
630
- limit_all_gathers=True,
631
- param_init_fn=lambda module: module.to_empty(
632
- device=torch.device("cuda"), recurse=False
633
- )
634
- if (rank != 0 and sync_module_states)
635
- else None,
636
- mixed_precision_policy=mp_policy,
637
- )
638
- self.accelerator.state.fsdp_plugin = fsdp_plugin
639
 
640
  return res
641
 
 
8
  import importlib.util
9
  import logging
10
  import math
 
11
  import sys
12
  from abc import abstractmethod
13
  from collections import defaultdict
 
18
 
19
  import torch
20
  import transformers
 
 
21
  from datasets import Dataset
 
22
  from torch.optim.lr_scheduler import OneCycleLR
23
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
24
  from transformers import (
 
31
  from transformers.utils import is_sagemaker_mp_enabled
32
  from trl import DPOTrainer
33
 
 
34
  from axolotl.loraplus import create_loraplus_optimizer
35
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
36
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
 
586
 
587
  @wraps(Trainer.create_accelerator_and_postprocess)
588
  def create_accelerator_and_postprocess(self):
 
589
  res = super().create_accelerator_and_postprocess()
590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
  if self.is_fsdp_enabled:
592
+ if (
593
+ "limit_all_gathers" in self.args.fsdp_config
594
+ and self.args.fsdp_config["limit_all_gathers"]
595
+ ):
596
+ self.accelerator.state.fsdp_plugin.limit_all_gathers = True
 
 
 
 
 
 
 
 
 
597
 
598
  return res
599
 
src/axolotl/utils/models.py CHANGED
@@ -5,16 +5,14 @@ import logging
5
  import math
6
  import os
7
  import types
8
- from typing import Any, Dict, List, Optional, Tuple, Type, Union # noqa: F401
9
 
10
  import addict
11
  import bitsandbytes as bnb
12
- import safetensors
13
  import torch
14
  import transformers
15
  from accelerate import init_empty_weights
16
- from bitsandbytes.nn import Linear4bit, Params4bit
17
- from fastcore.parallel import parallel
18
  from peft import (
19
  LoftQConfig,
20
  PeftConfig,
@@ -23,7 +21,7 @@ from peft import (
23
  prepare_model_for_kbit_training,
24
  )
25
  from peft.tuners.lora import QuantLinear
26
- from torch import Tensor, nn
27
  from transformers import ( # noqa: F401
28
  AddedToken,
29
  AutoConfig,
@@ -35,9 +33,7 @@ from transformers import ( # noqa: F401
35
  PreTrainedTokenizerBase,
36
  )
37
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
38
- from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub
39
 
40
- from axolotl.core.policies.auto_wrap import SUPPORTED_AUTO_WRAP_MODEL_TYPES
41
  from axolotl.models.mamba import fix_mamba_attn_for_loss
42
  from axolotl.monkeypatch.multipack import (
43
  SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -272,117 +268,6 @@ def load_tokenizer(cfg):
272
  return tokenizer
273
 
274
 
275
- def replace_linear(
276
- model: nn.Module,
277
- linear_replacement: Type[nn.Module],
278
- quant_config: Union[dict, None] = None,
279
- skip_modules=None,
280
- **kwargs,
281
- ):
282
- """
283
- Replace linear modules with a new Linear module.
284
- Parameters:
285
- model (`torch.nn.Module`):
286
- Input model or `torch.nn.Module` as the function is run recursively.
287
- linear_replacement (`torch.nn.Module`):
288
- The linear module that replaces the old one. Only expects standard arguments.
289
- If other arguments need to be passed, use a lambda.
290
- skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
291
- List of modules names not to convert. Defaults to `lm_head`.
292
- """
293
- if skip_modules is None:
294
- skip_modules = ["lm_head"]
295
- for name, module in model.named_children():
296
- if len(list(module.children())) > 0:
297
- replace_linear(
298
- module, linear_replacement, quant_config, skip_modules, **kwargs
299
- )
300
-
301
- if isinstance(module, torch.nn.Linear) and name not in skip_modules:
302
- if issubclass(linear_replacement, Linear4bit):
303
- model._modules[ # pylint: disable=protected-access
304
- name
305
- ] = linear_replacement(
306
- module.in_features,
307
- module.out_features,
308
- module.bias is not None,
309
- **kwargs,
310
- )
311
- else:
312
- raise ValueError(
313
- f"Unsupported linear replacement: {type(linear_replacement)}"
314
- )
315
- return model
316
-
317
-
318
- def load_and_quantize(
319
- module: nn.Module,
320
- name: str,
321
- value: Tensor,
322
- device: torch.device = None,
323
- dtype: torch.dtype = None,
324
- skip_names: Optional[List[str]] = None,
325
- is_meta_rank: bool = False,
326
- low_memory: bool = True,
327
- verbose: bool = False,
328
- quant_method: str = "bnb",
329
- ):
330
- """
331
- Loads `value` tensor into submodule of `module`, optionally skipping `skip_names` and converting to `dtype`.
332
-
333
- Quantizes `Params4bit` on `device` then places on "cpu" if low_memory=True or "meta" if is_meta_rank=True.
334
- """
335
-
336
- if skip_names is None:
337
- skip_names = []
338
-
339
- def place_on_device(value):
340
- if is_meta_rank:
341
- device = "meta"
342
- elif low_memory:
343
- device = "cpu"
344
- else:
345
- device = "cuda"
346
- return value.to(device=device, dtype=dtype)
347
-
348
- if any(skip_name in name for skip_name in skip_names):
349
- if verbose:
350
- print(f"Skipping {name} because it is in skip_names")
351
- return
352
-
353
- module_key, _, value_key = name.rpartition(".")
354
- try:
355
- submodule = module.get_submodule(module_key)
356
- except AttributeError as exc:
357
- print(f"Module {module_key} not found:\n{exc}")
358
- return
359
-
360
- try:
361
- if quant_method == "bnb":
362
- param = submodule.get_parameter(value_key)
363
- if isinstance(param, Params4bit):
364
- # With `sync_module_states=True`, a meta device Params4bit needs to be the same
365
- # shape as the quantized Params4bit with an initialized quant_state. However,
366
- # FSDP only syncs parameters and buffers, so the quant_state isn't copied. This
367
- # workaround quantizes Params4bit to initialize quant_state on all ranks, then
368
- # replaces Params4bit's data with a meta tensor to free memory on non-rank 0.
369
- value = type(param)(
370
- value.to(device=device, dtype=dtype).data, **param.__dict__
371
- ).cuda(device)
372
- if is_meta_rank:
373
- value = type(param)(value.data.to("meta"), **value.__dict__)
374
- elif low_memory:
375
- value = type(param)(value.data.to("cpu"), **value.__dict__)
376
- else:
377
- value = type(param)(place_on_device(value).data)
378
-
379
- except AttributeError:
380
- # it's a buffer
381
- value = place_on_device(value)
382
-
383
- setattr(submodule, value_key, value)
384
-
385
-
386
  def load_model(
387
  cfg: DictDefault,
388
  tokenizer: PreTrainedTokenizerBase,
@@ -568,6 +453,7 @@ def load_model(
568
  "bnb_4bit_compute_dtype": cfg.torch_dtype,
569
  "bnb_4bit_use_double_quant": True,
570
  "bnb_4bit_quant_type": "nf4",
 
571
  }
572
 
573
  if cfg.bnb_config_kwargs:
@@ -617,78 +503,10 @@ def load_model(
617
  model_kwargs["attn_implementation"] = "eager"
618
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
619
 
620
- qlora_fsdp = (
621
- cfg.fsdp
622
- and cfg.adapter == "qlora"
623
- and model_config.model_type in SUPPORTED_AUTO_WRAP_MODEL_TYPES
624
- )
625
 
626
  try:
627
- if qlora_fsdp:
628
- if cfg.bf16 or cfg.bfloat16:
629
- torch_dtype, compute_dtype = torch.float32, torch.bfloat16
630
- elif cfg.fp16 or cfg.float16:
631
- torch_dtype, compute_dtype = torch.float32, torch.float16
632
- else:
633
- torch_dtype, compute_dtype = torch.float32, torch.float16
634
-
635
- with init_empty_weights():
636
- LOG.info("Loading model with empty weights.")
637
- model = AutoModelForCausalLM.from_config(model_config)
638
- model.model = replace_linear(
639
- model.model,
640
- Linear4bit,
641
- compute_dtype=compute_dtype,
642
- quant_type="nf4",
643
- quant_storage=torch_dtype,
644
- )
645
-
646
- model.is_loaded_in_4bit = True
647
-
648
- # Grab the safetensors files that hold the weights
649
- try:
650
- idx = hub.cached_file(base_model, SAFE_WEIGHTS_INDEX_NAME)
651
- files, _ = hub.get_checkpoint_shard_files(base_model, idx)
652
- except OSError:
653
- try:
654
- # This means the model doesn't have a model.safetensors.index.json because it is not sharded
655
- files = []
656
- files.append(hub.cached_file(base_model, SAFE_WEIGHTS_NAME))
657
- except OSError as exc:
658
- # This means the model probably doesn't have a safetensors file
659
- raise exc
660
-
661
- # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly
662
- # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage
663
- def load_and_quantize_parallel(name_param, model, **kwargs):
664
- name, param = name_param
665
- load_and_quantize(model, name, param, **kwargs)
666
-
667
- param_count = sum((p.numel() for n, p in model.named_parameters()))
668
- for filename in files:
669
- weights = safetensors.torch.load_file(filename)
670
- quant_method = "bnb"
671
- devprops = torch.cuda.get_device_properties(torch.cuda.current_device())
672
- left = int(os.cpu_count() / torch.cuda.device_count())
673
- right = int(
674
- 8 * (devprops.total_memory / 1e9 / 40) * (70 / (param_count / 1e9))
675
- )
676
- n_workers = min(left, right)
677
- parallel(
678
- load_and_quantize_parallel,
679
- weights.items(),
680
- n_workers=n_workers,
681
- threadpool=True,
682
- model=model,
683
- dtype=torch_dtype,
684
- device=cfg.local_rank,
685
- skip_names=[],
686
- is_meta_rank=(cfg.local_rank != 0),
687
- verbose=False,
688
- quant_method=quant_method,
689
- )
690
-
691
- elif (
692
  model_config.model_type == "llama"
693
  and not cfg.trust_remote_code
694
  and not cfg.gptq
@@ -715,32 +533,6 @@ def load_model(
715
  if cfg.flash_attn_fuse_qkv:
716
  LOG.info("patching with fused QKV")
717
  replace_llama_qkv_with_fused(model)
718
- # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
719
- # This is a WIP, still an issue with the backward pass
720
- # RuntimeError: grad can be implicitly created only for scalar outputs
721
- # TODO: try config.sequence_parallel = False
722
- # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12
723
- # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components
724
- # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442
725
- # from flash_attn.utils.pretrained import state_dict_from_pretrained
726
- # from flash_attn.models.gpt import GPTLMHeadModel
727
- # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config
728
- # from transformers import GPTNeoXConfig
729
- # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model))
730
- # config.use_flash_attn = True
731
- # config.fused_bias_fc = True
732
- # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast"
733
- # config.activation_function = "gelu_fast"
734
- # config.fused_dropout_add_ln = True
735
- # # config.residual_in_fp32 = True
736
- #
737
- # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained(
738
- # base_model,
739
- # config,
740
- # dtype=torch_dtype,
741
- # device=cfg.device,
742
- # )
743
- # model.train() # sets to train instead of eval mode
744
  elif model_type == "MambaLMHeadModel":
745
  # FIXME this is janky at best and hacked together to make it work
746
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
 
5
  import math
6
  import os
7
  import types
8
+ from typing import Any, Dict, Optional, Tuple, Union # noqa: F401
9
 
10
  import addict
11
  import bitsandbytes as bnb
 
12
  import torch
13
  import transformers
14
  from accelerate import init_empty_weights
15
+ from bitsandbytes.nn import Params4bit
 
16
  from peft import (
17
  LoftQConfig,
18
  PeftConfig,
 
21
  prepare_model_for_kbit_training,
22
  )
23
  from peft.tuners.lora import QuantLinear
24
+ from torch import nn
25
  from transformers import ( # noqa: F401
26
  AddedToken,
27
  AutoConfig,
 
33
  PreTrainedTokenizerBase,
34
  )
35
  from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
 
36
 
 
37
  from axolotl.models.mamba import fix_mamba_attn_for_loss
38
  from axolotl.monkeypatch.multipack import (
39
  SUPPORTED_MULTIPACK_MODEL_TYPES,
 
268
  return tokenizer
269
 
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  def load_model(
272
  cfg: DictDefault,
273
  tokenizer: PreTrainedTokenizerBase,
 
453
  "bnb_4bit_compute_dtype": cfg.torch_dtype,
454
  "bnb_4bit_use_double_quant": True,
455
  "bnb_4bit_quant_type": "nf4",
456
+ "bnb_4bit_quant_storage": torch.bfloat16,
457
  }
458
 
459
  if cfg.bnb_config_kwargs:
 
503
  model_kwargs["attn_implementation"] = "eager"
504
  model_config._attn_implementation = "eager" # pylint: disable=protected-access
505
 
506
+ qlora_fsdp = cfg.fsdp and cfg.adapter == "qlora"
 
 
 
 
507
 
508
  try:
509
+ if (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  model_config.model_type == "llama"
511
  and not cfg.trust_remote_code
512
  and not cfg.gptq
 
533
  if cfg.flash_attn_fuse_qkv:
534
  LOG.info("patching with fused QKV")
535
  replace_llama_qkv_with_fused(model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  elif model_type == "MambaLMHeadModel":
537
  # FIXME this is janky at best and hacked together to make it work
538
  MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
src/axolotl/utils/trainer.py CHANGED
@@ -304,6 +304,10 @@ def setup_fsdp_envs(cfg):
304
  os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
305
  if cfg.fsdp_config.fsdp_sync_module_states:
306
  os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
 
 
 
 
307
  if cfg.fsdp_config.fsdp_state_dict_type:
308
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
309
  if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
 
304
  os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
305
  if cfg.fsdp_config.fsdp_sync_module_states:
306
  os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
307
+ if cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
308
+ os.environ["FSDP_CPU_RAM_EFFICIENT_LOADING"] = "true"
309
+ if cfg.fsdp_config.fsdp_use_orig_params:
310
+ os.environ["FSDP_USE_ORIG_PARAMS"] = "true"
311
  if cfg.fsdp_config.fsdp_state_dict_type:
312
  os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
313
  if cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap:
tests/e2e/test_mixtral.py CHANGED
@@ -77,7 +77,7 @@ class TestMixtral(unittest.TestCase):
77
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
78
  assert (
79
  model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
80
- == torch.uint8
81
  )
82
  assert (Path(temp_dir) / "adapter_model.bin").exists()
83
 
@@ -131,7 +131,7 @@ class TestMixtral(unittest.TestCase):
131
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
132
  assert (
133
  model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
134
- == torch.uint8
135
  )
136
  assert (Path(temp_dir) / "adapter_model.bin").exists()
137
 
 
77
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
78
  assert (
79
  model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
80
+ == torch.float32
81
  )
82
  assert (Path(temp_dir) / "adapter_model.bin").exists()
83
 
 
131
  model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
132
  assert (
133
  model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
134
+ == torch.float32
135
  )
136
  assert (Path(temp_dir) / "adapter_model.bin").exists()
137