xinyu1205 commited on
Commit
8c25077
·
1 Parent(s): 64c1dc7

Update models/tag2text.py

Browse files
Files changed (1) hide show
  1. models/tag2text.py +407 -349
models/tag2text.py CHANGED
@@ -1,429 +1,487 @@
1
  '''
2
- * Tag2Text
3
  * Written by Xinyu Huang
4
  '''
 
 
 
5
  import warnings
6
- warnings.filterwarnings("ignore")
7
 
8
- from models.vit import VisionTransformer, interpolate_pos_embed
9
- from models.swin_transformer import SwinTransformer, interpolate_relative_pos_embed
10
- from models.med import BertConfig, BertModel, BertLMHeadModel
11
- from transformers import BertTokenizer
12
-
13
- import torch
14
  from torch import nn
15
- import torch.nn.functional as F
16
-
17
- import os
18
- from urllib.parse import urlparse
19
- from timm.models.hub import download_cached_file
20
- from data.tag_class import tra_array
21
- import json
22
- import math
23
- import numpy as np
24
 
25
- def read_json(rpath):
26
- with open(rpath, 'r') as f:
27
- return json.load(f)
28
 
29
- # delete some tags that may disturb captioning
30
- # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
31
- delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359]
32
 
33
- # adjust thresholds for some tags
34
- # default threshold: 0.68
35
- # 2701: "person"; 2828: "man"; 1167: "woman";
36
- tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- class Tag2Text_Caption(nn.Module):
39
- def __init__(self,
40
- med_config = 'configs/med_config.json',
41
- image_size = 384,
42
- vit = 'base',
43
- vit_grad_ckpt = False,
44
- vit_ckpt_layer = 0,
45
- prompt = 'a picture of ',
46
- threshold = 0.68,
47
- ):
48
- """
49
  Args:
50
  med_config (str): path for the mixture of encoder-decoder model's configuration file
51
  image_size (int): input image size
52
  vit (str): model size of vision transformer
53
- """
 
 
54
  super().__init__()
55
 
56
- if vit=='swin_b':
 
57
  if image_size == 224:
58
- vision_config_path = 'configs/swin/config_swinB_224.json'
59
  elif image_size == 384:
60
- vision_config_path = 'configs/swin/config_swinB_384.json'
61
  vision_config = read_json(vision_config_path)
62
  assert image_size == vision_config['image_res']
63
  # assert config['patch_size'] == 32
64
  vision_width = vision_config['vision_width']
65
 
66
- self.visual_encoder = SwinTransformer(img_size=vision_config['image_res'],
67
- patch_size=4,
68
- in_chans=3,
69
- embed_dim=vision_config['embed_dim'],
70
- depths=vision_config['depths'],
71
- num_heads=vision_config['num_heads'],
72
- window_size=vision_config['window_size'],
73
- mlp_ratio=4.,
74
- qkv_bias=True,
75
- drop_rate=0.0,
76
- drop_path_rate=0.1,
77
- ape=False,
78
- patch_norm=True,
79
- use_checkpoint=False)
80
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  else:
82
- self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
 
83
 
 
 
84
 
85
- self.tokenizer = init_tokenizer()
 
 
 
 
 
86
 
87
- # create the decoder
88
  decoder_config = BertConfig.from_json_file(med_config)
89
- decoder_config.encoder_width = 768
90
- self.text_decoder = BertLMHeadModel(config=decoder_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- # create encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  encoder_config = BertConfig.from_json_file(med_config)
94
  encoder_config.encoder_width = vision_width
95
- self.tag_encoder = BertModel(config=encoder_config, add_pooling_layer=False)
96
-
 
 
 
 
 
 
 
 
97
  self.prompt = prompt
98
- self.prompt_length = len(self.tokenizer(self.prompt).input_ids)-1
99
 
100
- self.threshold = threshold
101
- num_features = 768
102
- self.num_class = 3429
103
 
104
- q2l_config = BertConfig.from_json_file('configs/q2l_config.json')
 
 
 
105
  q2l_config.encoder_width = vision_width
106
- self.vision_multi = BertModel(config=q2l_config, add_pooling_layer=False)
107
- self.vision_multi.resize_token_embeddings(len(self.tokenizer))
 
108
  self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
109
- self.fc = GroupWiseLinear(self.num_class, num_features, bias=True)
 
 
110
  self.del_selfattention()
111
 
112
- tie_encoder_decoder_weights(self.tag_encoder,self.vision_multi,'',' ')
113
- self.tag_array = tra_array
 
114
 
 
 
 
 
115
  self.class_threshold = torch.ones(self.num_class) * self.threshold
116
  for key,value in tag_thrshold.items():
117
  self.class_threshold[key] = value
118
-
 
 
 
 
 
 
 
119
  def del_selfattention(self):
120
- del self.vision_multi.embeddings
121
- for layer in self.vision_multi.encoder.layer:
122
  del layer.attention
123
-
124
- def generate(self, image, sample=False, num_beams=3, max_length=30, min_length=10, top_p=0.9, repetition_penalty=1.0, tag_input = None, return_tag_predict = False):
 
 
 
 
 
 
 
 
 
 
125
  image_embeds = self.visual_encoder(image)
126
- image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
 
127
 
128
- #==============generate tag==============#
129
  if tag_input == None:
130
- image_spatial_embeds = image_embeds[:,1:,:]
131
- image_cls_embeds = image_embeds[:,0,:]
132
 
133
  bs = image_spatial_embeds.shape[0]
134
- label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs,1,1)
135
- mlr_tagembedding = self.vision_multi(encoder_embeds = label_embed,
136
- encoder_hidden_states = image_embeds,
137
- encoder_attention_mask = image_atts,
138
- return_dict = False,
139
- mode = 'mlr',
140
- )
141
-
142
- logits = self.fc(mlr_tagembedding[0])
143
-
144
- # targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
145
- targets = torch.where(torch.sigmoid(logits) > self.class_threshold.to(image.device) , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
 
 
 
146
 
147
  tag = targets.cpu().numpy()
148
- tag[:,delete_tag_index] = 0
149
- bs = image.size(0)
 
 
150
  tag_input = []
151
  for b in range(bs):
152
  index = np.argwhere(tag[b] == 1)
153
- token = self.tag_array[index].squeeze(axis = 1)
154
- tag_input.append(' | '.join(token))
155
- #========================================#
156
-
 
 
157
  if not sample:
158
- image_embeds = image_embeds.repeat_interleave(num_beams,dim=0)
159
  tag_input_temp = []
160
  for tag in tag_input:
161
  for i in range(num_beams):
162
  tag_input_temp.append(tag)
163
  tag_input = tag_input_temp
164
 
 
 
165
 
166
- tag_input_tokenzier = self.tokenizer(tag_input, padding='max_length', truncation=True, max_length=40,
167
- return_tensors="pt").to(image.device)
 
 
 
 
 
168
  encoder_input_ids = tag_input_tokenzier.input_ids
169
- encoder_input_ids[:,0] = self.tokenizer.enc_token_id
170
-
171
- output_tagembedding = self.tag_encoder(encoder_input_ids,
172
- attention_mask = tag_input_tokenzier.attention_mask,
173
- encoder_hidden_states = image_embeds,
174
- encoder_attention_mask = image_atts,
175
- return_dict = True,
176
- )
177
-
 
 
 
178
  prompt = [self.prompt] * image.size(0)
179
- input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(image.device)
180
- input_ids[:,0] = self.tokenizer.bos_token_id
181
- input_ids = input_ids[:, :-1]
 
182
 
183
  if sample:
184
- #nucleus sampling
185
- model_kwargs = {"encoder_hidden_states": output_tagembedding.last_hidden_state, "encoder_attention_mask":None}
186
- outputs = self.text_decoder.generate(input_ids=input_ids,
187
- max_length=max_length,
188
- min_length=min_length,
189
- do_sample=True,
190
- top_p=top_p,
191
- num_return_sequences=1,
192
- eos_token_id=self.tokenizer.sep_token_id,
193
- pad_token_id=self.tokenizer.pad_token_id,
194
- repetition_penalty=1.1,
195
- **model_kwargs)
 
 
 
 
196
  else:
197
- #beam search
198
- model_kwargs = {"encoder_hidden_states": output_tagembedding.last_hidden_state, "encoder_attention_mask":None}
199
- outputs = self.text_decoder.generate(input_ids=input_ids,
200
- max_length=max_length,
201
- min_length=min_length,
202
- num_beams=num_beams,
203
- eos_token_id=self.tokenizer.sep_token_id,
204
- pad_token_id=self.tokenizer.pad_token_id,
205
- repetition_penalty=repetition_penalty,
206
- **model_kwargs)
207
-
208
- captions = []
 
 
 
 
209
  for output in outputs:
210
- caption = self.tokenizer.decode(output, skip_special_tokens=True)
211
  captions.append(caption[len(self.prompt):])
212
  if return_tag_predict == True:
213
- if sample:
214
- return captions, tag_input
215
- else:
216
- return captions, tag_input[0:int(len(tag_input)/num_beams)]
217
  return captions
218
 
219
 
220
- def tag2text_caption(pretrained='',**kwargs):
 
221
  model = Tag2Text_Caption(**kwargs)
222
  if pretrained:
223
  if kwargs['vit'] == 'swin_b':
224
- model,msg = load_checkpoint_swinbase(model,pretrained,kwargs)
225
  else:
226
- model,msg = load_checkpoint(model,pretrained)
227
- print('vit:',kwargs['vit'])
228
- print('msg_v2',msg)
229
- return model
230
-
231
-
232
- from typing import List
233
- def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, skip_key:str):
234
- uninitialized_encoder_weights: List[str] = []
235
- if decoder.__class__ != encoder.__class__:
236
- logger.info(
237
- f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized."
238
- )
239
-
240
- def tie_encoder_to_decoder_recursively(
241
- decoder_pointer: nn.Module,
242
- encoder_pointer: nn.Module,
243
- module_name: str,
244
- uninitialized_encoder_weights: List[str],
245
- skip_key: str,
246
- depth=0,
247
- ):
248
- assert isinstance(decoder_pointer, nn.Module) and isinstance(
249
- encoder_pointer, nn.Module
250
- ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module"
251
- if hasattr(decoder_pointer, "weight") and skip_key not in module_name:
252
- assert hasattr(encoder_pointer, "weight")
253
- encoder_pointer.weight = decoder_pointer.weight
254
- if hasattr(decoder_pointer, "bias"):
255
- assert hasattr(encoder_pointer, "bias")
256
- encoder_pointer.bias = decoder_pointer.bias
257
- print(module_name+' is tied')
258
- return
259
-
260
- encoder_modules = encoder_pointer._modules
261
- decoder_modules = decoder_pointer._modules
262
- if len(decoder_modules) > 0:
263
- assert (
264
- len(encoder_modules) > 0
265
- ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
266
-
267
- all_encoder_weights = set([module_name + "/" + sub_name for sub_name in encoder_modules.keys()])
268
- encoder_layer_pos = 0
269
- for name, module in decoder_modules.items():
270
- if name.isdigit():
271
- encoder_name = str(int(name) + encoder_layer_pos)
272
- decoder_name = name
273
- if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
274
- encoder_modules
275
- ) != len(decoder_modules):
276
- # this can happen if the name corresponds to the position in a list module list of layers
277
- # in this case the decoder has added a cross-attention that the encoder does not have
278
- # thus skip this step and subtract one layer pos from encoder
279
- encoder_layer_pos -= 1
280
- continue
281
- elif name not in encoder_modules:
282
- continue
283
- elif depth > 500:
284
- raise ValueError(
285
- "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model."
286
- )
287
- else:
288
- decoder_name = encoder_name = name
289
- tie_encoder_to_decoder_recursively(
290
- decoder_modules[decoder_name],
291
- encoder_modules[encoder_name],
292
- module_name + "/" + name,
293
- uninitialized_encoder_weights,
294
- skip_key,
295
- depth=depth + 1,
296
- )
297
- all_encoder_weights.remove(module_name + "/" + encoder_name)
298
-
299
- uninitialized_encoder_weights += list(all_encoder_weights)
300
-
301
- # tie weights recursively
302
- tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights, skip_key)
303
-
304
-
305
- class GroupWiseLinear(nn.Module):
306
- # could be changed to:
307
- # output = torch.einsum('ijk,zjk->ij', x, self.W)
308
- # or output = torch.einsum('ijk,jk->ij', x, self.W[0])
309
- def __init__(self, num_class, hidden_dim, bias=True):
310
- super().__init__()
311
- self.num_class = num_class
312
- self.hidden_dim = hidden_dim
313
- self.bias = bias
314
-
315
- self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
316
- if bias:
317
- self.b = nn.Parameter(torch.Tensor(1, num_class))
318
- self.reset_parameters()
319
-
320
- def reset_parameters(self):
321
- stdv = 1. / math.sqrt(self.W.size(2))
322
- for i in range(self.num_class):
323
- self.W[0][i].data.uniform_(-stdv, stdv)
324
- if self.bias:
325
- for i in range(self.num_class):
326
- self.b[0][i].data.uniform_(-stdv, stdv)
327
-
328
- def forward(self, x):
329
- # x: B,K,d
330
- x = (self.W * x).sum(-1)
331
- if self.bias:
332
- x = x + self.b
333
- return x
334
-
335
-
336
- def init_tokenizer():
337
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
338
- tokenizer.add_special_tokens({'bos_token':'[DEC]'})
339
- tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
340
- tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
341
- return tokenizer
342
-
343
-
344
- def create_vit(vit, image_size, use_grad_checkpointing=False, ckpt_layer=0, drop_path_rate=0):
345
-
346
- assert vit in ['base', 'large'], "vit parameter must be base or large"
347
- if vit=='base':
348
- vision_width = 768
349
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=12,
350
- num_heads=12, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
351
- drop_path_rate=0 or drop_path_rate
352
- )
353
- elif vit=='large':
354
- vision_width = 1024
355
- visual_encoder = VisionTransformer(img_size=image_size, patch_size=16, embed_dim=vision_width, depth=24,
356
- num_heads=16, use_grad_checkpointing=use_grad_checkpointing, ckpt_layer=ckpt_layer,
357
- drop_path_rate=0.1 or drop_path_rate
358
- )
359
- return visual_encoder, vision_width
360
-
361
- def is_url(url_or_filename):
362
- parsed = urlparse(url_or_filename)
363
- return parsed.scheme in ("http", "https")
364
-
365
- def load_checkpoint(model,url_or_filename):
366
- if is_url(url_or_filename):
367
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
368
- checkpoint = torch.load(cached_file, map_location='cpu')
369
- elif os.path.isfile(url_or_filename):
370
- checkpoint = torch.load(url_or_filename, map_location='cpu')
371
- else:
372
- raise RuntimeError('checkpoint url or path is invalid')
373
-
374
- state_dict = checkpoint['model']
375
-
376
- state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
377
- if 'visual_encoder_m.pos_embed' in model.state_dict().keys():
378
- state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],
379
- model.visual_encoder_m)
380
- for key in model.state_dict().keys():
381
- if key in state_dict.keys():
382
- if state_dict[key].shape!=model.state_dict()[key].shape:
383
- del state_dict[key]
384
-
385
- msg = model.load_state_dict(state_dict,strict=False)
386
- print('load checkpoint from %s'%url_or_filename)
387
- return model,msg
388
-
389
-
390
- def load_checkpoint_swinbase(model,url_or_filename,kwargs):
391
- if kwargs['image_size'] == 224:
392
- vision_config_path = 'configs/swin/config_swinB_224.json'
393
- elif kwargs['image_size'] == 384:
394
- vision_config_path = 'configs/swin/config_swinB_384.json'
395
- elif kwargs['image_size'] == 480:
396
- vision_config_path = 'configs/swin/config_swinB_480.json'
397
- elif kwargs['image_size'] == 576:
398
- vision_config_path = 'configs/swin/config_swinB_576.json'
399
- elif kwargs['image_size'] == 608:
400
- vision_config_path = 'configs/swin/config_swinB_608.json'
401
- window_size = read_json(vision_config_path)['window_size']
402
- print('--------------')
403
- print(url_or_filename)
404
- print('--------------')
405
- if is_url(url_or_filename):
406
- cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
407
- checkpoint = torch.load(cached_file, map_location='cpu')
408
- elif os.path.isfile(url_or_filename):
409
- checkpoint = torch.load(url_or_filename, map_location='cpu')
410
- else:
411
- raise RuntimeError('checkpoint url or path is invalid')
412
-
413
- state_dict = checkpoint['model']
414
-
415
- for k in list(state_dict.keys()):
416
- if 'relative_position_bias_table' in k:
417
- dst_num_pos = (2 * window_size - 1) ** 2
418
- state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k)
419
- elif ('relative_position_index' in k) or ('attn_mask' in k):
420
- del state_dict[k]
421
-
422
- msg = model.load_state_dict(state_dict,strict=False)
423
- print('load checkpoint from %s'%url_or_filename)
424
- return model,msg
425
-
426
-
427
 
