winglian commited on
Commit
5894f0e
1 Parent(s): 5cf226e

make mlflow optional (#1317)

Browse files

* make mlflow optional

* fix xformers

don't patch swiglu if xformers not working
fix the check for xformers swiglu

* fix install of xformers with extra index url for docker builds

* fix docker build arg quoting

.github/workflows/main.yml CHANGED
@@ -18,6 +18,7 @@ jobs:
18
  python_version: "3.10"
19
  pytorch: 2.1.2
20
  axolotl_extras:
 
21
  is_latest: true
22
  - cuda: 121
23
  cuda_version: 12.1.0
@@ -54,6 +55,7 @@ jobs:
54
  BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
55
  CUDA=${{ matrix.cuda }}
56
  PYTORCH_VERSION=${{ matrix.pytorch }}
 
57
  file: ./docker/Dockerfile
58
  push: ${{ github.event_name != 'pull_request' }}
59
  tags: |
 
18
  python_version: "3.10"
19
  pytorch: 2.1.2
20
  axolotl_extras:
21
+ axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
22
  is_latest: true
23
  - cuda: 121
24
  cuda_version: 12.1.0
 
55
  BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
56
  CUDA=${{ matrix.cuda }}
57
  PYTORCH_VERSION=${{ matrix.pytorch }}
58
+ AXOLOTL_ARGS=${{ matrix.axolotl_args }}
59
  file: ./docker/Dockerfile
60
  push: ${{ github.event_name != 'pull_request' }}
61
  tags: |
.github/workflows/tests.yml CHANGED
@@ -70,6 +70,7 @@ jobs:
70
  cuda_version: 11.8.0
71
  python_version: "3.10"
72
  pytorch: 2.1.2
 
73
  - cuda: 121
74
  cuda_version: 12.1.0
75
  python_version: "3.10"
@@ -87,11 +88,13 @@ jobs:
87
  # Set up build arguments
88
  BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
89
  CUDA="${{ matrix.cuda }}"
 
90
  PYTORCH_VERSION="${{ matrix.pytorch }}"
91
  # Build the Docker image
92
  docker build . \
93
  --file ./docker/Dockerfile-tests \
94
  --build-arg BASE_TAG=$BASE_TAG \
 
95
  --build-arg CUDA=$CUDA \
96
  --build-arg GITHUB_REF=$GITHUB_REF \
97
  --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
 
70
  cuda_version: 11.8.0
71
  python_version: "3.10"
72
  pytorch: 2.1.2
73
+ axolotl_args: "--extra-index-url https://download.pytorch.org/whl/cu118"
74
  - cuda: 121
75
  cuda_version: 12.1.0
76
  python_version: "3.10"
 
88
  # Set up build arguments
89
  BASE_TAG="main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}"
90
  CUDA="${{ matrix.cuda }}"
91
+ AXOLOTL_ARGS="${{ matrix.axolotl_args }}"
92
  PYTORCH_VERSION="${{ matrix.pytorch }}"
93
  # Build the Docker image
94
  docker build . \
95
  --file ./docker/Dockerfile-tests \
96
  --build-arg BASE_TAG=$BASE_TAG \
97
+ --build-arg AXOLOTL_ARGS="$AXOLOTL_ARGS" \
98
  --build-arg CUDA=$CUDA \
99
  --build-arg GITHUB_REF=$GITHUB_REF \
100
  --build-arg PYTORCH_VERSION=$PYTORCH_VERSION \
docker/Dockerfile CHANGED
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
3
 
4
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5
  ARG AXOLOTL_EXTRAS=""
 
6
  ARG CUDA="118"
7
  ENV BNB_CUDA_VERSION=$CUDA
8
  ARG PYTORCH_VERSION="2.0.1"
@@ -20,9 +21,9 @@ WORKDIR /workspace/axolotl
20
 
21
  # If AXOLOTL_EXTRAS is set, append it in brackets
22
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
23
- pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
24
  else \
25
- pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
26
  fi
27
 
28
  # So we can test the Docker image
 
3
 
4
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5
  ARG AXOLOTL_EXTRAS=""
6
+ ARG AXOLOTL_ARGS=""
7
  ARG CUDA="118"
8
  ENV BNB_CUDA_VERSION=$CUDA
9
  ARG PYTORCH_VERSION="2.0.1"
 
21
 
22
  # If AXOLOTL_EXTRAS is set, append it in brackets
23
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
24
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
25
  else \
26
+ pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
27
  fi
28
 
29
  # So we can test the Docker image
docker/Dockerfile-tests CHANGED
@@ -3,6 +3,7 @@ FROM winglian/axolotl-base:$BASE_TAG
3
 
4
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5
  ARG AXOLOTL_EXTRAS=""
 
6
  ARG CUDA="118"
7
  ENV BNB_CUDA_VERSION=$CUDA
8
  ARG PYTORCH_VERSION="2.0.1"
@@ -24,9 +25,9 @@ RUN git fetch origin +$GITHUB_REF && \
24
 
25
  # If AXOLOTL_EXTRAS is set, append it in brackets
26
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
27
- pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS]; \
28
  else \
