casperhansen winglian commited on
Commit
15d3a65
·
unverified ·
1 Parent(s): a21935f

Implement fused modules (#747)

Browse files

* MLP: Memory saving

* Remove RMSNorm restrictions

* Map packed weights to original

* FusedAttention module

* Simplify code

* Move fused modules

* Fix critical typo

* Split inplace

* Add FFT config

* Add validation of fused arguments

* Add fused arguments to config

* Update docs

* Fix validation logic

* Add fused modules to flash attn

* Only fuse during training

* Remove timing

* Formatting

* Formatting

* Formatting

* chore: lint

* chore: lint

* add e2e tests for fused llama

* no lora for tests

---------

Co-authored-by: Wing Lian <[email protected]>

README.md CHANGED
@@ -684,6 +684,8 @@ xformers_attention:
684
  flash_attention:
685
  flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
686
  flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
 
 
687
  # Whether to use scaled-dot-product attention
688
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
689
  sdp_attention:
 
684
  flash_attention:
685
  flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
686
  flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
687
+ flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
688
+ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
689
  # Whether to use scaled-dot-product attention
690
  # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
691
  sdp_attention:
examples/llama-2/README.md CHANGED
@@ -9,12 +9,16 @@ gradient_accumulation_steps: 2
9
  micro_batch_size: 1
10
 
11
  ```shell
12
- accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
13
-
14
  ```
15
  or
16
 
17
  ```shell
18
- accelerate launch scripts/finetune.py examples/llama-2/lora.yml
 
19
 
 
 
 
 
20
  ```
 
9
  micro_batch_size: 1
10
 
11
  ```shell
12
+ accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
 
13
  ```
14
  or
15
 
16
  ```shell
17
+ accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
18
+ ```
19
 
20
+ To launch a full finetuning with 16-bit precision:
21
+
22
+ ```shell
23
+ accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
24
  ```
examples/llama-2/fft_optimized.yml ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: NousResearch/Llama-2-7b-hf
2
+ base_model_config: NousResearch/Llama-2-7b-hf
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ is_llama_derived_model: true
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: mhenrichsen/alpaca_2k_test
13
+ type: alpaca
14
+ dataset_prepared_path: last_run_prepared
15
+ val_set_size: 0.01
16
+ output_dir: ./out
17
+
18
+ sequence_len: 4096
19
+ sample_packing: true
20
+ pad_to_sequence_len: true
21
+
22
+ adapter:
23
+ lora_model_dir:
24
+ lora_r:
25
+ lora_alpha:
26
+ lora_dropout:
27
+ lora_target_linear:
28
+ lora_fan_in_fan_out:
29
+
30
+ wandb_project:
31
+ wandb_entity:
32
+ wandb_watch:
33
+ wandb_run_id:
34
+ wandb_log_model:
35
+
36
+ gradient_accumulation_steps: 1
37
+ micro_batch_size: 1
38
+ num_epochs: 1
39
+ optimizer: adamw_bnb_8bit
40
+ lr_scheduler: cosine
41
+ learning_rate: 0.0002
42
+
43
+ train_on_inputs: false
44
+ group_by_length: false
45
+ bf16: true
46
+ fp16: false
47
+ tf32: false
48
+
49
+ gradient_checkpointing: true
50
+ early_stopping_patience:
51
+ resume_from_checkpoint:
52
+ local_rank:
53
+ logging_steps: 1
54
+ xformers_attention:
55
+ flash_attention: true
56
+ flash_attn_cross_entropy: false
57
+ flash_attn_rms_norm: true
58
+ flash_attn_fuse_qkv: false
59
+ flash_attn_fuse_mlp: true
60
+
61
+ warmup_steps: 100
62
+ eval_steps: 0.05
63
+ eval_table_size:
64
+ save_steps:
65
+ debug:
66
+ deepspeed: #deepspeed/zero2.json # multi-gpu only
67
+ weight_decay: 0.1
68
+ fsdp:
69
+ fsdp_config:
70
+ special_tokens:
71
+ bos_token: "<s>"
72
+ eos_token: "</s>"
73
+ unk_token: "<unk>"
src/axolotl/monkeypatch/fused_modules.py ADDED
File without changes
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -13,12 +13,18 @@ import transformers
13
  from einops import rearrange
14
  from flash_attn.bert_padding import pad_input, unpad_input
15
  from transformers.modeling_outputs import BaseModelOutputWithPast
 
16
  from transformers.models.llama.modeling_llama import (
17
  LlamaDecoderLayer as OriginalLlamaDecoderLayer,
18
  )
19
- from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
 
 
 
 
 
20
 
21
- from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
22
 
23
  try:
24
  from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -38,6 +44,28 @@ except ImportError:
38
  LOG = logging.getLogger("axolotl")
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def replace_llama_attn_with_flash_attn(
42
  packed: Optional[bool] = False,
43
  cross_entropy: Optional[bool] = False,
@@ -86,6 +114,91 @@ def replace_llama_attn_with_flash_attn(
86
  )
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
90
  # requires the attention mask to be the same as the key_padding_mask
91
  def _prepare_decoder_attention_mask(
@@ -147,9 +260,14 @@ def flashattn_forward(
147
  value_states = torch.cat(value_states, dim=-1)
148
 
149
  else:
150
- query_states = self.q_proj(hidden_states)
151
- key_states = self.k_proj(hidden_states)
152
- value_states = self.v_proj(hidden_states)
 
 
 
 
 
153
 
154
  query_states = query_states.view(
155
  bsz, q_len, self.num_heads, self.head_dim
 
13
  from einops import rearrange
14
  from flash_attn.bert_padding import pad_input, unpad_input
15
  from transformers.modeling_outputs import BaseModelOutputWithPast
16
+ from transformers.models.llama.modeling_llama import LlamaAttention
17
  from transformers.models.llama.modeling_llama import (
18
  LlamaDecoderLayer as OriginalLlamaDecoderLayer,
19
  )
20
+ from transformers.models.llama.modeling_llama import (
21
+ LlamaMLP,
22
+ apply_rotary_pos_emb,
23
+ repeat_kv,
24
+ )
25
+ from xformers.ops import SwiGLU
26
 
27
+ from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
28
 
29
  try:
30
  from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
 
44
  LOG = logging.getLogger("axolotl")
45
 
46
 
47
+ def replace_llama_mlp_with_swiglu(model):
48
+ for name, module in model.named_modules():
49
+ if isinstance(module, LlamaMLP):
50
+ mlp = FusedMLP(
51
+ module.config, module.gate_proj, module.up_proj, module.down_proj
52
+ )
53
+ set_module_name(model, name, mlp)
54
+
55
+
56
+ def replace_llama_qkv_with_fused(model):
57
+ for name, module in model.named_modules():
58
+ if isinstance(module, LlamaAttention):
59
+ qkv = FusedAttention(
60
+ module.config,
61
+ module.q_proj,
62
+ module.k_proj,
63
+ module.v_proj,
64
+ module.o_proj,
65
+ )
66
+ set_module_name(model, name, qkv)
67
+
68
+
69
  def replace_llama_attn_with_flash_attn(
70
  packed: Optional[bool] = False,
71
  cross_entropy: Optional[bool] = False,
 
114
  )
115
 
116
 
117
+ class FusedAttention(LlamaAttention):
118
+ """
119
+ Fused QKV Attention layer for incrementally improved training efficiency
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ config,
125
+ q: torch.nn.Linear, # pylint: disable=invalid-name
126
+ k: torch.nn.Linear, # pylint: disable=invalid-name
127
+ v: torch.nn.Linear, # pylint: disable=invalid-name
128
+ o: torch.nn.Linear, # pylint: disable=invalid-name
129
+ ):
130
+ super().__init__(config)
131
+ self.config = config
132
+ self.init_device = next(iter(q.state_dict().values())).device
133
+
134
+ # define equivalent fused qkv projection
135
+ self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
136
+ self.qkv_proj = torch.nn.Linear(
137
+ q.in_features, sum(self.out_features), device=self.init_device, bias=False
138
+ )
139
+ self.o_proj = o
140
+
141
+ # overwrite initialized weights with pretrained weights
142
+ self.qkv_proj.weight.data = torch.cat(
143
+ (q.weight.data, k.weight.data, v.weight.data), dim=0
144
+ )
145
+
146
+ def _post_training(self, model, name):
147
+ q_proj, k_proj, v_proj = torch.split(
148
+ self.qkv_proj.weight.data, self.out_features, dim=0
149
+ )
150
+
151
+ new_attn = LlamaAttention(self.config)
152
+ new_attn.q_proj.weight.data = q_proj
153
+ new_attn.k_proj.weight.data = k_proj
154
+ new_attn.v_proj.weight.data = v_proj
155
+
156
+ set_module_name(model, name, new_attn)
157
+
158
+
159
+ class FusedMLP(torch.nn.Module):
160
+ """
161
+ Fused MLP layer for incrementally improved training efficiency
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ config,
167
+ gate_proj: torch.nn.Linear,
168
+ up_proj: torch.nn.Linear,
169
+ down_proj: torch.nn.Linear,
170
+ ):
171
+ super().__init__()
172
+ self.config = config
173
+ self.swiglu = SwiGLU(
174
+ in_features=config.hidden_size,
175
+ hidden_features=config.intermediate_size,
176
+ bias=False,
177
+ _pack_weights=True,
178
+ )
179
+ # overwrite initialized weights with pretrained weights
180
+ self.swiglu.w12.weight.data = torch.cat(
181
+ (gate_proj.weight.data, up_proj.weight.data), dim=0
182
+ )
183
+ self.swiglu.w3.weight.data = down_proj.weight.data
184
+
185
+ def _post_training(self, model, name):
186
+ w1, w2 = torch.split( # pylint: disable=invalid-name
187
+ self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
188
+ )
189
+
190
+ # Assign the split weights back to the original layers
191
+ new_mlp = LlamaMLP(self.config)
192
+ new_mlp.gate_proj.weight.data = w1
193
+ new_mlp.up_proj.weight.data = w2
194
+ new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
195
+
196
+ set_module_name(model, name, new_mlp)
197
+
198
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
199
+ return self.swiglu(x)
200
+
201
+
202
  # Disable the transformation of the attention mask in LlamaModel as the flash attention
203
  # requires the attention mask to be the same as the key_padding_mask
204
  def _prepare_decoder_attention_mask(
 
260
  value_states = torch.cat(value_states, dim=-1)
261
 
262
  else:
263
+ if isinstance(self, FusedAttention):
264
+ query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
265
+ self.out_features, dim=-1
266
+ )
267
+ else:
268
+ query_states = self.q_proj(hidden_states)
269
+ key_states = self.k_proj(hidden_states)
270
+ value_states = self.v_proj(hidden_states)
271
 
272
  query_states = query_states.view(
273
  bsz, q_len, self.num_heads, self.head_dim
src/axolotl/monkeypatch/utils.py CHANGED
@@ -101,3 +101,16 @@ def get_cu_seqlens_from_pos_ids(position_ids):
101
  max_seq_lens.append(max_seq_len)
102
 
103
  return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  max_seq_lens.append(max_seq_len)
102
 
103
  return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
104
+
105
+
106
+ def set_module_name(model, name, value):
107
+ if "." in name:
108
+ parent_name = name.rsplit(".", 1)[0]
109
+ child_name = name[len(parent_name) + 1 :]
110
+ parent = model.get_submodule(parent_name)
111
+ else:
112
+ parent_name = ""
113
+ parent = model
114
+ child_name = name
115
+
116
+ setattr(parent, child_name, value)
src/axolotl/train.py CHANGED
@@ -40,10 +40,7 @@ class TrainDatasetMeta:
40
 
41
 
42
  def train(
43
- *,
44
- cfg: DictDefault,
45
- cli_args: TrainerCliArgs,
46
- dataset_meta: TrainDatasetMeta,
47
  ):
48
  # load the tokenizer first
49
  LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
@@ -120,6 +117,11 @@ def train(
120
 
121
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
122
 
 
 
 
 
 
123
  if trainer.is_fsdp_enabled:
124
  trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
125
  LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
 
40
 
41
 
42
  def train(
43
+ *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
 
 
 
44
  ):
45
  # load the tokenizer first
46
  LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
 
117
 
118
  LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
119
 
120
+ # post training
121
+ for name, module in model.named_modules():
122
+ if hasattr(module, "_post_training"):
123
+ module._post_training(model, name) # pylint: disable=protected-access
124
+
125
  if trainer.is_fsdp_enabled:
126
  trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
127
  LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
src/axolotl/utils/config.py CHANGED
@@ -189,9 +189,15 @@ def validate_config(cfg):
189
  if not cfg.load_in_4bit:
190
  raise ValueError("Require cfg.load_in_4bit to be True for qlora")
191
 
 
 
 
192
  if not cfg.load_in_8bit and cfg.adapter == "lora":
193
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
194
 
 
 
 
195
  if cfg.relora_steps:
196
  if cfg.adapter not in ("lora", "qlora"):
197
  raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -205,6 +211,9 @@ def validate_config(cfg):
205
  if cfg.lr_scheduler == "one_cycle":
206
  raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
207
 
 
 
 
208
  if cfg.trust_remote_code:
209
  LOG.warning(
210
  "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
 
189
  if not cfg.load_in_4bit:
190
  raise ValueError("Require cfg.load_in_4bit to be True for qlora")
191
 
192
+ if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
193
+ raise ValueError("Fused modules are not supported with QLoRA")
194
+
195
  if not cfg.load_in_8bit and cfg.adapter == "lora":
196
  LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
197
 
198
+ if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
199
+ raise ValueError("Fused modules are not supported with LoRA")
200
+
201
  if cfg.relora_steps:
202
  if cfg.adapter not in ("lora", "qlora"):
203
  raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
 
211
  if cfg.lr_scheduler == "one_cycle":
212
  raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
213
 
214
+ if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
215
+ raise ValueError("Fused modules are not supported with ReLoRA")
216
+
217
  if cfg.trust_remote_code:
218
  LOG.warning(
219
  "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
src/axolotl/utils/models.py CHANGED
@@ -272,6 +272,20 @@ def load_model(
272
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
273
  **model_kwargs,
274
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
276
  # This is a WIP, still an issue with the backward pass
277
  # RuntimeError: grad can be implicitly created only for scalar outputs
 
272
  load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
273
  **model_kwargs,
274
  )
275
+
276
+ if cfg.flash_attention and not inference:
277
+ from axolotl.monkeypatch.llama_attn_hijack_flash import (
278
+ replace_llama_mlp_with_swiglu,
279
+ replace_llama_qkv_with_fused,
280
+ )
281
+
282
+ if cfg.flash_attn_fuse_mlp:
283
+ LOG.info("patching with SwiGLU")
284
+ replace_llama_mlp_with_swiglu(model)
285
+
286
+ if cfg.flash_attn_fuse_qkv:
287
+ LOG.info("patching with fused QKV")
288
+ replace_llama_qkv_with_fused(model)
289
  # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
290
  # This is a WIP, still an issue with the backward pass
291
  # RuntimeError: grad can be implicitly created only for scalar outputs
tests/e2e/test_fused_llama.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import tempfile
8
+ import unittest
9
+ from pathlib import Path
10
+
11
+ from transformers.utils import is_torch_bf16_gpu_available
12
+
13
+ from axolotl.cli import load_datasets
14
+ from axolotl.common.cli import TrainerCliArgs
15
+ from axolotl.train import train
16
+ from axolotl.utils.config import normalize_config
17
+ from axolotl.utils.dict import DictDefault
18
+
19
+ LOG = logging.getLogger("axolotl.tests.e2e")
20
+ os.environ["WANDB_DISABLED"] = "true"
21
+
22
+
23
+ class TestFusedLlama(unittest.TestCase):
24
+ """
25
+ Test case for Llama models using Fused layers
26
+ """
27
+
28
+ def test_lora_packing(self):
29
+ # pylint: disable=duplicate-code
30
+ output_dir = tempfile.mkdtemp()
31
+ cfg = DictDefault(
32
+ {
33
+ "base_model": "JackFram/llama-68m",
34
+ "base_model_config": "JackFram/llama-68m",
35
+ "flash_attention": True,
36
+ "flash_attn_fuse_qkv": True,
37
+ "flash_attn_fuse_mlp": True,
38
+ "sample_packing": True,
39
+ "sequence_len": 1024,
40
+ "load_in_8bit": True,
41
+ "val_set_size": 0.1,
42
+ "special_tokens": {
43
+ "unk_token": "<unk>",
44
+ "bos_token": "<s>",
45
+ "eos_token": "</s>",
46
+ },
47
+ "datasets": [
48
+ {
49
+ "path": "mhenrichsen/alpaca_2k_test",
50
+ "type": "alpaca",
51
+ },
52
+ ],
53
+ "num_epochs": 2,
54
+ "micro_batch_size": 2,
55
+ "gradient_accumulation_steps": 1,
56
+ "output_dir": output_dir,
57
+ "learning_rate": 0.00001,
58
+ "optimizer": "adamw_torch",
59
+ "lr_scheduler": "cosine",
60
+ "max_steps": 20,
61
+ "save_steps": 10,
62
+ "eval_steps": 10,
63
+ }
64
+ )
65
+ normalize_config(cfg)
66
+ cli_args = TrainerCliArgs()
67
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
68
+
69
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
70
+ assert (Path(output_dir) / "pytorch_model.bin").exists()
71
+
72
+ def test_fft_packing(self):
73
+ # pylint: disable=duplicate-code
74
+ output_dir = tempfile.mkdtemp()
75
+ cfg = DictDefault(
76
+ {
77
+ "base_model": "JackFram/llama-68m",
78
+ "base_model_config": "JackFram/llama-68m",
79
+ "flash_attention": True,
80
+ "flash_attn_fuse_qkv": True,
81
+ "flash_attn_fuse_mlp": True,
82
+ "sample_packing": True,
83
+ "sequence_len": 1024,
84
+ "val_set_size": 0.1,
85
+ "special_tokens": {
86
+ "unk_token": "<unk>",
87
+ "bos_token": "<s>",
88
+ "eos_token": "</s>",
89
+ },
90
+ "datasets": [
91
+ {
92
+ "path": "mhenrichsen/alpaca_2k_test",
93
+ "type": "alpaca",
94
+ },
95
+ ],
96
+ "num_epochs": 2,
97
+ "micro_batch_size": 2,
98
+ "gradient_accumulation_steps": 1,
99
+ "output_dir": output_dir,
100
+ "learning_rate": 0.00001,
101
+ "optimizer": "adamw_torch",
102
+ "lr_scheduler": "cosine",
103
+ "max_steps": 20,
104
+ "save_steps": 10,
105
+ "eval_steps": 10,
106
+ }
107
+ )
108
+ if is_torch_bf16_gpu_available():
109
+ cfg.bf16 = True
110
+ else:
111
+ cfg.fp16 = True
112
+ normalize_config(cfg)
113
+ cli_args = TrainerCliArgs()
114
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
115
+
116
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
117
+ assert (Path(output_dir) / "pytorch_model.bin").exists()