428
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
 
1
  '''
2
+ * The Recognize Anything Model (RAM) & Tag2Text Model
3
  * Written by Xinyu Huang
4
  '''
5
+ import numpy as np
6
+ import json
7
+ import torch
8
  import warnings
 
9
 
 
 
 
 
 
 
10
  from torch import nn
11
+ from models.bert import BertConfig, BertModel, BertLMHeadModel
12
+ from models.vit import VisionTransformer
13
+ from models.swin_transformer import SwinTransformer
14
+ from data.ram_tag_list_threshold import ram_class_threshold
 
 
 
 
 
15
 
16
+ from models.utils import *
 
 
17
 
18
+ warnings.filterwarnings("ignore")
 
 
19
 
20
+ class RAM(nn.Module):
21
+ def __init__(self,
22
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
23
+ image_size=384,
24
+ vit='base',
25
+ vit_grad_ckpt=False,
26
+ vit_ckpt_layer=0,
27
+ prompt='a picture of ',
28
+ threshold=0.68,
29
+ delete_tag_index=[],
30
+ tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt',
31
+ tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt'):
32
+ r""" The Recognize Anything Model (RAM) inference module.
33
+ RAM is a strong image tagging model, which can recognize any common category with high accuracy.
34
+ Described in the paper " Recognize Anything: A Strong Image Tagging Model" https://recognize-anything.github.io/
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  Args:
37
  med_config (str): path for the mixture of encoder-decoder model's configuration file
38
  image_size (int): input image size
39
  vit (str): model size of vision transformer
40
+ threshold (int): tagging threshold
41
+ delete_tag_index (list): delete some tags that may disturb captioning
42
+ """
43
  super().__init__()
44
 
45
+ # create image encoder
46
+ if vit == 'swin_b':
47
  if image_size == 224:
48
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
49
  elif image_size == 384:
50
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
51
  vision_config = read_json(vision_config_path)
52
  assert image_size == vision_config['image_res']
53
  # assert config['patch_size'] == 32
54
  vision_width = vision_config['vision_width']
55
 
56
+ self.visual_encoder = SwinTransformer(
57
+ img_size=vision_config['image_res'],
58
+ patch_size=4,
59
+ in_chans=3,
60
+ embed_dim=vision_config['embed_dim'],
61
+ depths=vision_config['depths'],
62
+ num_heads=vision_config['num_heads'],
63
+ window_size=vision_config['window_size'],
64
+ mlp_ratio=4.,
65
+ qkv_bias=True,
66
+ drop_rate=0.0,
67
+ drop_path_rate=0.1,
68
+ ape=False,
69
+ patch_norm=True,
70
+ use_checkpoint=False)
71
+
72
+ elif vit == 'swin_l':
73
+ if image_size == 224:
74
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json'
75
+ elif image_size == 384:
76
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json'
77
+ vision_config = read_json(vision_config_path)
78
+ assert image_size == vision_config['image_res']
79
+ # assert config['patch_size'] == 32
80
+ vision_width = vision_config['vision_width']
81
+
82
+ self.visual_encoder = SwinTransformer(
83
+ img_size=vision_config['image_res'],
84
+ patch_size=4,
85
+ in_chans=3,
86
+ embed_dim=vision_config['embed_dim'],
87
+ depths=vision_config['depths'],
88
+ num_heads=vision_config['num_heads'],
89
+ window_size=vision_config['window_size'],
90
+ mlp_ratio=4.,
91
+ qkv_bias=True,
92
+ drop_rate=0.0,
93
+ drop_path_rate=0.1,
94
+ ape=False,
95
+ patch_norm=True,
96
+ use_checkpoint=False)
97
+
98
  else:
99
+ self.visual_encoder, vision_width = create_vit(
100
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
101
 
102
+ # create tokenzier
103
+ self.tokenizer = init_tokenizer()
104
 
105
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
106
+ # create image-tag interaction encoder
107
+ encoder_config = BertConfig.from_json_file(med_config)
108
+ encoder_config.encoder_width = 512
109
+ self.tag_encoder = BertModel(config=encoder_config,
110
+ add_pooling_layer=False)
111
 
112
+ # create image-tag-text decoder
113
  decoder_config = BertConfig.from_json_file(med_config)
114
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
115
+
116
+ self.delete_tag_index = delete_tag_index
117
+ self.prompt = prompt
118
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
119
+
120
+ # load tag list
121
+ self.tag_list = self.load_tag_list(tag_list)
122
+ self.tag_list_chinese = self.load_tag_list(tag_list_chinese)
123
+
124
+ # create image-tag recognition decoder
125
+ self.threshold = threshold
126
+ self.num_class = len(self.tag_list)
127
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
128
+ q2l_config.encoder_width = 512
129
+ self.tagging_head = BertModel(config=q2l_config,
130
+ add_pooling_layer=False)
131
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
132
+ self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
133
+
134
+ if q2l_config.hidden_size != 512:
135
+ self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size)
136
+ else:
137
+ self.wordvec_proj = nn.Identity()
138
+
139
+ self.fc = nn.Linear(q2l_config.hidden_size, 1)
140
+
141
+ self.del_selfattention()
142
+
143
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
144
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
145
+ ' ')
146
+ self.image_proj = nn.Linear(vision_width, 512)
147
+ self.label_embed = nn.Parameter(torch.load('data/textual_label_embedding.pth',map_location='cpu').float())
148
+
149
+ # adjust thresholds for some tags
150
+ self.class_threshold = torch.ones(self.num_class) * self.threshold
151
+ for key,value in enumerate(ram_class_threshold):
152
+ self.class_threshold[key] = value
153
+
154
+ def load_tag_list(self, tag_list_file):
155
+ with open(tag_list_file, 'r') as f:
156
+ tag_list = f.read().splitlines()
157
+ tag_list = np.array(tag_list)
158
+ return tag_list
159
+
160
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
161
+ def del_selfattention(self):
162
+ del self.tagging_head.embeddings
163
+ for layer in self.tagging_head.encoder.layer:
164
+ del layer.attention
165
+
166
+ def generate_tag(self,
167
+ image,
168
+ threshold=0.68,
169
+ tag_input=None,
170
+ ):
171
+
172
+ label_embed = torch.nn.functional.relu(self.wordvec_proj(self.label_embed))
173
+
174
+ image_embeds = self.image_proj(self.visual_encoder(image))
175
+ image_atts = torch.ones(image_embeds.size()[:-1],
176
+ dtype=torch.long).to(image.device)
177
+
178
+ # recognized image tags using image-tag recogntiion decoder
179
+ image_cls_embeds = image_embeds[:, 0, :]
180
+ image_spatial_embeds = image_embeds[:, 1:, :]
181
+
182
+ bs = image_spatial_embeds.shape[0]
183
+ label_embed = label_embed.unsqueeze(0).repeat(bs, 1, 1)
184
+ tagging_embed = self.tagging_head(
185
+ encoder_embeds=label_embed,
186
+ encoder_hidden_states=image_embeds,
187
+ encoder_attention_mask=image_atts,
188
+ return_dict=False,
189
+ mode='tagging',
190
+ )
191
+
192
+ logits = self.fc(tagging_embed[0]).squeeze(-1)
193
+
194
+ targets = torch.where(
195
+ torch.sigmoid(logits) > self.class_threshold.to(image.device),
196
+ torch.tensor(1.0).to(image.device),
197
+ torch.zeros(self.num_class).to(image.device))
198
+
199
+ tag = targets.cpu().numpy()
200
+ tag[:,self.delete_tag_index] = 0
201
+ tag_output = []
202
+ tag_output_chinese = []
203
+ for b in range(bs):
204
+ index = np.argwhere(tag[b] == 1)
205
+ token = self.tag_list[index].squeeze(axis=1)
206
+ tag_output.append(' | '.join(token))
207
+ token_chinese = self.tag_list_chinese[index].squeeze(axis=1)
208
+ tag_output_chinese.append(' | '.join(token_chinese))
209
+
210
+
211
+ return tag_output, tag_output_chinese
212
+
213
 
