winglian commited on
Commit
da265dd
·
unverified ·
1 Parent(s): e07347b

fix for accelerate env var for auto bf16, add new base image and expand torch_cuda_arch_list support (#1413)

Browse files
.github/workflows/base.yml CHANGED
@@ -16,17 +16,22 @@ jobs:
16
  cuda_version: 11.8.0
17
  python_version: "3.10"
18
  pytorch: 2.1.2
19
- torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
20
  - cuda: "121"
21
  cuda_version: 12.1.0
22
  python_version: "3.10"
23
  pytorch: 2.1.2
24
- torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
25
  - cuda: "121"
26
  cuda_version: 12.1.0
27
  python_version: "3.11"
28
  pytorch: 2.1.2
29
- torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
 
 
 
 
 
30
  steps:
31
  - name: Checkout
32
  uses: actions/checkout@v3
 
16
  cuda_version: 11.8.0
17
  python_version: "3.10"
18
  pytorch: 2.1.2
19
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
20
  - cuda: "121"
21
  cuda_version: 12.1.0
22
  python_version: "3.10"
23
  pytorch: 2.1.2
24
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
25
  - cuda: "121"
26
  cuda_version: 12.1.0
27
  python_version: "3.11"
28
  pytorch: 2.1.2
29
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
30
+ - cuda: "121"
31
+ cuda_version: 12.1.0
32
+ python_version: "3.11"
33
+ pytorch: 2.2.1
34
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
35
  steps:
36
  - name: Checkout
37
  uses: actions/checkout@v3
src/axolotl/utils/trainer.py CHANGED
@@ -11,6 +11,7 @@ import torch.cuda
11
  from accelerate.logging import get_logger
12
  from datasets import set_caching_enabled
13
  from torch.utils.data import DataLoader, RandomSampler
 
14
 
15
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
16
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
@@ -324,6 +325,11 @@ def prepare_optim_env(cfg):
324
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
325
  os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
326
 
 
 
 
 
 
327
 
328
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
329
  if cfg.rl in ["dpo", "ipo", "kto_pair"]:
 
11
  from accelerate.logging import get_logger
12
  from datasets import set_caching_enabled
13
  from torch.utils.data import DataLoader, RandomSampler
14
+ from transformers.utils import is_torch_bf16_gpu_available
15
 
16
  from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
17
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
 
325
  os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
326
  os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
327
 
328
+ if (cfg.bf16 == "auto" and is_torch_bf16_gpu_available()) or cfg.bf16 is True:
329
+ os.environ["ACCELERATE_MIXED_PRECISION"] = "bf16"
330
+ elif cfg.fp16:
331
+ os.environ["ACCELERATE_MIXED_PRECISION"] = "fp16"
332
+
333
 
334
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
335
  if cfg.rl in ["dpo", "ipo", "kto_pair"]: