ltg
/

PyTorch
English
custom_code
davda54 commited on
Commit
cf8d8a4
·
verified ·
1 Parent(s): 1998a54

fix causalLM

Browse files
Files changed (1) hide show
  1. modeling_ltgbert.py +4 -3
modeling_ltgbert.py CHANGED
@@ -318,6 +318,7 @@ class LtgbertModel(LtgbertPreTrainedModel):
318
  self.transformer = Encoder(config)
319
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
320
 
 
321
  def get_input_embeddings(self):
322
  return self.embedding.word_embedding
323
 
@@ -414,7 +415,7 @@ class LtgbertForMaskedLM(LtgbertModel):
414
 
415
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
416
  subword_prediction = self.classifier(sequence_output)
417
- subword_prediction[:, :, :16+1] = float("-inf")
418
 
419
  masked_lm_loss = None
420
  if labels is not None:
@@ -443,7 +444,6 @@ class Classifier(nn.Module):
443
  super().__init__()
444
 
445
  self.temperature = config.temperature
446
-
447
  drop_out = getattr(config, "cls_dropout", None)
448
  drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
449
 
@@ -494,6 +494,7 @@ class LtgbertForCausalLM(LtgbertModel):
494
  input_ids: torch.LongTensor = None,
495
  attention_mask: Optional[torch.Tensor] = None,
496
  position_ids: Optional[torch.LongTensor] = None,
 
497
  past_key_values = None,
498
  inputs_embeds: Optional[torch.FloatTensor] = None,
499
  labels: Optional[torch.LongTensor] = None,
@@ -511,7 +512,7 @@ class LtgbertForCausalLM(LtgbertModel):
511
 
512
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
513
  subword_prediction = self.classifier(sequence_output)
514
- subword_prediction[:, :, :16+1] = float("-inf")
515
 
516
  masked_lm_loss = None
517
  if labels is not None:
 
318
  self.transformer = Encoder(config)
319
  self.classifier = MaskClassifier(config, self.embedding.word_embedding.weight) if add_mlm_layer else None
320
 
321
+
322
  def get_input_embeddings(self):
323
  return self.embedding.word_embedding
324
 
 
415
 
416
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
417
  subword_prediction = self.classifier(sequence_output)
418
+ # subword_prediction[:, :, :16+1] = float("-inf")
419
 
420
  masked_lm_loss = None
421
  if labels is not None:
 
444
  super().__init__()
445
 
446
  self.temperature = config.temperature
 
447
  drop_out = getattr(config, "cls_dropout", None)
448
  drop_out = config.hidden_dropout_prob if drop_out is None else drop_out
449
 
 
494
  input_ids: torch.LongTensor = None,
495
  attention_mask: Optional[torch.Tensor] = None,
496
  position_ids: Optional[torch.LongTensor] = None,
497
+ token_type_ids: Optional[torch.Tensor] = None,
498
  past_key_values = None,
499
  inputs_embeds: Optional[torch.FloatTensor] = None,
500
  labels: Optional[torch.LongTensor] = None,
 
512
 
513
  sequence_output, contextualized_embeddings, attention_probs = self.get_contextualized_embeddings(input_ids, attention_mask)
514
  subword_prediction = self.classifier(sequence_output)
515
+ # subword_prediction[:, :, :16+1] = float("-inf")
516
 
517
  masked_lm_loss = None
518
  if labels is not None: