winglian commited on
Commit
16bb627
·
unverified ·
2 Parent(s): 06674a1 fd2c981

Merge pull request #92 from OpenAccess-AI-Collective/flash-optimum

Browse files
README.md CHANGED
@@ -421,6 +421,8 @@ optimizer:
421
  # specify weight decay
422
  weight_decay:
423
 
 
 
424
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
425
  xformers_attention:
426
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
 
421
  # specify weight decay
422
  weight_decay:
423
 
424
+ # whether to bettertransformers
425
+ flash_optimum:
426
  # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
427
  xformers_attention:
428
  # whether to use flash attention patch https://github.com/HazyResearch/flash-attention:
examples/pythia-12b/README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Pythia 12B
2
+
3
+ - Single-GPU A100 only (?)
4
+
5
+ ```shell
6
+ python scripts/finetune.py examples/pythia-12b/config.yml
7
+ ```
8
+
9
+ ⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️
examples/pythia-12b/config.yml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: EleutherAI/pythia-12b-deduped
2
+ base_model_config: EleutherAI/pythia-12b-deduped
3
+ base_model_ignore_patterns: pytorch* # prefer safetensors
4
+ model_type: GPTNeoXForCausalLM
5
+ tokenizer_type: AutoTokenizer
6
+ load_in_8bit: false
7
+ load_in_4bit: false
8
+ gptq: false
9
+ device_map: auto
10
+ datasets:
11
+ - path: vicgalle/alpaca-gpt4
12
+ type: alpaca
13
+ dataset_prepared_path: last_run_prepared
14
+ val_set_size: 0.05
15
+ adapter:
16
+ lora_model_dir:
17
+ sequence_len: 2048
18
+ max_packed_sequence_len: 2048
19
+ lora_r: 64
20
+ lora_alpha: 32
21
+ lora_dropout: 0.0
22
+ lora_target_modules:
23
+ lora_target_linear: true
24
+ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
25
+ wandb_project:
26
+ wandb_watch:
27
+ wandb_run_id:
28
+ wandb_log_model:
29
+ output_dir: ./pythia-12b
30
+ gradient_accumulation_steps: 1
31
+ micro_batch_size: 1
32
+ num_epochs: 5
33
+ learning_rate: 0.00003
34
+ optimizer: adamw_bnb_8bit
35
+ lr_scheduler: cosine
36
+ train_on_inputs: false
37
+ group_by_length: false
38
+ bf16: false
39
+ fp16: false
40
+ float16: true
41
+ tf32: true
42
+ flash_optimum: true
43
+ early_stopping_patience:
44
+ resume_from_checkpoint:
45
+ local_rank:
46
+ gradient_checkpointing: true
47
+ fsdp:
48
+ fsdp_config:
49
+ collator_pad_to_longest: true
requirements.txt CHANGED
@@ -11,6 +11,7 @@ sentencepiece
11
  wandb
12
  einops
13
  xformers
 
14
  # qlora things
15
  bert-score==0.3.13
16
  evaluate==0.4.0
 
11
  wandb
12
  einops
13
  xformers
14
+ optimum
15
  # qlora things
16
  bert-score==0.3.13
17
  evaluate==0.4.0
scripts/finetune.py CHANGED
@@ -12,13 +12,14 @@ from typing import Any, Dict, List, Optional, Union
12
  import fire
13
  import torch
14
  import yaml
 
 
 
15
  from transformers import GenerationConfig, TextStreamer
16
 
17
- from axolotl.utils.data import load_prepare_datasets
18
  from axolotl.utils.dict import DictDefault
19
  from axolotl.utils.models import load_model, load_tokenizer
20
-
21
- # add src to the pythonpath so we don't need to pip install this
22
  from axolotl.utils.tokenization import check_dataset_labels
23
  from axolotl.utils.trainer import setup_trainer
24
  from axolotl.utils.validation import validate_config
