winglian tmm1 commited on
Commit
0ddfb24
·
unverified ·
1 Parent(s): 89134f2

LISA (#1469)

Browse files

* add lisa support

* fix default and fix attribute traversal for layers

* improve lisa callback logging

* fix LISA by ensuring params are not frozen during __init__

* example config for lisa

---------

Co-authored-by: Aman Karmani <[email protected]>

examples/llama-2/lisa.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: NousResearch/Llama-2-7b-hf
2
+ model_type: LlamaForCausalLM
3
+ tokenizer_type: LlamaTokenizer
4
+
5
+ load_in_8bit: false
6
+ load_in_4bit: false
7
+ strict: false
8
+
9
+ datasets:
10
+ - path: teknium/GPT4-LLM-Cleaned
11
+ type: alpaca
12
+ dataset_prepared_path: last_run_prepared
13
+ val_set_size: 0.05
14
+ output_dir: ./lisa-out
15
+
16
+ sequence_len: 4096
17
+ sample_packing: true
18
+ pad_to_sequence_len: true
19
+
20
+ adapter:
21
+ lora_model_dir:
22
+ lora_r:
23
+ lora_alpha:
24
+ lora_dropout:
25
+ lora_target_linear:
26
+ lora_fan_in_fan_out:
27
+
28
+ lisa_n_layers: 4
29
+ lisa_step_interval: 20
30
+ lisa_layers_attribute: model.layers
31
+
32
+ wandb_project:
33
+ wandb_entity:
34
+ wandb_watch:
35
+ wandb_name:
36
+ wandb_log_model:
37
+
38
+ gradient_accumulation_steps: 2
39
+ micro_batch_size: 1
40
+ num_epochs: 1
41
+ optimizer: adamw_bnb_8bit
42
+ lr_scheduler: cosine
43
+ learning_rate: 5e-5 # recommendation from lisa paper for 7b
44
+
45
+ train_on_inputs: false
46
+ group_by_length: false
47
+ bf16: auto
48
+ fp16:
49
+ tf32: false
50
+
51
+ gradient_checkpointing: true
52
+ early_stopping_patience:
53
+ resume_from_checkpoint:
54
+ local_rank:
55
+ logging_steps: 1
56
+ xformers_attention:
57
+ flash_attention: true
58
+ flash_attn_cross_entropy: false
59
+ flash_attn_rms_norm: true
60
+ flash_attn_fuse_qkv: false
61
+ flash_attn_fuse_mlp: true
62
+
63
+ warmup_steps: 100
64
+ evals_per_epoch: 4
65
+ eval_table_size:
66
+ saves_per_epoch: 1
67
+ debug:
68
+ deepspeed:
69
+ weight_decay: 0.1
70
+ fsdp:
71
+ fsdp_config:
72
+ special_tokens:
73
+ bos_token: "<s>"
74
+ eos_token: "</s>"
75
+ unk_token: "<unk>"
src/axolotl/core/trainer_builder.py CHANGED
@@ -45,6 +45,7 @@ from axolotl.utils.callbacks import (
45
  causal_lm_bench_eval_callback_factory,
46
  log_prediction_callback_factory,
47
  )
 
48
  from axolotl.utils.collators import (
49
  BatchSamplerDataCollatorForSeq2Seq,
50
  DataCollatorForSeq2Seq,
@@ -200,6 +201,18 @@ class AxolotlTrainingArguments(TrainingArguments):
200
  orpo_alpha: Optional[float] = field(
201
  default=None,
202
  )
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
 
205
  class AxolotlTrainer(Trainer):
@@ -938,6 +951,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
938
  )
939
  callbacks.append(early_stop_cb)
940
 
 
 
941
  return callbacks
942
 
943
  def _get_trainer_cls(self):
@@ -1229,6 +1244,15 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1229
  "relora_prune_ratio"
1230
  ] = self.cfg.relora_prune_ratio
1231
 
 
 
 
 
 
 
 
 
 
1232
  training_arguments_kwargs = self.hook_pre_create_training_args(
1233
  training_arguments_kwargs
1234
  )
 
45
  causal_lm_bench_eval_callback_factory,
46
  log_prediction_callback_factory,
47
  )
48
+ from axolotl.utils.callbacks.lisa import lisa_callback_factory
49
  from axolotl.utils.collators import (
50
  BatchSamplerDataCollatorForSeq2Seq,
51
  DataCollatorForSeq2Seq,
 
201
  orpo_alpha: Optional[float] = field(
202
  default=None,
203
  )
204
+ lisa_n_layers: Optional[int] = field(
205
+ default=None,
206
+ metadata={"help": "the number of activate layers in LISA"},
207
+ )
208
+ lisa_step_interval: Optional[int] = field(
209
+ default=None,
210
+ metadata={"help": "how often to switch layers in LISA"},
211
+ )
212
+ lisa_layers_attribute: Optional[str] = field(
213
+ default=None,
214
+ metadata={"help": "path under the model to access the layers"},
215
+ )
216
 
217
 
218
  class AxolotlTrainer(Trainer):
 
951
  )
952
  callbacks.append(early_stop_cb)
953
 
954
+ if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
955
+ callbacks.append(lisa_callback_factory(trainer))
956
  return callbacks
957
 
958
  def _get_trainer_cls(self):
 
1244
  "relora_prune_ratio"