214
+ class Tag2Text_Caption(nn.Module):
215
+
216
+ def __init__(self,
217
+ med_config=f'{CONFIG_PATH}/configs/med_config.json',
218
+ image_size=384,
219
+ vit='base',
220
+ vit_grad_ckpt=False,
221
+ vit_ckpt_layer=0,
222
+ prompt='a picture of ',
223
+ threshold=0.68,
224
+ delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359],
225
+ tag_list=f'{CONFIG_PATH}/data/tag_list.txt'):
226
+ r""" Tag2Text inference module, both captioning and tagging are included.
227
+ Tag2Text is an efficient and controllable vision-language pre-training framework.
228
+ Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657
229
+
230
+ Args:
231
+ med_config (str): path for the mixture of encoder-decoder model's configuration file
232
+ image_size (int): input image size
233
+ vit (str): model size of vision transformer
234
+ threshold (int): tagging threshold
235
+ delete_tag_index (list): delete some tags that may disturb captioning
236
+ """
237
+ super().__init__()
238
+
239
+ # create image encoder
240
+ if vit == 'swin_b':
241
+ if image_size == 224:
242
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json'
243
+ elif image_size == 384:
244
+ vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json'
245
+ vision_config = read_json(vision_config_path)
246
+ assert image_size == vision_config['image_res']
247
+ # assert config['patch_size'] == 32
248
+ vision_width = vision_config['vision_width']
249
+
250
+ self.visual_encoder = SwinTransformer(
251
+ img_size=vision_config['image_res'],
252
+ patch_size=4,
253
+ in_chans=3,
254
+ embed_dim=vision_config['embed_dim'],
255
+ depths=vision_config['depths'],
256
+ num_heads=vision_config['num_heads'],
257
+ window_size=vision_config['window_size'],
258
+ mlp_ratio=4.,
259
+ qkv_bias=True,
260
+ drop_rate=0.0,
261
+ drop_path_rate=0.1,
262
+ ape=False,
263
+ patch_norm=True,
264
+ use_checkpoint=False)
265
+
266
+ else:
267
+ self.visual_encoder, vision_width = create_vit(
268
+ vit, image_size, vit_grad_ckpt, vit_ckpt_layer)
269
+
270
+ # create tokenzier
271
+ self.tokenizer = init_tokenizer()
272
+
273
+ # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder
274
+ # create image-tag interaction encoder
275
  encoder_config = BertConfig.from_json_file(med_config)
