let hf trainer handle torch compile (#516)
Browse files* let hf trainer handle torch compile
* remove torch compile checks, include option for backend
* suppress torch errors to get further
* require min torch version of 2.1.0 for torch compile to work
---------
Co-authored-by: Aman Karmani <[email protected]>
- README.md +4 -0
- src/axolotl/train.py +0 -4
- src/axolotl/utils/trainer.py +16 -0
README.md
CHANGED
@@ -519,6 +519,10 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|
519 |
# where to save the finished model to
|
520 |
output_dir: ./completed-model
|
521 |
|
|
|
|
|
|
|
|
|
522 |
# training hyperparameters
|
523 |
gradient_accumulation_steps: 1
|
524 |
micro_batch_size: 2
|
|
|
519 |
# where to save the finished model to
|
520 |
output_dir: ./completed-model
|
521 |
|
522 |
+
# whether to use torch.compile and which backend to use
|
523 |
+
torch_compile: # bool
|
524 |
+
torch_compile_backend: # Optional[str]
|
525 |
+
|
526 |
# training hyperparameters
|
527 |
gradient_accumulation_steps: 1
|
528 |
micro_batch_size: 2
|
src/axolotl/train.py
CHANGED
@@ -80,10 +80,6 @@ def train(
|
|
80 |
|
81 |
model.config.use_cache = False
|
82 |
|
83 |
-
if torch.__version__ >= "2" and sys.platform != "win32":
|
84 |
-
LOG.info("Compiling torch model")
|
85 |
-
model = torch.compile(model)
|
86 |
-
|
87 |
# go ahead and presave, so we have the adapter config available to inspect
|
88 |
if peft_config:
|
89 |
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
|
|
80 |
|
81 |
model.config.use_cache = False
|
82 |
|
|
|
|
|
|
|
|
|
83 |
# go ahead and presave, so we have the adapter config available to inspect
|
84 |
if peft_config:
|
85 |
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
src/axolotl/utils/trainer.py
CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
|
|
11 |
from typing import Optional, Union
|
12 |
|
13 |
import numpy as np
|
|
|
14 |
import torch.cuda
|
15 |
import transformers
|
16 |
from datasets import Dataset, set_caching_enabled
|
@@ -604,6 +605,21 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
604 |
if cfg.greater_is_better:
|
605 |
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
606 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
607 |
# DDP Config
|
608 |
if cfg.ddp_timeout:
|
609 |
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
|
|
11 |
from typing import Optional, Union
|
12 |
|
13 |
import numpy as np
|
14 |
+
import torch
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
from datasets import Dataset, set_caching_enabled
|
|
|
605 |
if cfg.greater_is_better:
|
606 |
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
607 |
|
608 |
+
if cfg.torch_compile:
|
609 |
+
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
610 |
+
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
611 |
+
else:
|
612 |
+
import torch._dynamo # pylint: disable=redefined-outer-name
|
613 |
+
|
614 |
+
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
615 |
+
True
|
616 |
+
)
|
617 |
+
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
618 |
+
if cfg.torch_compile_backend:
|
619 |
+
training_arguments_kwargs[
|
620 |
+
"torch_compile_backend"
|
621 |
+
] = cfg.torch_compile_backend
|
622 |
+
|
623 |
# DDP Config
|
624 |
if cfg.ddp_timeout:
|
625 |
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|