import torch import torch.nn as nn import torchvision from timm.models.vision_transformer import Block import math import gazelle.utils as utils from gazelle.backbone import DinoV2Backbone class GazeLLE(nn.Module): def __init__(self, backbone, inout=False, dim=256, num_layers=3, in_size=(448, 448), out_size=(64, 64)): super().__init__() self.backbone = backbone self.dim = dim self.num_layers = num_layers self.featmap_h, self.featmap_w = backbone.get_out_size(in_size) self.in_size = in_size self.out_size = out_size self.inout = inout self.linear = nn.Conv2d(backbone.get_dimension(), self.dim, 1) self.register_buffer("pos_embed", positionalencoding2d(self.dim, self.featmap_h, self.featmap_w).squeeze(dim=0).squeeze(dim=0)) self.transformer = nn.Sequential(*[ Block( dim=self.dim, num_heads=8, mlp_ratio=4, drop_path=0.1) for i in range(num_layers) ]) self.heatmap_head = nn.Sequential( nn.ConvTranspose2d(dim, dim, kernel_size=2, stride=2), nn.Conv2d(dim, 1, kernel_size=1, bias=False), nn.Sigmoid() ) self.head_token = nn.Embedding(1, self.dim) if self.inout: self.inout_head = nn.Sequential( nn.Linear(self.dim, 128), nn.ReLU(), nn.Dropout(0.1), nn.Linear(128, 1), nn.Sigmoid() ) self.inout_token = nn.Embedding(1, self.dim) def forward(self, input): # input["images"]: [B, 3, H, W] tensor of images # input["bboxes"]: list of lists of bbox tuples [[(xmin, ymin, xmax, ymax)]] per image in normalized image coords num_ppl_per_img = [len(bbox_list) for bbox_list in input["bboxes"]] x = self.backbone.forward(input["images"]) x = self.linear(x) x = x + self.pos_embed x = utils.repeat_tensors(x, num_ppl_per_img) # repeat image features along people dimension per image head_maps = torch.cat(self.get_input_head_maps(input["bboxes"]), dim=0).to(x.device) # [sum(N_p), 32, 32] head_map_embeddings = head_maps.unsqueeze(dim=1) * self.head_token.weight.unsqueeze(-1).unsqueeze(-1) x = x + head_map_embeddings x = x.flatten(start_dim=2).permute(0, 2, 1) # "b c h w -> b (h w) c" if self.inout: x = torch.cat([self.inout_token.weight.unsqueeze(dim=0).repeat(x.shape[0], 1, 1), x], dim=1) x = self.transformer(x) if self.inout: inout_tokens = x[:, 0, :] inout_preds = self.inout_head(inout_tokens).squeeze(dim=-1) inout_preds = utils.split_tensors(inout_preds, num_ppl_per_img) x = x[:, 1:, :] # slice off inout tokens from scene tokens x = x.reshape(x.shape[0], self.featmap_h, self.featmap_w, x.shape[2]).permute(0, 3, 1, 2) # b (h w) c -> b c h w x = self.heatmap_head(x).squeeze(dim=1) x = torchvision.transforms.functional.resize(x, self.out_size) heatmap_preds = utils.split_tensors(x, num_ppl_per_img) # resplit per image return {"heatmap": heatmap_preds, "inout": inout_preds if self.inout else None} def get_input_head_maps(self, bboxes): # bboxes: [[(xmin, ymin, xmax, ymax)]] - list of list of head bboxes per image head_maps = [] for bbox_list in bboxes: img_head_maps = [] for bbox in bbox_list: if bbox is None: # no bbox provided, use empty head map img_head_maps.append(torch.zeros(self.featmap_h, self.featmap_w)) else: xmin, ymin, xmax, ymax = bbox width, height = self.featmap_w, self.featmap_h xmin = round(xmin * width) ymin = round(ymin * height) xmax = round(xmax * width) ymax = round(ymax * height) head_map = torch.zeros((height, width)) head_map[ymin:ymax, xmin:xmax] = 1 img_head_maps.append(head_map) head_maps.append(torch.stack(img_head_maps)) return head_maps def get_gazelle_state_dict(self, include_backbone=False): if include_backbone: return self.state_dict() else: return {k: v for k, v in self.state_dict().items() if not k.startswith("backbone")} def load_gazelle_state_dict(self, ckpt_state_dict, include_backbone=False): current_state_dict = self.state_dict() keys1 = current_state_dict.keys() keys2 = ckpt_state_dict.keys() if not include_backbone: keys1 = set([k for k in keys1 if not k.startswith("backbone")]) keys2 = set([k for k in keys2 if not k.startswith("backbone")]) else: keys1 = set(keys1) keys2 = set(keys2) if len(keys2 - keys1) > 0: print("WARNING unused keys in provided state dict: ", keys2 - keys1) if len(keys1 - keys2) > 0: print("WARNING provided state dict does not have values for keys: ", keys1 - keys2) for k in list(keys1 & keys2): current_state_dict[k] = ckpt_state_dict[k] self.load_state_dict(current_state_dict, strict=False) # From https://github.com/wzlxjtu/PositionalEncoding2D/blob/master/positionalembedding2d.py def positionalencoding2d(d_model, height, width): """ :param d_model: dimension of the model :param height: height of the positions :param width: width of the positions :return: d_model*height*width position matrix """ if d_model % 4 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model)) pe = torch.zeros(d_model, height, width) # Each dimension use half of d_model d_model = int(d_model / 2) div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model)) pos_w = torch.arange(0., width).unsqueeze(1) pos_h = torch.arange(0., height).unsqueeze(1) pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) return pe # models def get_gazelle_model(model_name): factory = { "gazelle_dinov2_vitb14": gazelle_dinov2_vitb14, "gazelle_dinov2_vitl14": gazelle_dinov2_vitl14, "gazelle_dinov2_vitb14_inout": gazelle_dinov2_vitb14_inout, "gazelle_dinov2_vitl14_inout": gazelle_dinov2_vitl14_inout, } assert model_name in factory.keys(), "invalid model name" return factory[model_name]() def gazelle_dinov2_vitb14(): backbone = DinoV2Backbone('dinov2_vitb14') transform = backbone.get_transform((448, 448)) model = GazeLLE(backbone) return model, transform def gazelle_dinov2_vitl14(): backbone = DinoV2Backbone('dinov2_vitl14') transform = backbone.get_transform((448, 448)) model = GazeLLE(backbone) return model, transform def gazelle_dinov2_vitb14_inout(): backbone = DinoV2Backbone('dinov2_vitb14') transform = backbone.get_transform((448, 448)) model = GazeLLE(backbone, inout=True) return model, transform def gazelle_dinov2_vitl14_inout(): backbone = DinoV2Backbone('dinov2_vitl14') transform = backbone.get_transform((448, 448)) model = GazeLLE(backbone, inout=True) return model, transform