winglian commited on
Commit
5ea3aa3
·
unverified ·
1 Parent(s): f1f60cb

Fix Deepspeed loading (#950)

Browse files

* add check for zero3

* freeze parameters

* fixes for deepspeed loading

* fix model parameter check

* unfrozen parameters in example mixtral and logging when unfreezing

deepspeed/zero3_bf16.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 3,
4
+ "overlap_comm": true,
5
+ "contiguous_gradients": true,
6
+ "sub_group_size": 0,
7
+ "reduce_bucket_size": "auto",
8
+ "stage3_prefetch_bucket_size": "auto",
9
+ "stage3_param_persistence_threshold": "auto",
10
+ "stage3_max_live_parameters": 0,
11
+ "stage3_max_reuse_distance": 0,
12
+ "stage3_gather_16bit_weights_on_model_save": true
13
+ },
14
+ "bf16": {
15
+ "enabled": true
16
+ },
17
+ "fp16": {
18
+ "enabled": "auto",
19
+ "auto_cast": false,
20
+ "loss_scale": 0,
21
+ "initial_scale_power": 32,
22
+ "loss_scale_window": 1000,
23
+ "hysteresis": 2,
24
+ "min_loss_scale": 1
25
+ },
26
+ "optimizer": {
27
+ "type": "AdamW",
28
+ "params": {
29
+ "lr": "auto",
30
+ "betas": "auto",
31
+ "eps": "auto",
32
+ "weight_decay": "auto"
33
+ }
34
+ },
35
+ "gradient_accumulation_steps": "auto",
36
+ "train_batch_size": "auto",
37
+ "train_micro_batch_size_per_gpu": "auto",
38
+ "wall_clock_breakdown": false
39
+ }
examples/mistral/mixtral.yml CHANGED
@@ -14,6 +14,15 @@ dataset_prepared_path: last_run_prepared
14
  val_set_size: 0.0
15
  output_dir: ./qlora-out
16
 
 
 
 
 
 
 
 
 
 
17
  adapter: qlora
18
  lora_model_dir:
19
 
 
14
  val_set_size: 0.0
15
  output_dir: ./qlora-out
16
 
17
+ ## You can optionally freeze the entire model and unfreeze a subset of parameters
18
+ unfrozen_parameters:
19
+ # - lm_head.*
20
+ # - model.embed_tokens.*
21
+ # - model.layers.2[0-9]+.block_sparse_moe.gate.*
22
+ # - model.layers.2[0-9]+.block_sparse_moe.experts.*
23
+ # - model.layers.3[0-9]+.block_sparse_moe.gate.*
24
+ # - model.layers.3[0-9]+.block_sparse_moe.experts.*
25
+
26
  adapter: qlora
27
  lora_model_dir:
28
 
src/axolotl/cli/train.py CHANGED
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
22
 
23
  def do_cli(config: Path = Path("examples/"), **kwargs):
24
  # pylint: disable=duplicate-code
25
- print_axolotl_text_art()
26
  parsed_cfg = load_cfg(config, **kwargs)
 
27
  check_accelerate_default_config()
28
  check_user_token()
29
  parser = transformers.HfArgumentParser((TrainerCliArgs))
 
22
 
23
  def do_cli(config: Path = Path("examples/"), **kwargs):
24
  # pylint: disable=duplicate-code
 
25
  parsed_cfg = load_cfg(config, **kwargs)
26
+ print_axolotl_text_art()
27
  check_accelerate_default_config()
28
  check_user_token()
29
  parser = transformers.HfArgumentParser((TrainerCliArgs))
src/axolotl/train.py CHANGED
@@ -18,6 +18,7 @@ from axolotl.common.cli import TrainerCliArgs
18
  from axolotl.logging_config import configure_logging
19
  from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
 
21
  from axolotl.utils.models import load_model, load_tokenizer
22
  from axolotl.utils.trainer import setup_trainer
23
 
@@ -78,6 +79,9 @@ def train(
78
  )
79
  resume_from_checkpoint = cfg.resume_from_checkpoint
80
 
 
 
 
81
  trainer = setup_trainer(
82
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
83
  )
 
18
  from axolotl.logging_config import configure_logging
19
  from axolotl.monkeypatch import neft_embeddings
20
  from axolotl.utils.dict import DictDefault
21
+ from axolotl.utils.freeze import freeze_parameters_except
22
  from axolotl.utils.models import load_model, load_tokenizer
23
  from axolotl.utils.trainer import setup_trainer
24
 
 
79
  )
80
  resume_from_checkpoint = cfg.resume_from_checkpoint
81
 
82
+ if cfg.unfrozen_parameters:
83
+ freeze_parameters_except(model, cfg.unfrozen_parameters)
84
+
85
  trainer = setup_trainer(
86
  cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
87
  )
src/axolotl/utils/freeze.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module to freeze/unfreeze parameters by name
3
+ """
4
+ import logging
5
+ import re
6
+
7
+ from axolotl.utils.distributed import is_main_process
8
+
9
+ LOG = logging.getLogger("axolotl.utils.freeze")
10
+
11
+
12
+ def freeze_parameters_except(model, regex_patterns):
13
+ """
14
+ Freezes all layers of the given model except for the layers that match given regex patterns.
15
+ Periods in the patterns are treated as literal periods, not as wildcard characters.
16
+
17
+ Parameters:
18
+ - model (nn.Module): The PyTorch model to be modified.
19
+ - regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
20
+
21
+ Returns:
22
+ None; the model is modified in place.
23
+ """
24
+ # Escape periods and compile the regex patterns
25
+ compiled_patterns = [
26
+ re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
27
+ ]
28
+
29
+ # First, freeze all parameters in the model
30
+ for param in model.parameters():
31
+ param.requires_grad = False
32
+
33
+ # Unfreeze layers that match the regex patterns
34
+ for name, param in model.named_parameters():
35
+ if any(pattern.match(name) for pattern in compiled_patterns):
36
+ if is_main_process():
37
+ LOG.debug(f"unfreezing {name}")
38
+ param.requires_grad = True
src/axolotl/utils/models.py CHANGED
@@ -21,6 +21,7 @@ from transformers import ( # noqa: F401
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
23
  )
 
24
 
25
  from axolotl.models.mamba import fix_mamba_attn_for_loss
26
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
@@ -285,6 +286,9 @@ def load_model(
285
  model_kwargs["max_memory"] = cfg.max_memory
286
  model_kwargs["torch_dtype"] = cfg.torch_dtype
287
 
 
 
 
288
  if cfg.model_revision:
289
  model_kwargs["revision"] = cfg.model_revision
290
  if cfg.gptq:
 
21
  PreTrainedModel,
22
  PreTrainedTokenizerBase,
23
  )
24
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
25
 
26
  from axolotl.models.mamba import fix_mamba_attn_for_loss
27
  from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
 
286
  model_kwargs["max_memory"] = cfg.max_memory
287
  model_kwargs["torch_dtype"] = cfg.torch_dtype
288
 
289
+ if is_deepspeed_zero3_enabled():
290
+ del model_kwargs["device_map"]
291
+
292
  if cfg.model_revision:
293
  model_kwargs["revision"] = cfg.model_revision
294
  if cfg.gptq:
src/axolotl/utils/trainer.py CHANGED
@@ -276,6 +276,7 @@ def prepare_optim_env(cfg):
276
  setup_fsdp_envs(cfg)
277
  elif cfg.deepspeed:
278
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
 
279
 
280
 
281
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
 
276
  setup_fsdp_envs(cfg)
277
  elif cfg.deepspeed:
278
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
279
+ os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
280
 
281
 
282
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):