Implement fused modules (#747)
Browse files* MLP: Memory saving
* Remove RMSNorm restrictions
* Map packed weights to original
* FusedAttention module
* Simplify code
* Move fused modules
* Fix critical typo
* Split inplace
* Add FFT config
* Add validation of fused arguments
* Add fused arguments to config
* Update docs
* Fix validation logic
* Add fused modules to flash attn
* Only fuse during training
* Remove timing
* Formatting
* Formatting
* Formatting
* chore: lint
* chore: lint
* add e2e tests for fused llama
* no lora for tests
---------
Co-authored-by: Wing Lian <[email protected]>
- README.md +2 -0
- examples/llama-2/README.md +7 -3
- examples/llama-2/fft_optimized.yml +73 -0
- src/axolotl/monkeypatch/fused_modules.py +0 -0
- src/axolotl/monkeypatch/llama_attn_hijack_flash.py +123 -5
- src/axolotl/monkeypatch/utils.py +13 -0
- src/axolotl/train.py +6 -4
- src/axolotl/utils/config.py +9 -0
- src/axolotl/utils/models.py +14 -0
- tests/e2e/test_fused_llama.py +117 -0
README.md
CHANGED
@@ -684,6 +684,8 @@ xformers_attention:
|
|
684 |
flash_attention:
|
685 |
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
686 |
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
|
|
|
|
687 |
# Whether to use scaled-dot-product attention
|
688 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
689 |
sdp_attention:
|
|
|
684 |
flash_attention:
|
685 |
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
686 |
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
687 |
+
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
688 |
+
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
689 |
# Whether to use scaled-dot-product attention
|
690 |
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
691 |
sdp_attention:
|
examples/llama-2/README.md
CHANGED
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
|
|
9 |
micro_batch_size: 1
|
10 |
|
11 |
```shell
|
12 |
-
accelerate launch
|
13 |
-
|
14 |
```
|
15 |
or
|
16 |
|
17 |
```shell
|
18 |
-
accelerate launch
|
|
|
19 |
|
|
|
|
|
|
|
|
|
20 |
```
|
|
|
9 |
micro_batch_size: 1
|
10 |
|
11 |
```shell
|
12 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
|
|
13 |
```
|
14 |
or
|
15 |
|
16 |
```shell
|
17 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
18 |
+
```
|
19 |
|
20 |
+
To launch a full finetuning with 16-bit precision:
|
21 |
+
|
22 |
+
```shell
|
23 |
+
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
24 |
```
|
examples/llama-2/fft_optimized.yml
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: NousResearch/Llama-2-7b-hf
|
2 |
+
base_model_config: NousResearch/Llama-2-7b-hf
|
3 |
+
model_type: LlamaForCausalLM
|
4 |
+
tokenizer_type: LlamaTokenizer
|
5 |
+
is_llama_derived_model: true
|
6 |
+
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
strict: false
|
10 |
+
|
11 |
+
datasets:
|
12 |
+
- path: mhenrichsen/alpaca_2k_test
|
13 |
+
type: alpaca
|
14 |
+
dataset_prepared_path: last_run_prepared
|
15 |
+
val_set_size: 0.01
|
16 |
+
output_dir: ./out
|
17 |
+
|
18 |
+
sequence_len: 4096
|
19 |
+
sample_packing: true
|
20 |
+
pad_to_sequence_len: true
|
21 |
+
|
22 |
+
adapter:
|
23 |
+
lora_model_dir:
|
24 |
+
lora_r:
|
25 |
+
lora_alpha:
|
26 |
+
lora_dropout:
|
27 |
+
lora_target_linear:
|
28 |
+
lora_fan_in_fan_out:
|
29 |
+
|
30 |
+
wandb_project:
|
31 |
+
wandb_entity:
|
32 |
+
wandb_watch:
|
33 |
+
wandb_run_id:
|
34 |
+
wandb_log_model:
|
35 |
+
|
36 |
+
gradient_accumulation_steps: 1
|
37 |
+
micro_batch_size: 1
|
38 |
+
num_epochs: 1
|
39 |
+
optimizer: adamw_bnb_8bit
|
40 |
+
lr_scheduler: cosine
|
41 |
+
learning_rate: 0.0002
|
42 |
+
|
43 |
+
train_on_inputs: false
|
44 |
+
group_by_length: false
|
45 |
+
bf16: true
|
46 |
+
fp16: false
|
47 |
+
tf32: false
|
48 |
+
|
49 |
+
gradient_checkpointing: true
|
50 |
+
early_stopping_patience:
|
51 |
+
resume_from_checkpoint:
|
52 |
+
local_rank:
|
53 |
+
logging_steps: 1
|
54 |
+
xformers_attention:
|
55 |
+
flash_attention: true
|
56 |
+
flash_attn_cross_entropy: false
|
57 |
+
flash_attn_rms_norm: true
|
58 |
+
flash_attn_fuse_qkv: false
|
59 |
+
flash_attn_fuse_mlp: true
|
60 |
+
|
61 |
+
warmup_steps: 100
|
62 |
+
eval_steps: 0.05
|
63 |
+
eval_table_size:
|
64 |
+
save_steps:
|
65 |
+
debug:
|
66 |
+
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
67 |
+
weight_decay: 0.1
|
68 |
+
fsdp:
|
69 |
+
fsdp_config:
|
70 |
+
special_tokens:
|
71 |
+
bos_token: "<s>"
|
72 |
+
eos_token: "</s>"
|
73 |
+
unk_token: "<unk>"
|
src/axolotl/monkeypatch/fused_modules.py
ADDED
File without changes
|
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -13,12 +13,18 @@ import transformers
|
|
13 |
from einops import rearrange
|
14 |
from flash_attn.bert_padding import pad_input, unpad_input
|
15 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
|
16 |
from transformers.models.llama.modeling_llama import (
|
17 |
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
18 |
)
|
19 |
-
from transformers.models.llama.modeling_llama import
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
22 |
|
23 |
try:
|
24 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
@@ -38,6 +44,28 @@ except ImportError:
|
|
38 |
LOG = logging.getLogger("axolotl")
|
39 |
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
def replace_llama_attn_with_flash_attn(
|
42 |
packed: Optional[bool] = False,
|
43 |
cross_entropy: Optional[bool] = False,
|
@@ -86,6 +114,91 @@ def replace_llama_attn_with_flash_attn(
|
|
86 |
)
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
90 |
# requires the attention mask to be the same as the key_padding_mask
|
91 |
def _prepare_decoder_attention_mask(
|
@@ -147,9 +260,14 @@ def flashattn_forward(
|
|
147 |
value_states = torch.cat(value_states, dim=-1)
|
148 |
|
149 |
else:
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
query_states = query_states.view(
|
155 |
bsz, q_len, self.num_heads, self.head_dim
|
|
|
13 |
from einops import rearrange
|
14 |
from flash_attn.bert_padding import pad_input, unpad_input
|
15 |
from transformers.modeling_outputs import BaseModelOutputWithPast
|
16 |
+
from transformers.models.llama.modeling_llama import LlamaAttention
|
17 |
from transformers.models.llama.modeling_llama import (
|
18 |
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
19 |
)
|
20 |
+
from transformers.models.llama.modeling_llama import (
|
21 |
+
LlamaMLP,
|
22 |
+
apply_rotary_pos_emb,
|
23 |
+
repeat_kv,
|
24 |
+
)
|
25 |
+
from xformers.ops import SwiGLU
|
26 |
|
27 |
+
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
28 |
|
29 |
try:
|
30 |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
|
|
44 |
LOG = logging.getLogger("axolotl")
|
45 |
|
46 |
|
47 |
+
def replace_llama_mlp_with_swiglu(model):
|
48 |
+
for name, module in model.named_modules():
|
49 |
+
if isinstance(module, LlamaMLP):
|
50 |
+
mlp = FusedMLP(
|
51 |
+
module.config, module.gate_proj, module.up_proj, module.down_proj
|
52 |
+
)
|
53 |
+
set_module_name(model, name, mlp)
|
54 |
+
|
55 |
+
|
56 |
+
def replace_llama_qkv_with_fused(model):
|
57 |
+
for name, module in model.named_modules():
|
58 |
+
if isinstance(module, LlamaAttention):
|
59 |
+
qkv = FusedAttention(
|
60 |
+
module.config,
|
61 |
+
module.q_proj,
|
62 |
+
module.k_proj,
|
63 |
+
module.v_proj,
|
64 |
+
module.o_proj,
|
65 |
+
)
|
66 |
+
set_module_name(model, name, qkv)
|
67 |
+
|
68 |
+
|
69 |
def replace_llama_attn_with_flash_attn(
|
70 |
packed: Optional[bool] = False,
|
71 |
cross_entropy: Optional[bool] = False,
|
|
|
114 |
)
|
115 |
|
116 |
|
117 |
+
class FusedAttention(LlamaAttention):
|
118 |
+
"""
|
119 |
+
Fused QKV Attention layer for incrementally improved training efficiency
|
120 |
+
"""
|
121 |
+
|
122 |
+
def __init__(
|
123 |
+
self,
|
124 |
+
config,
|
125 |
+
q: torch.nn.Linear, # pylint: disable=invalid-name
|
126 |
+
k: torch.nn.Linear, # pylint: disable=invalid-name
|
127 |
+
v: torch.nn.Linear, # pylint: disable=invalid-name
|
128 |
+
o: torch.nn.Linear, # pylint: disable=invalid-name
|
129 |
+
):
|
130 |
+
super().__init__(config)
|
131 |
+
self.config = config
|
132 |
+
self.init_device = next(iter(q.state_dict().values())).device
|
133 |
+
|
134 |
+
# define equivalent fused qkv projection
|
135 |
+
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
136 |
+
self.qkv_proj = torch.nn.Linear(
|
137 |
+
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
138 |
+
)
|
139 |
+
self.o_proj = o
|
140 |
+
|
141 |
+
# overwrite initialized weights with pretrained weights
|
142 |
+
self.qkv_proj.weight.data = torch.cat(
|
143 |
+
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
144 |
+
)
|
145 |
+
|
146 |
+
def _post_training(self, model, name):
|
147 |
+
q_proj, k_proj, v_proj = torch.split(
|
148 |
+
self.qkv_proj.weight.data, self.out_features, dim=0
|
149 |
+
)
|
150 |
+
|
151 |
+
new_attn = LlamaAttention(self.config)
|
152 |
+
new_attn.q_proj.weight.data = q_proj
|
153 |
+
new_attn.k_proj.weight.data = k_proj
|
154 |
+
new_attn.v_proj.weight.data = v_proj
|
155 |
+
|
156 |
+
set_module_name(model, name, new_attn)
|
157 |
+
|
158 |
+
|
159 |
+
class FusedMLP(torch.nn.Module):
|
160 |
+
"""
|
161 |
+
Fused MLP layer for incrementally improved training efficiency
|
162 |
+
"""
|
163 |
+
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
config,
|
167 |
+
gate_proj: torch.nn.Linear,
|
168 |
+
up_proj: torch.nn.Linear,
|
169 |
+
down_proj: torch.nn.Linear,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.config = config
|
173 |
+
self.swiglu = SwiGLU(
|
174 |
+
in_features=config.hidden_size,
|
175 |
+
hidden_features=config.intermediate_size,
|
176 |
+
bias=False,
|
177 |
+
_pack_weights=True,
|
178 |
+
)
|
179 |
+
# overwrite initialized weights with pretrained weights
|
180 |
+
self.swiglu.w12.weight.data = torch.cat(
|
181 |
+
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
182 |
+
)
|
183 |
+
self.swiglu.w3.weight.data = down_proj.weight.data
|
184 |
+
|
185 |
+
def _post_training(self, model, name):
|
186 |
+
w1, w2 = torch.split( # pylint: disable=invalid-name
|
187 |
+
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
188 |
+
)
|
189 |
+
|
190 |
+
# Assign the split weights back to the original layers
|
191 |
+
new_mlp = LlamaMLP(self.config)
|
192 |
+
new_mlp.gate_proj.weight.data = w1
|
193 |
+
new_mlp.up_proj.weight.data = w2
|
194 |
+
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
195 |
+
|
196 |
+
set_module_name(model, name, new_mlp)
|
197 |
+
|
198 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
199 |
+
return self.swiglu(x)
|
200 |
+
|
201 |
+
|
202 |
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
203 |
# requires the attention mask to be the same as the key_padding_mask
|
204 |
def _prepare_decoder_attention_mask(
|
|
|
260 |
value_states = torch.cat(value_states, dim=-1)
|
261 |
|
262 |
else:
|
263 |
+
if isinstance(self, FusedAttention):
|
264 |
+
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
265 |
+
self.out_features, dim=-1
|
266 |
+
)
|
267 |
+
else:
|
268 |
+
query_states = self.q_proj(hidden_states)
|
269 |
+
key_states = self.k_proj(hidden_states)
|
270 |
+
value_states = self.v_proj(hidden_states)
|
271 |
|
272 |
query_states = query_states.view(
|
273 |
bsz, q_len, self.num_heads, self.head_dim
|
src/axolotl/monkeypatch/utils.py
CHANGED
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
|
101 |
max_seq_lens.append(max_seq_len)
|
102 |
|
103 |
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
max_seq_lens.append(max_seq_len)
|
102 |
|
103 |
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
104 |
+
|
105 |
+
|
106 |
+
def set_module_name(model, name, value):
|
107 |
+
if "." in name:
|
108 |
+
parent_name = name.rsplit(".", 1)[0]
|
109 |
+
child_name = name[len(parent_name) + 1 :]
|
110 |
+
parent = model.get_submodule(parent_name)
|
111 |
+
else:
|
112 |
+
parent_name = ""
|
113 |
+
parent = model
|
114 |
+
child_name = name
|
115 |
+
|
116 |
+
setattr(parent, child_name, value)
|
src/axolotl/train.py
CHANGED
@@ -40,10 +40,7 @@ class TrainDatasetMeta:
|
|
40 |
|
41 |
|
42 |
def train(
|
43 |
-
*,
|
44 |
-
cfg: DictDefault,
|
45 |
-
cli_args: TrainerCliArgs,
|
46 |
-
dataset_meta: TrainDatasetMeta,
|
47 |
):
|
48 |
# load the tokenizer first
|
49 |
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
@@ -120,6 +117,11 @@ def train(
|
|
120 |
|
121 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
122 |
|
|
|
|
|
|
|
|
|
|
|
123 |
if trainer.is_fsdp_enabled:
|
124 |
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
125 |
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
|
|
40 |
|
41 |
|
42 |
def train(
|
43 |
+
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
|
|
|
|
|
|
44 |
):
|
45 |
# load the tokenizer first
|
46 |
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
|
|
117 |
|
118 |
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
119 |
|
120 |
+
# post training
|
121 |
+
for name, module in model.named_modules():
|
122 |
+
if hasattr(module, "_post_training"):
|
123 |
+
module._post_training(model, name) # pylint: disable=protected-access
|
124 |
+
|
125 |
if trainer.is_fsdp_enabled:
|
126 |
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
127 |
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
src/axolotl/utils/config.py
CHANGED
@@ -189,9 +189,15 @@ def validate_config(cfg):
|
|
189 |
if not cfg.load_in_4bit:
|
190 |
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
191 |
|
|
|
|
|
|
|
192 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
193 |
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
194 |
|
|
|
|
|
|
|
195 |
if cfg.relora_steps:
|
196 |
if cfg.adapter not in ("lora", "qlora"):
|
197 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
@@ -205,6 +211,9 @@ def validate_config(cfg):
|
|
205 |
if cfg.lr_scheduler == "one_cycle":
|
206 |
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
207 |
|
|
|
|
|
|
|
208 |
if cfg.trust_remote_code:
|
209 |
LOG.warning(
|
210 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
|
|
189 |
if not cfg.load_in_4bit:
|
190 |
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
191 |
|
192 |
+
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
193 |
+
raise ValueError("Fused modules are not supported with QLoRA")
|
194 |
+
|
195 |
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
196 |
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
197 |
|
198 |
+
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
199 |
+
raise ValueError("Fused modules are not supported with LoRA")
|
200 |
+
|
201 |
if cfg.relora_steps:
|
202 |
if cfg.adapter not in ("lora", "qlora"):
|
203 |
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
|
|
211 |
if cfg.lr_scheduler == "one_cycle":
|
212 |
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
213 |
|
214 |
+
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
215 |
+
raise ValueError("Fused modules are not supported with ReLoRA")
|
216 |
+
|
217 |
if cfg.trust_remote_code:
|
218 |
LOG.warning(
|
219 |
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
src/axolotl/utils/models.py
CHANGED
@@ -272,6 +272,20 @@ def load_model(
|
|
272 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
273 |
**model_kwargs,
|
274 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
276 |
# This is a WIP, still an issue with the backward pass
|
277 |
# RuntimeError: grad can be implicitly created only for scalar outputs
|
|
|
272 |
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
273 |
**model_kwargs,
|
274 |
)
|
275 |
+
|
276 |
+
if cfg.flash_attention and not inference:
|
277 |
+
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
278 |
+
replace_llama_mlp_with_swiglu,
|
279 |
+
replace_llama_qkv_with_fused,
|
280 |
+
)
|
281 |
+
|
282 |
+
if cfg.flash_attn_fuse_mlp:
|
283 |
+
LOG.info("patching with SwiGLU")
|
284 |
+
replace_llama_mlp_with_swiglu(model)
|
285 |
+
|
286 |
+
if cfg.flash_attn_fuse_qkv:
|
287 |
+
LOG.info("patching with fused QKV")
|
288 |
+
replace_llama_qkv_with_fused(model)
|
289 |
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
290 |
# This is a WIP, still an issue with the backward pass
|
291 |
# RuntimeError: grad can be implicitly created only for scalar outputs
|
tests/e2e/test_fused_llama.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
E2E tests for lora llama
|
3 |
+
"""
|
4 |
+
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import tempfile
|
8 |
+
import unittest
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
from transformers.utils import is_torch_bf16_gpu_available
|
12 |
+
|
13 |
+
from axolotl.cli import load_datasets
|
14 |
+
from axolotl.common.cli import TrainerCliArgs
|
15 |
+
from axolotl.train import train
|
16 |
+
from axolotl.utils.config import normalize_config
|
17 |
+
from axolotl.utils.dict import DictDefault
|
18 |
+
|
19 |
+
LOG = logging.getLogger("axolotl.tests.e2e")
|
20 |
+
os.environ["WANDB_DISABLED"] = "true"
|
21 |
+
|
22 |
+
|
23 |
+
class TestFusedLlama(unittest.TestCase):
|
24 |
+
"""
|
25 |
+
Test case for Llama models using Fused layers
|
26 |
+
"""
|
27 |
+
|
28 |
+
def test_lora_packing(self):
|
29 |
+
# pylint: disable=duplicate-code
|
30 |
+
output_dir = tempfile.mkdtemp()
|
31 |
+
cfg = DictDefault(
|
32 |
+
{
|
33 |
+
"base_model": "JackFram/llama-68m",
|
34 |
+
"base_model_config": "JackFram/llama-68m",
|
35 |
+
"flash_attention": True,
|
36 |
+
"flash_attn_fuse_qkv": True,
|
37 |
+
"flash_attn_fuse_mlp": True,
|
38 |
+
"sample_packing": True,
|
39 |
+
"sequence_len": 1024,
|
40 |
+
"load_in_8bit": True,
|
41 |
+
"val_set_size": 0.1,
|
42 |
+
"special_tokens": {
|
43 |
+
"unk_token": "<unk>",
|
44 |
+
"bos_token": "<s>",
|
45 |
+
"eos_token": "</s>",
|
46 |
+
},
|
47 |
+
"datasets": [
|
48 |
+
{
|
49 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
50 |
+
"type": "alpaca",
|
51 |
+
},
|
52 |
+
],
|
53 |
+
"num_epochs": 2,
|
54 |
+
"micro_batch_size": 2,
|
55 |
+
"gradient_accumulation_steps": 1,
|
56 |
+
"output_dir": output_dir,
|
57 |
+
"learning_rate": 0.00001,
|
58 |
+
"optimizer": "adamw_torch",
|
59 |
+
"lr_scheduler": "cosine",
|
60 |
+
"max_steps": 20,
|
61 |
+
"save_steps": 10,
|
62 |
+
"eval_steps": 10,
|
63 |
+
}
|
64 |
+
)
|
65 |
+
normalize_config(cfg)
|
66 |
+
cli_args = TrainerCliArgs()
|
67 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
68 |
+
|
69 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
70 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
71 |
+
|
72 |
+
def test_fft_packing(self):
|
73 |
+
# pylint: disable=duplicate-code
|
74 |
+
output_dir = tempfile.mkdtemp()
|
75 |
+
cfg = DictDefault(
|
76 |
+
{
|
77 |
+
"base_model": "JackFram/llama-68m",
|
78 |
+
"base_model_config": "JackFram/llama-68m",
|
79 |
+
"flash_attention": True,
|
80 |
+
"flash_attn_fuse_qkv": True,
|
81 |
+
"flash_attn_fuse_mlp": True,
|
82 |
+
"sample_packing": True,
|
83 |
+
"sequence_len": 1024,
|
84 |
+
"val_set_size": 0.1,
|
85 |
+
"special_tokens": {
|
86 |
+
"unk_token": "<unk>",
|
87 |
+
"bos_token": "<s>",
|
88 |
+
"eos_token": "</s>",
|
89 |
+
},
|
90 |
+
"datasets": [
|
91 |
+
{
|
92 |
+
"path": "mhenrichsen/alpaca_2k_test",
|
93 |
+
"type": "alpaca",
|
94 |
+
},
|
95 |
+
],
|
96 |
+
"num_epochs": 2,
|
97 |
+
"micro_batch_size": 2,
|
98 |
+
"gradient_accumulation_steps": 1,
|
99 |
+
"output_dir": output_dir,
|
100 |
+
"learning_rate": 0.00001,
|
101 |
+
"optimizer": "adamw_torch",
|
102 |
+
"lr_scheduler": "cosine",
|
103 |
+
"max_steps": 20,
|
104 |
+
"save_steps": 10,
|
105 |
+
"eval_steps": 10,
|
106 |
+
}
|
107 |
+
)
|
108 |
+
if is_torch_bf16_gpu_available():
|
109 |
+
cfg.bf16 = True
|
110 |
+
else:
|
111 |
+
cfg.fp16 = True
|
112 |
+
normalize_config(cfg)
|
113 |
+
cli_args = TrainerCliArgs()
|
114 |
+
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
115 |
+
|
116 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
117 |
+
assert (Path(output_dir) / "pytorch_model.bin").exists()
|