Nanobit commited on
Commit
697c50d
·
unverified ·
1 Parent(s): 90e0d67

Feat: Allow usage of native Mistral FA when no sample_packing (#669)

Browse files

* Allow usage of native Mistral FA when no sample_packing

* fix: do not apply custom patch when sample_pack off

* chore: lint

* chore: pin transformer to v4.35.0.dev0

* fix: split sample_packing to separate test

requirements.txt CHANGED
@@ -4,7 +4,7 @@ torch==2.0.1
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
- transformers @ git+https://github.com/huggingface/transformers.git@5e11d72d4d0939138fbabfebe9a69d2061519547
8
  bitsandbytes>=0.41.1
9
  accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10
  deepspeed
 
4
  auto-gptq
5
  packaging
6
  peft @ git+https://github.com/huggingface/peft.git
7
+ transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
8
  bitsandbytes>=0.41.1
9
  accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
10
  deepspeed
src/axolotl/utils/models.py CHANGED
@@ -149,7 +149,7 @@ def load_model(
149
  # Note: This might overwrite previous additional_special_tokens
150
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
151
 
152
- if cfg.is_mistral_derived_model and cfg.flash_attention:
153
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
154
  replace_mistral_attn_with_flash_attn,
155
  )
@@ -200,7 +200,11 @@ def load_model(
200
  )
201
  # sample packing uses custom FA2 patch
202
  if cfg.flash_attention and not cfg.sample_packing:
203
- if cfg.is_llama_derived_model or cfg.is_falcon_derived_model:
 
 
 
 
204
  model_kwargs["use_flash_attention_2"] = True
205
  try:
206
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
 
149
  # Note: This might overwrite previous additional_special_tokens
150
  tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
151
 
152
+ if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
153
  from axolotl.monkeypatch.mistral_attn_hijack_flash import (
154
  replace_mistral_attn_with_flash_attn,
155
  )
 
200
  )
201
  # sample packing uses custom FA2 patch
202
  if cfg.flash_attention and not cfg.sample_packing:
203
+ if (
204
+ cfg.is_llama_derived_model
205
+ or cfg.is_falcon_derived_model
206
+ or cfg.is_mistral_derived_model
207
+ ):
208
  model_kwargs["use_flash_attention_2"] = True
209
  try:
210
  if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
tests/e2e/test_mistral.py CHANGED
@@ -71,53 +71,6 @@ class TestMistral(unittest.TestCase):
71
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
  assert (Path(output_dir) / "adapter_model.bin").exists()
73
 
74
- def test_lora_packing(self):
75
- # pylint: disable=duplicate-code
76
- output_dir = tempfile.mkdtemp()
77
- cfg = DictDefault(
78
- {
79
- "base_model": "openaccess-ai-collective/tiny-mistral",
80
- "base_model_config": "openaccess-ai-collective/tiny-mistral",
81
- "flash_attention": True,
82
- "sample_packing": True,
83
- "sequence_len": 1024,
84
- "load_in_8bit": True,
85
- "adapter": "lora",
86
- "lora_r": 32,
87
- "lora_alpha": 64,
88
- "lora_dropout": 0.05,
89
- "lora_target_linear": True,
90
- "val_set_size": 0.1,
91
- "special_tokens": {
92
- "unk_token": "<unk>",
93
- "bos_token": "<s>",
94
- "eos_token": "</s>",
95
- },
96
- "datasets": [
97
- {
98
- "path": "mhenrichsen/alpaca_2k_test",
99
- "type": "alpaca",
100
- },
101
- ],
102
- "num_epochs": 2,
103
- "micro_batch_size": 2,
104
- "gradient_accumulation_steps": 1,
105
- "output_dir": output_dir,
106
- "learning_rate": 0.00001,
107
- "optimizer": "adamw_torch",
108
- "lr_scheduler": "cosine",
109
- "max_steps": 20,
110
- "save_steps": 10,
111
- "eval_steps": 10,
112
- }
113
- )
114
- normalize_config(cfg)
115
- cli_args = TrainerCliArgs()
116
- dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
117
-
118
- train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
119
- assert (Path(output_dir) / "adapter_model.bin").exists()
120
-
121
  def test_ft(self):
122
  # pylint: disable=duplicate-code
123
  output_dir = tempfile.mkdtemp()
@@ -161,48 +114,3 @@ class TestMistral(unittest.TestCase):
161
 
162
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
163
  assert (Path(output_dir) / "pytorch_model.bin").exists()
