Spaces:
Sleeping
Sleeping
# This file contains the function to generate the center coordinates as tensor for the current net. | |
import torch | |
def landmark_coordinates(maps, grid_x=None, grid_y=None): | |
""" | |
Generate the center coordinates as tensor for the current net. | |
Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19 | |
Parameters | |
---------- | |
maps: torch.Tensor | |
Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability | |
grid_x: torch.Tensor | |
The grid x coordinates | |
grid_y: torch.Tensor | |
The grid y coordinates | |
Returns | |
---------- | |
loc_x: Tensor | |
The centroid x coordinates | |
loc_y: Tensor | |
The centroid y coordinates | |
grid_x: Tensor | |
grid_y: Tensor | |
""" | |
return_grid = False | |
if grid_x is None or grid_y is None: | |
return_grid = True | |
grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]), | |
torch.arange(maps.shape[3]), indexing='ij') | |
grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) | |
grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) | |
map_sums = maps.sum(3).sum(2).detach() | |
maps_x = grid_x * maps | |
maps_y = grid_y * maps | |
loc_x = maps_x.sum(3).sum(2) / map_sums | |
loc_y = maps_y.sum(3).sum(2) / map_sums | |
if return_grid: | |
return loc_x, loc_y, grid_x, grid_y | |
else: | |
return loc_x, loc_y | |