276
  encoder_config.encoder_width = vision_width
277
+ self.tag_encoder = BertModel(config=encoder_config,
278
+ add_pooling_layer=False)
279
+
280
+ # create image-tag-text decoder
281
+ decoder_config = BertConfig.from_json_file(med_config)
282
+ self.text_decoder = BertLMHeadModel(config=decoder_config)
283
+
284
+ # delete some tags that may disturb captioning
285
+ # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one"
286
+ self.delete_tag_index = delete_tag_index
287
  self.prompt = prompt
288
+ self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1
289
 
290
+ # load tag list
291
+ self.tag_list = self.load_tag_list(tag_list)
 
292
 
293
+ # create image-tag recognition decoder
294
+ self.threshold = threshold
295
+ self.num_class = len(self.tag_list)
296
+ q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json')
297
  q2l_config.encoder_width = vision_width
298
+ self.tagging_head = BertModel(config=q2l_config,
299
+ add_pooling_layer=False)
300
+ self.tagging_head.resize_token_embeddings(len(self.tokenizer))
301
  self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size)
302
+ self.fc = GroupWiseLinear(self.num_class,
303
+ q2l_config.hidden_size,
304
+ bias=True)
305
  self.del_selfattention()
306
 
307
+ # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder"
308
+ tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '',
309
+ ' ')
310
 
311
+ # adjust thresholds for some tags
312
+ # default threshold: 0.68
313
+ # 2701: "person"; 2828: "man"; 1167: "woman";
314
+ tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7}
315
  self.class_threshold = torch.ones(self.num_class) * self.threshold
316
  for key,value in tag_thrshold.items():
317
  self.class_threshold[key] = value
318
+
319
+ def load_tag_list(self, tag_list_file):
320
+ with open(tag_list_file, 'r') as f:
321
+ tag_list = f.read().splitlines()
322
+ tag_list = np.array(tag_list)
323
+ return tag_list
324
+
325
+ # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label
326
  def del_selfattention(self):
327
+ del self.tagging_head.embeddings
328
+ for layer in self.tagging_head.encoder.layer:
329
  del layer.attention
330
+
331
+ def generate(self,
332
+ image,
333
+ sample=False,
334
+ num_beams=3,
335
+ max_length=30,
336
+ min_length=10,
337
+ top_p=0.9,
338
+ repetition_penalty=1.0,
339
+ tag_input=None,
340
+ return_tag_predict=False):
341
+
342
  image_embeds = self.visual_encoder(image)
343
+ image_atts = torch.ones(image_embeds.size()[:-1],
344
+ dtype=torch.long).to(image.device)
345
 
346
+ # if not user specified tags, recognized image tags using image-tag recogntiion decoder
347
  if tag_input == None:
348
+ image_cls_embeds = image_embeds[:, 0, :]
349
+ image_spatial_embeds = image_embeds[:, 1:, :]
350
 
351
  bs = image_spatial_embeds.shape[0]
352
+ label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
353
+ tagging_embed = self.tagging_head(
354
+ encoder_embeds=label_embed,
355
+ encoder_hidden_states=image_embeds,
356
+ encoder_attention_mask=image_atts,
357
+ return_dict=False,
358
+ mode='tagging',
359
+ )
360
+
361
+ logits = self.fc(tagging_embed[0])
362
+
363
+ targets = torch.where(
364
+ torch.sigmoid(logits) > self.class_threshold,
365
+ torch.tensor(1.0).to(image.device),
366
+ torch.zeros(self.num_class).to(image.device))
367
 
368
  tag = targets.cpu().numpy()
369
+
370
+ # delete some tags that may disturb captioning
371
+ tag[:, self.delete_tag_index] = 0
372
+
373
  tag_input = []
374
  for b in range(bs):
375
  index = np.argwhere(tag[b] == 1)
376
+ token = self.tag_list[index].squeeze(axis=1)
377
+ tag_input.append(' | '.join(token))
378
+
379
+ tag_output = tag_input
380
+
381
+ # beam search for text generation(default)
382
  if not sample:
383
+ image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
384
  tag_input_temp = []
385
  for tag in tag_input:
386
  for i in range(num_beams):
387
  tag_input_temp.append(tag)
388
  tag_input = tag_input_temp
389
 
390
+ image_atts = torch.ones(image_embeds.size()[:-1],
391
+ dtype=torch.long).to(image.device)
392
 
393
+ # tokenizer input tags
394
+ tag_input_tokenzier = self.tokenizer(tag_input,
395
+ padding='max_length',
396
+ truncation=True,
397
+ max_length=40,
398
+ return_tensors="pt").to(
399
+ image.device)
400
  encoder_input_ids = tag_input_tokenzier.input_ids
401
+ encoder_input_ids[:, 0] = self.tokenizer.enc_token_id
402
+
403
+ # put input tag into image-tag interaction encoder to interact with image embeddings
404
+ output_tagembedding = self.tag_encoder(
405
+ encoder_input_ids,
406
+ attention_mask=tag_input_tokenzier.attention_mask,
407
+ encoder_hidden_states=image_embeds,
408
+ encoder_attention_mask=image_atts,
409
+ return_dict=True,
410
+ )
411
+
412
+ # prompt trick for better captioning, followed BLIP
413
  prompt = [self.prompt] * image.size(0)
414
+ input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
415
+ image.device)
416
+ input_ids[:, 0] = self.tokenizer.bos_token_id
417
+ input_ids = input_ids[:, :-1]
418
 
419
  if sample:
420
+ # nucleus sampling
421
+ model_kwargs = {
422
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
423
+ "encoder_attention_mask": None
424
+ }
425
+ outputs = self.text_decoder.generate(
426
+ input_ids=input_ids,
427
+ max_length=max_length,
428
+ min_length=min_length,
429
+ do_sample=True,
430
+ top_p=top_p,
431
+ num_return_sequences=1,
432
+ eos_token_id=self.tokenizer.sep_token_id,
433
+ pad_token_id=self.tokenizer.pad_token_id,
434
+ repetition_penalty=1.1,
435
+ **model_kwargs)
436
  else:
437
+ # beam search (default)
438
+ model_kwargs = {
439
+ "encoder_hidden_states": output_tagembedding.last_hidden_state,
440
+ "encoder_attention_mask": None
441
+ }
442
+ outputs = self.text_decoder.generate(
443
+ input_ids=input_ids,
444
+ max_length=max_length,
445
+ min_length=min_length,
446
+ num_beams=num_beams,
447
+ eos_token_id=self.tokenizer.sep_token_id,
448
+ pad_token_id=self.tokenizer.pad_token_id,
449
+ repetition_penalty=repetition_penalty,
450
+ **model_kwargs)
451
+
452
+ captions = []
453
  for output in outputs:
454
+ caption = self.tokenizer.decode(output, skip_special_tokens=True)
455
  captions.append(caption[len(self.prompt):])
456
  if return_tag_predict == True:
457
+ return captions, tag_output
 
 
 
458
  return captions
459
 
460
 
461
+ # load Tag2Text pretrained model parameters
462
+ def tag2text_caption(pretrained='', **kwargs):
463
  model = Tag2Text_Caption(**kwargs)
464
  if pretrained:
465
  if kwargs['vit'] == 'swin_b':
466
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
467
  else:
468
+ model, msg = load_checkpoint(model, pretrained)
469
+ print('vit:', kwargs['vit'])
470
+ print('msg', msg)
471
+ return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
 
474
+ # load RAM pretrained model parameters
475
+ def ram(pretrained='', **kwargs):
476
+ model = RAM(**kwargs)
477
+ if pretrained:
478
+ if kwargs['vit'] == 'swin_b':
479
+ model, msg = load_checkpoint_swinbase(model, pretrained, kwargs)
480
+ elif kwargs['vit'] == 'swin_l':
481
+ model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs)
482
+ else:
483
+ model, msg = load_checkpoint(model, pretrained)
484
+ print('vit:', kwargs['vit'])
485
+ print('msg', msg)
486
+ return model
487