zxdu20 commited on
Commit
8c4ad86
·
1 Parent(s): 258e6db

Remove hardcode bos_token_id

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +11 -13
modeling_chatglm.py CHANGED
@@ -754,9 +754,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
754
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
755
  self.word_embeddings = new_embeddings
756
 
757
- @staticmethod
758
- def get_masks(seq, device):
759
- context_length = seq.index(130004) + 1
760
 
761
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
762
  attention_mask.tril_()
@@ -767,9 +766,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
767
  return attention_mask
768
 
769
  def get_position_ids(self, seq, mask_position, device, gmask=False):
770
- context_length = seq.index(130004) + 1
771
  if self.position_encoding_2d:
772
- seq_length = seq.index(130004)
773
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
774
  if not gmask:
775
  position_ids[seq_length:] = mask_position
@@ -824,14 +823,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
824
 
825
  if past_key_values is None:
826
  past_key_values = tuple([None] * len(self.layers))
827
-
828
- MASK, gMASK = 130000, 130001
829
- mask_token = MASK if MASK in input_ids else gMASK
830
- use_gmask = False if MASK in input_ids else gMASK
831
  seq = input_ids[0].tolist()
832
 
833
- mask_position = seq.index(mask_token)
834
-
835
  if attention_mask is None:
836
  attention_mask = self.get_masks(
837
  seq=seq,
@@ -839,6 +832,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
839
  )
840
 
841
  if position_ids is None:
 
 
 
 
 
842
  position_ids = self.get_position_ids(
843
  seq=seq,
844
  mask_position=mask_position,
@@ -942,7 +940,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
942
  attention_mask = (attention_mask < 0.5).bool()
943
 
944
  if self.position_encoding_2d:
945
- seq_length = seq.index(130004)
946
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
947
  if not gmask:
948
  position_ids[seq_length:] = mask_position
@@ -980,7 +978,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
980
 
981
  # only last token for input_ids if past is not None
982
  if past is not None or past_key_values is not None:
983
- context_length = seq.index(130004)
984
  last_token = input_ids[:, -1].unsqueeze(-1)
985
  if self.position_encoding_2d:
986
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
 
754
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
755
  self.word_embeddings = new_embeddings
756
 
757
+ def get_masks(self, seq, device):
758
+ context_length = seq.index(self.config.bos_token_id) + 1
 
759
 
760
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
761
  attention_mask.tril_()
 
766
  return attention_mask
767
 
768
  def get_position_ids(self, seq, mask_position, device, gmask=False):
769
+ context_length = seq.index(self.config.bos_token_id) + 1
770
  if self.position_encoding_2d:
771
+ seq_length = seq.index(self.config.bos_token_id)
772
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
773
  if not gmask:
774
  position_ids[seq_length:] = mask_position
 
823
 
824
  if past_key_values is None:
825
  past_key_values = tuple([None] * len(self.layers))
 
 
 
 
826
  seq = input_ids[0].tolist()
827
 
 
 
828
  if attention_mask is None:
829
  attention_mask = self.get_masks(
830
  seq=seq,
 
832
  )
833
 
834
  if position_ids is None:
835
+ MASK, gMASK = 130000, 130001
836
+ mask_token = MASK if MASK in input_ids else gMASK
837
+ use_gmask = False if MASK in input_ids else gMASK
838
+
839
+ mask_position = seq.index(mask_token)
840
  position_ids = self.get_position_ids(
841
  seq=seq,
842
  mask_position=mask_position,
 
940
  attention_mask = (attention_mask < 0.5).bool()
941
 
942
  if self.position_encoding_2d:
943
+ seq_length = seq.index(self.config.bos_token_id)
944
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
945
  if not gmask:
946
  position_ids[seq_length:] = mask_position
 
978
 
979
  # only last token for input_ids if past is not None
980
  if past is not None or past_key_values is not None:
981
+ context_length = seq.index(self.config.bos_token_id)
982
  last_token = input_ids[:, -1].unsqueeze(-1)
983
  if self.position_encoding_2d:
984
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,