ccdv commited on
Commit
7a01dba
·
1 Parent(s): f160ab4

replace 1e4 mask

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. modeling_lsg_camembert.py +11 -7
README.md CHANGED
@@ -46,6 +46,7 @@ You can change various parameters like :
46
  * local block size (block_size=128)
47
  * sparse block size (sparse_block_size=128)
48
  * sparsity factor (sparsity_factor=2)
 
49
  * see config.json file
50
 
51
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
 
46
  * local block size (block_size=128)
47
  * sparse block size (sparse_block_size=128)
48
  * sparsity factor (sparsity_factor=2)
49
+ * mask_first_token (mask first token since it is redundant with the first global token)
50
  * see config.json file
51
 
52
  Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
modeling_lsg_camembert.py CHANGED
@@ -182,7 +182,11 @@ class CausalAttentionProduct(nn.Module):
182
 
183
  # Add causal mask
184
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
185
- causal_mask = torch.tril(torch.ones(*causal_shape, device=attention_mask.device), diagonal=-1).T * (-10000)
 
 
 
 
186
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
187
 
188
  del attention_mask
@@ -300,7 +304,7 @@ class LSGAttentionProduct(nn.Module):
300
 
301
  # Pad before block reshaping
302
  if is_attn_mask:
303
- pad_value = -10000
304
  hidden_states = hidden_states.transpose(-1, -2)
305
  else:
306
  pad_value = 0
@@ -333,7 +337,7 @@ class LSGAttentionProduct(nn.Module):
333
 
334
  # Pad before block reshaping
335
  if is_attn_mask:
336
- pad_value = -10000
337
  hidden_states = hidden_states.transpose(-1, -2)
338
  else:
339
  pad_value = 0
@@ -557,7 +561,7 @@ class LSGSelfAttention(BaseSelfAttention):
557
  keys = keys.sum(dim=-2) / (mask + 1e-6)
558
  values = values.sum(dim=-2) / (mask + 1e-6)
559
 
560
- mask = - (1. - mask.clamp(0, 1)) * 1e4
561
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
562
 
563
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -622,7 +626,7 @@ class LSGSelfAttention(BaseSelfAttention):
622
  keys /= mask + 1e-8
623
  values /= mask + 1e-8
624
 
625
- mask = -10000 * (1. - mask.clamp(0, 1))
626
 
627
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
628
 
@@ -988,7 +992,7 @@ class LSGCamembertModel(LSGCamembertPreTrainedModel, RobertaModel):
988
  n, t = inputs_.size()[:2]
989
 
990
  if attention_mask is None:
991
- attention_mask = torch.ones(n, t, device=inputs_.device)
992
  if self.mask_first_token:
993
  attention_mask[:,0] = 0
994
 
@@ -1069,7 +1073,7 @@ class LSGCamembertModel(LSGCamembertPreTrainedModel, RobertaModel):
1069
  )
1070
 
1071
  extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1072
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1073
 
1074
  return extended_attention_mask
1075
 
 
182
 
183
  # Add causal mask
184
  causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
185
+ causal_mask = torch.tril(
186
+ torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
187
+ diagonal=-1
188
+ )
189
+ causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
190
  attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
191
 
192
  del attention_mask
 
304
 
305
  # Pad before block reshaping
306
  if is_attn_mask:
307
+ pad_value = torch.finfo(hidden_states.dtype).min
308
  hidden_states = hidden_states.transpose(-1, -2)
309
  else:
310
  pad_value = 0
 
337
 
338
  # Pad before block reshaping
339
  if is_attn_mask:
340
+ pad_value = torch.finfo(hidden_states.dtype).min
341
  hidden_states = hidden_states.transpose(-1, -2)
342
  else:
343
  pad_value = 0
 
561
  keys = keys.sum(dim=-2) / (mask + 1e-6)
562
  values = values.sum(dim=-2) / (mask + 1e-6)
563
 
564
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
565
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
566
 
567
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
626
  keys /= mask + 1e-8
627
  values /= mask + 1e-8
628
 
629
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
630
 
631
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
632
 
 
992
  n, t = inputs_.size()[:2]
993
 
994
  if attention_mask is None:
995
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
996
  if self.mask_first_token:
997
  attention_mask[:,0] = 0
998
 
 
1073
  )
1074
 
1075
  extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1076
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(extended_attention_mask.dtype).min
1077
 
1078
  return extended_attention_mask
1079