winglian commited on
Commit
2752d5f
·
unverified ·
1 Parent(s): 9e300ac

multipack for gemma (#1313)

Browse files

* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling

examples/gemma/qlora.yml CHANGED
@@ -1,49 +1,49 @@
1
  # use google/gemma-7b if you have access
2
- base_model: mhenrichsen/gemma-7b
3
  model_type: AutoModelForCausalLM
4
  tokenizer_type: AutoTokenizer
5
-
6
  load_in_8bit: false
7
  load_in_4bit: true
8
  strict: false
9
-
10
  # huggingface repo
11
  datasets:
12
  - path: mhenrichsen/alpaca_2k_test
13
  type: alpaca
14
  val_set_size: 0.1
15
  output_dir: ./out
16
-
17
  adapter: qlora
18
  lora_r: 32
19
  lora_alpha: 16
20
  lora_dropout: 0.05
21
  lora_target_linear: true
22
-
23
  sequence_len: 4096
24
  sample_packing: false
25
  pad_to_sequence_len: false
26
-
27
  wandb_project:
28
  wandb_entity:
29
  wandb_watch:
30
  wandb_name:
31
  wandb_log_model:
32
-
33
-
34
  gradient_accumulation_steps: 3
35
  micro_batch_size: 2
36
  num_epochs: 4
37
  optimizer: adamw_bnb_8bit
38
  lr_scheduler: cosine
39
  learning_rate: 0.0002
40
-
41
  train_on_inputs: false
42
  group_by_length: false
43
  bf16: auto
44
  fp16:
45
  tf32: false
46
-
47
  gradient_checkpointing: true
48
  early_stopping_patience:
49
  resume_from_checkpoint:
@@ -51,7 +51,7 @@ local_rank:
51
  logging_steps: 1
52
  xformers_attention:
53
  flash_attention: true
54
-
55
  warmup_ratio: 0.1
56
  evals_per_epoch: 4
57
  eval_table_size:
 
1
  # use google/gemma-7b if you have access
2
+ base_model: mhenrichsen/gemma-7b
3
  model_type: AutoModelForCausalLM
4
  tokenizer_type: AutoTokenizer
5
+
6
  load_in_8bit: false
7
  load_in_4bit: true
8
  strict: false
9
+
10
  # huggingface repo
11
  datasets:
12
  - path: mhenrichsen/alpaca_2k_test
13
  type: alpaca
14
  val_set_size: 0.1
15
  output_dir: ./out
16
+
17
  adapter: qlora
18
  lora_r: 32
19
  lora_alpha: 16
20
  lora_dropout: 0.05
21
  lora_target_linear: true
22
+
23
  sequence_len: 4096
24
  sample_packing: false
25
  pad_to_sequence_len: false
26
+
27
  wandb_project:
28
  wandb_entity:
29
  wandb_watch:
30
  wandb_name:
31
  wandb_log_model:
32
+
33
+
34
  gradient_accumulation_steps: 3
35
  micro_batch_size: 2
36
  num_epochs: 4
37
  optimizer: adamw_bnb_8bit
38
  lr_scheduler: cosine
39
  learning_rate: 0.0002
40
+
41
  train_on_inputs: false
42
  group_by_length: false
43
  bf16: auto
44
  fp16:
45
  tf32: false
46
+
47
  gradient_checkpointing: true
48
  early_stopping_patience:
49
  resume_from_checkpoint:
 
51
  logging_steps: 1
52
  xformers_attention:
53
  flash_attention: true
54
+
55
  warmup_ratio: 0.1
56
  evals_per_epoch: 4
57
  eval_table_size:
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft @ git+https://github.com/huggingface/peft.git
4
- transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
 
1
  --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
2
  packaging==23.2
3
  peft @ git+https://github.com/huggingface/peft.git
4
+ transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
5
  tokenizers==0.15.0
6
  bitsandbytes>=0.41.1
7
  accelerate==0.26.1
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
275
  kv_seq_len = key_states.shape[-2]
276
  if past_key_value is not None:
277
  kv_seq_len += past_key_value[0].shape[-2]
278
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
 
279
  query_states, key_states = apply_rotary_pos_emb(
280
  query_states, key_states, cos, sin, position_ids
281
  )
@@ -425,7 +427,9 @@ def flashattn_forward(
425
  if past_key_value is not None:
426
  kv_seq_len += past_key_value[0].shape[-2]
427
 
428
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
 
 
429
  query_states, key_states = apply_rotary_pos_emb(
430
  query_states, key_states, cos, sin, position_ids
431
  )
@@ -688,6 +692,9 @@ def llama_model_forward(
688
  output_attentions: Optional[bool] = None,
689
  output_hidden_states: Optional[bool] = None,
690
  return_dict: Optional[bool] = None,
 
 
 
691
  ) -> Union[Tuple, BaseModelOutputWithPast]:
692
  output_attentions = (
693
  output_attentions
 
275
  kv_seq_len = key_states.shape[-2]
276
  if past_key_value is not None:
277
  kv_seq_len += past_key_value[0].shape[-2]
278
+ cos, sin = self.rotary_emb(
279
+ value_states, seq_len=kv_seq_len, position_ids=position_ids
280
+ )
281
  query_states, key_states = apply_rotary_pos_emb(
282
  query_states, key_states, cos, sin, position_ids
283
  )
 
427
  if past_key_value is not None:
428
  kv_seq_len += past_key_value[0].shape[-2]
429
 
430
+ cos, sin = self.rotary_emb(
431
+ value_states, seq_len=kv_seq_len, position_ids=position_ids
432
+ )
433
  query_states, key_states = apply_rotary_pos_emb(
434
  query_states, key_states, cos, sin, position_ids
435
  )
 
692
  output_attentions: Optional[bool] = None,
693
  output_hidden_states: Optional[bool] = None,
694
  return_dict: Optional[bool] = None,
695
+ cache_position: Optional[ # pylint: disable=unused-argument
696
+ torch.LongTensor
697
+ ] = None,
698
  ) -> Union[Tuple, BaseModelOutputWithPast]:
699
  output_attentions = (
700
  output_attentions
src/axolotl/monkeypatch/multipack.py CHANGED
@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
6
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
7
  from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
- SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
10
 
11
 
12
  def patch_for_multipack(model_type):
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
28
  transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
29
  get_unpad_data
30
  )
 
 
 
 
 
6
  from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
7
  from axolotl.monkeypatch.utils import get_unpad_data
8
 
9
+ SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
10
 
11
 
12
  def patch_for_multipack(model_type):
 
28
  transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
29
  get_unpad_data
30
  )
31
+ elif model_type == "gemma":
32
+ transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
33
+ get_unpad_data
34
+ )