|
import numpy as np |
|
from transformers.models.deformable_detr.modeling_deformable_detr import DeformableDetrMLPPredictionHead |
|
import torch.nn as nn |
|
import torch |
|
def PairDetr(model, num_queries, num_classes): |
|
in_features = model.class_embed[0].in_features |
|
model.model.query_position_embeddings = nn.Embedding(num_queries, 512) |
|
class_embed = nn.Linear(in_features, num_classes) |
|
bbox_embed = DeformableDetrMLPPredictionHead( |
|
input_dim=256, hidden_dim=256, output_dim=8, num_layers=3 |
|
) |
|
model.class_embed = nn.ModuleList([class_embed for _ in range(6)]) |
|
model.bbox_embed = nn.ModuleList([bbox_embed for _ in range(6)]) |
|
return model |
|
|
|
def inverse_sigmoid(x, eps=1e-5): |
|
x = x.clamp(min=0, max=1) |
|
x1 = x.clamp(min=eps) |
|
x2 = (1 - x).clamp(min=eps) |
|
return torch.log(x1 / x2) |
|
|
|
def forward(model, |
|
pixel_values, |
|
pixel_mask=None, |
|
decoder_attention_mask=None, |
|
encoder_outputs=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
labels=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None,) -> torch.Tensor: |
|
return_dict = return_dict if return_dict is not None else model.config.use_return_dict |
|
|
|
outputs = model.model( |
|
pixel_values, |
|
pixel_mask=pixel_mask, |
|
decoder_attention_mask=decoder_attention_mask, |
|
encoder_outputs=encoder_outputs, |
|
inputs_embeds=inputs_embeds, |
|
decoder_inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs.intermediate_hidden_states if return_dict else outputs[2] |
|
init_reference = outputs.init_reference_points if return_dict else outputs[0] |
|
inter_references = outputs.intermediate_reference_points if return_dict else outputs[3] |
|
outputs_classes = [] |
|
outputs_coords = [] |
|
cons = inverse_sigmoid(init_reference) |
|
for level in range(hidden_states.shape[1]): |
|
if level == 0: |
|
reference = init_reference |
|
else: |
|
reference = inter_references[:, level - 1] |
|
reference = inverse_sigmoid(reference) |
|
outputs_class = model.class_embed[level](hidden_states[:, level]) |
|
delta_bbox = model.bbox_embed[level](hidden_states[:, level]) |
|
if reference.shape[-1] == 4: |
|
delta_bbox[..., :4] += reference |
|
outputs_coord_logits = delta_bbox |
|
elif reference.shape[-1] == 2: |
|
delta_bbox[..., :2] += reference |
|
delta_bbox[..., 4:6] += cons |
|
outputs_coord_logits = delta_bbox |
|
else: |
|
raise ValueError(f"reference.shape[-1] should be 4 or 2, but got {reference.shape[-1]}") |
|
outputs_coord = outputs_coord_logits.sigmoid() |
|
outputs_classes.append(outputs_class) |
|
outputs_coords.append(outputs_coord) |
|
outputs_class = torch.stack(outputs_classes, dim=1) |
|
outputs_coord = torch.stack(outputs_coords, dim=1) |
|
|
|
logits = outputs_class[:, -1] |
|
pred_boxes = outputs_coord[:, -1] |
|
|
|
dict_outputs = { |
|
"logits":logits, |
|
"pred_boxes": pred_boxes, |
|
"init_reference_points": outputs.init_reference_points, |
|
} |
|
return dict_outputs |
|
|