Support gradient checkpointing

#3
by maxall4 - opened
Files changed (3) hide show
  1. config.json +1 -1
  2. model.py +18 -7
  3. modeling_hyena.py +25 -0
config.json CHANGED
@@ -87,4 +87,4 @@
87
  "use_flashfft": false,
88
  "use_interpolated_rotary_pos_emb": true,
89
  "vocab_size": 512
90
- }
 
87
  "use_flashfft": false,
88
  "use_interpolated_rotary_pos_emb": true,
89
  "vocab_size": 512
90
+ }
model.py CHANGED
@@ -22,6 +22,8 @@ try:
22
  except ImportError:
23
  "could not import swap_mha_rope from positional_embeddings.py"
24
 
 
 
25
  # dummy import to force huggingface to bundle the tokenizer
26
  from .tokenizer import ByteTokenizer
27
 
@@ -64,6 +66,7 @@ class AttentionBlock(nn.Module):
64
  self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
65
 
66
  self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
 
67
 
68
  def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
69
  if (
@@ -71,13 +74,12 @@ class AttentionBlock(nn.Module):
71
  ): # workaround for masking bug in FA. This works because Wqkv does not have bias
72
  # and attention scores will be also automatically zeroed.
73
  u = u * padding_mask[..., None]
74
- u = (
75
- self.inner_mha_cls(
76
  self.pre_norm(u),
77
  inference_params=inference_params,
78
- )
79
- + u
80
  )
 
 
81
  if type(padding_mask) == torch.Tensor: # guard against bias
82
  u = u * padding_mask[..., None]
83
  u = self.mlp(self.post_norm(u)) + u
@@ -120,7 +122,7 @@ class ParallelHyenaFilter(nn.Module):
120
  self.data_dtype = None
121
 
122
  if self.use_flash_depthwise:
123
- self.fir_fn = FlashDepthwiseConv1d(
124
  channels=3 * self.hidden_size,
125
  kernel_size=self.short_filter_length,
126
  padding=self.short_filter_length - 1,
@@ -287,6 +289,7 @@ class ParallelGatedConvBlock(nn.Module):
287
 
288
  self.proj_norm_fn = self.proj_norm
289
  self.res_mlp_norm_fn = self.res_mlp_norm
 
290
 
291
  if self.config.get("compile", False):
292
  self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
@@ -308,6 +311,8 @@ class ParallelGatedConvBlock(nn.Module):
308
 
309
  z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
310
 
 
 
311
  z_in = self.out_filter_dense(z) + u
312
 
313
  if type(padding_mask) == torch.Tensor: # guard against bias
@@ -343,13 +348,15 @@ class StripedHyena(nn.Module):
343
  from flashfftconv import FlashFFTConv
344
  except:
345
  raise ImportError
346
- self.flash_fft = FlashFFTConv(2 * config.seqlen, dtype=torch.bfloat16)
347
  else:
348
  self.flash_fft = None
349
 
350
  self.blocks = nn.ModuleList(
351
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
352
  )
 
 
353
 
354
  def forward(self, x, inference_params_dict=None, padding_mask=None):
355
  L = x.shape[1]
@@ -379,7 +386,11 @@ class StripedHyena(nn.Module):
379
  x = x * padding_mask[..., None]
380
 
381
  for _, block in enumerate(self.blocks):
382
- x, _ = block(x, inference_params=None, padding_mask=padding_mask)
 
 
 
 
383
  return x, None
384
 
385
  def initialize_inference_params(self):
 
22
  except ImportError:
23
  "could not import swap_mha_rope from positional_embeddings.py"
24
 
25
+ from flashfftconv import FlashDepthWiseConv1d
26
+
27
  # dummy import to force huggingface to bundle the tokenizer
28
  from .tokenizer import ByteTokenizer
29
 
 
66
  self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
67
 
68
  self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
69
+ self.filter_output = None
70
 
71
  def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
72
  if (
 
74
  ): # workaround for masking bug in FA. This works because Wqkv does not have bias
75
  # and attention scores will be also automatically zeroed.
76
  u = u * padding_mask[..., None]
77
+ w = self.inner_mha_cls(
 
78
  self.pre_norm(u),
79
  inference_params=inference_params,
 
 
80
  )
81
+ self.filter_output = w
82
+ u = w + u
83
  if type(padding_mask) == torch.Tensor: # guard against bias
84
  u = u * padding_mask[..., None]
85
  u = self.mlp(self.post_norm(u)) + u
 
122
  self.data_dtype = None
123
 
124
  if self.use_flash_depthwise:
125
+ self.fir_fn = FlashDepthWiseConv1d(
126
  channels=3 * self.hidden_size,
127
  kernel_size=self.short_filter_length,
128
  padding=self.short_filter_length - 1,
 
289
 
290
  self.proj_norm_fn = self.proj_norm
291
  self.res_mlp_norm_fn = self.res_mlp_norm
292
+ self.filter_output = None
293
 
294
  if self.config.get("compile", False):
295
  self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
 
311
 
312
  z, inference_params = self.filter(z, inference_params=inference_params, padding_mask=padding_mask)
313
 
314
+ self.filter_output = z
315
+
316
  z_in = self.out_filter_dense(z) + u
317
 
318
  if type(padding_mask) == torch.Tensor: # guard against bias
 
348
  from flashfftconv import FlashFFTConv
349
  except:
350
  raise ImportError
351
+ self.flash_fft = FlashFFTConv(2 * config.max_seqlen, dtype=torch.bfloat16)
352
  else:
353
  self.flash_fft = None
354
 
355
  self.blocks = nn.ModuleList(
356
  get_block(config, layer_idx, flash_fft=self.flash_fft) for layer_idx in range(config.num_layers)
357
  )
358
+ self.gradient_checkpointing = False
359
+ self._gradient_checkpointing_func = None
360
 
361
  def forward(self, x, inference_params_dict=None, padding_mask=None):
362
  L = x.shape[1]
 
386
  x = x * padding_mask[..., None]
387
 
388
  for _, block in enumerate(self.blocks):
389
+ if self.gradient_checkpointing and self.training:
390
+ x, _ = self._gradient_checkpointing_func(block.__call__, x, None, padding_mask)
391
+ else:
392
+ x, _ = block(x, inference_params=None, padding_mask=padding_mask)
393
+
394
  return x, None
395
 
396
  def initialize_inference_params(self):
modeling_hyena.py CHANGED
@@ -2,6 +2,7 @@
2
  """StripedHyena custom code port for the Hugging Face Hub"""
3
 
4
  import torch
 
5
  from torch.nn import functional as F
6
  from .configuration_hyena import StripedHyenaConfig
7
  from transformers import PreTrainedModel
@@ -50,8 +51,32 @@ class StripedHyenaModelForCausalLM(StripedHyenaPreTrainedModel):
50
  def force_dtype(self):
51
  self.backbone.to_bfloat16_except_poles_residues()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
54
  self.backbone.gradient_checkpointing = enable
 
55
 
56
  def get_input_embeddings(self):
57
  return self.backbone.embedding_layer
 
2
  """StripedHyena custom code port for the Hugging Face Hub"""
3
 
4
  import torch
5
+ import functools
6
  from torch.nn import functional as F
7
  from .configuration_hyena import StripedHyenaConfig
8
  from transformers import PreTrainedModel
 
51
  def force_dtype(self):
52
  self.backbone.to_bfloat16_except_poles_residues()
53
 
54
+ def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
55
+ if not self.supports_gradient_checkpointing:
56
+ raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
57
+
58
+ if gradient_checkpointing_kwargs is None:
59
+ gradient_checkpointing_kwargs = {"use_reentrant": True}
60
+
61
+ # TODO support deepspeed checkpoint
62
+ gradient_checkpointing_func = functools.partial(
63
+ torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
64
+ )
65
+
66
+ self._set_gradient_checkpointing(
67
+ enable=True, gradient_checkpointing_func=gradient_checkpointing_func
68
+ )
69
+
70
+ if getattr(self, "_hf_peft_config_loaded", False):
71
+ # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
72
+ # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
73
+ # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
74
+ # the gradients to make sure the gradient flows.
75
+ self.enable_input_require_grads()
76
+
77
  def _set_gradient_checkpointing(self, enable, gradient_checkpointing_func):
78
  self.backbone.gradient_checkpointing = enable
79
+ self.backbone._gradient_checkpointing_func = gradient_checkpointing_func
80
 
81
  def get_input_embeddings(self):
82
  return self.backbone.embedding_layer