remove image tokens from chatglm-6b
Browse files- config.json +4 -4
- configuration_chatglm.py +4 -3
- modeling_chatglm.py +14 -13
- pytorch_model-00001-of-00008.bin → pytorch_model-00001-of-00008-slim.bin +2 -2
- pytorch_model-00008-of-00008.bin → pytorch_model-00008-of-00008-slim.bin +2 -2
- pytorch_model.bin.index.json +26 -26
- tokenization_chatglm.py +6 -16
- tokenizer_config.json +1 -1
config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"_name_or_path": "
|
3 |
"architectures": [
|
4 |
"ChatGLMModel"
|
5 |
],
|
@@ -8,8 +8,8 @@
|
|
8 |
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
9 |
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
|
10 |
},
|
11 |
-
"bos_token_id":
|
12 |
-
"eos_token_id":
|
13 |
"hidden_size": 4096,
|
14 |
"inner_hidden_size": 16384,
|
15 |
"layernorm_epsilon": 1e-05,
|
@@ -21,5 +21,5 @@
|
|
21 |
"torch_dtype": "float16",
|
22 |
"transformers_version": "4.23.1",
|
23 |
"use_cache": true,
|
24 |
-
"vocab_size":
|
25 |
}
|
|
|
1 |
{
|
2 |
+
"_name_or_path": "silver/chatglm-6b-slim",
|
3 |
"architectures": [
|
4 |
"ChatGLMModel"
|
5 |
],
|
|
|
8 |
"AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration",
|
9 |
"AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"
|
10 |
},
|
11 |
+
"bos_token_id": 130004,
|
12 |
+
"eos_token_id": 130005,
|
13 |
"hidden_size": 4096,
|
14 |
"inner_hidden_size": 16384,
|
15 |
"layernorm_epsilon": 1e-05,
|
|
|
21 |
"torch_dtype": "float16",
|
22 |
"transformers_version": "4.23.1",
|
23 |
"use_cache": true,
|
24 |
+
"vocab_size": 130528
|
25 |
}
|
configuration_chatglm.py
CHANGED
@@ -12,6 +12,7 @@ class ChatGLMConfig(PretrainedConfig):
|
|
12 |
It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
|
13 |
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
14 |
the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
|
|
|
15 |
|
16 |
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
17 |
to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
@@ -58,14 +59,14 @@ class ChatGLMConfig(PretrainedConfig):
|
|
58 |
|
59 |
def __init__(
|
60 |
self,
|
61 |
-
vocab_size=
|
62 |
hidden_size=4096,
|
63 |
num_layers=28,
|
64 |
num_attention_heads=32,
|
65 |
layernorm_epsilon=1e-5,
|
66 |
use_cache=False,
|
67 |
-
bos_token_id=
|
68 |
-
eos_token_id=
|
69 |
pad_token_id=0,
|
70 |
max_sequence_length=2048,
|
71 |
inner_hidden_size=16384,
|
|
|
12 |
It is used to instantiate an ChatGLM model according to the specified arguments, defining the model
|
13 |
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
14 |
the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture.
|
15 |
+
We remove 20K image tokens on top of ChatGLM-6B to save memories.
|
16 |
|
17 |
Configuration objects inherit from [`PretrainedConfig`] and can be used
|
18 |
to control the model outputs. Read the documentation from [`PretrainedConfig`]
|
|
|
59 |
|
60 |
def __init__(
|
61 |
self,
|
62 |
+
vocab_size=130528,
|
63 |
hidden_size=4096,
|
64 |
num_layers=28,
|
65 |
num_attention_heads=32,
|
66 |
layernorm_epsilon=1e-5,
|
67 |
use_cache=False,
|
68 |
+
bos_token_id=130004,
|
69 |
+
eos_token_id=130005,
|
70 |
pad_token_id=0,
|
71 |
max_sequence_length=2048,
|
72 |
inner_hidden_size=16384,
|
modeling_chatglm.py
CHANGED
@@ -28,7 +28,7 @@ from transformers.utils import logging
|
|
28 |
from transformers.generation.logits_process import LogitsProcessor
|
29 |
from transformers.generation.utils import LogitsProcessorList
|
30 |
|
31 |
-
from
|
32 |
|
33 |
# flags required to enable jit fusion kernels
|
34 |
torch._C._jit_set_profiling_mode(False)
|
@@ -38,12 +38,13 @@ torch._C._jit_override_can_fuse_on_gpu(True)
|
|
38 |
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
41 |
-
_CHECKPOINT_FOR_DOC = "
|
42 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
43 |
|
44 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
-
"
|
46 |
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
|
|
|
47 |
]
|
48 |
|
49 |
|
@@ -51,7 +52,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|
51 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
52 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
53 |
scores.zero_()
|
54 |
-
scores[...,
|
55 |
return scores
|
56 |
|
57 |
|
@@ -755,7 +756,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
755 |
|
756 |
@staticmethod
|
757 |
def get_masks(seq, device):
|
758 |
-
context_length = seq.index(
|
759 |
|
760 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
761 |
attention_mask.tril_()
|
@@ -766,9 +767,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
766 |
return attention_mask
|
767 |
|
768 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
769 |
-
context_length = seq.index(
|
770 |
if self.position_encoding_2d:
|
771 |
-
seq_length = seq.index(
|
772 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
773 |
if not gmask:
|
774 |
position_ids[seq_length:] = mask_position
|
@@ -824,7 +825,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
824 |
if past_key_values is None:
|
825 |
past_key_values = tuple([None] * len(self.layers))
|
826 |
|
827 |
-
MASK, gMASK =
|
828 |
mask_token = MASK if MASK in input_ids else gMASK
|
829 |
use_gmask = False if MASK in input_ids else gMASK
|
830 |
seq = input_ids[0].tolist()
|
@@ -941,7 +942,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
941 |
attention_mask = (attention_mask < 0.5).bool()
|
942 |
|
943 |
if self.position_encoding_2d:
|
944 |
-
seq_length = seq.index(
|
945 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
946 |
if not gmask:
|
947 |
position_ids[seq_length:] = mask_position
|
@@ -968,7 +969,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
968 |
**kwargs
|
969 |
) -> dict:
|
970 |
|
971 |
-
MASK, gMASK =
|
972 |
mask_token = MASK if MASK in input_ids else gMASK
|
973 |
use_gmask = False if MASK in input_ids else gMASK
|
974 |
seq = input_ids[0].tolist()
|
@@ -979,7 +980,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
979 |
|
980 |
# only last token for input_ids if past is not None
|
981 |
if past is not None or past_key_values is not None:
|
982 |
-
context_length = seq.index(
|
983 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
984 |
if self.position_encoding_2d:
|
985 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
@@ -1119,8 +1120,8 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
1119 |
self,
|
1120 |
**kwargs,
|
1121 |
):
|
1122 |
-
MASK, gMASK =
|
1123 |
-
bos, eos =
|
1124 |
|
1125 |
if "eos_token_id" not in kwargs:
|
1126 |
kwargs["eos_token_id"] = eos
|
|
|
28 |
from transformers.generation.logits_process import LogitsProcessor
|
29 |
from transformers.generation.utils import LogitsProcessorList
|
30 |
|
31 |
+
from configuration_chatglm import ChatGLMConfig
|
32 |
|
33 |
# flags required to enable jit fusion kernels
|
34 |
torch._C._jit_set_profiling_mode(False)
|
|
|
38 |
|
39 |
logger = logging.get_logger(__name__)
|
40 |
|
41 |
+
_CHECKPOINT_FOR_DOC = "silver/ChatGLM-6B"
|
42 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
43 |
|
44 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
45 |
+
"silver/chatglm-6b-slim",
|
46 |
# See all ChatGLM-6B models at https://huggingface.co/models?filter=chatglm
|
47 |
+
# See the slim model at https://huggingface.co/silver/chatglm-6b-slim
|
48 |
]
|
49 |
|
50 |
|
|
|
52 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
53 |
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
54 |
scores.zero_()
|
55 |
+
scores[..., 5] = 5e4
|
56 |
return scores
|
57 |
|
58 |
|
|
|
756 |
|
757 |
@staticmethod
|
758 |
def get_masks(seq, device):
|
759 |
+
context_length = seq.index(130004) + 1
|
760 |
|
761 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
762 |
attention_mask.tril_()
|
|
|
767 |
return attention_mask
|
768 |
|
769 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
770 |
+
context_length = seq.index(130004) + 1
|
771 |
if self.position_encoding_2d:
|
772 |
+
seq_length = seq.index(130004)
|
773 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
774 |
if not gmask:
|
775 |
position_ids[seq_length:] = mask_position
|
|
|
825 |
if past_key_values is None:
|
826 |
past_key_values = tuple([None] * len(self.layers))
|
827 |
|
828 |
+
MASK, gMASK = 130000, 130001
|
829 |
mask_token = MASK if MASK in input_ids else gMASK
|
830 |
use_gmask = False if MASK in input_ids else gMASK
|
831 |
seq = input_ids[0].tolist()
|
|
|
942 |
attention_mask = (attention_mask < 0.5).bool()
|
943 |
|
944 |
if self.position_encoding_2d:
|
945 |
+
seq_length = seq.index(130004)
|
946 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
947 |
if not gmask:
|
948 |
position_ids[seq_length:] = mask_position
|
|
|
969 |
**kwargs
|
970 |
) -> dict:
|
971 |
|
972 |
+
MASK, gMASK = 130000, 130001
|
973 |
mask_token = MASK if MASK in input_ids else gMASK
|
974 |
use_gmask = False if MASK in input_ids else gMASK
|
975 |
seq = input_ids[0].tolist()
|
|
|
980 |
|
981 |
# only last token for input_ids if past is not None
|
982 |
if past is not None or past_key_values is not None:
|
983 |
+
context_length = seq.index(130004)
|
984 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
985 |
if self.position_encoding_2d:
|
986 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
1120 |
self,
|
1121 |
**kwargs,
|
1122 |
):
|
1123 |
+
MASK, gMASK = 130000, 130001
|
1124 |
+
bos, eos = 130004, 130005
|
1125 |
|
1126 |
if "eos_token_id" not in kwargs:
|
1127 |
kwargs["eos_token_id"] = eos
|
pytorch_model-00001-of-00008.bin → pytorch_model-00001-of-00008-slim.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c85647a7f3c817274a767dbee01428a9f1b3eb855cfd7849625b8ad7753e4dbf
|
3 |
+
size 1904493208
|
pytorch_model-00008-of-00008.bin → pytorch_model-00008-of-00008-slim.bin
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36e8039413913b7326c4fc5fcbcd2bf4c03b03bea1ff1bdaf6b74b46df0053e9
|
3 |
+
size 1233126329
|
pytorch_model.bin.index.json
CHANGED
@@ -3,35 +3,35 @@
|
|
3 |
"total_size": 13744473856
|
4 |
},
|
5 |
"weight_map": {
|
6 |
-
"lm_head.weight": "pytorch_model-00008-of-00008.bin",
|
7 |
"transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin",
|
8 |
"transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin",
|
9 |
-
"transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008.bin",
|
10 |
-
"transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008.bin",
|
11 |
-
"transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin",
|
12 |
-
"transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin",
|
13 |
-
"transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin",
|
14 |
-
"transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008.bin",
|
15 |
-
"transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008.bin",
|
16 |
-
"transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008.bin",
|
17 |
-
"transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008.bin",
|
18 |
-
"transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin",
|
19 |
-
"transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin",
|
20 |
-
"transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin",
|
21 |
-
"transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin",
|
22 |
-
"transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008.bin",
|
23 |
-
"transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008.bin",
|
24 |
-
"transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008.bin",
|
25 |
-
"transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008.bin",
|
26 |
-
"transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008.bin",
|
27 |
-
"transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008.bin",
|
28 |
-
"transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008.bin",
|
29 |
"transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
|
30 |
"transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
|
31 |
-
"transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008.bin",
|
32 |
-
"transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008.bin",
|
33 |
-
"transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008.bin",
|
34 |
-
"transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008.bin",
|
35 |
"transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
|
36 |
"transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
|
37 |
"transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
|
@@ -370,6 +370,6 @@
|
|
370 |
"transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
|
371 |
"transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
|
372 |
"transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
|
373 |
-
"transformer.word_embeddings.weight": "pytorch_model-00001-of-00008.bin"
|
374 |
}
|
375 |
}
|
|
|
3 |
"total_size": 13744473856
|
4 |
},
|
5 |
"weight_map": {
|
6 |
+
"lm_head.weight": "pytorch_model-00008-of-00008-slim.bin",
|
7 |
"transformer.final_layernorm.bias": "pytorch_model-00007-of-00008.bin",
|
8 |
"transformer.final_layernorm.weight": "pytorch_model-00007-of-00008.bin",
|
9 |
+
"transformer.layers.0.attention.dense.bias": "pytorch_model-00001-of-00008-slim.bin",
|
10 |
+
"transformer.layers.0.attention.dense.weight": "pytorch_model-00001-of-00008-slim.bin",
|
11 |
+
"transformer.layers.0.attention.query_key_value.bias": "pytorch_model-00001-of-00008-slim.bin",
|
12 |
+
"transformer.layers.0.attention.query_key_value.weight": "pytorch_model-00001-of-00008-slim.bin",
|
13 |
+
"transformer.layers.0.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008-slim.bin",
|
14 |
+
"transformer.layers.0.input_layernorm.bias": "pytorch_model-00001-of-00008-slim.bin",
|
15 |
+
"transformer.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00008-slim.bin",
|
16 |
+
"transformer.layers.0.mlp.dense_4h_to_h.bias": "pytorch_model-00001-of-00008-slim.bin",
|
17 |
+
"transformer.layers.0.mlp.dense_4h_to_h.weight": "pytorch_model-00001-of-00008-slim.bin",
|
18 |
+
"transformer.layers.0.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008-slim.bin",
|
19 |
+
"transformer.layers.0.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008-slim.bin",
|
20 |
+
"transformer.layers.0.post_attention_layernorm.bias": "pytorch_model-00001-of-00008-slim.bin",
|
21 |
+
"transformer.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00008-slim.bin",
|
22 |
+
"transformer.layers.1.attention.dense.bias": "pytorch_model-00001-of-00008-slim.bin",
|
23 |
+
"transformer.layers.1.attention.dense.weight": "pytorch_model-00001-of-00008-slim.bin",
|
24 |
+
"transformer.layers.1.attention.query_key_value.bias": "pytorch_model-00001-of-00008-slim.bin",
|
25 |
+
"transformer.layers.1.attention.query_key_value.weight": "pytorch_model-00001-of-00008-slim.bin",
|
26 |
+
"transformer.layers.1.attention.rotary_emb.inv_freq": "pytorch_model-00001-of-00008-slim.bin",
|
27 |
+
"transformer.layers.1.input_layernorm.bias": "pytorch_model-00001-of-00008-slim.bin",
|
28 |
+
"transformer.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00008-slim.bin",
|
29 |
"transformer.layers.1.mlp.dense_4h_to_h.bias": "pytorch_model-00002-of-00008.bin",
|
30 |
"transformer.layers.1.mlp.dense_4h_to_h.weight": "pytorch_model-00002-of-00008.bin",
|
31 |
+
"transformer.layers.1.mlp.dense_h_to_4h.bias": "pytorch_model-00001-of-00008-slim.bin",
|
32 |
+
"transformer.layers.1.mlp.dense_h_to_4h.weight": "pytorch_model-00001-of-00008-slim.bin",
|
33 |
+
"transformer.layers.1.post_attention_layernorm.bias": "pytorch_model-00001-of-00008-slim.bin",
|
34 |
+
"transformer.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00008-slim.bin",
|
35 |
"transformer.layers.10.attention.dense.bias": "pytorch_model-00003-of-00008.bin",
|
36 |
"transformer.layers.10.attention.dense.weight": "pytorch_model-00003-of-00008.bin",
|
37 |
"transformer.layers.10.attention.query_key_value.bias": "pytorch_model-00003-of-00008.bin",
|
|
|
370 |
"transformer.layers.9.mlp.dense_h_to_4h.weight": "pytorch_model-00003-of-00008.bin",
|
371 |
"transformer.layers.9.post_attention_layernorm.bias": "pytorch_model-00003-of-00008.bin",
|
372 |
"transformer.layers.9.post_attention_layernorm.weight": "pytorch_model-00003-of-00008.bin",
|
373 |
+
"transformer.word_embeddings.weight": "pytorch_model-00001-of-00008-slim.bin"
|
374 |
}
|
375 |
}
|
tokenization_chatglm.py
CHANGED
@@ -16,7 +16,7 @@ from transformers.utils import logging
|
|
16 |
logger = logging.get_logger(__name__)
|
17 |
|
18 |
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
19 |
-
"
|
20 |
}
|
21 |
|
22 |
|
@@ -85,17 +85,13 @@ class SPTokenizer:
|
|
85 |
def get_tab_token():
|
86 |
return f"<|tab|>"
|
87 |
|
88 |
-
@property
|
89 |
-
def num_image_tokens(self):
|
90 |
-
return 20000
|
91 |
-
|
92 |
@property
|
93 |
def num_text_tokens(self):
|
94 |
return self.text_tokenizer.num_tokens
|
95 |
|
96 |
@property
|
97 |
def num_tokens(self):
|
98 |
-
return self.
|
99 |
|
100 |
@staticmethod
|
101 |
def _encode_whitespaces(text: str, max_len: int = 80):
|
@@ -125,11 +121,11 @@ class SPTokenizer:
|
|
125 |
if not add_dummy_prefix:
|
126 |
text = "<n>" + text
|
127 |
tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text)
|
128 |
-
tokens = [x
|
129 |
return tokens if add_dummy_prefix else tokens[2:]
|
130 |
|
131 |
def decode(self, text_ids: List[int], special_tokens=False) -> str:
|
132 |
-
ids = [int(_id)
|
133 |
ids = [_id for _id in ids if _id >= 0]
|
134 |
text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids)
|
135 |
text = text.replace("<n>", "\n")
|
@@ -156,15 +152,9 @@ class SPTokenizer:
|
|
156 |
|
157 |
def __getitem__(self, x: Union[int, str]):
|
158 |
if isinstance(x, int):
|
159 |
-
|
160 |
-
return "<image_{}>".format(x)
|
161 |
-
else:
|
162 |
-
return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
|
163 |
elif isinstance(x, str):
|
164 |
-
|
165 |
-
return int(x[7:-1])
|
166 |
-
else:
|
167 |
-
return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
|
168 |
else:
|
169 |
raise ValueError("The key should be str or int.")
|
170 |
|
|
|
16 |
logger = logging.get_logger(__name__)
|
17 |
|
18 |
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
19 |
+
"silver/chatglm-6b-slim": 2048,
|
20 |
}
|
21 |
|
22 |
|
|
|
85 |
def get_tab_token():
|
86 |
return f"<|tab|>"
|
87 |
|
|
|
|
|
|
|
|
|
88 |
@property
|
89 |
def num_text_tokens(self):
|
90 |
return self.text_tokenizer.num_tokens
|
91 |
|
92 |
@property
|
93 |
def num_tokens(self):
|
94 |
+
return self.num_text_tokens
|
95 |
|
96 |
@staticmethod
|
97 |
def _encode_whitespaces(text: str, max_len: int = 80):
|
|
|
121 |
if not add_dummy_prefix:
|
122 |
text = "<n>" + text
|
123 |
tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text)
|
124 |
+
tokens = [x for x in tmp]
|
125 |
return tokens if add_dummy_prefix else tokens[2:]
|
126 |
|
127 |
def decode(self, text_ids: List[int], special_tokens=False) -> str:
|
128 |
+
ids = [int(_id) for _id in text_ids]
|
129 |
ids = [_id for _id in ids if _id >= 0]
|
130 |
text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids)
|
131 |
text = text.replace("<n>", "\n")
|
|
|
152 |
|
153 |
def __getitem__(self, x: Union[int, str]):
|
154 |
if isinstance(x, int):
|
155 |
+
return self.text_tokenizer.convert_id_to_token(x)
|
|
|
|
|
|
|
156 |
elif isinstance(x, str):
|
157 |
+
return self.text_tokenizer.convert_token_to_id(x)
|
|
|
|
|
|
|
158 |
else:
|
159 |
raise ValueError("The key should be str or int.")
|
160 |
|
tokenizer_config.json
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
{
|
2 |
-
"name_or_path": "
|
3 |
"bos_token": "<sop>",
|
4 |
"eop_token": "<eop>",
|
5 |
"eos_token": "</s>",
|
|
|
1 |
{
|
2 |
+
"name_or_path": "silver/chatglm-6b-slim",
|
3 |
"bos_token": "<sop>",
|
4 |
"eop_token": "<eop>",
|
5 |
"eos_token": "</s>",
|