1245
  ] = self.cfg.relora_prune_ratio
1246
 
1247
+ if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
1248
+ training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
1249
+ training_arguments_kwargs[
1250
+ "lisa_step_interval"
1251
+ ] = self.cfg.lisa_step_interval
1252
+ training_arguments_kwargs[
1253
+ "lisa_layers_attribute"
1254
+ ] = self.cfg.lisa_layers_attribute
1255
+
1256
  training_arguments_kwargs = self.hook_pre_create_training_args(
1257
  training_arguments_kwargs
1258
  )
src/axolotl/utils/callbacks/lisa.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for LISA
3
+
4
+ Adapted from https://github.com/OptimalScale/LMFlow/pull/701 for HF transformers & Axolotl
5
+ Arxiv: https://arxiv.org/abs/2403.17919
6
+ License: Apache 2.0
7
+ """
8
+
9
+ import logging
10
+ from functools import reduce
11
+ from typing import TYPE_CHECKING
12
+
13
+ import numpy as np
14
+ from transformers import TrainerCallback
15
+
16
+ if TYPE_CHECKING:
17
+ from axolotl.core.trainer_builder import AxolotlTrainer
18
+
19
+ LOG = logging.getLogger("axolotl.callbacks.lisa")
20
+
21
+
22
+ def lisa_callback_factory(trainer: "AxolotlTrainer"):
23
+ class LISACallback(TrainerCallback):
24
+ """trainer callback for lisa layer switching"""
25
+
26
+ def __init__(
27
+ self, n_layers, step_interval, trainer, layers_attribute="model.layers"
28
+ ):
29
+ super().__init__()
30
+ self.n_layers = n_layers
31
+ self.step_interval = step_interval
32
+ self.layers_attribute = layers_attribute
33
+ self.trainer = trainer
34
+
35
+ reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
36
+
37
+ self.total_layers = len(
38
+ reduce(getattr, self.layers_attribute.split("."), self.trainer.model)
39
+ )
40
+ self.active_layers_indices = []
41
+
42
+ layers = reduce(
43
+ getattr, self.layers_attribute.split("."), self.trainer.model
44
+ )
45
+ LOG.info(
46
+ f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
47
+ )
48
+
49
+ def freeze_all_layers(self):
50
+ layers = reduce(
51
+ getattr, self.layers_attribute.split("."), self.trainer.model
52
+ )
53
+ for layer in layers:
54
+ for param in layer.parameters():
55
+ param.requires_grad = False
56
+
57
+ def on_step_begin(
58
+ self, args, state, control, **kwargs
59
+ ): # pylint: disable=unused-argument
60
+ # Check if it's time to switch active layers, including at step 0
61
+ if state.global_step % self.step_interval == 0 or state.global_step == 1:
62
+ self.switch_active_layers()
63
+
64
+ def switch_active_layers(self):
65
+ # First, disable gradients for all layers
66
+ self.freeze_all_layers()
67
+
68
+ # Randomly select n_layers to activate
69
+ layers = reduce(
70
+ getattr, self.layers_attribute.split("."), self.trainer.model
71
+ )
72
+ self.active_layers_indices = np.random.choice(
73
+ range(self.total_layers), self.n_layers, replace=False
74
+ )
75
+ LOG.info(
76
+ f"Activating layers at indices: {self.active_layers_indices} for the next steps."
77
+ )
78
+
79
+ # Enable gradients only for the selected layers
80
+ for idx in self.active_layers_indices:
81
+ for param in layers[idx].parameters():
82
+ param.requires_grad = True
83
+
84
+ lisa_callback = LISACallback(
85
+ n_layers=trainer.args.lisa_n_layers,
86
+ step_interval=trainer.args.lisa_step_interval,
87
+ trainer=trainer,
88
+ layers_attribute=trainer.args.lisa_layers_attribute,
89
+ )
90
+
91
+ return lisa_callback
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -370,6 +370,23 @@ class MLFlowConfig(BaseModel):
370
  hf_mlflow_log_artifacts: Optional[bool] = None
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  class WandbConfig(BaseModel):
374
  """wandb configuration subset"""
375
 
@@ -404,6 +421,7 @@ class AxolotlInputConfig(
404
  HyperparametersConfig,
405
  WandbConfig,
406
  MLFlowConfig,
 
407
  RemappedParameters,
408
  DeprecatedParameters,
409
  BaseModel,
 
370
  hf_mlflow_log_artifacts: Optional[bool] = None
371
 
372
 
373
+ class LISAConfig(BaseModel):
374
+ """LISA options"""
375
+
376
+ lisa_n_layers: Optional[int] = Field(
377
+ default=None,
378
+ metadata={"help": "the number of activate layers in LISA"},
379
+ )
380
+ lisa_step_interval: Optional[int] = Field(
381
+ default=None,
382
+ metadata={"help": "how often to switch layers in LISA"},
383
+ )
384
+ lisa_layers_attribute: Optional[str] = Field(
385
+ default="model.layers",
386
+ metadata={"help": "path under the model to access the layers"},
387
+ )
388
+
389
+
390
  class WandbConfig(BaseModel):
391
  """wandb configuration subset"""
392
 
 
421
  HyperparametersConfig,
422
  WandbConfig,
423
  MLFlowConfig,
424
+ LISAConfig,
425
  RemappedParameters,
426
  DeprecatedParameters,
427
  BaseModel,