maxall4 commited on
Commit
ea9486a
·
verified ·
1 Parent(s): 281dab8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +3 -3
model.py CHANGED
@@ -355,7 +355,7 @@ class StripedHyena(nn.Module):
355
  self.gradient_checkpointing = False
356
  self._gradient_checkpointing_func = None
357
 
358
- def forward(self, x, inference_params_dict=None, padding_mask=None):
359
  L = x.shape[1]
360
  x = self.embedding_layer.embed(x)
361
  if inference_params_dict is not None:
@@ -370,7 +370,7 @@ class StripedHyena(nn.Module):
370
  x = self.unembed.unembed(x)
371
  return x, inference_params_dict_out
372
 
373
- def stateful_forward(self, x, inference_params_dict=None):
374
  for block_idx, block in enumerate(self.blocks):
375
  block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
376
  inference_params = inference_params_dict[block_name]
@@ -378,7 +378,7 @@ class StripedHyena(nn.Module):
378
 
379
  return x, inference_params_dict
380
 
381
- def stateless_forward(self, x, padding_mask=None):
382
  if type(padding_mask) == torch.Tensor:
383
  x = x * padding_mask[..., None]
384
 
 
355
  self.gradient_checkpointing = False
356
  self._gradient_checkpointing_func = None
357
 
358
+ def forward(self, input_ids, inference_params_dict=None, padding_mask=None):
359
  L = x.shape[1]
360
  x = self.embedding_layer.embed(x)
361
  if inference_params_dict is not None:
 
370
  x = self.unembed.unembed(x)
371
  return x, inference_params_dict_out
372
 
373
+ def stateful_forward(self, input_ids, inference_params_dict=None):
374
  for block_idx, block in enumerate(self.blocks):
375
  block_name = "mha" if block_idx in self.config.attn_layer_idxs else "hyena"
376
  inference_params = inference_params_dict[block_name]
 
378
 
379
  return x, inference_params_dict
380
 
381
+ def stateless_forward(self, input_ids, padding_mask=None):
382
  if type(padding_mask) == torch.Tensor:
383
  x = x * padding_mask[..., None]
384