Jackmin108 commited on
Commit
d2c9d06
·
1 Parent(s): 7c7eafb

fix: check adapter

Browse files

Signed-off-by: Meow <[email protected]>

Files changed (1) hide show
  1. xlm_padding.py +8 -6
xlm_padding.py CHANGED
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
- def unpad_input(hidden_states, attention_mask, adapter_mask):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
@@ -114,11 +114,13 @@ def unpad_input(hidden_states, attention_mask, adapter_mask):
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
 
117
- cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
118
- for i in range(len(adapter_mask)):
119
- start_idx = cu_seqlens[i]
120
- end_idx = cu_seqlens[i + 1]
121
- cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
 
 
122
 
123
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
124
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
 
98
  index_first_axis_residual = IndexFirstAxisResidual.apply
99
 
100
 
101
+ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
102
  """
103
  Arguments:
104
  hidden_states: (batch, seqlen, ...)
 
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
  cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
116
 
117
+ cu_adapter_mask = None
118
+ if adapter_mask:
119
+ cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
120
+ for i in range(len(adapter_mask)):
121
+ start_idx = cu_seqlens[i]
122
+ end_idx = cu_seqlens[i + 1]
123
+ cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
124
 
125
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
126
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim