multipack for gemma (#1313)
Browse files* multipack for gemma
* chore: lint
* handle cache_position kwarg in updated llama modeling
* add position_ids to rotary embed call for updated llama modeling
examples/gemma/qlora.yml
CHANGED
@@ -1,49 +1,49 @@
|
|
1 |
# use google/gemma-7b if you have access
|
2 |
-
base_model: mhenrichsen/gemma-7b
|
3 |
model_type: AutoModelForCausalLM
|
4 |
tokenizer_type: AutoTokenizer
|
5 |
-
|
6 |
load_in_8bit: false
|
7 |
load_in_4bit: true
|
8 |
strict: false
|
9 |
-
|
10 |
# huggingface repo
|
11 |
datasets:
|
12 |
- path: mhenrichsen/alpaca_2k_test
|
13 |
type: alpaca
|
14 |
val_set_size: 0.1
|
15 |
output_dir: ./out
|
16 |
-
|
17 |
adapter: qlora
|
18 |
lora_r: 32
|
19 |
lora_alpha: 16
|
20 |
lora_dropout: 0.05
|
21 |
lora_target_linear: true
|
22 |
-
|
23 |
sequence_len: 4096
|
24 |
sample_packing: false
|
25 |
pad_to_sequence_len: false
|
26 |
-
|
27 |
wandb_project:
|
28 |
wandb_entity:
|
29 |
wandb_watch:
|
30 |
wandb_name:
|
31 |
wandb_log_model:
|
32 |
-
|
33 |
-
|
34 |
gradient_accumulation_steps: 3
|
35 |
micro_batch_size: 2
|
36 |
num_epochs: 4
|
37 |
optimizer: adamw_bnb_8bit
|
38 |
lr_scheduler: cosine
|
39 |
learning_rate: 0.0002
|
40 |
-
|
41 |
train_on_inputs: false
|
42 |
group_by_length: false
|
43 |
bf16: auto
|
44 |
fp16:
|
45 |
tf32: false
|
46 |
-
|
47 |
gradient_checkpointing: true
|
48 |
early_stopping_patience:
|
49 |
resume_from_checkpoint:
|
@@ -51,7 +51,7 @@ local_rank:
|
|
51 |
logging_steps: 1
|
52 |
xformers_attention:
|
53 |
flash_attention: true
|
54 |
-
|
55 |
warmup_ratio: 0.1
|
56 |
evals_per_epoch: 4
|
57 |
eval_table_size:
|
|
|
1 |
# use google/gemma-7b if you have access
|
2 |
+
base_model: mhenrichsen/gemma-7b
|
3 |
model_type: AutoModelForCausalLM
|
4 |
tokenizer_type: AutoTokenizer
|
5 |
+
|
6 |
load_in_8bit: false
|
7 |
load_in_4bit: true
|
8 |
strict: false
|
9 |
+
|
10 |
# huggingface repo
|
11 |
datasets:
|
12 |
- path: mhenrichsen/alpaca_2k_test
|
13 |
type: alpaca
|
14 |
val_set_size: 0.1
|
15 |
output_dir: ./out
|
16 |
+
|
17 |
adapter: qlora
|
18 |
lora_r: 32
|
19 |
lora_alpha: 16
|
20 |
lora_dropout: 0.05
|
21 |
lora_target_linear: true
|
22 |
+
|
23 |
sequence_len: 4096
|
24 |
sample_packing: false
|
25 |
pad_to_sequence_len: false
|
26 |
+
|
27 |
wandb_project:
|
28 |
wandb_entity:
|
29 |
wandb_watch:
|
30 |
wandb_name:
|
31 |
wandb_log_model:
|
32 |
+
|
33 |
+
|
34 |
gradient_accumulation_steps: 3
|
35 |
micro_batch_size: 2
|
36 |
num_epochs: 4
|
37 |
optimizer: adamw_bnb_8bit
|
38 |
lr_scheduler: cosine
|
39 |
learning_rate: 0.0002
|
40 |
+
|
41 |
train_on_inputs: false
|
42 |
group_by_length: false
|
43 |
bf16: auto
|
44 |
fp16:
|
45 |
tf32: false
|
46 |
+
|
47 |
gradient_checkpointing: true
|
48 |
early_stopping_patience:
|
49 |
resume_from_checkpoint:
|
|
|
51 |
logging_steps: 1
|
52 |
xformers_attention:
|
53 |
flash_attention: true
|
54 |
+
|
55 |
warmup_ratio: 0.1
|
56 |
evals_per_epoch: 4
|
57 |
eval_table_size:
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft @ git+https://github.com/huggingface/peft.git
|
4 |
-
transformers @ git+https://github.com/huggingface/transformers.git@
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.41.1
|
7 |
accelerate==0.26.1
|
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft @ git+https://github.com/huggingface/peft.git
|
4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@ae49b218c3d718df90d8e4a109016450fb8f0632
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.41.1
|
7 |
accelerate==0.26.1
|
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
|
|
275 |
kv_seq_len = key_states.shape[-2]
|
276 |
if past_key_value is not None:
|
277 |
kv_seq_len += past_key_value[0].shape[-2]
|
278 |
-
cos, sin = self.rotary_emb(
|
|
|
|
|
279 |
query_states, key_states = apply_rotary_pos_emb(
|
280 |
query_states, key_states, cos, sin, position_ids
|
281 |
)
|
@@ -425,7 +427,9 @@ def flashattn_forward(
|
|
425 |
if past_key_value is not None:
|
426 |
kv_seq_len += past_key_value[0].shape[-2]
|
427 |
|
428 |
-
cos, sin = self.rotary_emb(
|
|
|
|
|
429 |
query_states, key_states = apply_rotary_pos_emb(
|
430 |
query_states, key_states, cos, sin, position_ids
|
431 |
)
|
@@ -688,6 +692,9 @@ def llama_model_forward(
|
|
688 |
output_attentions: Optional[bool] = None,
|
689 |
output_hidden_states: Optional[bool] = None,
|
690 |
return_dict: Optional[bool] = None,
|
|
|
|
|
|
|
691 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
692 |
output_attentions = (
|
693 |
output_attentions
|
|
|
275 |
kv_seq_len = key_states.shape[-2]
|
276 |
if past_key_value is not None:
|
277 |
kv_seq_len += past_key_value[0].shape[-2]
|
278 |
+
cos, sin = self.rotary_emb(
|
279 |
+
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
280 |
+
)
|
281 |
query_states, key_states = apply_rotary_pos_emb(
|
282 |
query_states, key_states, cos, sin, position_ids
|
283 |
)
|
|
|
427 |
if past_key_value is not None:
|
428 |
kv_seq_len += past_key_value[0].shape[-2]
|
429 |
|
430 |
+
cos, sin = self.rotary_emb(
|
431 |
+
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
432 |
+
)
|
433 |
query_states, key_states = apply_rotary_pos_emb(
|
434 |
query_states, key_states, cos, sin, position_ids
|
435 |
)
|
|
|
692 |
output_attentions: Optional[bool] = None,
|
693 |
output_hidden_states: Optional[bool] = None,
|
694 |
return_dict: Optional[bool] = None,
|
695 |
+
cache_position: Optional[ # pylint: disable=unused-argument
|
696 |
+
torch.LongTensor
|
697 |
+
] = None,
|
698 |
) -> Union[Tuple, BaseModelOutputWithPast]:
|
699 |
output_attentions = (
|
700 |
output_attentions
|
src/axolotl/monkeypatch/multipack.py
CHANGED
@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|
6 |
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
8 |
|
9 |
-
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
10 |
|
11 |
|
12 |
def patch_for_multipack(model_type):
|
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
|
|
28 |
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
29 |
get_unpad_data
|
30 |
)
|
|
|
|
|
|
|
|
|
|
6 |
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
7 |
from axolotl.monkeypatch.utils import get_unpad_data
|
8 |
|
9 |
+
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
10 |
|
11 |
|
12 |
def patch_for_multipack(model_type):
|
|
|
28 |
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
29 |
get_unpad_data
|
30 |
)
|
31 |
+
elif model_type == "gemma":
|
32 |
+
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
33 |
+
get_unpad_data
|
34 |
+
)
|