shwu commited on
Commit
769e287
·
1 Parent(s): d1844e4

feat: better modeling_blip2chatglm

Browse files
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
  "_commit_hash": null,
3
  "architectures": [
4
- "BlipFor2ChatGLM"
5
  ],
6
  "initializer_factor": 1.0,
7
  "initializer_range": 0.02,
@@ -174,7 +174,7 @@
174
  "tie_word_embeddings": false,
175
  "torch_dtype": "float32",
176
  "transformers_version": null,
177
- "use_decoder_only_language_model": false,
178
  "vision_config": {
179
  "_name_or_path": "",
180
  "add_cross_attention": false,
@@ -248,7 +248,7 @@
248
  "tokenizer_class": null,
249
  "top_k": 50,
250
  "top_p": 1.0,
251
- "torch_dtype": null,
252
  "torchscript": false,
253
  "transformers_version": "4.27.3",
254
  "typical_p": 1.0,
@@ -256,7 +256,7 @@
256
  },
257
  "auto_map": {
258
  "AutoConfig": "configuration_blip2chatglm.Blip2ChatGLMConfig",
259
- "AutoModel": "modeling_blip2chatglm.Blip2ForChatGLM",
260
- "AutoModelForCausalLM": "modeling_blip2chatglm.Blip2ChatGLM"
261
  }
262
  }
 
1
  {
2
  "_commit_hash": null,
3
  "architectures": [
4
+ "Blip2ChatGLMForConditionalGeneration"
5
  ],
6
  "initializer_factor": 1.0,
7
  "initializer_range": 0.02,
 
174
  "tie_word_embeddings": false,
175
  "torch_dtype": "float32",
176
  "transformers_version": null,
177
+ "use_decoder_only_language_model": true,
178
  "vision_config": {
179
  "_name_or_path": "",
180
  "add_cross_attention": false,
 
248
  "tokenizer_class": null,
249
  "top_k": 50,
250
  "top_p": 1.0,
251
+ "torch_dtype": "float16",
252
  "torchscript": false,
253
  "transformers_version": "4.27.3",
254
  "typical_p": 1.0,
 
256
  },
257
  "auto_map": {
258
  "AutoConfig": "configuration_blip2chatglm.Blip2ChatGLMConfig",
259
+ "AutoModel": "modeling_blip2chatglm.Blip2ChatGLMForConditionalGeneration",
260
+ "AutoModelForCausalLM": "modeling_blip2chatglm.Blip2ChatGLMForConditionalGeneration"
261
  }
262
  }
configuration_blip2chatglm.py CHANGED
@@ -49,7 +49,7 @@ class Blip2ChatGLMConfig(PretrainedConfig):
49
  self.num_query_tokens = num_query_tokens
50
  self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
51
  # self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
52
- self.use_decoder_only_language_model = False # chatglm is an encoder-decoder model
53
  self.initializer_factor = 1.0
54
  self.initializer_range = 0.02
55
 
 
49
  self.num_query_tokens = num_query_tokens
50
  self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
51
  # self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
52
+ self.use_decoder_only_language_model = True # chatglm has no encoder
53
  self.initializer_factor = 1.0
54
  self.initializer_range = 0.02
55
 
generation_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "transformers_version": "4.27.3"
4
+ }
pytorch_model.bin → ice_text.model RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4f62d72c97cb28762f2fcb9e9b00e1d23c7d546da79fb4cfde386231b9b8d956
3
- size 4377310673
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e974d9a69c242ce014c88c2b26089270f6198f3c0b700a887666cd3e816f17e
3
+ size 2706249
modeling_blip2chatglm.py CHANGED
@@ -1,6 +1,8 @@
1
  import copy
 
2
  from typing import Callable, List, Optional, Tuple, Union
3
  import torch
 
4
  import warnings
5
  from torch import Tensor, nn
6
 
@@ -8,8 +10,14 @@ from transformers import (
8
  PreTrainedModel,
9
  Blip2VisionModel,
10
  Blip2QFormerModel,
 
 
 
11
  GenerationConfig,
12
  )
 
 
 
13
  from transformers.utils import logging
14
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
15
 
@@ -23,9 +31,13 @@ from .configuration_blip2chatglm import Blip2ChatGLMConfig
23
  logger = logging.get_logger(__name__)
24
 
25
 
26
- class Blip2ForChatGLM(PreTrainedModel):
 
 
27
  def __init__(self, config: Blip2ChatGLMConfig):
28
- super().__init__(config)
 
 
29
 
30
  self.vision_model = Blip2VisionModel(config.vision_config)
31
 