29
- pip install -e .[deepspeed,flash-attn,mamba-ssm]; \
30
  fi
31
 
32
  # So we can test the Docker image
 
3
 
4
  ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
5
  ARG AXOLOTL_EXTRAS=""
6
+ ARG AXOLOTL_ARGS=""
7
  ARG CUDA="118"
8
  ENV BNB_CUDA_VERSION=$CUDA
9
  ARG PYTORCH_VERSION="2.0.1"
 
25
 
26
  # If AXOLOTL_EXTRAS is set, append it in brackets
27
  RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
28
+ pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
29
  else \
30
+ pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
31
  fi
32
 
33
  # So we can test the Docker image
requirements.txt CHANGED
@@ -21,7 +21,6 @@ hf_transfer
21
  colorama
22
  numba
23
  numpy>=1.24.4
24
- mlflow
25
  # qlora things
26
  evaluate==0.4.1
27
  scipy
 
21
  colorama
22
  numba
23
  numpy>=1.24.4
 
24
  # qlora things
25
  evaluate==0.4.1
26
  scipy
setup.py CHANGED
@@ -82,5 +82,8 @@ setup(
82
  "auto-gptq": [
83
  "auto-gptq==0.5.1",
84
  ],
 
 
 
85
  },
86
  )
 
82
  "auto-gptq": [
83
  "auto-gptq==0.5.1",
84
  ],
85
+ "mlflow": [
86
+ "mlflow",
87
+ ],
88
  },
89
  )
src/axolotl/core/trainer_builder.py CHANGED
@@ -5,6 +5,7 @@ Builder for the training args and trainer
5
 
6
  import abc
7
  import importlib
 
8
  import logging
9
  import math
10
  import sys
@@ -34,7 +35,6 @@ from axolotl.utils.callbacks import (
34
  EvalFirstStepCallback,
35
  GPUStatsCallback,
36
  LossWatchDogCallback,
37
- SaveAxolotlConfigtoMlflowCallback,
38
  SaveAxolotlConfigtoWandBCallback,
39
  SaveBetterTransformerModelCallback,
40
  bench_eval_callback_factory,
@@ -62,6 +62,10 @@ except ImportError:
62
  LOG = logging.getLogger("axolotl.core.trainer_builder")
63
 
64
 
 
 
 
 
65
  def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
66
  if isinstance(tag_names, str):
67
  tag_names = [tag_names]
@@ -648,7 +652,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
648
  callbacks.append(
649
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
650
  )
651
- if self.cfg.use_mlflow:
 
 
 
 
652
  callbacks.append(
653
  SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
654
  )
 
5
 
6
  import abc
7
  import importlib
8
+ import importlib.util
9
  import logging
10
  import math
11
  import sys
 
35
  EvalFirstStepCallback,
36
  GPUStatsCallback,
37
  LossWatchDogCallback,
 
38
  SaveAxolotlConfigtoWandBCallback,
39
  SaveBetterTransformerModelCallback,
40
  bench_eval_callback_factory,
 
62
  LOG = logging.getLogger("axolotl.core.trainer_builder")
63
 
64
 
65
+ def is_mlflow_available():
66
+ return importlib.util.find_spec("mlflow") is not None
67
+
68
+
69
  def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
70
  if isinstance(tag_names, str):
71
  tag_names = [tag_names]
 
652
  callbacks.append(
653
  SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
654
  )
655
+ if self.cfg.use_mlflow and is_mlflow_available():
656
+ from axolotl.utils.callbacks.mlflow_ import (
657
+ SaveAxolotlConfigtoMlflowCallback,
658
+ )
659
+
660
  callbacks.append(
661
  SaveAxolotlConfigtoMlflowCallback(self.cfg.axolotl_config_path)
662
  )
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -44,6 +44,18 @@ except ImportError:
44
  LOG = logging.getLogger("axolotl")
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def replace_llama_mlp_with_swiglu(model):
48
  for name, module in model.named_modules():
49
  if isinstance(module, LlamaMLP):
 
44
  LOG = logging.getLogger("axolotl")
45
 
46
 
47
+ def is_xformers_swiglu_available() -> bool:
48
+ from xformers.ops.common import get_xformers_operator
49
+
50
+ try:
51
+ get_xformers_operator("swiglu_packedw")()
52
+ return True
53
+ except RuntimeError as exc:
54
+ if "No such operator xformers::swiglu_packedw " in str(exc):
55
+ return False
56
+ return True
57
+
58
+
59
  def replace_llama_mlp_with_swiglu(model):
60
  for name, module in model.named_modules():
61
  if isinstance(module, LlamaMLP):
src/axolotl/utils/{callbacks.py → callbacks/__init__.py} RENAMED
@@ -9,7 +9,6 @@ from tempfile import NamedTemporaryFile
9
  from typing import TYPE_CHECKING, Dict, List
10
 
11
  import evaluate
12
- import mlflow
13
  import numpy as np
14
  import pandas as pd
15
  import torch
@@ -42,8 +41,8 @@ from axolotl.utils.distributed import (
42
  if TYPE_CHECKING:
43
  from axolotl.core.trainer_builder import AxolotlTrainingArguments
44
 
45
- LOG = logging.getLogger("axolotl.callbacks")
46
  IGNORE_INDEX = -100
 
47
 
48
 
49
  class EvalFirstStepCallback(
@@ -756,31 +755,3 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
756
  except (FileNotFoundError, ConnectionError) as err:
757
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
758
  return control
759
-
760
-
761
- class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
762
- """Callback to save axolotl config to mlflow"""
763
-
764
- def __init__(self, axolotl_config_path):
765
- self.axolotl_config_path = axolotl_config_path
766
-
767
- def on_train_begin(
768
- self,
769
- args: AxolotlTrainingArguments, # pylint: disable=unused-argument
770
- state: TrainerState, # pylint: disable=unused-argument
771
- control: TrainerControl,
772
- **kwargs, # pylint: disable=unused-argument
773
- ):
774
- if is_main_process():
775
- try:
776
- with NamedTemporaryFile(
777
- mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
778
- ) as temp_file:
779
- copyfile(self.axolotl_config_path, temp_file.name)
780
- mlflow.log_artifact(temp_file.name, artifact_path="")
781
- LOG.info(
782
- "The Axolotl config has been saved to the MLflow artifacts."
783
- )
784
- except (FileNotFoundError, ConnectionError) as err:
785
- LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
786
- return control
 
9
  from typing import TYPE_CHECKING, Dict, List
10
 
11
  import evaluate
 
12
  import numpy as np
13
  import pandas as pd
14
  import torch
 
41
  if TYPE_CHECKING:
42
  from axolotl.core.trainer_builder import AxolotlTrainingArguments
43
 
 
44
  IGNORE_INDEX = -100
45
+ LOG = logging.getLogger("axolotl.callbacks")
46
 
47
 
48
  class EvalFirstStepCallback(
 
755
  except (FileNotFoundError, ConnectionError) as err:
756
  LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
757
  return control
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/axolotl/utils/callbacks/mlflow_.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLFlow module for trainer callbacks"""
2
+ import logging
3
+ from shutil import copyfile
4
+ from tempfile import NamedTemporaryFile
5
+ from typing import TYPE_CHECKING
6
+
7
+ import mlflow
8
+ from transformers import TrainerCallback, TrainerControl, TrainerState
9
+
10
+ from axolotl.utils.distributed import is_main_process
11
+
12
+ if TYPE_CHECKING:
13
+ from axolotl.core.trainer_builder import AxolotlTrainingArguments
14
+
15
+ LOG = logging.getLogger("axolotl.callbacks")
16
+
17
+
18
+ class SaveAxolotlConfigtoMlflowCallback(TrainerCallback):
19
+ # pylint: disable=duplicate-code
20
+ """Callback to save axolotl config to mlflow"""
21
+
22
+ def __init__(self, axolotl_config_path):
23
+ self.axolotl_config_path = axolotl_config_path
24
+
25
+ def on_train_begin(
26
+ self,
27
+ args: "AxolotlTrainingArguments", # pylint: disable=unused-argument
28
+ state: TrainerState, # pylint: disable=unused-argument
29
+ control: TrainerControl,
30
+ **kwargs, # pylint: disable=unused-argument
31
+ ):
32
+ if is_main_process():
33
+ try:
34
+ with NamedTemporaryFile(
35
+ mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
36
+ ) as temp_file:
37
+ copyfile(self.axolotl_config_path, temp_file.name)
38
+ mlflow.log_artifact(temp_file.name, artifact_path="")
39
+ LOG.info(
40
+ "The Axolotl config has been saved to the MLflow artifacts."
41
+ )
42
+ except (FileNotFoundError, ConnectionError) as err:
43
+ LOG.warning(f"Error while saving Axolotl config to MLflow: {err}")
44
+ return control
src/axolotl/utils/models.py CHANGED
@@ -512,11 +512,12 @@ def load_model(
512
 
513
  if cfg.flash_attention and not inference:
514
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
 
515
  replace_llama_mlp_with_swiglu,
516
  replace_llama_qkv_with_fused,
517
  )
518
 
519
- if cfg.flash_attn_fuse_mlp:
520
  LOG.info("patching with SwiGLU")
521
  replace_llama_mlp_with_swiglu(model)
522
 
 
512
 
513
  if cfg.flash_attention and not inference:
514
  from axolotl.monkeypatch.llama_attn_hijack_flash import (
515
+ is_xformers_swiglu_available,
516
  replace_llama_mlp_with_swiglu,
517
  replace_llama_qkv_with_fused,
518
  )
519
 
520
+ if cfg.flash_attn_fuse_mlp and is_xformers_swiglu_available():
521
  LOG.info("patching with SwiGLU")
522
  replace_llama_mlp_with_swiglu(model)
523
 
tests/e2e/patched/test_fused_llama.py CHANGED
@@ -57,9 +57,9 @@ class TestFusedLlama(unittest.TestCase):
57
  "learning_rate": 0.00001,
58
  "optimizer": "adamw_torch",
59
  "lr_scheduler": "cosine",
60
- "max_steps": 20,
61
- "save_steps": 10,
62
- "eval_steps": 10,
63
  }
64
  )
65
  if is_torch_bf16_gpu_available():
 
57
  "learning_rate": 0.00001,
58
  "optimizer": "adamw_torch",
59
  "lr_scheduler": "cosine",
60
+ "max_steps": 10,
61
+ "save_steps": 5,
62
+ "eval_steps": 5,
63
  }
64
  )
65
  if is_torch_bf16_gpu_available():