winglian commited on
Commit
6319da1
·
unverified ·
1 Parent(s): 132eb74

Unsloth gradient checkpointing offload (#1528)

Browse files

* unsloth gradient checkpointing

* fix validation too

* fixes to make it work with mistral

* monkeypatch the checkpoint fn earlier

src/axolotl/monkeypatch/mistral_attn_hijack_flash.py CHANGED
@@ -516,24 +516,18 @@ def mistral_model_forward(
516
  past_key_value = past_key_values[idx] if past_key_values is not None else None
517
 
518
  if self.gradient_checkpointing and self.training:
519
-
520
- def create_custom_forward(module):
521
- def custom_forward(*inputs):
522
- # None for past_key_value
523
- return module(*inputs)
524
-
525
- return custom_forward
526
-
527
- layer_outputs = torch.utils.checkpoint.checkpoint(
528
- create_custom_forward(decoder_layer),
529
- hidden_states,
530
- attention_mask,
531
- position_ids,
532
- past_key_value,
533
- output_attentions,
534
- None,
535
- cu_seqlens,
536
- max_seqlen,
537
  )
538
  else:
539
  layer_outputs = decoder_layer(
 
516
  past_key_value = past_key_values[idx] if past_key_values is not None else None
517
 
518
  if self.gradient_checkpointing and self.training:
519
+ layer_outputs = (
520
+ self._gradient_checkpointing_func( # pylint: disable=protected-access
521
+ decoder_layer.__call__,
522
+ hidden_states,
523
+ attention_mask,
524
+ position_ids,
525
+ past_key_value,
526
+ output_attentions,
527
+ None,
528
+ cu_seqlens,
529
+ max_seqlen,
530
+ )
 
 
 
 
 
 
531
  )
532
  else:
533
  layer_outputs = decoder_layer(
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -479,6 +479,7 @@ class AxolotlInputConfig(
479
  eval_causal_lm_metrics: Optional[List[str]] = None
480
  do_bench_eval: Optional[bool] = None
481
  bench_dataset: Optional[str] = None
 
482
  metric_for_best_model: Optional[str] = None
483
  greater_is_better: Optional[bool] = None
484
 
@@ -494,7 +495,9 @@ class AxolotlInputConfig(
494
 
495
  # torch_dtype: Optional[torch.dtype]
496
 
497
- gradient_checkpointing: Optional[bool] = Field(default=False)
 
 
498
  gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
499
 
500
  unfrozen_parameters: Optional[List[str]] = None
 
479
  eval_causal_lm_metrics: Optional[List[str]] = None
480
  do_bench_eval: Optional[bool] = None
481
  bench_dataset: Optional[str] = None
482
+ bench_split: Optional[str] = None
483
  metric_for_best_model: Optional[str] = None
484
  greater_is_better: Optional[bool] = None
485
 
 
495
 
496
  # torch_dtype: Optional[torch.dtype]
497
 
498
+ gradient_checkpointing: Optional[Union[Literal["unsloth"], bool]] = Field(
499
+ default=False
500
+ )
501
  gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
502
 
503
  unfrozen_parameters: Optional[List[str]] = None
src/axolotl/utils/gradient_checkpointing/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """custom checkpointing utils"""
2
+ from axolotl.utils.gradient_checkpointing.unsloth import (
3
+ Unsloth_Offloaded_Gradient_Checkpointer,
4
+ )
5
+
6
+
7
+ def hf_grad_checkpoint_unsloth_wrapper(
8
+ decoder_layer, *args, use_reentrant=None
9
+ ): # pylint: disable=unused-argument
10
+ return Unsloth_Offloaded_Gradient_Checkpointer.apply(
11
+ decoder_layer.__self__,
12
+ *args,
13
+ )
src/axolotl/utils/gradient_checkpointing/unsloth.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unsloth checkpointing"""
2
+
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ import torch
17
+
18
+
19
+ class Unsloth_Offloaded_Gradient_Checkpointer( # pylint: disable=invalid-name
20
+ torch.autograd.Function
21
+ ):
22
+ """
23
+ Saves VRAM by smartly offloading to RAM.
24
+ Tiny hit to performance, since we mask the movement via non blocking calls.
25
+ """
26
+
27
+ @staticmethod
28
+ @torch.cuda.amp.custom_fwd
29
+ def forward(ctx, forward_function, hidden_states, *args):
30
+ saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
31
+ with torch.no_grad():
32
+ output = forward_function(hidden_states, *args)
33
+ ctx.save_for_backward(saved_hidden_states)
34
+ ctx.forward_function = forward_function
35
+ ctx.args = args
36
+ return output
37
+
38
+ @staticmethod
39
+ @torch.cuda.amp.custom_bwd
40
+ def backward(ctx, dY):
41
+ (hidden_states,) = ctx.saved_tensors
42
+ hidden_states = hidden_states.to("cuda", non_blocking=True).detach()
43
+ hidden_states.requires_grad = True
44
+ with torch.enable_grad():
45
+ (output,) = ctx.forward_function(hidden_states, *ctx.args)
46
+ torch.autograd.backward(output, dY)
47
+ return (
48
+ None,
49
+ hidden_states.grad,
50
+ ) + (
51
+ None,
52
+ ) * len(ctx.args)
src/axolotl/utils/models.py CHANGED
@@ -11,6 +11,7 @@ import addict
11
  import bitsandbytes as bnb
12
  import torch
13
  import transformers
 
14
  from accelerate import init_empty_weights
15
  from bitsandbytes.nn import Params4bit
16
  from peft import (
@@ -44,6 +45,7 @@ from axolotl.utils.bench import log_gpu_memory_usage
44
  from axolotl.utils.chat_templates import chat_templates
45
  from axolotl.utils.dict import DictDefault
46
  from axolotl.utils.distributed import zero_only
 
47
  from axolotl.utils.lora_embeddings import get_linear_embedding_layers
48
  from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
49
 
@@ -310,6 +312,9 @@ def load_model(
310
  # TODO refactor as a kwarg
311
  load_in_8bit = cfg.load_in_8bit
312
 
 
 
 
313
  if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
314
  if cfg.flash_attention:
315
  from axolotl.monkeypatch.btlm_attn_hijack_flash import (
 
11
  import bitsandbytes as bnb
12
  import torch
13
  import transformers
14
+ import transformers.modeling_utils
15
  from accelerate import init_empty_weights
16
  from bitsandbytes.nn import Params4bit
17
  from peft import (
 
45
  from axolotl.utils.chat_templates import chat_templates
46
  from axolotl.utils.dict import DictDefault
47
  from axolotl.utils.distributed import zero_only
48
+ from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
49
  from axolotl.utils.lora_embeddings import get_linear_embedding_layers
50
  from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
51
 
 
312
  # TODO refactor as a kwarg
313
  load_in_8bit = cfg.load_in_8bit
314
 
315
+ if cfg.gradient_checkpointing == "unsloth":
316
+ transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
317
+
318
  if hasattr(model_config, "model_type") and model_config.model_type == "btlm":
319
  if cfg.flash_attention:
320
  from axolotl.monkeypatch.btlm_attn_hijack_flash import (