@@ -37,21 +49,65 @@ class Blip2ForChatGLM(PreTrainedModel):
37
  self.language_projection = nn.Linear(
38
  config.qformer_config.hidden_size, config.text_config.hidden_size
39
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def forward(
42
  self,
43
  pixel_values: torch.FloatTensor,
 
 
 
44
  output_attentions: Optional[bool] = None,
45
  output_hidden_states: Optional[bool] = None,
 
46
  return_dict: Optional[bool] = None,
47
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return_dict = (
49
  return_dict if return_dict is not None else self.config.use_return_dict
50
  )
51
 
52
  # step 1: forward the images through the vision encoder,
53
  # to get image embeddings of shape (batch_size, seq_len, hidden_size)
54
- vision_outputs = self.vision_model.forward(
55
  pixel_values=pixel_values,
56
  output_attentions=output_attentions,
57
  output_hidden_states=output_hidden_states,
@@ -65,7 +121,7 @@ class Blip2ForChatGLM(PreTrainedModel):
65
  )
66
 
67
  query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
68
- query_outputs = self.qformer.forward(
69
  query_embeds=query_tokens,
70
  encoder_hidden_states=image_embeds,
71
  encoder_attention_mask=image_attention_mask,
@@ -76,23 +132,54 @@ class Blip2ForChatGLM(PreTrainedModel):
76
  query_output = query_outputs[0]
77
 
78
  # step 3: use the language model, conditioned on the query outputs and the prompt
79
- language_model_inputs = self.language_projection.forward(query_output)
80
-
81
- return vision_outputs, query_outputs, language_model_inputs
 
 
 
 
 
 
 
 
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- class Blip2ChatGLM(PreTrainedModel):
85
- config_class = Blip2ChatGLMConfig
 
86
 
87
- def __init__(
88
- self,
89
- config: Blip2ChatGLMConfig,
90
- blip2: Blip2ForChatGLM,
91
- lm: ChatGLMForConditionalGeneration,
92
- ) -> None:
93
- super().__init__(config)
94
- self.blip2 = blip2
95
- self.language = lm
96
 
97
  @torch.no_grad()
98
  def stream_chat(
@@ -106,12 +193,12 @@ class Blip2ChatGLM(PreTrainedModel):
106
  do_sample=True,
107
  temperature=1,
108
  ):
109
- device = self.blip2.device
110
  # 1. Prepare token ids
111
  images = []
112
  image_slots = []
113
 
114
- nvtokens = self.blip2.query_tokens.size(1)
115
  if history:
116
  input_ids = tokenizer(
117
  f"[Round {len(history)}]\n问:", add_special_tokens=False
@@ -181,27 +268,27 @@ class Blip2ChatGLM(PreTrainedModel):
181
  # 2. Prepare image embeddings
182
  if len(images) != 0:
183
  image = torch.cat(list(images), dim=0)
184
- vision_outputs = self.blip2.vision_model.forward(image)
185
  image_embeds = vision_outputs[0]
186
  image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
187
  device
188
  )
189
 
190
- query_tokens = self.blip2.query_tokens.expand(image_embeds.shape[0], -1, -1)
191
- query_outputs = self.blip2.qformer.forward(
192
  query_embeds=query_tokens,
193
  encoder_hidden_states=image_embeds,
194
  encoder_attention_mask=image_atts,
195
  )
196
  query_output = query_outputs[0]
197
 
198
- vtokens = self.blip2.language_projection(query_output)
199
  else:
200
  vtokens = []
201
 
202
  # 3. Place image embeddings into slots
203
  input_ids = torch.as_tensor(input_ids, dtype=torch.long).to(device).unsqueeze(0)
204
- inputs_embeds = self.language.transformer.word_embeddings(input_ids)
205
  for slot, vimg in zip(image_slots, vtokens):
206
  inputs_embeds[0][-slot : -slot + nvtokens, :] = vimg
207
 
@@ -216,17 +303,16 @@ class Blip2ChatGLM(PreTrainedModel):
216
  "logits_processor": logits_processor,
217
  }
218
 
219
- for outputs in self.mm_stream_generate(
220
  input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
221
  ):
222
  outputs = outputs.tolist()[0][len(input_ids[0]) :]
223
  response = tokenizer.decode(outputs)
224
- response = self.language.process_response(response)
225
- new_history = history + [(query, response)]
226
- yield response, new_history
227
 
228
  @torch.no_grad()
229
- def mm_stream_generate(
230
  self,
231
  input_ids,
232
  inputs_embeds,
@@ -238,10 +324,23 @@ class Blip2ChatGLM(PreTrainedModel):
238
  ] = None,
239
  **kwargs,
240
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
242
 
243
  if generation_config is None:
244
- generation_config = self.language.generation_config
245
  generation_config = copy.deepcopy(generation_config)
246
  model_kwargs = generation_config.update(**kwargs)
247
  bos_token_id, eos_token_id = (
@@ -279,7 +378,7 @@ class Blip2ChatGLM(PreTrainedModel):
279
  if input_ids_seq_length >= generation_config.max_length:
280
  input_ids_string = (
281
  "decoder_input_ids"
282
- if self.language.config.is_encoder_decoder
283
  else "input_ids"
284
  )
285
  logger.warning(
@@ -298,7 +397,7 @@ class Blip2ChatGLM(PreTrainedModel):
298
  else StoppingCriteriaList()
299
  )
300
 
301
- logits_processor = self.language._get_logits_processor(
302
  generation_config=generation_config,
303
  input_ids_seq_length=input_ids_seq_length,
304
  encoder_input_ids=input_ids,
@@ -306,19 +405,19 @@ class Blip2ChatGLM(PreTrainedModel):
306
  logits_processor=logits_processor,
307
  )
308
 
309
- stopping_criteria = self.language._get_stopping_criteria(
310
  generation_config=generation_config, stopping_criteria=stopping_criteria
311
  )
312
- logits_warper = self.language._get_logits_warper(generation_config)
313
 
314
  unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
315
  scores = None
316
  while True:
317
- model_inputs = self.language.prepare_inputs_for_generation(
318
  input_ids, inputs_embeds=inputs_embeds, **model_kwargs
319
  )
320
  # forward pass to get next token
321
- outputs = self.language(
322
  **model_inputs,
323
  return_dict=True,
324
  output_attentions=False,
@@ -343,14 +442,14 @@ class Blip2ChatGLM(PreTrainedModel):
343
  inputs_embeds = torch.cat(
344
  [
345
  inputs_embeds,
346
- self.language.get_input_embeddings()(next_tokens)[:, None, :],
347
  ],
348
  dim=1,
349
  )
350
- model_kwargs = self.language._update_model_kwargs_for_generation(
351
  outputs,
352
  model_kwargs,
353
- is_encoder_decoder=self.language.config.is_encoder_decoder,
354
  )
355
  unfinished_sequences = unfinished_sequences.mul(
356
  (sum(next_tokens != i for i in eos_token_id)).long()
@@ -360,3 +459,107 @@ class Blip2ChatGLM(PreTrainedModel):
360
  if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
361
  break
362
  yield input_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
+ import os
3
  from typing import Callable, List, Optional, Tuple, Union
4
  import torch
5
+ from torch.nn import CrossEntropyLoss
6
  import warnings
7
  from torch import Tensor, nn
8
 
 
10
  PreTrainedModel,
11
  Blip2VisionModel,
12
  Blip2QFormerModel,
13
+ Blip2Model,
14
+ Blip2PreTrainedModel,
15
+ Blip2ForConditionalGeneration,
16
  GenerationConfig,
17
  )
18
+ from transformers.models.blip_2.modeling_blip_2 import (
19
+ Blip2ForConditionalGenerationModelOutput,
20
+ )
21
  from transformers.utils import logging
22
  from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
23
 
 
31
  logger = logging.get_logger(__name__)
32
 
33
 
34
+ class Blip2ChatGLMForConditionalGeneration(Blip2ForConditionalGeneration):
35
+ config_class = Blip2ChatGLMConfig
36
+
37
  def __init__(self, config: Blip2ChatGLMConfig):
38
+ Blip2PreTrainedModel.__init__(self, config)
39
+ # NOTE: we only initialize Blip2PreTrainedModel
40
+ # directly call super().__init__() will cause error since ChatGLM cannot be found by AutoModel
41
 
42
  self.vision_model = Blip2VisionModel(config.vision_config)
43
 
 
49
  self.language_projection = nn.Linear(
50
  config.qformer_config.hidden_size, config.text_config.hidden_size
51
  )
52
+ self.language_model = ChatGLMForConditionalGeneration(config.text_config)
53
+
54
+ # Initialize weights and apply final processing
55
+ # self.post_init()
56
+
57
+ def setup_dtype(self, vision_encoder_dtype: str = "fp32", lm_dtype: str = "fp16"):
58
+ if vision_encoder_dtype == "fp32":
59
+ self.vision_model = self.vision_model.float()
60
+ elif vision_encoder_dtype == "fp16":
61
+ self.vision_model = self.vision_model.half()
62
+ else:
63
+ raise NotImplementedError(
64
+ f"Unsupported vision_encoder_dtype: {vision_encoder_dtype}"
65
+ )
66
+
67
+ if lm_dtype == "fp32":
68
+ self.language_model = self.language_model.float()
69
+ elif lm_dtype == "fp16":
70
+ self.language_model = self.language_model.half()
71
+ elif lm_dtype == "int4":
72
+ self.language_model = self.language_model.half().quantize(4)
73
+ elif lm_dtype == "int8":
74
+ self.language_model = self.language_model.half().quantize(8)
75
+ else:
76
+ raise NotImplementedError(f"Unsupported lm_dtype: {lm_dtype}")
77
 
78
  def forward(
79
  self,
80
  pixel_values: torch.FloatTensor,
81
+ input_ids: torch.FloatTensor,
82
+ image_slot_offset: Optional[torch.LongTensor] = None,
83
+ attention_mask: Optional[torch.LongTensor] = None,
84
  output_attentions: Optional[bool] = None,
85
  output_hidden_states: Optional[bool] = None,
86
+ labels: Optional[torch.LongTensor] = None,
87
  return_dict: Optional[bool] = None,
88
+ ) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]:
89
+ """_summary_
90
+
91
+ Args:
92
+ pixel_values (torch.FloatTensor): _description_
93
+ input_ids (torch.FloatTensor): input_ids[:, :num_query_tokens] should be filled with tokenizer.unk_token_id
94
+ image_slot_offset (Optional[torch.LongTensor], optional): if not set, all vtokens are placed as prefix (image_slot_offset = torch.zeros(bsz)). Defaults to None.
95
+ attention_mask (Optional[torch.LongTensor], optional): _description_. Defaults to None.
96
+ output_attentions (Optional[bool], optional): _description_. Defaults to None.
97
+ output_hidden_states (Optional[bool], optional): _description_. Defaults to None.
98
+ labels (Optional[torch.LongTensor], optional): _description_. Defaults to None.
99
+ return_dict (Optional[bool], optional): _description_. Defaults to None.
100
+
101
+ Returns:
102
+ Union[Tuple, Blip2ForConditionalGenerationModelOutput]: _description_
103
+ """
104
  return_dict = (
105
  return_dict if return_dict is not None else self.config.use_return_dict
106
  )
107
 
108
  # step 1: forward the images through the vision encoder,
109
  # to get image embeddings of shape (batch_size, seq_len, hidden_size)
110
+ vision_outputs = self.vision_model(
111
  pixel_values=pixel_values,
112
  output_attentions=output_attentions,
113
  output_hidden_states=output_hidden_states,
 
121
  )
122
 
123
  query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
124
+ query_outputs = self.qformer(
125
  query_embeds=query_tokens,
126
  encoder_hidden_states=image_embeds,
127
  encoder_attention_mask=image_attention_mask,
 
132
  query_output = query_outputs[0]
133
 
134
  # step 3: use the language model, conditioned on the query outputs and the prompt
135
+ language_model_inputs = self.language_projection(query_output)
136
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
137
+ if image_slot_offset is None:
138
+ # image as prefix
139
+ # update data to avoid inplace operation of leaf Variable
140
+ inputs_embeds.data[:, : self.config.num_query_tokens, :] = language_model_inputs
141
+ else:
142
+ for i, offset in enumerate(image_slot_offset):
143
+ inputs_embeds.data[i, offset : offset + self.config.num_query_tokens, :] = (
144
+ language_model_inputs[i]
145
+ )
146
 
147
+ outputs = self.language_model(
148
+ input_ids=input_ids,
149
+ inputs_embeds=inputs_embeds,
150
+ attention_mask=attention_mask,
151
+ output_attentions=output_attentions,
152
+ output_hidden_states=output_hidden_states,
153
+ return_dict=return_dict,
154
+ )
155
+ logits = outputs.logits if return_dict else outputs[0]
156
+ loss = None
157
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
158
+ if labels is not None:
159
+ logits = logits[:, -labels.size(1) :, :]
160
+ # Shift so that tokens < n predict n
161
+ shift_logits = logits[..., :-1, :].contiguous()
162
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
163
+
164
+ # Flatten the tokens
165
+ loss_fct = CrossEntropyLoss(reduction="mean")
166
+
167
+ loss = loss_fct(
168
+ shift_logits.view(-1, self.config.text_config.vocab_size),
169
+ shift_labels.view(-1),
170
+ )
171
 
172
+ if not return_dict:
173
+ output = (logits, vision_outputs, query_outputs, outputs)
174
+ return ((loss,) + output) if loss is not None else output
175
 
176
+ return Blip2ForConditionalGenerationModelOutput(
177
+ loss=loss,
178
+ logits=logits,
179
+ vision_outputs=vision_outputs,
180
+ qformer_outputs=query_outputs,
181
+ language_model_outputs=outputs,
182
+ )
 
 
183
 
184
  @torch.no_grad()
185
  def stream_chat(
 
193
  do_sample=True,
194
  temperature=1,
195
  ):
196
+ device = self.device
197
  # 1. Prepare token ids
198
  images = []
199
  image_slots = []
200
 
201
+ nvtokens = self.config.num_query_tokens
202
  if history:
203
  input_ids = tokenizer(
204
  f"[Round {len(history)}]\n问:", add_special_tokens=False
 
268
  # 2. Prepare image embeddings
269
  if len(images) != 0:
270
  image = torch.cat(list(images), dim=0)
271
+ vision_outputs = self.vision_model.forward(image)
272
  image_embeds = vision_outputs[0]
273
  image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
274
  device
275
  )
276
 
277
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
278
+ query_outputs = self.qformer.forward(
279
  query_embeds=query_tokens,
280
  encoder_hidden_states=image_embeds,
281
  encoder_attention_mask=image_atts,
282
  )
283
  query_output = query_outputs[0]
284
 
285
+ vtokens = self.language_projection(query_output)
286
  else:
287
  vtokens = []
288
 
289
  # 3. Place image embeddings into slots
290
  input_ids = torch.as_tensor(input_ids, dtype=torch.long).to(device).unsqueeze(0)
291
+ inputs_embeds = self.language_model.transformer.word_embeddings(input_ids)
292
  for slot, vimg in zip(image_slots, vtokens):
293
  inputs_embeds[0][-slot : -slot + nvtokens, :] = vimg
294
 
 
303
  "logits_processor": logits_processor,
304
  }
305
 
306
+ for outputs in self.stream_generate(
307
  input_ids=input_ids, inputs_embeds=inputs_embeds, **gen_kwargs
308
  ):
309
  outputs = outputs.tolist()[0][len(input_ids[0]) :]
310
  response = tokenizer.decode(outputs)
311
+ response = self.language_model.process_response(response)
312
+ yield response
 
313
 
314
  @torch.no_grad()
315
+ def stream_generate(
316
  self,
317
  input_ids,
318
  inputs_embeds,
 
324
  ] = None,
325
  **kwargs,
326
  ):
327
+ """slightly modified from chatglm implementation to support inputs_embeds
328
+
329
+ Args:
330
+ input_ids (_type_): _description_
331
+ inputs_embeds (_type_): _description_
332
+ generation_config (Optional[GenerationConfig], optional): _description_. Defaults to None.
333
+ logits_processor (Optional[LogitsProcessorList], optional): _description_. Defaults to None.
334
+ stopping_criteria (Optional[StoppingCriteriaList], optional): _description_. Defaults to None.
335
+ prefix_allowed_tokens_fn (Optional[ Callable[[int, torch.Tensor], List[int]] ], optional): _description_. Defaults to None.
336
+
337
+ Yields:
338
+ _type_: _description_
339
+ """
340
  batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
341
 
342
  if generation_config is None:
343
+ generation_config = self.language_model.generation_config
344
  generation_config = copy.deepcopy(generation_config)
345
  model_kwargs = generation_config.update(**kwargs)
346
  bos_token_id, eos_token_id = (
 
378
  if input_ids_seq_length >= generation_config.max_length:
379
  input_ids_string = (
380
  "decoder_input_ids"
381
+ if self.language_model.config.is_encoder_decoder
382
  else "input_ids"
383
  )
384
  logger.warning(
 
397
  else StoppingCriteriaList()
398
  )
399
 
400
+ logits_processor = self.language_model._get_logits_processor(
401
  generation_config=generation_config,
402
  input_ids_seq_length=input_ids_seq_length,
403
  encoder_input_ids=input_ids,
 
405
  logits_processor=logits_processor,
406
  )
407
 
408
+ stopping_criteria = self.language_model._get_stopping_criteria(
409
  generation_config=generation_config, stopping_criteria=stopping_criteria
410
  )
411
+ logits_warper = self.language_model._get_logits_warper(generation_config)
412
 
413
  unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
414
  scores = None
415
  while True:
416
+ model_inputs = self.prepare_inputs_for_generation(
417
  input_ids, inputs_embeds=inputs_embeds, **model_kwargs
418
  )
419
  # forward pass to get next token
420
+ outputs = self.language_model(
421
  **model_inputs,
422
  return_dict=True,
423
  output_attentions=False,
 
442
  inputs_embeds = torch.cat(
443
  [
444
  inputs_embeds,
445
+ self.language_model.get_input_embeddings()(next_tokens)[:, None, :],
446
  ],
447
  dim=1,
448
  )
449
+ model_kwargs = self.language_model._update_model_kwargs_for_generation(
450
  outputs,
451
  model_kwargs,
452
+ is_encoder_decoder=self.language_model.config.is_encoder_decoder,
453
  )
454
  unfinished_sequences = unfinished_sequences.mul(
455
  (sum(next_tokens != i for i in eos_token_id)).long()
 
459
  if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
460
  break
461
  yield input_ids
462
+
463
+ def prepare_inputs_for_generation(
464
+ self,
465
+ input_ids: torch.LongTensor,
466
+ inputs_embeds: Optional[torch.Tensor] = None,
467
+ past: Optional[torch.Tensor] = None,
468
+ past_key_values: Optional[torch.Tensor] = None,
469
+ attention_mask: Optional[torch.Tensor] = None,
470
+ position_ids: Optional[torch.Tensor] = None,
471
+ **kwargs,
472
+ ) -> dict:
473
+ """slightly modified from chatglm implementation to support inputs_embeds
474
+
475
+ Args:
476
+ input_ids (torch.LongTensor): _description_
477
+ inputs_embeds (Optional[torch.Tensor], optional): _description_. Defaults to None.
478
+ past (Optional[torch.Tensor], optional): _description_. Defaults to None.
479
+ past_key_values (Optional[torch.Tensor], optional): _description_. Defaults to None.
480
+ attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None.
481
+ position_ids (Optional[torch.Tensor], optional): _description_. Defaults to None.
482
+
483
+ Returns:
484
+ dict: _description_
485
+ """
486
+ batch_size, seq_length = input_ids.shape
487
+ MASK, gMASK = self.language_model.config.mask_token_id, self.language_model.config.gmask_token_id
488
+ seqs = input_ids.tolist()
489
+ mask_positions, use_gmasks = [], []
490
+ for seq in seqs:
491
+ mask_token = gMASK if gMASK in seq else MASK
492
+ use_gmask = mask_token == gMASK
493
+ mask_positions.append(seq.index(mask_token))
494
+ use_gmasks.append(use_gmask)
495
+
496
+ # only last token for input_ids if past is not None
497
+ if past is not None or past_key_values is not None:
498
+ last_token = input_ids[:, -1].unsqueeze(-1)
499
+ if attention_mask is not None and attention_mask.dtype == torch.bool:
500
+ attention_mask = attention_mask[:, :, -1:]
501
+ else:
502
+ attention_mask = None
503
+ if position_ids is not None:
504
+ position_ids = position_ids[..., -1:]
505
+ else:
506
+ context_lengths = [seq.index(self.language_model.config.bos_token_id) for seq in seqs]
507
+ if self.language_model.position_encoding_2d:
508
+ position_ids = torch.tensor(
509
+ [
510
+ [mask_position, seq_length - context_length]
511
+ for mask_position, context_length in zip(
512
+ mask_positions, context_lengths
513
+ )
514
+ ],
515
+ dtype=torch.long,
516
+ device=input_ids.device,
517
+ ).unsqueeze(-1)
518
+ else:
519
+ position_ids = torch.tensor(
520
+ [mask_position for mask_position in mask_positions],
521
+ dtype=torch.long,
522
+ device=input_ids.device,
523
+ ).unsqueeze(-1)
524
+
525
+ if past is None:
526
+ past = past_key_values
527
+ return {
528
+ "input_ids": last_token,
529
+ "past_key_values": past,
530
+ "position_ids": position_ids,
531
+ "attention_mask": attention_mask,
532
+ }
533
+ else:
534
+ if attention_mask is not None and attention_mask.dtype != torch.bool:
535
+ logger.warning_once(
536
+ f"The dtype of attention mask ({attention_mask.dtype}) is not bool"
537
+ )
538
+ attention_mask = None
539
+ if attention_mask is None:
540
+ attention_mask = self.language_model.get_masks(input_ids, device=input_ids.device)
541
+ if position_ids is None:
542
+ position_ids = self.language_model.get_position_ids(
543
+ input_ids,
544
+ device=input_ids.device,
545
+ mask_positions=mask_positions,
546
+ use_gmasks=use_gmasks,
547
+ )
548
+
549
+ if inputs_embeds is not None:
550
+ assert input_ids.size(1) == inputs_embeds.size(
551
+ 1
552
+ ), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
553
+ return {
554
+ "inputs_embeds": inputs_embeds,
555
+ "past_key_values": past,
556
+ "position_ids": position_ids,
557
+ "attention_mask": attention_mask,
558
+ }
559
+ else:
560
+ return {
561
+ "input_ids": input_ids,
562
+ "past_key_values": past,
563
+ "position_ids": position_ids,
564
+ "attention_mask": attention_mask,
565
+ }
modeling_chatglm.py CHANGED
@@ -55,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
  if torch.isnan(scores).any() or torch.isinf(scores).any():
57
  scores.zero_()
58
- scores[..., 20005] = 5e4
59
  return scores
60
 
61
 
@@ -280,10 +280,8 @@ def attention_fn(
280
  # [sk, b, np, hn] -> [sk, b * np, hn]
281
  key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
282
 
283
- matmul_result = torch.empty(
284
- output_size[0] * output_size[1],
285
- output_size[2],
286
- output_size[3],
287
  dtype=query_layer.dtype,
288
  device=query_layer.device,
289
  )
@@ -348,10 +346,18 @@ def attention_fn(
348
  return outputs
349
 
350
 
 
 
 
 
351
  class SelfAttention(torch.nn.Module):
352
  def __init__(self, hidden_size, num_attention_heads,
353
  layer_id, hidden_size_per_attention_head=None, bias=True,
354
- params_dtype=torch.float, position_encoding_2d=True):
 
 
 
 
355
  super(SelfAttention, self).__init__()
356
 
357
  self.layer_id = layer_id
@@ -379,7 +385,7 @@ class SelfAttention(torch.nn.Module):
379
  self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
380
 
381
  # Strided linear layer.
382
- self.query_key_value = skip_init(
383
  torch.nn.Linear,
384
  hidden_size,
385
  3 * self.inner_hidden_size,
@@ -387,7 +393,7 @@ class SelfAttention(torch.nn.Module):
387
  dtype=params_dtype,
388
  )
389
 
390
- self.dense = skip_init(
391
  torch.nn.Linear,
392
  self.inner_hidden_size,
393
  hidden_size,
@@ -500,8 +506,12 @@ class GEGLU(torch.nn.Module):
500
 
501
  class GLU(torch.nn.Module):
502
  def __init__(self, hidden_size, inner_hidden_size=None,
503
- layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float):
504
  super(GLU, self).__init__()
 
 
 
 
505
  self.layer_id = layer_id
506
  self.activation_func = activation_func
507
 
@@ -510,7 +520,7 @@ class GLU(torch.nn.Module):
510
  if inner_hidden_size is None:
511
  inner_hidden_size = 4 * hidden_size
512
  self.inner_hidden_size = inner_hidden_size
513
- self.dense_h_to_4h = skip_init(
514
  torch.nn.Linear,
515
  self.hidden_size,
516
  self.inner_hidden_size,
@@ -518,7 +528,7 @@ class GLU(torch.nn.Module):
518
  dtype=params_dtype,
519
  )
520
  # Project back to h.
521
- self.dense_4h_to_h = skip_init(
522
  torch.nn.Linear,
523
  self.inner_hidden_size,
524
  self.hidden_size,
@@ -554,7 +564,8 @@ class GLMBlock(torch.nn.Module):
554
  use_bias=True,
555
  params_dtype=torch.float,
556
  num_layers=28,
557
- position_encoding_2d=True
 
558
  ):
559
  super(GLMBlock, self).__init__()
560
  # Set output layer initialization if not provided.
@@ -574,7 +585,8 @@ class GLMBlock(torch.nn.Module):
574
  hidden_size_per_attention_head=hidden_size_per_attention_head,
575
  bias=use_bias,
576
  params_dtype=params_dtype,
577
- position_encoding_2d=self.position_encoding_2d
 
578
  )
579
 
580
  # Layernorm on the input data.
@@ -589,6 +601,7 @@ class GLMBlock(torch.nn.Module):
589
  bias=use_bias,
590
  layer_id=layer_id,
591
  params_dtype=params_dtype,
 
592
  )
593
 
594
  def forward(
@@ -676,8 +689,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
676
 
677
  return attention_mask
678
 
679
- def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
680
  batch_size, seq_length = input_ids.shape
 
 
681
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
682
  if self.position_encoding_2d:
683
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
@@ -691,8 +706,8 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
691
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
692
  else:
693
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
694
- if not gmask:
695
- for i, context_length in enumerate(context_lengths):
696
  position_ids[context_length:] = mask_positions[i]
697
 
698
  return position_ids
@@ -783,9 +798,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
783
  `encoder_hidden_states` is then expected as an input to the forward pass.
784
  """
785
 
786
- def __init__(self, config: ChatGLMConfig):
787
  super().__init__(config)
788
-
 
 
 
789
  # recording parameters
790
  self.max_sequence_length = config.max_sequence_length
791
  self.hidden_size = config.hidden_size
@@ -800,7 +818,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
800
  self.pre_seq_len = config.pre_seq_len
801
  self.prefix_projection = config.prefix_projection
802
 
803
- self.word_embeddings = skip_init(
804
  torch.nn.Embedding,
805
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
806
  dtype=self.params_dtype
@@ -819,6 +837,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
819
  use_bias=True,
820
  params_dtype=self.params_dtype,
821
  position_encoding_2d=self.position_encoding_2d,
 
822
  )
823
 
824
  self.layers = torch.nn.ModuleList(
@@ -894,12 +913,18 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
894
  )
895
  use_cache = False
896
 
897
- if input_ids is not None and inputs_embeds is not None:
898
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
899
- elif input_ids is not None:
 
 
 
 
 
 
 
900
  batch_size, seq_length = input_ids.shape[:2]
901
  elif inputs_embeds is not None:
902
- # NOTE: fix
903
  batch_size, seq_length = inputs_embeds.shape[:2]
904
  else:
905
  raise ValueError("You have to specify either input_ids or inputs_embeds")
@@ -923,15 +948,20 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
923
 
924
  if position_ids is None:
925
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
926
- mask_token = gMASK if gMASK in input_ids else MASK
927
- use_gmask = True if gMASK in input_ids else False
 
 
 
 
 
 
928
 
929
- mask_positions = [seq.tolist().index(mask_token) for seq in input_ids]
930
  position_ids = self.get_position_ids(
931
  input_ids,
932
  mask_positions=mask_positions,
933
  device=input_ids.device,
934
- gmask=use_gmask
935
  )
936
 
937
  if self.pre_seq_len is not None and attention_mask is not None:
@@ -950,10 +980,10 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
950
  if attention_mask is None:
951
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
952
 
953
- else:
954
- pass
955
- # NOTE: this is a hack to make the code work with the LAVIS training
956
- # attention_mask = attention_mask.to(input_ids.device)
957
 
958
  for i, layer in enumerate(self.layers):
959
 
@@ -1009,8 +1039,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
1009
 
1010
 
1011
  class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1012
- def __init__(self, config: ChatGLMConfig):
1013
  super().__init__(config)
 
 
 
 
1014
 
1015
  # self.hidden_size = config.hidden_size
1016
  # self.params_dtype = torch.half
@@ -1019,9 +1053,9 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1019
 
1020
  self.position_encoding_2d = config.position_encoding_2d
1021
 
1022
- self.transformer = ChatGLMModel(config)
1023
 
1024
- self.lm_head = skip_init(
1025
  nn.Linear,
1026
  config.hidden_size,
1027
  config.vocab_size,
@@ -1080,7 +1114,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1080
  def prepare_inputs_for_generation(
1081
  self,
1082
  input_ids: torch.LongTensor,
1083
- inputs_embeds: Optional[torch.Tensor] = None,
1084
  past: Optional[torch.Tensor] = None,
1085
  past_key_values: Optional[torch.Tensor] = None,
1086
  attention_mask: Optional[torch.Tensor] = None,
@@ -1089,10 +1122,13 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1089
  ) -> dict:
1090
  batch_size, seq_length = input_ids.shape
1091
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
1092
- mask_token = gMASK if gMASK in input_ids else MASK
1093
- use_gmask = True if gMASK in input_ids else False
1094
  seqs = input_ids.tolist()
1095
- mask_positions = [seq.index(mask_token) for seq in seqs]
 
 
 
 
 
1096
 
1097
  # only last token for input_ids if past is not None
1098
  if past is not None or past_key_values is not None:
@@ -1135,23 +1171,15 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1135
  input_ids,
1136
  device=input_ids.device,
1137
  mask_positions=mask_positions,
1138
- gmask=use_gmask
1139
  )
1140
- if inputs_embeds is not None:
1141
- assert input_ids.size(1) == inputs_embeds.size(1), f"Make sure that both input_ids ({input_ids.size(1)}) and inputs_embeds ({inputs_embeds.size(1)}) have the same length."
1142
- return {
1143
- "inputs_embeds": inputs_embeds,
1144
- "past_key_values": past,
1145
- "position_ids": position_ids,
1146
- "attention_mask": attention_mask
1147
- }
1148
- else:
1149
- return {
1150
- "input_ids": input_ids,
1151
- "past_key_values": past,
1152
- "position_ids": position_ids,
1153
- "attention_mask": attention_mask
1154
- }
1155
 
1156
  def forward(
1157
  self,
 
55
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56
  if torch.isnan(scores).any() or torch.isinf(scores).any():
57
  scores.zero_()
58
+ scores[..., 5] = 5e4
59
  return scores
60
 
61
 
 
280
  # [sk, b, np, hn] -> [sk, b * np, hn]
281
  key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)
282
 
283
+ matmul_result = torch.zeros(
284
+ 1, 1, 1,
 
 
285
  dtype=query_layer.dtype,
286
  device=query_layer.device,
287
  )
 
346
  return outputs
347
 
348
 
349
+ def default_init(cls, *args, **kwargs):
350
+ return cls(*args, **kwargs)
351
+
352
+
353
  class SelfAttention(torch.nn.Module):
354
  def __init__(self, hidden_size, num_attention_heads,
355
  layer_id, hidden_size_per_attention_head=None, bias=True,
356
+ params_dtype=torch.float, position_encoding_2d=True, empty_init=True):
357
+ if empty_init:
358
+ init_method = skip_init
359
+ else:
360
+ init_method = default_init
361
  super(SelfAttention, self).__init__()
362
 
363
  self.layer_id = layer_id
 
385
  self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head
386
 
387
  # Strided linear layer.
388
+ self.query_key_value = init_method(
389
  torch.nn.Linear,
390
  hidden_size,
391
  3 * self.inner_hidden_size,
 
393
  dtype=params_dtype,
394
  )
395
 
396
+ self.dense = init_method(
397
  torch.nn.Linear,
398
  self.inner_hidden_size,
399
  hidden_size,
 
506
 
507
  class GLU(torch.nn.Module):
508
  def __init__(self, hidden_size, inner_hidden_size=None,
509
+ layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True):
510
  super(GLU, self).__init__()
511
+ if empty_init:
512
+ init_method = skip_init
513
+ else:
514
+ init_method = default_init
515
  self.layer_id = layer_id
516
  self.activation_func = activation_func
517
 
 
520
  if inner_hidden_size is None:
521
  inner_hidden_size = 4 * hidden_size
522
  self.inner_hidden_size = inner_hidden_size
523
+ self.dense_h_to_4h = init_method(
524
  torch.nn.Linear,
525
  self.hidden_size,
526
  self.inner_hidden_size,
 
528
  dtype=params_dtype,
529
  )
530
  # Project back to h.
531
+ self.dense_4h_to_h = init_method(
532
  torch.nn.Linear,
533
  self.inner_hidden_size,
534
  self.hidden_size,
 
564
  use_bias=True,
565
  params_dtype=torch.float,
566
  num_layers=28,
567
+ position_encoding_2d=True,
568
+ empty_init=True
569
  ):
570
  super(GLMBlock, self).__init__()
571
  # Set output layer initialization if not provided.
 
585
  hidden_size_per_attention_head=hidden_size_per_attention_head,
586
  bias=use_bias,
587
  params_dtype=params_dtype,
588
+ position_encoding_2d=self.position_encoding_2d,
589
+ empty_init=empty_init
590
  )
591
 
592
  # Layernorm on the input data.
 
601
  bias=use_bias,
602
  layer_id=layer_id,
603
  params_dtype=params_dtype,
604
+ empty_init=empty_init
605
  )
606
 
607
  def forward(
 
689
 
690
  return attention_mask
691
 
692
+ def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None):
693
  batch_size, seq_length = input_ids.shape
694
+ if use_gmasks is None:
695
+ use_gmasks = [False] * batch_size
696
  context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
697
  if self.position_encoding_2d:
698
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
 
706
  position_ids = torch.stack((position_ids, block_position_ids), dim=1)
707
  else:
708
  position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
709
+ for i, context_length in enumerate(context_lengths):
710
+ if not use_gmasks[i]:
711
  position_ids[context_length:] = mask_positions[i]
712
 
713
  return position_ids
 
798
  `encoder_hidden_states` is then expected as an input to the forward pass.
799
  """
800
 
801
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
802
  super().__init__(config)
803
+ if empty_init:
804
+ init_method = skip_init
805
+ else:
806
+ init_method = default_init
807
  # recording parameters
808
  self.max_sequence_length = config.max_sequence_length
809
  self.hidden_size = config.hidden_size
 
818
  self.pre_seq_len = config.pre_seq_len
819
  self.prefix_projection = config.prefix_projection
820
 
821
+ self.word_embeddings = init_method(
822
  torch.nn.Embedding,
823
  num_embeddings=self.vocab_size, embedding_dim=self.hidden_size,
824
  dtype=self.params_dtype
 
837
  use_bias=True,
838
  params_dtype=self.params_dtype,
839
  position_encoding_2d=self.position_encoding_2d,
840
+ empty_init=empty_init
841
  )
842
 
843
  self.layers = torch.nn.ModuleList(
 
913
  )
914
  use_cache = False
915
 
916
+ # if input_ids is not None and inputs_embeds is not None:
917
+ # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
918
+ # elif input_ids is not None:
919
+ # batch_size, seq_length = input_ids.shape[:2]
920
+ # elif inputs_embeds is not None:
921
+ # batch_size, seq_length. _ = inputs_embeds.shape[:2]
922
+ # else:
923
+ # raise ValueError("You have to specify either input_ids or inputs_embeds")
924
+
925
+ if input_ids is not None:
926
  batch_size, seq_length = input_ids.shape[:2]
927
  elif inputs_embeds is not None:
 
928
  batch_size, seq_length = inputs_embeds.shape[:2]
929
  else:
930
  raise ValueError("You have to specify either input_ids or inputs_embeds")
 
948
 
949
  if position_ids is None:
950
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
951
+ seqs = input_ids.tolist()
952
+
953
+ mask_positions, use_gmasks = [], []
954
+ for seq in seqs:
955
+ mask_token = gMASK if gMASK in seq else MASK
956
+ use_gmask = mask_token == gMASK
957
+ mask_positions.append(seq.index(mask_token))
958
+ use_gmasks.append(use_gmask)
959
 
 
960
  position_ids = self.get_position_ids(
961
  input_ids,
962
  mask_positions=mask_positions,
963
  device=input_ids.device,
964
+ use_gmasks=use_gmasks
965
  )
966
 
967
  if self.pre_seq_len is not None and attention_mask is not None:
 
980
  if attention_mask is None:
981
  attention_mask = torch.zeros(1, 1, device=input_ids.device).bool()
982
 
983
+ # NOTE: this is a hack to make the code work with the LAVIS training
984
+ # else:
985
+ # pass
986
+ # attention_mask = attention_mask.to(input_ids.device)
987
 
988
  for i, layer in enumerate(self.layers):
989
 
 
1039
 
1040
 
1041
  class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
1042
+ def __init__(self, config: ChatGLMConfig, empty_init=True):
1043
  super().__init__(config)
1044
+ if empty_init:
1045
+ init_method = skip_init
1046
+ else:
1047
+ init_method = default_init
1048
 
1049
  # self.hidden_size = config.hidden_size
1050
  # self.params_dtype = torch.half
 
1053
 
1054
  self.position_encoding_2d = config.position_encoding_2d
1055
 
1056
+ self.transformer = ChatGLMModel(config, empty_init=empty_init)
1057
 
1058
+ self.lm_head = init_method(
1059
  nn.Linear,
1060
  config.hidden_size,
1061
  config.vocab_size,
 
1114
  def prepare_inputs_for_generation(
1115
  self,
1116
  input_ids: torch.LongTensor,
 
1117
  past: Optional[torch.Tensor] = None,
1118
  past_key_values: Optional[torch.Tensor] = None,
1119
  attention_mask: Optional[torch.Tensor] = None,
 
1122
  ) -> dict:
1123
  batch_size, seq_length = input_ids.shape
1124
  MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id
 
 
1125
  seqs = input_ids.tolist()
1126
+ mask_positions, use_gmasks = [], []
1127
+ for seq in seqs:
1128
+ mask_token = gMASK if gMASK in seq else MASK
1129
+ use_gmask = mask_token == gMASK
1130
+ mask_positions.append(seq.index(mask_token))
1131
+ use_gmasks.append(use_gmask)
1132
 
1133
  # only last token for input_ids if past is not None
1134
  if past is not None or past_key_values is not None:
 
1171
  input_ids,
1172
  device=input_ids.device,
1173
  mask_positions=mask_positions,
1174
+ use_gmasks=use_gmasks
1175
  )
1176
+
1177
+ return {
1178
+ "input_ids": input_ids,
1179
+ "past_key_values": past,
1180
+ "position_ids": position_ids,
1181
+ "attention_mask": attention_mask
1182
+ }
 
 
 
 
 
 
 
 
1183
 
1184
  def forward(
1185
  self,
preprocessor_config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "BlipImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "processor_class": "Blip2Processor",
18
+ "resample": 3,
19
+ "rescale_factor": 0.00392156862745098,
20
+ "size": {
21
+ "height": 224,
22
+ "width": 224
23
+ }
24
+ }
pytorch_model-00001-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81ec9cc9a7e6034300a115898aac9fda06c69cf15d1b3c470d633ae7ce0ad3c9
3
+ size 1995030990
pytorch_model-00002-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de3166cf720b1a7cf6be0872f773d1f5e587e109435540f6511c37391827f1d6
3
+ size 1983142386
pytorch_model-00003-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a21fec7efe30123a73cd4bc77a4f8bf58c26e808743d47c606950215afd5e6c
3
+ size 1913134013
pytorch_model-00004-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31bbd7aa605cde4258795220732aa24505ef451bf7e86a434c23c7fb75207e3
3
+ size 1879578439
pytorch_model-00005-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d45c57fd01a8e6b10a5d31d01af88580212e991705c5308ffcfe76bce8eb9df1
3
+ size 1879571453
pytorch_model-00006-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d1032cc31e7f2cda475a12f3a4016934c7d1c82c35b1cec93e159f0bbbc428c
3
+ size 1980242201
pytorch_model-00007-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0088535e5adf2f7b2cc2064aff91ffb979fb895a8cb2e2eee14e97a358c192a
3
+ size 1913134077
pytorch_model-00008-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a45b9e99ae15bad8d1722abd5f6e441c6cb4fe87bfa32f6c01e5b0a58409ec5d
3
+ size 1208293115
pytorch_model-00009-of-00009.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a6323bcad07ce5cc7934323c438abe9a8f45029553cd29098fe22314b14edb9a
3
+ size 1069286314
pytorch_model.bin.index.json ADDED
The diff for this file is too large to render. See raw diff
 
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<sop>",
3
+ "eos_token": "<eop>",
4
+ "mask_token": "[MASK]",
5
+ "pad_token": "<pad>",
6
+ "unk_token": "<unk>"
7
+ }
tokenization_chatglm.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tokenization classes for ChatGLM."""
2
+ from typing import List, Optional, Union
3
+ import os
4
+
5
+ from transformers.tokenization_utils import PreTrainedTokenizer
6
+ from transformers.utils import logging, PaddingStrategy
7
+ from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
8
+ from typing import Dict
9
+ import sentencepiece as spm
10
+ import numpy as np
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
15
+ "THUDM/chatglm-6b": 2048,
16
+ }
17
+
18
+
19
+ class TextTokenizer:
20
+ def __init__(self, model_path):
21
+ self.sp = spm.SentencePieceProcessor()
22
+ self.sp.Load(model_path)
23
+ self.num_tokens = self.sp.vocab_size()
24
+
25
+ def encode(self, text):
26
+ return self.sp.EncodeAsIds(text)
27
+
28
+ def decode(self, ids: List[int]):
29
+ return self.sp.DecodeIds(ids)
30
+
31
+ def tokenize(self, text):
32
+ return self.sp.EncodeAsPieces(text)
33
+
34
+ def convert_tokens_to_ids(self, tokens):
35
+ return [self.sp.PieceToId(token) for token in tokens]
36
+
37
+ def convert_token_to_id(self, token):
38
+ return self.sp.PieceToId(token)
39
+
40
+ def convert_id_to_token(self, idx):
41
+ return self.sp.IdToPiece(idx)
42
+
43
+ def __len__(self):
44
+ return self.num_tokens
45
+
46
+
47
+ class SPTokenizer:
48
+ def __init__(
49
+ self,
50
+ vocab_file,
51
+ num_image_tokens=20000,
52
+ max_blank_length=80,
53
+ byte_fallback=True,
54
+ ):
55
+ assert vocab_file is not None
56
+ self.vocab_file = vocab_file
57
+ self.num_image_tokens = num_image_tokens
58
+ self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
59
+ self.max_blank_length = max_blank_length
60
+ self.byte_fallback = byte_fallback
61
+ self.text_tokenizer = TextTokenizer(vocab_file)
62
+
63
+ def _get_text_tokenizer(self):
64
+ return self.text_tokenizer
65
+
66
+ @staticmethod
67
+ def get_blank_token(length: int):
68
+ assert length >= 2
69
+ return f"<|blank_{length}|>"
70
+
71
+ @staticmethod
72
+ def get_tab_token():
73
+ return f"<|tab|>"
74
+
75
+ @property
76
+ def num_text_tokens(self):
77
+ return self.text_tokenizer.num_tokens
78
+
79
+ @property
80
+ def num_tokens(self):
81
+ return self.num_image_tokens + self.num_text_tokens
82
+
83
+ @staticmethod
84
+ def _encode_whitespaces(text: str, max_len: int = 80):
85
+ text = text.replace("\t", SPTokenizer.get_tab_token())
86
+ for i in range(max_len, 1, -1):
87
+ text = text.replace(" " * i, SPTokenizer.get_blank_token(i))
88
+ return text
89
+
90
+ def _preprocess(self, text: str, linebreak=True, whitespaces=True):
91
+ if linebreak:
92
+ text = text.replace("\n", "<n>")
93
+ if whitespaces:
94
+ text = self._encode_whitespaces(text, max_len=self.max_blank_length)
95
+ return text
96
+
97
+ def encode(
98
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
99
+ ) -> List[int]:
100
+ """
101
+ @param text: Text to encode.
102
+ @param linebreak: Whether to encode newline (\n) in text.
103
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
104
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
105
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
106
+ """
107
+ text = self._preprocess(text, linebreak, whitespaces)
108
+ if not add_dummy_prefix:
109
+ text = "<n>" + text
110
+ tmp = self._get_text_tokenizer().encode(text)
111
+ tokens = [x + self.num_image_tokens for x in tmp]
112
+ return tokens if add_dummy_prefix else tokens[2:]
113
+
114
+ def decode(self, text_ids: List[int]) -> str:
115
+ ids = [int(_id) - self.num_image_tokens for _id in text_ids]
116
+ ids = [_id for _id in ids if _id >= 0]
117
+ text = self._get_text_tokenizer().decode(ids)
118
+ text = text.replace("<n>", "\n")
119
+ text = text.replace(SPTokenizer.get_tab_token(), "\t")
120
+ for i in range(2, self.max_blank_length + 1):
121
+ text = text.replace(self.get_blank_token(i), " " * i)
122
+ return text
123
+
124
+ def tokenize(
125
+ self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True
126
+ ) -> List[str]:
127
+ """
128
+ @param text: Text to encode.
129
+ @param linebreak: Whether to encode newline (\n) in text.
130
+ @param whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
131
+ @param special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
132
+ @param add_dummy_prefix: Whether to add dummy blank space in the beginning.
133
+ """
134
+ text = self._preprocess(text, linebreak, whitespaces)
135
+ if not add_dummy_prefix:
136
+ text = "<n>" + text
137
+ tokens = self._get_text_tokenizer().tokenize(text)
138
+ return tokens if add_dummy_prefix else tokens[2:]
139
+
140
+ def __getitem__(self, x: Union[int, str]):
141
+ if isinstance(x, int):
142
+ if x < self.num_image_tokens:
143
+ return "<image_{}>".format(x)
144
+ else:
145
+ return self.text_tokenizer.convert_id_to_token(x - self.num_image_tokens)
146
+ elif isinstance(x, str):
147
+ if x.startswith("<image_") and x.endswith(">") and x[7:-1].isdigit():
148
+ return int(x[7:-1])
149
+ else:
150
+ return self.text_tokenizer.convert_token_to_id(x) + self.num_image_tokens
151
+ else:
152
+ raise ValueError("The key should be str or int.")
153
+
154
+
155
+ class ChatGLMTokenizer(PreTrainedTokenizer):
156
+ """
157
+ Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding.
158
+
159
+ Args:
160
+ vocab_file (`str`):
161
+ Path to the vocabulary file.
162
+ """
163
+
164
+ vocab_files_names = {"vocab_file": "ice_text.model"}
165
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
166
+ model_input_names = ["input_ids", "attention_mask", "position_ids"]
167
+
168
+ def __init__(
169
+ self,
170
+ vocab_file,
171
+ do_lower_case=False,
172
+ remove_space=False,
173
+ bos_token='<sop>',
174
+ eos_token='<eop>',
175
+ end_token='</s>',
176
+ mask_token='[MASK]',
177
+ gmask_token='[gMASK]',
178
+ padding_side="left",
179
+ num_image_tokens=20000,
180
+ **kwargs
181
+ ) -> None:
182
+ super().__init__(
183
+ do_lower_case=do_lower_case,
184
+ remove_space=remove_space,
185
+ padding_side=padding_side,
186
+ bos_token=bos_token,
187
+ eos_token=eos_token,
188
+ end_token=end_token,
189
+ mask_token=mask_token,
190
+ gmask_token=gmask_token,
191
+ num_image_tokens=num_image_tokens,
192
+ **kwargs
193
+ )
194
+
195
+ self.do_lower_case = do_lower_case
196
+ self.remove_space = remove_space
197
+ self.vocab_file = vocab_file
198
+
199
+ self.bos_token = bos_token
200
+ self.eos_token = eos_token
201
+ self.end_token = end_token
202
+ self.mask_token = mask_token
203
+ self.gmask_token = gmask_token
204
+
205
+ self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens)
206
+
207
+ """ Initialisation """
208
+
209
+ @property
210
+ def gmask_token_id(self) -> Optional[int]:
211
+ if self.gmask_token is None:
212
+ return None
213
+ return self.convert_tokens_to_ids(self.gmask_token)
214
+
215
+ @property
216
+ def end_token_id(self) -> Optional[int]:
217
+ """
218
+ `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been
219
+ set.
220
+ """
221
+ if self.end_token is None:
222
+ return None
223
+ return self.convert_tokens_to_ids(self.end_token)
224
+
225
+ @property
226
+ def vocab_size(self):
227
+ """ Returns vocab size """
228
+ return self.sp_tokenizer.num_tokens
229
+
230
+ def get_vocab(self):
231
+ """ Returns vocab as a dict """
232
+ vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
233
+ vocab.update(self.added_tokens_encoder)
234
+ return vocab
235
+
236
+ def preprocess_text(self, inputs):
237
+ if self.remove_space:
238
+ outputs = " ".join(inputs.strip().split())
239
+ else:
240
+ outputs = inputs
241
+
242
+ if self.do_lower_case:
243
+ outputs = outputs.lower()
244
+
245
+ return outputs
246
+
247
+ def _tokenize(self, text, **kwargs):
248
+ """ Returns a tokenized string. """
249
+ text = self.preprocess_text(text)
250
+
251
+ seq = self.sp_tokenizer.tokenize(text)
252
+
253
+ return seq
254
+
255
+ def _decode(
256
+ self,
257
+ token_ids: Union[int, List[int]],
258
+ skip_special_tokens: bool = False,
259
+ clean_up_tokenization_spaces: bool = True,
260
+ **kwargs
261
+ ) -> str:
262
+ if isinstance(token_ids, int):
263
+ token_ids = [token_ids]
264
+ if len(token_ids) == 0:
265
+ return ""
266
+ if self.pad_token_id in token_ids: # remove pad
267
+ token_ids = list(filter((self.pad_token_id).__ne__, token_ids))
268
+ return self.sp_tokenizer.decode(token_ids)
269
+
270
+ def _convert_token_to_id(self, token):
271
+ """ Converts a token (str) in an id using the vocab. """
272
+ return self.sp_tokenizer[token]
273
+
274
+ def _convert_id_to_token(self, index):
275
+ """Converts an index (integer) in a token (str) using the vocab."""
276
+ return self.sp_tokenizer[index]
277
+
278
+ def save_vocabulary(self, save_directory, filename_prefix=None):
279
+ """
280
+ Save the vocabulary and special tokens file to a directory.
281
+
282
+ Args:
283
+ save_directory (`str`):
284
+ The directory in which to save the vocabulary.
285
+ filename_prefix (`str`, *optional*):
286
+ An optional prefix to add to the named of the saved files.
287
+
288
+ Returns:
289
+ `Tuple(str)`: Paths to the files saved.
290
+ """
291
+ if os.path.isdir(save_directory):
292
+ vocab_file = os.path.join(
293
+ save_directory, self.vocab_files_names["vocab_file"]
294
+ )
295
+ else:
296
+ vocab_file = save_directory
297
+
298
+ with open(self.vocab_file, 'rb') as fin:
299
+ proto_str = fin.read()
300
+
301
+ with open(vocab_file, "wb") as writer:
302
+ writer.write(proto_str)
303
+
304
+ return (vocab_file,)
305
+
306
+ def build_inputs_with_special_tokens(
307
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
308
+ ) -> List[int]:
309
+ """
310
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
311
+ adding special tokens. A BERT sequence has the following format:
312
+
313
+ - single sequence: `[CLS] X [SEP]`
314
+ - pair of sequences: `[CLS] A [SEP] B [SEP]`
315
+
316
+ Args:
317
+ token_ids_0 (`List[int]`):
318
+ List of IDs to which the special tokens will be added.
319
+ token_ids_1 (`List[int]`, *optional*):
320
+ Optional second list of IDs for sequence pairs.
321
+
322
+ Returns:
323
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
324
+ """
325
+ mask_ids = self.sp_tokenizer[self.mask_token]
326
+ gmask_ids = self.sp_tokenizer[self.gmask_token]
327
+ eos_id = self.sp_tokenizer[self.eos_token]
328
+ if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0:
329
+ token_ids_0 += [gmask_ids]
330
+
331
+ if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids:
332
+ token_ids_0 += [self.sp_tokenizer[self.end_token]]
333
+
334
+ token_ids_0 += [self.sp_tokenizer[self.bos_token]]
335
+
336
+ if token_ids_1 is not None:
337
+ if not token_ids_1 or token_ids_1[-1] != eos_id:
338
+ token_ids_1 += [eos_id]
339
+ token_ids_0 += token_ids_1
340
+
341
+ return token_ids_0
342
+
343
+ def _pad(
344
+ self,
345
+ encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
346
+ max_length: Optional[int] = None,
347
+ padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
348
+ pad_to_multiple_of: Optional[int] = None,
349
+ return_attention_mask: Optional[bool] = None,
350
+ ) -> dict:
351
+ """
352
+ Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
353
+
354
+ Args:
355
+ encoded_inputs:
356
+ Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
357
+ max_length: maximum length of the returned list and optionally padding length (see below).
358
+ Will truncate by taking into account the special tokens.
359
+ padding_strategy: PaddingStrategy to use for padding.
360
+
361
+ - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
362
+ - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
363
+ - PaddingStrategy.DO_NOT_PAD: Do not pad
364
+ The tokenizer padding sides are defined in self.padding_side:
365
+
366
+ - 'left': pads on the left of the sequences
367
+ - 'right': pads on the right of the sequences
368
+ pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
369
+ This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
370
+ `>= 7.5` (Volta).
371
+ return_attention_mask:
372
+ (optional) Set to False to avoid returning attention mask (default: set to model specifics)
373
+ """
374
+ # Load from model defaults
375
+ bos_token_id = self.sp_tokenizer[self.bos_token]
376
+ mask_token_id = self.sp_tokenizer[self.mask_token]
377
+ gmask_token_id = self.sp_tokenizer[self.gmask_token]
378
+ assert self.padding_side == "left"
379
+
380
+ required_input = encoded_inputs[self.model_input_names[0]]
381
+ seq_length = len(required_input)
382
+
383
+ if padding_strategy == PaddingStrategy.LONGEST:
384
+ max_length = len(required_input)
385
+
386
+ if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
387
+ max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
388
+
389
+ needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
390
+
391
+ # Initialize attention mask if not present.
392
+ if max_length is not None:
393
+ if "attention_mask" not in encoded_inputs:
394
+ if bos_token_id in required_input:
395
+ context_length = required_input.index(bos_token_id)
396
+ else:
397
+ context_length = seq_length
398
+ attention_mask = np.ones((1, seq_length, seq_length))
399
+ attention_mask = np.tril(attention_mask)
400
+ attention_mask[:, :, :context_length] = 1
401
+ attention_mask = np.bool_(attention_mask < 0.5)
402
+ encoded_inputs["attention_mask"] = attention_mask
403
+
404
+ if "position_ids" not in encoded_inputs:
405
+ position_ids = np.arange(seq_length, dtype=np.int64)
406
+ mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
407
+ if mask_token in required_input:
408
+ mask_position = required_input.index(mask_token)
409
+ position_ids[context_length:] = mask_position
410
+ block_position_ids = np.concatenate(
411
+ [np.zeros(context_length, dtype=np.int64),
412
+ np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
413
+ encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
414
+
415
+ if needs_to_be_padded:
416
+ difference = max_length - len(required_input)
417
+
418
+ if "attention_mask" in encoded_inputs:
419
+ encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"],
420
+ pad_width=[(0, 0), (difference, 0), (difference, 0)],
421
+ mode='constant', constant_values=True)
422
+ if "token_type_ids" in encoded_inputs:
423
+ encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[
424
+ "token_type_ids"
425
+ ]
426
+ if "special_tokens_mask" in encoded_inputs:
427
+ encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"]
428
+ if "position_ids" in encoded_inputs:
429
+ encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"],
430
+ pad_width=[(0, 0), (difference, 0)])
431
+ encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
432
+
433
+ return encoded_inputs
tokenizer_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoTokenizer": [
4
+ "tokenization_chatglm.ChatGLMTokenizer",
5
+ null
6
+ ]
7
+ },
8
+ "bos_token": "<sop>",
9
+ "do_lower_case": false,
10
+ "end_token": "</s>",
11
+ "eos_token": "<eop>",
12
+ "gmask_token": "[gMASK]",
13
+ "mask_token": "[MASK]",
14
+ "model_max_length": 1000000000000000019884624838656,
15
+ "num_image_tokens": 0,
16
+ "pad_token": "<pad>",
17
+ "padding_side": "left",
18
+ "processor_class": "Blip2Processor",
19
+ "remove_space": false,
20
+ "special_tokens_map_file": null,
21
+ "tokenizer_class": "ChatGLMTokenizer",
22
+ "unk_token": "<unk>"
23
+ }