@@ -217,9 +218,20 @@ def train(
217
  if (
218
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
219
  ): # don't need to load dataset for these
220
- train_dataset, eval_dataset = load_prepare_datasets(
221
- tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
222
- )
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  if cfg.debug or "debug" in kwargs:
225
  logging.info("check_dataset_labels...")
@@ -285,12 +297,15 @@ def train(
285
 
286
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
287
  if cfg.local_rank == 0:
 
 
 
 
 
 
 
288
  signal.signal(
289
- signal.SIGINT,
290
- lambda signal, frame: (
291
- model.save_pretrained(cfg.output_dir),
292
- sys.exit(0),
293
- ),
294
  )
295
 
296
  logging.info("Starting trainer...")
@@ -313,13 +328,21 @@ def train(
313
 
314
  if not Path(cfg.output_dir).is_dir():
315
  os.makedirs(cfg.output_dir, exist_ok=True)
316
- trainer.train(resume_from_checkpoint=resume_from_checkpoint)
 
 
 
 
 
 
317
 
318
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
319
 
320
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
321
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
322
  if cfg.local_rank == 0:
 
 
323
  model.save_pretrained(cfg.output_dir)
324
 
325
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
 
12
  import fire
13
  import torch
14
  import yaml
15
+
16
+ # add src to the pythonpath so we don't need to pip install this
17
+ from optimum.bettertransformer import BetterTransformer
18
  from transformers import GenerationConfig, TextStreamer
19
 
20
+ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
21
  from axolotl.utils.dict import DictDefault
22
  from axolotl.utils.models import load_model, load_tokenizer
 
 
23
  from axolotl.utils.tokenization import check_dataset_labels
24
  from axolotl.utils.trainer import setup_trainer
25
  from axolotl.utils.validation import validate_config
 
218
  if (
219
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
220
  ): # don't need to load dataset for these
221
+ if not cfg.pretraining_dataset:
222
+ train_dataset, eval_dataset = load_prepare_datasets(
223
+ tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
224
+ )
225
+ else:
226
+ train_dataset = load_pretraining_dataset(
227
+ cfg.pretraining_dataset,
228
+ tokenizer,
229
+ max_tokens=cfg.sequence_len,
230
+ seed=cfg.seed,
231
+ )
232
+ # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
233
+ train_dataset = train_dataset.with_format("torch")
234
+ eval_dataset = None
235
 
236
  if cfg.debug or "debug" in kwargs:
237
  logging.info("check_dataset_labels...")
 
297
 
298
  # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
299
  if cfg.local_rank == 0:
300
+
301
+ def terminate_handler(_, __, model):
302
+ if cfg.flash_optimum:
303
+ model = BetterTransformer.reverse(model)
304
+ model.save_pretrained(cfg.output_dir)
305
+ sys.exit(0)
306
+
307
  signal.signal(
308
+ signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model)
 
 
 
 
309
  )
310
 
311
  logging.info("Starting trainer...")
 
328
 
329
  if not Path(cfg.output_dir).is_dir():
330
  os.makedirs(cfg.output_dir, exist_ok=True)
331
+ if cfg.flash_optimum:
332
+ with torch.backends.cuda.sdp_kernel(
333
+ enable_flash=True, enable_math=True, enable_mem_efficient=True
334
+ ):
335
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
336
+ else:
337
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
338
 
339
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
340
 
341
  # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
342
  # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
343
  if cfg.local_rank == 0:
344
+ if cfg.flash_optimum:
345
+ model = BetterTransformer.reverse(model)
346
  model.save_pretrained(cfg.output_dir)
347
 
348
  # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
src/axolotl/utils/callbacks.py CHANGED
@@ -2,13 +2,14 @@
2
 
3
  import os
4
 
 
5
  from transformers import (
6
  TrainerCallback,
7
  TrainerControl,
8
  TrainerState,
9
  TrainingArguments,
10
  )
11
- from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
12
 
13
 
14
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
@@ -30,3 +31,39 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
30
  kwargs["model"].save_pretrained(peft_model_path)
31
 
32
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import os
4
 
5
+ from optimum.bettertransformer import BetterTransformer
6
  from transformers import (
7
  TrainerCallback,
8
  TrainerControl,
9
  TrainerState,
10
  TrainingArguments,
11
  )
12
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
13
 
14
 
15
  class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
 
31
  kwargs["model"].save_pretrained(peft_model_path)
32
 
33
  return control
34
+
35
+
36
+ class SaveBetterTransformerModelCallback(
37
+ TrainerCallback
38
+ ): # pylint: disable=too-few-public-methods
39
+ """Callback to save the BetterTransformer wrapped model"""
40
+
41
+ def on_step_end(
42
+ self,
43
+ args: TrainingArguments,
44
+ state: TrainerState,
45
+ control: TrainerControl,
46
+ **kwargs,
47
+ ):
48
+ # Save
49
+ if (
50
+ args.save_strategy == IntervalStrategy.STEPS
51
+ and args.save_steps > 0
52
+ and state.global_step % args.save_steps == 0
53
+ ):
54
+ control.should_save = True
55
+
56
+ if control.should_save:
57
+ checkpoint_folder = os.path.join(
58
+ args.output_dir,
59
+ f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
60
+ )
61
+
62
+ model = BetterTransformer.reverse(kwargs["model"])
63
+ model.save_pretrained(checkpoint_folder)
64
+ # FIXME - need to cleanup old checkpoints
65
+
66
+ # since we're saving here, we don't need the trainer loop to attempt to save too b/c
67
+ # the trainer will raise an exception since it can't save a BetterTransformer wrapped model
68
+ control.should_save = False
69
+ return control
src/axolotl/utils/data.py CHANGED
@@ -1,10 +1,11 @@
1
  """Module containing data utilities"""
2
-
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
 
8
  from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
9
  from huggingface_hub import hf_hub_download
10
  from transformers import PreTrainedTokenizerBase
@@ -394,8 +395,127 @@ def load_prepare_datasets(
394
  index=cfg.dataset_shard_idx,
395
  )
396
 
397
- dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
398
- train_dataset = dataset["train"]
399
- eval_dataset = dataset["test"]
 
 
 
 
400
 
401
  return train_dataset, eval_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Module containing data utilities"""
2
+ import functools
3
  import logging
4
  from hashlib import md5
5
  from pathlib import Path
6
  from typing import List, Tuple, Union
7
 
8
+ import torch
9
  from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PreTrainedTokenizerBase
 
395
  index=cfg.dataset_shard_idx,
396
  )
397
 
398
+ if cfg.val_set_size:
399
+ dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
400
+ train_dataset = dataset["train"]
401
+ eval_dataset = dataset["test"]
402
+ else:
403
+ train_dataset = dataset
404
+ eval_dataset = None
405
 
406
  return train_dataset, eval_dataset
407
+
408
+
409
+ def encode_pretraining(tokenizer, max_tokens, examples):
410
+ res = tokenizer(
411
+ examples["text"],
412
+ truncation=True,
413
+ max_length=max_tokens - 2,
414
+ add_special_tokens=True,
415
+ )
416
+ # Convert to PyTorch tensors
417
+ input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
418
+ attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
419
+ new_input_ids = []
420
+ new_attention_mask = []
421
+ # Append EOS and PAD tokens to input_ids, and correct attention_mask
422
+ for i, _ in enumerate(input_ids):
423
+ input_ids[i] = torch.cat(
424
+ (
425
+ input_ids[i],
426
+ torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
427
+ ),
428
+ dim=0,
429
+ )
430
+ attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
431
+
432
+ # Concatenate tokens so that their lengths are less than max_tokens
433
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
434
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
435
+
436
+ for ids, mask in zip(input_ids, attention_mask):
437
+ if buffer_input_ids.numel() == max_tokens:
438
+ new_input_ids.append(buffer_input_ids)
439
+ new_attention_mask.append(buffer_attention_mask)
440
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
441
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
442
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
443
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
444
+ elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
445
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
446
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
447
+ else:
448
+ buffer_input_ids = torch.cat(
449
+ (
450
+ buffer_input_ids,
451
+ torch.full(
452
+ (max_tokens - buffer_input_ids.numel(),),
453
+ tokenizer.pad_token_id,
454
+ dtype=torch.long,
455
+ ),
456
+ ),
457
+ dim=0,
458
+ )
459
+ buffer_attention_mask = torch.cat(
460
+ (
461
+ buffer_attention_mask,
462
+ torch.full(
463
+ (max_tokens - buffer_attention_mask.numel(),),
464
+ 0,
465
+ dtype=torch.long,
466
+ ),
467
+ ),
468
+ dim=0,
469
+ )
470
+ new_input_ids.append(buffer_input_ids)
471
+ new_attention_mask.append(buffer_attention_mask)
472
+ buffer_input_ids = torch.tensor([], dtype=torch.long)
473
+ buffer_attention_mask = torch.tensor([], dtype=torch.long)
474
+
475
+ buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
476
+ buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
477
+
478
+ if buffer_input_ids.numel() > 0: # for any leftover tokens
479
+ while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
480
+ buffer_input_ids = torch.cat(
481
+ (
482
+ buffer_input_ids,
483
+ torch.full(
484
+ (max_tokens - buffer_input_ids.numel(),),
485
+ tokenizer.pad_token_id,
486
+ dtype=torch.long,
487
+ ),
488
+ ),
489
+ dim=0,
490
+ )
491
+ buffer_attention_mask = torch.cat(
492
+ (
493
+ buffer_attention_mask,
494
+ torch.full(
495
+ (max_tokens - buffer_attention_mask.numel(),),
496
+ 0,
497
+ dtype=torch.long,
498
+ ),
499
+ ),
500
+ dim=0,
501
+ )
502
+ new_input_ids.append(buffer_input_ids)
503
+ new_attention_mask.append(buffer_attention_mask)
504
+
505
+ ret = {
506
+ "input_ids": [seq.tolist() for seq in new_input_ids],
507
+ "labels": [seq.tolist() for seq in new_input_ids],
508
+ "attention_mask": [seq.tolist() for seq in new_attention_mask],
509
+ }
510
+
511
+ logging.debug(len(ret["input_ids"]))
512
+ return ret
513
+
514
+
515
+ def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
516
+ encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
517
+ dataset = load_dataset(path, streaming=True, split="train")
518
+ dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
519
+ # TODO dynamically figure out which columns/features to remove
520
+ dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
521
+ return dataset
src/axolotl/utils/models.py CHANGED
@@ -10,8 +10,9 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
 
13
  from transformers import PreTrainedModel # noqa: F401
14
- from transformers import ( # noqa: F401
15
  AutoConfig,
16
  AutoModelForCausalLM,
17
  AutoTokenizer,
@@ -121,9 +122,9 @@ def load_model(
121
  logging.info("patching with xpos rope")
122
  replace_llama_rope_with_xpos_rope()
123
 
124
- if cfg.bf16:
125
  torch_dtype = torch.bfloat16
126
- elif cfg.load_in_8bit or cfg.fp16:
127
  torch_dtype = torch.float16
128
  else:
129
  torch_dtype = torch.float32
@@ -287,6 +288,15 @@ def load_model(
287
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
288
  model.resize_token_embeddings(embeddings_len)
289
 
 
 
 
 
 
 
 
 
 
290
  if not cfg.gptq and (
291
  (cfg.adapter == "lora" and load_in_8bit)
292
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
@@ -332,6 +342,9 @@ def load_model(
332
  logging.warning("there are no parameters that require gradient updates")
333
  model.config.use_cache = False
334
 
 
 
 
335
  # TODO resume_from_checkpoint handling
336
  return model, lora_config
337
 
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
+ from optimum.bettertransformer import BetterTransformer
14
  from transformers import PreTrainedModel # noqa: F401
15
+ from transformers import (
16
  AutoConfig,
17
  AutoModelForCausalLM,
18
  AutoTokenizer,
 
122
  logging.info("patching with xpos rope")
123
  replace_llama_rope_with_xpos_rope()
124
 
125
+ if cfg.bf16 or cfg.bfloat16:
126
  torch_dtype = torch.bfloat16
127
+ elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
128
  torch_dtype = torch.float16
129
  else:
130
  torch_dtype = torch.float32
 
288
  embeddings_len = math.ceil(len(tokenizer) / 32) * 32
289
  model.resize_token_embeddings(embeddings_len)
290
 
291
+ if (
292
+ hasattr(model.config, "max_position_embeddings")
293
+ and cfg.sequence_len >= model.config.max_position_embeddings
294
+ ):
295
+ logging.warning(
296
+ f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
297
+ )
298
+ model.config.max_position_embeddings = cfg.sequence_len
299
+
300
  if not cfg.gptq and (
301
  (cfg.adapter == "lora" and load_in_8bit)
302
  or (cfg.adapter == "qlora" and cfg.load_in_4bit)
 
342
  logging.warning("there are no parameters that require gradient updates")
343
  model.config.use_cache = False
344
 
345
+ if cfg.flash_optimum:
346
+ model = BetterTransformer.transform(model)
347
+
348
  # TODO resume_from_checkpoint handling
349
  return model, lora_config
350
 
src/axolotl/utils/trainer.py CHANGED
@@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR
16
  from transformers import EarlyStoppingCallback, Trainer
17
  from transformers.trainer_pt_utils import get_parameter_names
18
 
19
- from axolotl.utils.callbacks import SavePeftModelCallback
 
 
 
20
  from axolotl.utils.schedulers import InterpolatingLogScheduler
21
 
22
 
@@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
228
  ]: # only save in rank 0
229
  callbacks.append(SavePeftModelCallback)
230
 
 
 
 
231
  data_collator_kwargs = {
232
  "padding": True,
233
  }
 
16
  from transformers import EarlyStoppingCallback, Trainer
17
  from transformers.trainer_pt_utils import get_parameter_names
18
 
19
+ from axolotl.utils.callbacks import (
20
+ SaveBetterTransformerModelCallback,
21
+ SavePeftModelCallback,
22
+ )
23
  from axolotl.utils.schedulers import InterpolatingLogScheduler
24
 
25
 
 
231
  ]: # only save in rank 0
232
  callbacks.append(SavePeftModelCallback)
233
 
234
+ if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
235
+ callbacks.append(SaveBetterTransformerModelCallback)
236
+
237
  data_collator_kwargs = {
238
  "padding": True,
239
  }
src/axolotl/utils/validation.py CHANGED
@@ -2,6 +2,8 @@
2
 
3
  import logging
4
 
 
 
5
 
6
  def validate_config(cfg):
7
  if cfg.gradient_accumulation_steps and cfg.batch_size:
@@ -62,7 +64,37 @@ def validate_config(cfg):
62
  ) and cfg.gradient_checkpointing:
63
  raise ValueError("gradient_checkpointing is not supported for MPT models")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # TODO
66
  # MPT 7b
67
  # https://github.com/facebookresearch/bitsandbytes/issues/25
68
- # no 8bit adamw w bf16
 
 
 
 
 
 
 
 
2
 
3
  import logging
4
 
5
+ import torch
6
+
7
 
8
  def validate_config(cfg):
9
  if cfg.gradient_accumulation_steps and cfg.batch_size:
 
64
  ) and cfg.gradient_checkpointing:
65
  raise ValueError("gradient_checkpointing is not supported for MPT models")
66
 
67
+ if cfg.flash_optimum is True:
68
+ if cfg.adapter:
69
+ logging.warning(
70
+ "BetterTransformers probably doesn't work with PEFT adapters"
71
+ )
72
+ if cfg.fp16 or cfg.bf16:
73
+ raise ValueError("AMP is not supported with BetterTransformer")
74
+ if cfg.float16 is not True and cfg.bloat16 is not True:
75
+ logging.warning(
76
+ "You should probably set bfloat16 or float16 to true to "
77
+ "load the model in float16 for BetterTransformers"
78
+ )
79
+ if int(torch.__version__.split(".")[0]) < 2:
80
+ logging.warning("torch>=2.0.0 required")
81
+ raise ValueError(
82
+ f"flash_optimum for BetterTransformers may not be used with {torch.__version__}"
83
+ )
84
+
85
+ if cfg.pretraining_dataset and cfg.group_by_length:
86
+ logging.warning(
87
+ "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
88
+ )
89
+
90
  # TODO
91
  # MPT 7b
92
  # https://github.com/facebookresearch/bitsandbytes/issues/25
93
+ # no 8bit adaAmw w bf16
94
+
95
+ # GPT-NeoX
96
+ # evals broken when extending context len
97
+ # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
98
+ # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product
99
+ # attention_mask = causal_mask + attention_mask
100
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3
tests/test_validation.py CHANGED
@@ -212,3 +212,54 @@ class ValidationTest(unittest.TestCase):
212
 
213
  with pytest.raises(ValueError, match=regex_exp):
214
  validate_config(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  with pytest.raises(ValueError, match=regex_exp):
214
  validate_config(cfg)
215
+
216
+ def test_flash_optimum(self):
217
+ cfg = DictDefault(
218
+ {
219
+ "flash_optimum": True,
220
+ "adapter": "lora",
221
+ }
222
+ )
223
+
224
+ with self._caplog.at_level(logging.WARNING):
225
+ validate_config(cfg)
226
+ assert any(
227
+ "BetterTransformers probably doesn't work with PEFT adapters"
228
+ in record.message
229
+ for record in self._caplog.records
230
+ )
231
+
232
+ cfg = DictDefault(
233
+ {
234
+ "flash_optimum": True,
235
+ }
236
+ )
237
+
238
+ with self._caplog.at_level(logging.WARNING):
239
+ validate_config(cfg)
240
+ assert any(
241
+ "probably set bfloat16 or float16" in record.message
242
+ for record in self._caplog.records
243
+ )
244
+
245
+ cfg = DictDefault(
246
+ {
247
+ "flash_optimum": True,
248
+ "fp16": True,
249
+ }
250
+ )
251
+ regex_exp = r".*AMP is not supported.*"
252
+
253
+ with pytest.raises(ValueError, match=regex_exp):
254
+ validate_config(cfg)
255
+
256
+ cfg = DictDefault(
257
+ {
258
+ "flash_optimum": True,
259
+ "bf16": True,
260
+ }
261
+ )
262
+ regex_exp = r".*AMP is not supported.*"
263
+
264
+ with pytest.raises(ValueError, match=regex_exp):
265
+ validate_config(cfg)