Update model.py
Browse files
model.py
CHANGED
@@ -66,6 +66,7 @@ class AttentionBlock(nn.Module):
|
|
66 |
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
|
67 |
|
68 |
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
|
|
|
69 |
|
70 |
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
71 |
if (
|
@@ -288,6 +289,7 @@ class ParallelGatedConvBlock(nn.Module):
|
|
288 |
|
289 |
self.proj_norm_fn = self.proj_norm
|
290 |
self.res_mlp_norm_fn = self.res_mlp_norm
|
|
|
291 |
|
292 |
if self.config.get("compile", False):
|
293 |
self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
|
|
|
66 |
self.inner_mha_cls.rotary_emb.register_buffer("inv_freq", self.inner_mha_cls.rotary_emb.inv_freq)
|
67 |
|
68 |
self.mlp = ParallelGatedMLP(config).to(dtype=mlp_dtype)
|
69 |
+
self.filter_output = None
|
70 |
|
71 |
def forward(self, u, inference_params=None, padding_mask=None, *args, **kwargs):
|
72 |
if (
|
|
|
289 |
|
290 |
self.proj_norm_fn = self.proj_norm
|
291 |
self.res_mlp_norm_fn = self.res_mlp_norm
|
292 |
+
self.filter_output = None
|
293 |
|
294 |
if self.config.get("compile", False):
|
295 |
self.proj_norm_fn = torch.compile(self.proj_norm, fullgraph=True, dynamic=False, mode="reduce-overhead")
|