maxall4 commited on
Commit
e7b43af
·
verified ·
1 Parent(s): b07b34f

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -0
model.py CHANGED
@@ -66,6 +66,7 @@ class AttentionBlock(nn.Module):
66
  self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
67
 
68
  self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
 
69
 
70
  def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
71
  if (
@@ -288,6 +289,7 @@ class ParallelGatedConvBlock(nn.Module):
288
 
289
  self.proj_norm_fn = self.proj_norm
290
  self.res_mlp_norm_fn = self.res_mlp_norm
 
291
 
292
  if self.config.get("compile", False):
293
  self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
 
66
  self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
67
 
68
  self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
69
+ self.filter_output = None
70
 
71
  def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
72
  if (
 
289
 
290
  self.proj_norm_fn = self.proj_norm
291
  self.res_mlp_norm_fn = self.res_mlp_norm
292
+ self.filter_output = None
293
 
294
  if self.config.get("compile", False):
295
  self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")