Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,082 Bytes
d59f323 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from xtuner.registry import BUILDER
from xtuner.model.utils import LoadWoInit, guess_load_checkpoint
from xtuner.model.llava import LLaVAModel
from mmengine.model import BaseModel
from mmengine import print_log
from projects.glamm.utils import prepare_inputs_labels_for_multimodal
from projects.glamm.utils import DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
class GLaMM(LLaVAModel):
def __init__(self,
use_activation_checkpointing=True,
tokenizer=None,
grounding_encoder=None,
region_encoder=None,
loss_mask=None,
loss_dice=None,
*args, **kwargs):
super(GLaMM, self).__init__(
*args, use_activation_checkpointing=use_activation_checkpointing, **kwargs)
self.use_activation_checkpointing = use_activation_checkpointing
self.tokenizer = BUILDER.build(tokenizer)
self._add_special_tokens()
self.grounding_encoder = BUILDER.build(grounding_encoder)
self.grounding_encoder.requires_grad_(False)
self.grounding_encoder.mask_decoder.requires_grad_(True)
if region_encoder is not None:
self.region_encoder = BUILDER.build(region_encoder)
in_dim = self.config.hidden_size
out_dim = self.grounding_encoder.mask_decoder.transformer_dim
self.text_hidden_fcs = nn.Sequential(
nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True),
nn.Linear(in_dim, out_dim), nn.Dropout(0.0)
)
self.loss_mask = BUILDER.build(loss_mask)
self.loss_dice = BUILDER.build(loss_dice)
def _add_special_tokens(self):
reg_tokens = ['<im_start>', '<im_end>', '<bbox>', '<point>']
segmentation_tokens = ['[SEG]']
phrase_tokens = ['<p>', '</p>']
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
num_new_tokens = self.tokenizer.add_tokens(
special_tokens, special_tokens=True)
if num_new_tokens > 0:
self.llm.resize_token_embeddings(len(self.tokenizer))
input_embeddings = self.llm.get_input_embeddings().weight.data
output_embeddings = self.llm.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0]
self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0]
self.bbox_token_idx = self.tokenizer("<bbox>", add_special_tokens=False).input_ids[0]
if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm:
self.llm.enable_input_require_grads()
def forward(self, data, data_samples=None, mode='loss'):
if 'pixel_values' in data:
visual_outputs = self.visual_encoder(
data['pixel_values'].to(self.visual_encoder.dtype),
output_hidden_states=True)
pixel_values = self.projector(
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
data['pixel_values'] = pixel_values
bboxes = data.pop('bboxes', None)
if bboxes is not None:
select_hidden_state_layer = -2
num_level_reg_features = 4
mlvl_reg_features = visual_outputs.hidden_states[select_hidden_state_layer::-3]
mlvl_reg_features = mlvl_reg_features[::-1]
mlvl_reg_features = mlvl_reg_features[-num_level_reg_features:]
mlvl_reg_features = [item[:, 1:] for item in mlvl_reg_features]
mlvl_reg_features = self.region_encoder(mlvl_reg_features, bboxes)
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
if bboxes is not None:
inputs_embeds = data['inputs_embeds']
for i, reg_feat in enumerate(mlvl_reg_features):
reg_mask = data['new_input_ids'][i] == self.bbox_token_idx
inputs_embeds[i][reg_mask] = reg_feat
data['inputs_embeds'] = inputs_embeds
if mode == 'loss':
return self.compute_loss(data, data_samples)
elif mode == 'predict':
return self.predict(data, data_samples)
elif mode == 'tensor':
return self._forward(data, data_samples)
else:
raise NotImplementedError
def compute_loss(self, data, data_samples=None):
g_pixel_values = data.pop('g_pixel_values', None)
gt_masks = data.pop('masks', None)
new_input_ids = data.pop('new_input_ids', None)
output = self.llm(output_hidden_states=True, **data)
if gt_masks is None:
return {'llm_loss': output.loss}
resize_list = [pixel.shape[-2:] for pixel in g_pixel_values]
ori_size_list = [mask.shape[-2:] for mask in gt_masks]
g_pixel_values = torch.stack([
self.grounding_encoder.preprocess(pixel) for pixel in g_pixel_values
])
image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values)
seg_token_mask = new_input_ids == self.seg_token_idx
hidden_states = output.hidden_states
hidden_states = self.text_hidden_fcs(hidden_states[-1])
pred_embeddings = hidden_states[seg_token_mask]
seg_token_counts = seg_token_mask.int().sum(-1)
pred_embeddings_list = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0)
pred_masks = self._generate_and_postprocess_masks(
pred_embeddings_list, image_embeddings, resize_list, ori_size_list)
bs = len(pred_masks)
loss_mask, loss_dice = 0, 0
for i in range(bs):
pred_mask = pred_masks[i]
gt_mask = gt_masks[i]
sam_loss_mask = self.loss_mask(pred_mask, gt_mask)
sam_loss_dice = self.loss_dice(pred_mask, gt_mask)
accuracy = torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean()
loss_mask += sam_loss_mask
loss_dice += sam_loss_dice
loss_dict = {
'loss_mask': loss_mask / bs,
'loss_dice': loss_dice / bs,
'accuracy': accuracy,
'llm_loss': output.loss,
}
return loss_dict
def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list=None, orig_size_list=None, infer=False):
pred_masks = []
for i, pred_embedding in enumerate(pred_embeddings):
sparse_embeddings, dense_embeddings = self.grounding_encoder.prompt_encoder(
points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1)
)
sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype)
low_res_masks, _ = self.grounding_encoder.mask_decoder(
image_embeddings=image_embeddings[i].unsqueeze(0),
image_pe=self.grounding_encoder.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
multimask_output=False, )
pred_mask = self.grounding_encoder.postprocess_masks(
low_res_masks, input_size=resize_list[i], original_size=orig_size_list[i], )
pred_masks.append(pred_mask[:, 0])
return pred_masks
def predict(self, data):
pass
def _forward(self, data, dta_samples=None):
outputs = self.llm(**data)
return outputs
|