Jackmin108
commited on
Commit
·
d2c9d06
1
Parent(s):
7c7eafb
fix: check adapter
Browse filesSigned-off-by: Meow <[email protected]>
- xlm_padding.py +8 -6
xlm_padding.py
CHANGED
@@ -98,7 +98,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
|
|
98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
|
100 |
|
101 |
-
def unpad_input(hidden_states, attention_mask, adapter_mask):
|
102 |
"""
|
103 |
Arguments:
|
104 |
hidden_states: (batch, seqlen, ...)
|
@@ -114,11 +114,13 @@ def unpad_input(hidden_states, attention_mask, adapter_mask):
|
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
116 |
|
117 |
-
cu_adapter_mask =
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
|
123 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
124 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|
|
|
98 |
index_first_axis_residual = IndexFirstAxisResidual.apply
|
99 |
|
100 |
|
101 |
+
def unpad_input(hidden_states, attention_mask, adapter_mask=None):
|
102 |
"""
|
103 |
Arguments:
|
104 |
hidden_states: (batch, seqlen, ...)
|
|
|
114 |
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
115 |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
116 |
|
117 |
+
cu_adapter_mask = None
|
118 |
+
if adapter_mask:
|
119 |
+
cu_adapter_mask = torch.empty(cu_seqlens[-1], dtype=torch.int32)
|
120 |
+
for i in range(len(adapter_mask)):
|
121 |
+
start_idx = cu_seqlens[i]
|
122 |
+
end_idx = cu_seqlens[i + 1]
|
123 |
+
cu_adapter_mask[start_idx:end_idx] = adapter_mask[i]
|
124 |
|
125 |
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
|
126 |
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
|