debisoft commited on
Commit
62621c0
·
1 Parent(s): 162bf01

Patch for Gradient Checkpointing

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +39 -14
modeling_mpt.py CHANGED
@@ -34,12 +34,19 @@ class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
35
  base_model_prefix = 'model'
36
  _no_split_modules = ['MPTBlock']
 
 
 
 
 
 
37
 
38
  class MPTModel(MPTPreTrainedModel):
39
 
40
  def __init__(self, config: MPTConfig):
41
  config._validate_config()
42
  super().__init__(config)
 
43
  self.attn_impl = config.attn_config['attn_impl']
44
  self.prefix_lm = config.attn_config['prefix_lm']
45
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
@@ -80,10 +87,10 @@ class MPTModel(MPTPreTrainedModel):
80
  log.debug(self)
81
  log.debug(f"Using {self.config.init_config['name']} initialization.")
82
 
83
- def get_input_embeddings(self) -> nn.Embedding:
84
  return self.wte
85
 
86
- def set_input_embeddings(self, value: nn.Embedding) -> None:
87
  self.wte = value
88
 
89
  @torch.no_grad()
@@ -143,7 +150,7 @@ class MPTModel(MPTPreTrainedModel):
143
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
144
  return attn_bias
145
 
146
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
147
  return_dict = return_dict if return_dict is not None else self.config.return_dict
148
  use_cache = use_cache if use_cache is not None else self.config.use_cache
149
  if attention_mask is not None:
@@ -159,13 +166,15 @@ class MPTModel(MPTPreTrainedModel):
159
  raise NotImplementedError('MPT does not support training with left padding.')
160
  if self.prefix_lm and prefix_mask is None:
161
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
162
- if inputs_embeds is not None:
163
- raise NotImplementedError('inputs_embeds is not implemented for MPT.')
164
  if self.training:
165
  if self.attn_uses_sequence_id and sequence_id is None:
166
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
167
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
168
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
 
169
  S = input_ids.size(1)
170
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
171
  tok_emb = self.wte(input_ids)
@@ -203,7 +212,25 @@ class MPTModel(MPTPreTrainedModel):
203
  assert all_hidden_states is not None
204
  all_hidden_states = all_hidden_states + (x,)
205
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
206
- (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if presents is not None:
208
  presents += (present,)
209
  if output_attentions:
@@ -232,7 +259,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
232
  if not config.tie_word_embeddings:
233
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
234
  log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
235
- self.transformer: MPTModel = MPTModel(config)
236
  for child in self.transformer.children():
237
  if isinstance(child, torch.nn.ModuleList):
238
  continue
@@ -266,11 +293,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
266
  def get_decoder(self) -> MPTModel:
267
  return self.transformer
268
 
269
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
270
  return_dict = return_dict if return_dict is not None else self.config.return_dict
271
  use_cache = use_cache if use_cache is not None else self.config.use_cache
272
- if inputs_embeds is not None:
273
- raise NotImplementedError('inputs_embeds has to be None (for hf/peft support).')
274
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
275
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
276
  if self.logit_scale is not None:
@@ -279,9 +304,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
279
  logits *= self.logit_scale
280
  loss = None
281
  if labels is not None:
282
- _labels = torch.roll(labels, shifts=-1)
283
- _labels[:, -1] = -100
284
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
285
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
286
 
287
  def param_init_fn(self, module: nn.Module) -> None:
@@ -324,4 +349,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
324
  reordered_past = []
325
  for layer_past in past_key_values:
326
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
327
- return reordered_past
 
34
  config_class = MPTConfig
35
  base_model_prefix = 'model'
36
  _no_split_modules = ['MPTBlock']
37
+
38
+ supports_gradient_checkpointing = True
39
+
40
+ def _set_gradient_checkpointing(self, module, value=False):
41
+ if isinstance(module, MPTModel):
42
+ module.gradient_checkpointing = value
43
 
44
  class MPTModel(MPTPreTrainedModel):
45
 
46
  def __init__(self, config: MPTConfig):
47
  config._validate_config()
48
  super().__init__(config)
49
+ self.gradient_checkpointing = False
50
  self.attn_impl = config.attn_config['attn_impl']
51
  self.prefix_lm = config.attn_config['prefix_lm']
52
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
 
87
  log.debug(self)
88
  log.debug(f"Using {self.config.init_config['name']} initialization.")
89
 
90
+ def get_input_embeddings(self)
91
  return self.wte
92
 
93
+ def set_input_embeddings(self, value) -> None:
94
  self.wte = value
95
 
96
  @torch.no_grad()
 
150
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
151
  return attn_bias
152
 
153
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None) -> BaseModelOutputWithPast:
154
  return_dict = return_dict if return_dict is not None else self.config.return_dict
155
  use_cache = use_cache if use_cache is not None else self.config.use_cache
156
  if attention_mask is not None:
 
166
  raise NotImplementedError('MPT does not support training with left padding.')
167
  if self.prefix_lm and prefix_mask is None:
168
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
 
169
  if self.training:
170
  if self.attn_uses_sequence_id and sequence_id is None:
171
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
172
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
173
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
174
+ if self.gradient_checkpointing and self.training:
175
+ if use_cache:
176
+ use_cache = False
177
+
178
  S = input_ids.size(1)
179
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
180
  tok_emb = self.wte(input_ids)
 
212
  assert all_hidden_states is not None
213
  all_hidden_states = all_hidden_states + (x,)
214
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
215
+ if self.gradient_checkpointing and self.training:
216
+
217
+ def create_custom_forward(module):
218
+ def custom_forward(*inputs):
219
+ # None for past_key_value
220
+ return module(*inputs)
221
+
222
+ return custom_forward
223
+
224
+ (x, attn_weights, past_key_value) = torch.utils.checkpoint.checkpoint(
225
+ create_custom_forward(block),
226
+ x,
227
+ past_key_value,
228
+ attn_bias,
229
+ attention_mask,
230
+ self.is_causal,
231
+ )
232
+ else:
233
+ (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
234
  if presents is not None:
235
  presents += (present,)
236
  if output_attentions:
 
259
  if not config.tie_word_embeddings:
260
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
261
  log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
262
+ self.transformer = MPTModel(config)
263
  for child in self.transformer.children():
264
  if isinstance(child, torch.nn.ModuleList):
265
  continue
 
293
  def get_decoder(self) -> MPTModel:
294
  return self.transformer
295
 
296
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None) -> CausalLMOutputWithPast:
297
  return_dict = return_dict if return_dict is not None else self.config.return_dict
298
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
 
299
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
300
  logits = self.transformer.wte(outputs.last_hidden_state.to(self.transformer.wte.weight.device), True)
301
  if self.logit_scale is not None:
 
304
  logits *= self.logit_scale
305
  loss = None
306
  if labels is not None:
307
+ labels = torch.roll(labels, shifts=-1)
308
+ labels[:, -1] = -100
309
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
310
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
311
 
312
  def param_init_fn(self, module: nn.Module) -> None:
 
349
  reordered_past = []
350
  for layer_past in past_key_values:
351
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
352
+ return reordered_past