pdiscoformer / utils /get_landmark_coordinates.py
ananthu-aniraj's picture
add initial files
20239f9
raw
history blame
1.56 kB
# This file contains the function to generate the center coordinates as tensor for the current net.
import torch
def landmark_coordinates(maps, grid_x=None, grid_y=None):
"""
Generate the center coordinates as tensor for the current net.
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19
Parameters
----------
maps: torch.Tensor
Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability
grid_x: torch.Tensor
The grid x coordinates
grid_y: torch.Tensor
The grid y coordinates
Returns
----------
loc_x: Tensor
The centroid x coordinates
loc_y: Tensor
The centroid y coordinates
grid_x: Tensor
grid_y: Tensor
"""
return_grid = False
if grid_x is None or grid_y is None:
return_grid = True
grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]),
torch.arange(maps.shape[3]), indexing='ij')
grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True)
map_sums = maps.sum(3).sum(2).detach()
maps_x = grid_x * maps
maps_y = grid_y * maps
loc_x = maps_x.sum(3).sum(2) / map_sums
loc_y = maps_y.sum(3).sum(2) / map_sums
if return_grid:
return loc_x, loc_y, grid_x, grid_y
else:
return loc_x, loc_y