zxdu20 commited on
Commit
2460dc2
·
1 Parent(s): 42095d4

Remove hardcode bos_token_id

Browse files
Files changed (1) hide show
  1. modeling_chatglm.py +11 -13
modeling_chatglm.py CHANGED
@@ -753,9 +753,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
753
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
754
  self.word_embeddings = new_embeddings
755
 
756
- @staticmethod
757
- def get_masks(seq, device):
758
- context_length = seq.index(150004) + 1
759
 
760
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
761
  attention_mask.tril_()
@@ -766,9 +765,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
766
  return attention_mask
767
 
768
  def get_position_ids(self, seq, mask_position, device, gmask=False):
769
- context_length = seq.index(150004) + 1
770
  if self.position_encoding_2d:
771
- seq_length = seq.index(150004)
772
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
773
  if not gmask:
774
  position_ids[seq_length:] = mask_position
@@ -823,14 +822,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
823
 
824
  if past_key_values is None:
825
  past_key_values = tuple([None] * len(self.layers))
826
-
827
- MASK, gMASK = 150000, 150001
828
- mask_token = MASK if MASK in input_ids else gMASK
829
- use_gmask = False if MASK in input_ids else gMASK
830
  seq = input_ids[0].tolist()
831
 
832
- mask_position = seq.index(mask_token)
833
-
834
  if attention_mask is None:
835
  attention_mask = self.get_masks(
836
  seq=seq,
@@ -838,6 +831,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
838
  )
839
 
840
  if position_ids is None:
 
 
 
 
 
841
  position_ids = self.get_position_ids(
842
  seq=seq,
843
  mask_position=mask_position,
@@ -941,7 +939,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
941
  attention_mask = (attention_mask < 0.5).bool()
942
 
943
  if self.position_encoding_2d:
944
- seq_length = seq.index(150004)
945
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
946
  if not gmask:
947
  position_ids[seq_length:] = mask_position
@@ -979,7 +977,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
979
 
980
  # only last token for input_ids if past is not None
981
  if past is not None or past_key_values is not None:
982
- context_length = seq.index(150004)
983
  last_token = input_ids[:, -1].unsqueeze(-1)
984
  if self.position_encoding_2d:
985
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
 
753
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
754
  self.word_embeddings = new_embeddings
755
 
756
+ def get_masks(self, seq, device):
757
+ context_length = seq.index(self.config.bos_token_id) + 1
 
758
 
759
  attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
760
  attention_mask.tril_()
 
765
  return attention_mask
766
 
767
  def get_position_ids(self, seq, mask_position, device, gmask=False):
768
+ context_length = seq.index(self.config.bos_token_id) + 1
769
  if self.position_encoding_2d:
770
+ seq_length = seq.index(self.config.bos_token_id)
771
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
772
  if not gmask:
773
  position_ids[seq_length:] = mask_position
 
822
 
823
  if past_key_values is None:
824
  past_key_values = tuple([None] * len(self.layers))
 
 
 
 
825
  seq = input_ids[0].tolist()
826
 
 
 
827
  if attention_mask is None:
828
  attention_mask = self.get_masks(
829
  seq=seq,
 
831
  )
832
 
833
  if position_ids is None:
834
+ MASK, gMASK = 150000, 150001
835
+ mask_token = MASK if MASK in input_ids else gMASK
836
+ use_gmask = False if MASK in input_ids else gMASK
837
+
838
+ mask_position = seq.index(mask_token)
839
  position_ids = self.get_position_ids(
840
  seq=seq,
841
  mask_position=mask_position,
 
939
  attention_mask = (attention_mask < 0.5).bool()
940
 
941
  if self.position_encoding_2d:
942
+ seq_length = seq.index(self.config.bos_token_id)
943
  position_ids = torch.arange(context_length, dtype=torch.long, device=device)
944
  if not gmask:
945
  position_ids[seq_length:] = mask_position
 
977
 
978
  # only last token for input_ids if past is not None
979
  if past is not None or past_key_values is not None:
980
+ context_length = seq.index(self.config.bos_token_id)
981
  last_token = input_ids[:, -1].unsqueeze(-1)
982
  if self.position_encoding_2d:
983
  position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,