fix: remove cleaving

#13
by Markus28 - opened
Files changed (1) hide show
  1. modeling_bert.py +4 -23
modeling_bert.py CHANGED
@@ -166,25 +166,6 @@ class BertEncoder(nn.Module):
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
169
- self._last_layer_idx = len(self.layers) - 1
170
-
171
- @property
172
- def last_layer_idx(self):
173
- return self._last_layer_idx
174
-
175
- @last_layer_idx.setter
176
- def last_layer_idx(self, idx: int):
177
- assert 0 <= idx < len(self.layers)
178
- self._last_layer_idx = idx
179
-
180
- @property
181
- def cleaved_layers(self):
182
- return len(self.layers) - self.last_layer_idx - 1
183
-
184
- @cleaved_layers.setter
185
- def cleaved_layers(self, n: int):
186
- assert 0 <= n < len(self.layers)
187
- self.last_layer_idx = len(self.layers) - n - 1
188
 
189
  @property
190
  def gradient_checkpointing(self):
@@ -205,7 +186,7 @@ class BertEncoder(nn.Module):
205
  mixer_kwargs = (
206
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
207
  )
208
- for layer in self.layers[:self.last_layer_idx + 1]:
209
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
210
  if subset_mask is not None:
211
  hidden_states = hidden_states[subset_mask]
@@ -216,11 +197,11 @@ class BertEncoder(nn.Module):
216
  )
217
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
218
  if subset_mask is None:
219
- for layer in self.layers[:self.last_layer_idx + 1]:
220
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
221
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
222
  else:
223
- for layer in self.layers[:self.last_layer_idx]:
224
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
225
  if key_padding_mask is not None:
226
  subset_idx = torch.nonzero(
@@ -247,7 +228,7 @@ class BertEncoder(nn.Module):
247
  "cu_seqlens_k": cu_seqlens,
248
  "max_seqlen_k": max_seqlen_in_batch,
249
  }
250
- hidden_states = self.layers[self.last_layer_idx](hidden_states_subset, mixer_kwargs=mixer_kwargs)
251
  return hidden_states
252
 
253
 
 
166
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
167
  )
168
  self._grad_checkpointing = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
  @property
171
  def gradient_checkpointing(self):
 
186
  mixer_kwargs = (
187
  {"key_padding_mask": key_padding_mask.bool()} if key_padding_mask is not None else None
188
  )
189
+ for layer in self.layers:
190
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
191
  if subset_mask is not None:
192
  hidden_states = hidden_states[subset_mask]
 
197
  )
198
  mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch}
199
  if subset_mask is None:
200
+ for layer in self.layers:
201
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
202
  hidden_states = pad_input(hidden_states, indices, batch, seqlen)
203
  else:
204
+ for layer in self.layers[:-1]:
205
  hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs)
206
  if key_padding_mask is not None:
207
  subset_idx = torch.nonzero(
 
228
  "cu_seqlens_k": cu_seqlens,
229
  "max_seqlen_k": max_seqlen_in_batch,
230
  }
231
+ hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs)
232
  return hidden_states
233
 
234