support galore once upstreamed into transformers (#1409)
Browse files* support galore once upstreamed into transformers
* update module name for llama in readme and fix typing for all linear
* bump trl for deprecation fixes from newer transformers
* include galore as an extra and install in docker image
* fix optim_args type
* fix optim_args
* update dependencies for galore
* add galore to cicd dockerfile
- README.md +19 -0
- cicd/Dockerfile.jinja +2 -2
- docker/Dockerfile +2 -2
- requirements.txt +2 -2
- setup.py +3 -0
- src/axolotl/core/trainer_builder.py +14 -1
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +9 -0
README.md
CHANGED
@@ -907,7 +907,26 @@ lr_div_factor: # Learning rate div factor
|
|
907 |
# - paged_adamw_8bit
|
908 |
# - paged_lion_32bit
|
909 |
# - paged_lion_8bit
|
|
|
|
|
|
|
|
|
|
|
|
|
910 |
optimizer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
# Specify weight decay
|
912 |
weight_decay:
|
913 |
# adamw hyperparams
|
|
|
907 |
# - paged_adamw_8bit
|
908 |
# - paged_lion_32bit
|
909 |
# - paged_lion_8bit
|
910 |
+
# - galore_adamw
|
911 |
+
# - galore_adamw_8bit
|
912 |
+
# - galore_adafactor
|
913 |
+
# - galore_adamw_layerwise
|
914 |
+
# - galore_adamw_8bit_layerwise
|
915 |
+
# - galore_adafactor_layerwise
|
916 |
optimizer:
|
917 |
+
# Dictionary of arguments to pass to the optimizer
|
918 |
+
optim_args:
|
919 |
+
# For Galore Optimizers the following optim_args are available
|
920 |
+
# rank: # type: int
|
921 |
+
# update_proj_gap # type: int
|
922 |
+
# scale # type: float
|
923 |
+
# proj_type: # type: str, default = std
|
924 |
+
|
925 |
+
# The target modules to optimize, i.e. the module names that you would like to train, right now this is used only for GaLore algorithm
|
926 |
+
optim_target_modules:
|
927 |
+
# - self_attn # for llama
|
928 |
+
# - mlp
|
929 |
+
|
930 |
# Specify weight decay
|
931 |
weight_decay:
|
932 |
# adamw hyperparams
|
cicd/Dockerfile.jinja
CHANGED
@@ -23,9 +23,9 @@ RUN git fetch origin +$GITHUB_REF && \
|
|
23 |
|
24 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
25 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
26 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
27 |
else \
|
28 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
29 |
fi
|
30 |
|
31 |
# So we can test the Docker image
|
|
|
23 |
|
24 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
25 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
26 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
27 |
else \
|
28 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
29 |
fi
|
30 |
|
31 |
# So we can test the Docker image
|
docker/Dockerfile
CHANGED
@@ -21,9 +21,9 @@ WORKDIR /workspace/axolotl
|
|
21 |
|
22 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
24 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
25 |
else \
|
26 |
-
pip install -e .[deepspeed,flash-attn,mamba-ssm] $AXOLOTL_ARGS; \
|
27 |
fi
|
28 |
|
29 |
# So we can test the Docker image
|
|
|
21 |
|
22 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
23 |
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
24 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
25 |
else \
|
26 |
+
pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \
|
27 |
fi
|
28 |
|
29 |
# So we can test the Docker image
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.9.0
|
4 |
-
transformers
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.43.0
|
7 |
accelerate==0.26.1
|
@@ -39,5 +39,5 @@ s3fs
|
|
39 |
gcsfs
|
40 |
# adlfs
|
41 |
|
42 |
-
trl
|
43 |
fastcore>=1.5.29
|
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.9.0
|
4 |
+
transformers @ git+https://github.com/huggingface/transformers.git@f6261d7d81edd036fc53bfede65fe91f01a661aa
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.43.0
|
7 |
accelerate==0.26.1
|
|
|
39 |
gcsfs
|
40 |
# adlfs
|
41 |
|
42 |
+
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
43 |
fastcore>=1.5.29
|
setup.py
CHANGED
@@ -89,5 +89,8 @@ setup(
|
|
89 |
"lion-pytorch": [
|
90 |
"lion-pytorch==0.1.2",
|
91 |
],
|
|
|
|
|
|
|
92 |
},
|
93 |
)
|
|
|
89 |
"lion-pytorch": [
|
90 |
"lion-pytorch==0.1.2",
|
91 |
],
|
92 |
+
"galore": [
|
93 |
+
"galore_torch",
|
94 |
+
],
|
95 |
},
|
96 |
)
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -220,7 +220,7 @@ class AxolotlTrainer(Trainer):
|
|
220 |
num_epochs=1,
|
221 |
bench_data_collator=None,
|
222 |
eval_data_collator=None,
|
223 |
-
**kwargs
|
224 |
):
|
225 |
self.num_epochs = num_epochs
|
226 |
self.bench_data_collator = bench_data_collator
|
@@ -239,6 +239,7 @@ class AxolotlTrainer(Trainer):
|
|
239 |
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
240 |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
241 |
self.args,
|
|
|
242 |
)
|
243 |
|
244 |
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
@@ -1150,6 +1151,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
1150 |
training_arguments_kwargs["optim"] = (
|
1151 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
1152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1153 |
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
1154 |
training_arguments_kwargs[
|
1155 |
"loraplus_lr_embedding"
|
|
|
220 |
num_epochs=1,
|
221 |
bench_data_collator=None,
|
222 |
eval_data_collator=None,
|
223 |
+
**kwargs,
|
224 |
):
|
225 |
self.num_epochs = num_epochs
|
226 |
self.bench_data_collator = bench_data_collator
|
|
|
239 |
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
240 |
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
241 |
self.args,
|
242 |
+
opt_model,
|
243 |
)
|
244 |
|
245 |
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
|
|
1151 |
training_arguments_kwargs["optim"] = (
|
1152 |
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
1153 |
)
|
1154 |
+
if self.cfg.optim_args:
|
1155 |
+
if isinstance(self.cfg.optim_args, dict):
|
1156 |
+
optim_args = ",".join(
|
1157 |
+
[f"{key}={value}" for key, value in self.cfg.optim_args.items()]
|
1158 |
+
)
|
1159 |
+
else:
|
1160 |
+
optim_args = self.cfg.optim_args
|
1161 |
+
training_arguments_kwargs["optim_args"] = optim_args
|
1162 |
+
if self.cfg.optim_target_modules:
|
1163 |
+
training_arguments_kwargs[
|
1164 |
+
"optim_target_modules"
|
1165 |
+
] = self.cfg.optim_target_modules
|
1166 |
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
1167 |
training_arguments_kwargs[
|
1168 |
"loraplus_lr_embedding"
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -313,6 +313,15 @@ class HyperparametersConfig(BaseModel):
|
|
313 |
learning_rate: Union[str, float]
|
314 |
weight_decay: Optional[float] = None
|
315 |
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
torchdistx_path: Optional[str] = None
|
317 |
lr_scheduler: Optional[SchedulerType] = None
|
318 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
|
|
313 |
learning_rate: Union[str, float]
|
314 |
weight_decay: Optional[float] = None
|
315 |
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
316 |
+
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
317 |
+
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
|
318 |
+
)
|
319 |
+
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
|
320 |
+
default=None,
|
321 |
+
metadata={
|
322 |
+
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
323 |
+
},
|
324 |
+
)
|
325 |
torchdistx_path: Optional[str] = None
|
326 |
lr_scheduler: Optional[SchedulerType] = None
|
327 |
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|