164
-
165
- def test_ft_packing(self):
166
- # pylint: disable=duplicate-code
167
- output_dir = tempfile.mkdtemp()
168
- cfg = DictDefault(
169
- {
170
- "base_model": "openaccess-ai-collective/tiny-mistral",
171
- "base_model_config": "openaccess-ai-collective/tiny-mistral",
172
- "flash_attention": True,
173
- "sample_packing": True,
174
- "sequence_len": 1024,
175
- "val_set_size": 0.1,
176
- "special_tokens": {
177
- "unk_token": "<unk>",
178
- "bos_token": "<s>",
179
- "eos_token": "</s>",
180
- },
181
- "datasets": [
182
- {
183
- "path": "mhenrichsen/alpaca_2k_test",
184
- "type": "alpaca",
185
- },
186
- ],
187
- "num_epochs": 2,
188
- "micro_batch_size": 2,
189
- "gradient_accumulation_steps": 1,
190
- "output_dir": output_dir,
191
- "learning_rate": 0.00001,
192
- "optimizer": "adamw_torch",
193
- "lr_scheduler": "cosine",
194
- "max_steps": 20,
195
- "save_steps": 10,
196
- "eval_steps": 10,
197
- }
198
- )
199
- if is_torch_bf16_gpu_available():
200
- cfg.bf16 = True
201
- else:
202
- cfg.fp16 = True
203
- normalize_config(cfg)
204
- cli_args = TrainerCliArgs()
205
- dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
206
-
207
- train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
208
- assert (Path(output_dir) / "pytorch_model.bin").exists()
 
71
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
72
  assert (Path(output_dir) / "adapter_model.bin").exists()
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def test_ft(self):
75
  # pylint: disable=duplicate-code
76
  output_dir = tempfile.mkdtemp()
 
114
 
115
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
116
  assert (Path(output_dir) / "pytorch_model.bin").exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/e2e/test_mistral_samplepack.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import tempfile
8
+ import unittest
9
+ from pathlib import Path
10
+
11
+ from transformers.utils import is_torch_bf16_gpu_available
12
+
13
+ from axolotl.cli import load_datasets
14
+ from axolotl.common.cli import TrainerCliArgs
15
+ from axolotl.train import train
16
+ from axolotl.utils.config import normalize_config
17
+ from axolotl.utils.dict import DictDefault
18
+
19
+ LOG = logging.getLogger("axolotl.tests.e2e")
20
+ os.environ["WANDB_DISABLED"] = "true"
21
+
22
+
23
+ class TestMistral(unittest.TestCase):
24
+ """
25
+ Test case for Llama models using LoRA
26
+ """
27
+
28
+ def test_lora_packing(self):
29
+ # pylint: disable=duplicate-code
30
+ output_dir = tempfile.mkdtemp()
31
+ cfg = DictDefault(
32
+ {
33
+ "base_model": "openaccess-ai-collective/tiny-mistral",
34
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
35
+ "flash_attention": True,
36
+ "sample_packing": True,
37
+ "sequence_len": 1024,
38
+ "load_in_8bit": True,
39
+ "adapter": "lora",
40
+ "lora_r": 32,
41
+ "lora_alpha": 64,
42
+ "lora_dropout": 0.05,
43
+ "lora_target_linear": True,
44
+ "val_set_size": 0.1,
45
+ "special_tokens": {
46
+ "unk_token": "<unk>",
47
+ "bos_token": "<s>",
48
+ "eos_token": "</s>",
49
+ },
50
+ "datasets": [
51
+ {
52
+ "path": "mhenrichsen/alpaca_2k_test",
53
+ "type": "alpaca",
54
+ },
55
+ ],
56
+ "num_epochs": 2,
57
+ "micro_batch_size": 2,
58
+ "gradient_accumulation_steps": 1,
59
+ "output_dir": output_dir,
60
+ "learning_rate": 0.00001,
61
+ "optimizer": "adamw_torch",
62
+ "lr_scheduler": "cosine",
63
+ "max_steps": 20,
64
+ "save_steps": 10,
65
+ "eval_steps": 10,
66
+ }
67
+ )
68
+ normalize_config(cfg)
69
+ cli_args = TrainerCliArgs()
70
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
71
+
72
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
73
+ assert (Path(output_dir) / "adapter_model.bin").exists()
74
+
75
+ def test_ft_packing(self):
76
+ # pylint: disable=duplicate-code
77
+ output_dir = tempfile.mkdtemp()
78
+ cfg = DictDefault(
79
+ {
80
+ "base_model": "openaccess-ai-collective/tiny-mistral",
81
+ "base_model_config": "openaccess-ai-collective/tiny-mistral",
82
+ "flash_attention": True,
83
+ "sample_packing": True,
84
+ "sequence_len": 1024,
85
+ "val_set_size": 0.1,
86
+ "special_tokens": {
87
+ "unk_token": "<unk>",
88
+ "bos_token": "<s>",
89
+ "eos_token": "</s>",
90
+ },
91
+ "datasets": [
92
+ {
93
+ "path": "mhenrichsen/alpaca_2k_test",
94
+ "type": "alpaca",
95
+ },
96
+ ],
97
+ "num_epochs": 2,
98
+ "micro_batch_size": 2,
99
+ "gradient_accumulation_steps": 1,
100
+ "output_dir": output_dir,
101
+ "learning_rate": 0.00001,
102
+ "optimizer": "adamw_torch",
103
+ "lr_scheduler": "cosine",
104
+ "max_steps": 20,
105
+ "save_steps": 10,
106
+ "eval_steps": 10,
107
+ }
108
+ )
109
+ if is_torch_bf16_gpu_available():
110
+ cfg.bf16 = True
111
+ else:
112
+ cfg.fp16 = True
113
+ normalize_config(cfg)
114
+ cli_args = TrainerCliArgs()
115
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
116
+
117
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
118
+ assert (Path(output_dir) / "pytorch_model.bin").exists()