Remove hardcode bos_token_id
Browse files- modeling_chatglm.py +11 -13
modeling_chatglm.py
CHANGED
@@ -754,9 +754,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
754 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
755 |
self.word_embeddings = new_embeddings
|
756 |
|
757 |
-
|
758 |
-
|
759 |
-
context_length = seq.index(130004) + 1
|
760 |
|
761 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
762 |
attention_mask.tril_()
|
@@ -767,9 +766,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
767 |
return attention_mask
|
768 |
|
769 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
770 |
-
context_length = seq.index(
|
771 |
if self.position_encoding_2d:
|
772 |
-
seq_length = seq.index(
|
773 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
774 |
if not gmask:
|
775 |
position_ids[seq_length:] = mask_position
|
@@ -824,14 +823,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
824 |
|
825 |
if past_key_values is None:
|
826 |
past_key_values = tuple([None] * len(self.layers))
|
827 |
-
|
828 |
-
MASK, gMASK = 130000, 130001
|
829 |
-
mask_token = MASK if MASK in input_ids else gMASK
|
830 |
-
use_gmask = False if MASK in input_ids else gMASK
|
831 |
seq = input_ids[0].tolist()
|
832 |
|
833 |
-
mask_position = seq.index(mask_token)
|
834 |
-
|
835 |
if attention_mask is None:
|
836 |
attention_mask = self.get_masks(
|
837 |
seq=seq,
|
@@ -839,6 +832,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
839 |
)
|
840 |
|
841 |
if position_ids is None:
|
|
|
|
|
|
|
|
|
|
|
842 |
position_ids = self.get_position_ids(
|
843 |
seq=seq,
|
844 |
mask_position=mask_position,
|
@@ -942,7 +940,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
942 |
attention_mask = (attention_mask < 0.5).bool()
|
943 |
|
944 |
if self.position_encoding_2d:
|
945 |
-
seq_length = seq.index(
|
946 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
947 |
if not gmask:
|
948 |
position_ids[seq_length:] = mask_position
|
@@ -980,7 +978,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
980 |
|
981 |
# only last token for input_ids if past is not None
|
982 |
if past is not None or past_key_values is not None:
|
983 |
-
context_length = seq.index(
|
984 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
985 |
if self.position_encoding_2d:
|
986 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
754 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
755 |
self.word_embeddings = new_embeddings
|
756 |
|
757 |
+
def get_masks(self, seq, device):
|
758 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
|
|
759 |
|
760 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
761 |
attention_mask.tril_()
|
|
|
766 |
return attention_mask
|
767 |
|
768 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
769 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
770 |
if self.position_encoding_2d:
|
771 |
+
seq_length = seq.index(self.config.bos_token_id)
|
772 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
773 |
if not gmask:
|
774 |
position_ids[seq_length:] = mask_position
|
|
|
823 |
|
824 |
if past_key_values is None:
|
825 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
|
|
|
|
|
826 |
seq = input_ids[0].tolist()
|
827 |
|
|
|
|
|
828 |
if attention_mask is None:
|
829 |
attention_mask = self.get_masks(
|
830 |
seq=seq,
|
|
|
832 |
)
|
833 |
|
834 |
if position_ids is None:
|
835 |
+
MASK, gMASK = 130000, 130001
|
836 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
837 |
+
use_gmask = False if MASK in input_ids else gMASK
|
838 |
+
|
839 |
+
mask_position = seq.index(mask_token)
|
840 |
position_ids = self.get_position_ids(
|
841 |
seq=seq,
|
842 |
mask_position=mask_position,
|
|
|
940 |
attention_mask = (attention_mask < 0.5).bool()
|
941 |
|
942 |
if self.position_encoding_2d:
|
943 |
+
seq_length = seq.index(self.config.bos_token_id)
|
944 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
945 |
if not gmask:
|
946 |
position_ids[seq_length:] = mask_position
|
|
|
978 |
|
979 |
# only last token for input_ids if past is not None
|
980 |
if past is not None or past_key_values is not None:
|
981 |
+
context_length = seq.index(self.config.bos_token_id)
|
982 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
983 |
if self.position_encoding_2d:
|
984 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|