winglian commited on
Commit
788649f
Β·
unverified Β·
1 Parent(s): 9be92d1

attempt to also run e2e tests that needs gpus (#1070)

Browse files

* attempt to also run e2e tests that needs gpus

* fix stray quote

* checkout specific github ref

* dockerfile for tests with proper checkout

ensure wandb is dissabled for docker pytests
clear wandb env after testing
clear wandb env after testing
make sure to provide a default val for pop
tryin skipping wandb validation tests
explicitly disable wandb in the e2e tests
explicitly report_to None to see if that fixes the docker e2e tests
split gpu from non-gpu unit tests
skip bf16 check in test for now
build docker w/o cache since it uses branch name ref
revert some changes now that caching is fixed
skip bf16 check if on gpu w support

* pytest skip for auto-gptq requirements

* skip mamba tests for now, split multipack and non packed lora llama tests

* split tests that use monkeypatches

* fix relative import for prev commit

* move other tests using monkeypatches to the correct run

.github/workflows/tests-docker.yml CHANGED
@@ -36,11 +36,19 @@ jobs:
36
  PYTORCH_VERSION="${{ matrix.pytorch }}"
37
  # Build the Docker image
38
  docker build . \
39
- --file ./docker/Dockerfile \
40
  --build-arg BASE_TAG=$BASE_TAG \
41
  --build-arg CUDA=$CUDA \
 
42
  --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
43
- --tag test-axolotl
 
44
  - name: Unit Tests w docker image
45
  run: |
46
  docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
 
 
 
 
 
 
 
36
  PYTORCH_VERSION="${{ matrix.pytorch }}"
37
  # Build the Docker image
38
  docker build . \
39
+ --file ./docker/Dockerfile-tests \
40
  --build-arg BASE_TAG=$BASE_TAG \
41
  --build-arg CUDA=$CUDA \
42
+ --build-arg GITHUB_REF=$GITHUB_REF \
43
  --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
44
+ --tag test-axolotl \
45
+ --no-cache
46
  - name: Unit Tests w docker image
47
  run: |
48
  docker run --rm test-axolotl pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
49
+ - name: GPU Unit Tests w docker image
50
+ run: |
51
+ docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest --ignore=tests/e2e/patched/ /workspace/axolotl/tests/e2e/
52
+ - name: GPU Unit Tests monkeypatched w docker image
53
+ run: |
54
+ docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm test-axolotl pytest /workspace/axolotl/tests/e2e/patched/
docker/Dockerfile-tests ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_TAG=main-base
2
+ FROM winglian/axolotl-base:$BASE_TAG
3
+
4
+ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5
+ ARG AXOLOTL_EXTRAS=""
6
+ ARG CUDA="118"
7
+ ENV BNB_CUDA_VERSION=$CUDA
8
+ ARG PYTORCH_VERSION="2.0.1"
9
+ ARG GITHUB_REF="main"
10
+
11
+ ENV PYTORCH_VERSION=$PYTORCH_VERSION
12
+
13
+ RUN apt-get update && \
14
+ apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
15
+
16
+ WORKDIR /workspace
17
+
18
+ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
19
+
20
+ WORKDIR /workspace/axolotl
21
+
22
+ RUN git fetch origin +$GITHUB_REF && \
23
+ git checkout FETCH_HEAD
24
+
25
+ # If AXOLOTL_EXTRAS is set, append it in brackets
26
+ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
27
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
28
+ else \
29
+ pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
30
+ fi
31
+
32
+ # So we can test the Docker image
33
+ RUN pip install pytest
34
+
35
+ # fix so that git fetch/pull from remote works
36
+ RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
37
+ git config --get remote.origin.fetch
38
+
39
+ # helper for huggingface-login cli
40
+ RUN git config --global credential.helper store
tests/e2e/patched/__init__.py ADDED
File without changes
tests/e2e/{test_fused_llama.py β†’ patched/test_fused_llama.py} RENAMED
@@ -15,7 +15,7 @@ from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
- from .utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
 
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
+ from ..utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
tests/e2e/patched/test_lora_llama_multipack.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ import pytest
11
+ from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
12
+
13
+ from axolotl.cli import load_datasets
14
+ from axolotl.common.cli import TrainerCliArgs
15
+ from axolotl.train import train
16
+ from axolotl.utils.config import normalize_config
17
+ from axolotl.utils.dict import DictDefault
18
+
19
+ from ..utils import with_temp_dir
20
+
21
+ LOG = logging.getLogger("axolotl.tests.e2e")
22
+ os.environ["WANDB_DISABLED"] = "true"
23
+
24
+
25
+ class TestLoraLlama(unittest.TestCase):
26
+ """
27
+ Test case for Llama models using LoRA w multipack
28
+ """
29
+
30
+ @with_temp_dir
31
+ def test_lora_packing(self, temp_dir):
32
+ # pylint: disable=duplicate-code
33
+ cfg = DictDefault(
34
+ {
35
+ "base_model": "JackFram/llama-68m",
36
+ "tokenizer_type": "LlamaTokenizer",
37
+ "sequence_len": 1024,
38
+ "sample_packing": True,
39
+ "flash_attention": True,
40
+ "load_in_8bit": True,
41
+ "adapter": "lora",
42
+ "lora_r": 32,
43
+ "lora_alpha": 64,
44
+ "lora_dropout": 0.05,
45
+ "lora_target_linear": True,
46
+ "val_set_size": 0.1,
47
+ "special_tokens": {
48
+ "unk_token": "<unk>",
49
+ "bos_token": "<s>",
50
+ "eos_token": "</s>",
51
+ },
52
+ "datasets": [
53
+ {
54
+ "path": "mhenrichsen/alpaca_2k_test",
55
+ "type": "alpaca",
56
+ },
57
+ ],
58
+ "num_epochs": 2,
59
+ "micro_batch_size": 8,
60
+ "gradient_accumulation_steps": 1,
61
+ "output_dir": temp_dir,
62
+ "learning_rate": 0.00001,
63
+ "optimizer": "adamw_torch",
64
+ "lr_scheduler": "cosine",
65
+ }
66
+ )
67
+ if is_torch_bf16_gpu_available():
68
+ cfg.bf16 = True
69
+ else:
70
+ cfg.fp16 = True
71
+
72
+ normalize_config(cfg)
73
+ cli_args = TrainerCliArgs()
74
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
75
+
76
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
77
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
78
+
79
+ @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
80
+ @with_temp_dir
81
+ def test_lora_gptq_packed(self, temp_dir):
82
+ # pylint: disable=duplicate-code
83
+ cfg = DictDefault(
84
+ {
85
+ "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
86
+ "model_type": "AutoModelForCausalLM",
87
+ "tokenizer_type": "LlamaTokenizer",
88
+ "sequence_len": 1024,
89
+ "sample_packing": True,
90
+ "flash_attention": True,
91
+ "load_in_8bit": True,
92
+ "adapter": "lora",
93
+ "gptq": True,
94
+ "gptq_disable_exllama": True,
95
+ "lora_r": 32,
96
+ "lora_alpha": 64,
97
+ "lora_dropout": 0.05,
98
+ "lora_target_linear": True,
99
+ "val_set_size": 0.1,
100
+ "special_tokens": {
101
+ "unk_token": "<unk>",
102
+ "bos_token": "<s>",
103
+ "eos_token": "</s>",
104
+ },
105
+ "datasets": [
106
+ {
107
+ "path": "mhenrichsen/alpaca_2k_test",
108
+ "type": "alpaca",
109
+ },
110
+ ],
111
+ "num_epochs": 2,
112
+ "save_steps": 0.5,
113
+ "micro_batch_size": 8,
114
+ "gradient_accumulation_steps": 1,
115
+ "output_dir": temp_dir,
116
+ "learning_rate": 0.00001,
117
+ "optimizer": "adamw_torch",
118
+ "lr_scheduler": "cosine",
119
+ }
120
+ )
121
+ normalize_config(cfg)
122
+ cli_args = TrainerCliArgs()
123
+ dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
124
+
125
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
126
+ assert (Path(temp_dir) / "adapter_model.bin").exists()
tests/e2e/{test_mistral_samplepack.py β†’ patched/test_mistral_samplepack.py} RENAMED
@@ -15,7 +15,7 @@ from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
- from .utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
 
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
+ from ..utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
tests/e2e/{test_mixtral_samplepack.py β†’ patched/test_mixtral_samplepack.py} RENAMED
@@ -15,7 +15,7 @@ from axolotl.train import train
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
- from .utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
 
15
  from axolotl.utils.config import normalize_config
16
  from axolotl.utils.dict import DictDefault
17
 
18
+ from ..utils import with_temp_dir
19
 
20
  LOG = logging.getLogger("axolotl.tests.e2e")
21
  os.environ["WANDB_DISABLED"] = "true"
tests/e2e/{test_model_patches.py β†’ patched/test_model_patches.py} RENAMED
@@ -9,7 +9,7 @@ from axolotl.utils.config import normalize_config
9
  from axolotl.utils.dict import DictDefault
10
  from axolotl.utils.models import load_model, load_tokenizer
11
 
12
- from .utils import with_temp_dir
13
 
14
 
15
  class TestModelPatches(unittest.TestCase):
 
9
  from axolotl.utils.dict import DictDefault
10
  from axolotl.utils.models import load_model, load_tokenizer
11
 
12
+ from ..utils import with_temp_dir
13
 
14
 
15
  class TestModelPatches(unittest.TestCase):
tests/e2e/{test_resume.py β†’ patched/test_resume.py} RENAMED
@@ -17,7 +17,7 @@ from axolotl.train import train
17
  from axolotl.utils.config import normalize_config
18
  from axolotl.utils.dict import DictDefault
19
 
20
- from .utils import most_recent_subdir, with_temp_dir
21
 
22
  LOG = logging.getLogger("axolotl.tests.e2e")
23
  os.environ["WANDB_DISABLED"] = "true"
@@ -29,7 +29,7 @@ class TestResumeLlama(unittest.TestCase):
29
  """
30
 
31
  @with_temp_dir
32
- def test_resume_qlora(self, temp_dir):
33
  # pylint: disable=duplicate-code
34
  cfg = DictDefault(
35
  {
 
17
  from axolotl.utils.config import normalize_config
18
  from axolotl.utils.dict import DictDefault
19
 
20
+ from ..utils import most_recent_subdir, with_temp_dir
21
 
22
  LOG = logging.getLogger("axolotl.tests.e2e")
23
  os.environ["WANDB_DISABLED"] = "true"
 
29
  """
30
 
31
  @with_temp_dir
32
+ def test_resume_qlora_packed(self, temp_dir):
33
  # pylint: disable=duplicate-code
34
  cfg = DictDefault(
35
  {
tests/e2e/test_lora_llama.py CHANGED
@@ -65,96 +65,3 @@ class TestLoraLlama(unittest.TestCase):
65
 
66
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
67
  assert (Path(temp_dir) / "adapter_model.bin").exists()
68
-
69
- @with_temp_dir
70
- def test_lora_packing(self, temp_dir):
71
- # pylint: disable=duplicate-code
72
- cfg = DictDefault(
73
- {
74
- "base_model": "JackFram/llama-68m",
75
- "tokenizer_type": "LlamaTokenizer",
76
- "sequence_len": 1024,
77
- "sample_packing": True,
78
- "flash_attention": True,
79
- "load_in_8bit": True,
80
- "adapter": "lora",
81
- "lora_r": 32,
82
- "lora_alpha": 64,
83
- "lora_dropout": 0.05,
84
- "lora_target_linear": True,
85
- "val_set_size": 0.1,
86
- "special_tokens": {
87
- "unk_token": "<unk>",
88
- "bos_token": "<s>",
89
- "eos_token": "</s>",
90
- },
91
- "datasets": [
92
- {
93
- "path": "mhenrichsen/alpaca_2k_test",
94
- "type": "alpaca",
95
- },
96
- ],
97
- "num_epochs": 2,
98
- "micro_batch_size": 8,
99
- "gradient_accumulation_steps": 1,
100
- "output_dir": temp_dir,
101
- "learning_rate": 0.00001,
102
- "optimizer": "adamw_torch",
103
- "lr_scheduler": "cosine",
104
- "bf16": True,
105
- }
106
- )
107
- normalize_config(cfg)
108
- cli_args = TrainerCliArgs()
109
- dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
110
-
111
- train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
112
- assert (Path(temp_dir) / "adapter_model.bin").exists()
113
-
114
- @with_temp_dir
115
- def test_lora_gptq(self, temp_dir):
116
- # pylint: disable=duplicate-code
117
- cfg = DictDefault(
118
- {
119
- "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
120
- "model_type": "AutoModelForCausalLM",
121
- "tokenizer_type": "LlamaTokenizer",
122
- "sequence_len": 1024,
123
- "sample_packing": True,
124
- "flash_attention": True,
125
- "load_in_8bit": True,
126
- "adapter": "lora",
127
- "gptq": True,
128
- "gptq_disable_exllama": True,
129
- "lora_r": 32,
130
- "lora_alpha": 64,
131
- "lora_dropout": 0.05,
132
- "lora_target_linear": True,
133
- "val_set_size": 0.1,
134
- "special_tokens": {
135
- "unk_token": "<unk>",
136
- "bos_token": "<s>",
137
- "eos_token": "</s>",
138
- },
139
- "datasets": [
140
- {
141
- "path": "mhenrichsen/alpaca_2k_test",
142
- "type": "alpaca",
143
- },
144
- ],
145
- "num_epochs": 2,
146
- "save_steps": 0.5,
147
- "micro_batch_size": 8,
148
- "gradient_accumulation_steps": 1,
149
- "output_dir": temp_dir,
150
- "learning_rate": 0.00001,
151
- "optimizer": "adamw_torch",
152
- "lr_scheduler": "cosine",
153
- }
154
- )
155
- normalize_config(cfg)
156
- cli_args = TrainerCliArgs()
157
- dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
158
-
159
- train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
160
- assert (Path(temp_dir) / "adapter_model.bin").exists()
 
65
 
66
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
67
  assert (Path(temp_dir) / "adapter_model.bin").exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tests/e2e/test_mamba.py CHANGED
@@ -7,6 +7,8 @@ import os
7
  import unittest
8
  from pathlib import Path
9
 
 
 
10
  from axolotl.cli import load_datasets
11
  from axolotl.common.cli import TrainerCliArgs
12
  from axolotl.train import train
@@ -19,9 +21,10 @@ LOG = logging.getLogger("axolotl.tests.e2e")
19
  os.environ["WANDB_DISABLED"] = "true"
20
 
21
 
22
- class TestMistral(unittest.TestCase):
 
23
  """
24
- Test case for Llama models using LoRA
25
  """
26
 
27
  @with_temp_dir
 
7
  import unittest
8
  from pathlib import Path
9
 
10
+ import pytest
11
+
12
  from axolotl.cli import load_datasets
13
  from axolotl.common.cli import TrainerCliArgs
14
  from axolotl.train import train
 
21
  os.environ["WANDB_DISABLED"] = "true"
22
 
23
 
24
+ @pytest.mark.skip(reason="skipping until upstreamed into transformers")
25
+ class TestMamba(unittest.TestCase):
26
  """
27
+ Test case for Mamba models
28
  """
29
 
30
  @with_temp_dir
tests/e2e/test_phi.py CHANGED
@@ -8,6 +8,7 @@ import unittest
8
  from pathlib import Path
9
 
10
  import pytest
 
11
 
12
  from axolotl.cli import load_datasets
13
  from axolotl.common.cli import TrainerCliArgs
@@ -59,7 +60,6 @@ class TestPhi(unittest.TestCase):
59
  "learning_rate": 0.00001,
60
  "optimizer": "paged_adamw_8bit",
61
  "lr_scheduler": "cosine",
62
- "bf16": True,
63
  "flash_attention": True,
64
  "max_steps": 10,
65
  "save_steps": 10,
@@ -67,6 +67,10 @@ class TestPhi(unittest.TestCase):
67
  "save_safetensors": True,
68
  }
69
  )
 
 
 
 
70
  normalize_config(cfg)
71
  cli_args = TrainerCliArgs()
72
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
@@ -110,9 +114,13 @@ class TestPhi(unittest.TestCase):
110
  "learning_rate": 0.00001,
111
  "optimizer": "adamw_bnb_8bit",
112
  "lr_scheduler": "cosine",
113
- "bf16": True,
114
  }
115
  )
 
 
 
 
 
116
  normalize_config(cfg)
117
  cli_args = TrainerCliArgs()
118
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
8
  from pathlib import Path
9
 
10
  import pytest
11
+ from transformers.utils import is_torch_bf16_gpu_available
12
 
13
  from axolotl.cli import load_datasets
14
  from axolotl.common.cli import TrainerCliArgs
 
60
  "learning_rate": 0.00001,
61
  "optimizer": "paged_adamw_8bit",
62
  "lr_scheduler": "cosine",
 
63
  "flash_attention": True,
64
  "max_steps": 10,
65
  "save_steps": 10,
 
67
  "save_safetensors": True,
68
  }
69
  )
70
+ if is_torch_bf16_gpu_available():
71
+ cfg.bf16 = True
72
+ else:
73
+ cfg.fp16 = True
74
  normalize_config(cfg)
75
  cli_args = TrainerCliArgs()
76
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
114
  "learning_rate": 0.00001,
115
  "optimizer": "adamw_bnb_8bit",
116
  "lr_scheduler": "cosine",
 
117
  }
118
  )
119
+ if is_torch_bf16_gpu_available():
120
+ cfg.bf16 = True
121
+ else:
122
+ cfg.fp16 = True
123
+
124
  normalize_config(cfg)
125
  cli_args = TrainerCliArgs()
126
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
tests/test_validation.py CHANGED
@@ -6,6 +6,7 @@ import unittest
6
  from typing import Optional
7
 
8
  import pytest
 
9
 
10
  from axolotl.utils.config import validate_config
11
  from axolotl.utils.dict import DictDefault
@@ -354,6 +355,10 @@ class ValidationTest(unittest.TestCase):
354
  with pytest.raises(ValueError, match=regex_exp):
355
  validate_config(cfg)
356
 
 
 
 
 
357
  def test_merge_lora_no_bf16_fail(self):
358
  """
359
  This is assumed to be run on a CPU machine, so bf16 is not supported.
@@ -778,6 +783,15 @@ class ValidationWandbTest(ValidationTest):
778
  assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
779
  assert os.environ.get("WANDB_DISABLED", "") != "true"
780
 
 
 
 
 
 
 
 
 
 
781
  def test_wandb_set_disabled(self):
782
  cfg = DictDefault({})
783
 
@@ -798,3 +812,6 @@ class ValidationWandbTest(ValidationTest):
798
  setup_wandb_env_vars(cfg)
799
 
800
  assert os.environ.get("WANDB_DISABLED", "") != "true"
 
 
 
 
6
  from typing import Optional
7
 
8
  import pytest
9
+ from transformers.utils import is_torch_bf16_gpu_available
10
 
11
  from axolotl.utils.config import validate_config
12
  from axolotl.utils.dict import DictDefault
 
355
  with pytest.raises(ValueError, match=regex_exp):
356
  validate_config(cfg)
357
 
358
+ @pytest.mark.skipif(
359
+ is_torch_bf16_gpu_available(),
360
+ reason="test should only run on gpus w/o bf16 support",
361
+ )
362
  def test_merge_lora_no_bf16_fail(self):
363
  """
364
  This is assumed to be run on a CPU machine, so bf16 is not supported.
 
783
  assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
784
  assert os.environ.get("WANDB_DISABLED", "") != "true"
785
 
786
+ os.environ.pop("WANDB_PROJECT", None)
787
+ os.environ.pop("WANDB_NAME", None)
788
+ os.environ.pop("WANDB_RUN_ID", None)
789
+ os.environ.pop("WANDB_ENTITY", None)
790
+ os.environ.pop("WANDB_MODE", None)
791
+ os.environ.pop("WANDB_WATCH", None)
792
+ os.environ.pop("WANDB_LOG_MODEL", None)
793
+ os.environ.pop("WANDB_DISABLED", None)
794
+
795
  def test_wandb_set_disabled(self):
796
  cfg = DictDefault({})
797
 
 
812
  setup_wandb_env_vars(cfg)
813
 
814
  assert os.environ.get("WANDB_DISABLED", "") != "true"
815
+
816
+ os.environ.pop("WANDB_PROJECT", None)
817
+ os.environ.pop("WANDB_DISABLED", None)