Spaces:
Running
on
Zero
Running
on
Zero
""" | |
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
""" | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from timm.models.layers import DropPath, trunc_normal_ | |
import MinkowskiEngine as ME | |
from MinkowskiEngine import SparseTensor | |
from Swin3D.sparse_dl.attn.attn_coff import ( | |
SelfAttnAIOFunction, | |
PosEmb, | |
TableDims, | |
IndexMode, | |
PrecisionMode, | |
) | |
import Swin3D.sparse_dl.knn | |
from Swin3D.sparse_dl.knn import KNN | |
from .mink_layers import ( | |
assign_feats, | |
SparseTensorLayerNorm, | |
SparseTensorLinear, | |
) | |
def query_knn_feature( | |
K, src_xyz, query_xyz, src_feat, src_offset, query_offset, return_idx=False | |
): | |
""" | |
gather feature in the KNN neighborhood | |
""" | |
assert ( | |
src_xyz.is_contiguous() | |
and query_xyz.is_contiguous() | |
and src_feat.is_contiguous() | |
) | |
if query_xyz is None: | |
query_xyz = src_xyz | |
query_offset = src_offset | |
idx, _ = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) | |
n, m, c = src_xyz.shape[0], query_xyz.shape[0], src_feat.shape[1] | |
grouped_feat = src_feat[idx.view(-1).long(), :].view(m, K, c) | |
if return_idx: | |
return grouped_feat, idx | |
else: | |
return grouped_feat | |
def knn_linear_interpolation( | |
src_xyz, query_xyz, src_feat, src_offset, query_offset, K=3 | |
): | |
""" | |
interpolation feature using distance in KNN neighborhood | |
""" | |
N, C = query_xyz.shape[0], src_feat.shape[1] | |
assert ( | |
src_xyz.is_contiguous() | |
and query_xyz.is_contiguous() | |
and src_feat.is_contiguous() | |
) | |
# (N, K) | |
idx, dist = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) | |
weight = 1.0 / (dist + 1e-8) | |
norm = torch.sum(weight, dim=1, keepdim=True) | |
weight = weight / norm | |
query_feat = torch.zeros((N, C), dtype=src_feat.dtype, device=src_feat.device) | |
for i in range(K): | |
query_feat += src_feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) | |
return query_feat | |
def sparse_self_attention( | |
w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: str = "v1" | |
): | |
""" | |
Args: | |
indices [torch.Tensor]: sparse window index with shape [N, 2], N is the total | |
number of non-empty voxels with indices (window_id, within_window_id). window_id | |
is ordered and starts from 0; within_window_id is a sparse index to indicate the | |
offset of kernel_size ** 3. | |
feats [torch.Tensor]: sprase features of each non-empty voxel with shape [N, C] | |
Outputs: | |
[M, 3]: sparse indices of cofficient matrix (window_id, att_a_id, att_b_id). att_a_id | |
and att_b_id are the within_window_id | |
[M, 1]: the sparse coffient matrix | |
Spaces: | |
W: total number of windows | |
N: total number of input voxels | |
M: total number of output cofficients | |
""" | |
w_sizes_2 = w_sizes**2 | |
# w2n_indices - [W], mapping window index to window global offset in input | |
# space | |
w_cumsum = torch.cumsum(w_sizes, dim=-1) | |
w2n_indices = torch.cat( | |
[torch.zeros(1, dtype=w_cumsum.dtype, device=w_cumsum.device), w_cumsum[:-1]] | |
) | |
# w2m indices - [W], mapping window index to window global offset in output | |
# space | |
w2_cumsum = torch.cumsum(w_sizes_2, dim=-1) | |
w2m_indices = torch.cat( | |
[torch.zeros(1, dtype=w2_cumsum.dtype, device=w2_cumsum.device), w2_cumsum[:-1]] | |
) | |
# m2w indices - [M], mapping element global offset to the window index | |
m2w_indices = torch.zeros( | |
[w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device | |
) | |
m2w_offset = torch.zeros( | |
[w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device | |
) | |
m2w_indices[w2m_indices[1:]] = 1 | |
m2w_offset[w2m_indices[1:]] = w_sizes_2[:-1] | |
m2w_indices = torch.cumsum(m2w_indices, dim=-1) | |
m2w_offset = torch.cumsum(m2w_offset, dim=-1) | |
# m_indices = [M], element global offset in output space | |
m_indices = torch.arange( | |
0, w2_cumsum[-1], dtype=w_sizes.dtype, device=w_sizes.device | |
) | |
# m2n_indices - [M], mapping element global offset to the window global offset | |
# in input space | |
m2n_indices = w2n_indices[m2w_indices] | |
m_offset = m_indices - m2w_offset | |
m2w_sizes = w_sizes[m2w_indices] | |
# print_log_main("m_offset:", m_offset, m_offset.shape) | |
# print_log_main("m2n_indices:", m2n_indices, m2n_indices.shape) | |
y_offset = m2n_indices + m_offset % m2w_sizes | |
x_offset = m2n_indices + torch.div(m_offset, m2w_sizes, rounding_mode="floor") | |
# print_log_main("=================================") | |
# print_log_main(w_sizes[:5]) | |
# print_log_main(x_offset[:50]) | |
# print_log_main(y_offset[:50]) | |
# coord = torch.stack([m2w_indices, w_w_id[x_offset], w_w_id[y_offset]], axis=-1) | |
if protocol == "v1": | |
return x_offset, y_offset | |
elif protocol == "v2": | |
return x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices | |
class Mlp(nn.Module): | |
def __init__( | |
self, | |
in_features, | |
hidden_features=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
drop=0.0, | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class GridCoordsDown(nn.Module): | |
""" | |
downsample the grid coordinates | |
keep the nearest point to the average point of the downsampled grid | |
""" | |
def __init__(self, stride): | |
super().__init__() | |
self.stride = stride | |
self.avg_pool = ME.MinkowskiAvgPooling( | |
kernel_size=self.stride, stride=self.stride, dimension=3 | |
) | |
self.unpool = ME.MinkowskiPoolingTranspose( | |
kernel_size=stride, stride=stride, dimension=3 | |
) | |
self.max_pool = ME.MinkowskiMaxPooling( | |
kernel_size=self.stride, stride=self.stride, dimension=3 | |
) | |
def forward(self, coords_sp, sp, return_map=False): | |
device = sp.C.device | |
# is_pool = True means pooling map | |
# is_pool = False means conv map (query as center) | |
N = sp.shape[0] | |
avg_coords_sp = self.avg_pool(coords_sp) | |
dist_sp = self.unpool(avg_coords_sp) - coords_sp | |
dist = dist_sp.F | |
dist = -torch.sqrt((dist**2).sum(dim=1)).unsqueeze(1) | |
dist_sp = assign_feats(dist_sp, dist) | |
min_dist_sp = self.max_pool(dist_sp) | |
map_pair = sp.coordinate_manager.kernel_map( | |
dist_sp.coordinate_map_key, | |
min_dist_sp.coordinate_map_key, | |
stride=self.stride, | |
kernel_size=self.stride, | |
is_pool=True, | |
)[0] | |
in_map, out_map = map_pair | |
broad_min_dist_sp = self.unpool(min_dist_sp) | |
mask = (broad_min_dist_sp.F == dist_sp.F).squeeze(1) | |
in_map = in_map[mask].long() | |
out_map = out_map[mask].long() | |
downsample_map = torch.zeros(N, dtype=torch.long, device=device) - 1 | |
downsample_map[out_map] = in_map | |
assert (downsample_map >= 0).all() | |
assert (dist_sp.F[downsample_map] == min_dist_sp.F).all() | |
new_coords = coords_sp.F[downsample_map] | |
new_coords_sp = assign_feats(sp, new_coords) | |
if return_map: | |
return new_coords_sp, downsample_map | |
else: | |
return new_coords_sp | |
def get_offset(batch): | |
offset = [] | |
bs = batch.max() + 1 | |
for i in range(bs): | |
offset.append(torch.sum(batch == i)) | |
offset = torch.cuda.IntTensor(offset) | |
offset = offset.cumsum(dim=0).int() | |
return offset | |
class GridDownsample(nn.Module): | |
""" | |
use stride to downsample voxel | |
use grid maxpooling with kernel_size | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): | |
super().__init__() | |
self.kernel_size = kernel_size | |
self.stride = stride | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.sp_pool = ME.MinkowskiMaxPooling( | |
kernel_size=kernel_size, stride=stride, dimension=3 | |
) | |
self.coords_pool = GridCoordsDown(stride=stride) | |
self.norm = SparseTensorLayerNorm(in_channels) | |
self.linear = SparseTensorLinear(in_channels, out_channels) | |
def forward(self, sp, coords_sp): | |
sp_down = self.sp_pool(self.linear(self.norm(sp))) | |
coords_sp_down = self.coords_pool(coords_sp, sp_down) | |
return sp_down, coords_sp_down | |
def extra_repr(self) -> str: | |
return f"kernel_size={self.kernel_size}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" | |
class GridKNNDownsample(nn.Module): | |
""" | |
use stride to downsample voxel | |
use KNN to do maxpooling | |
""" | |
def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): | |
super().__init__() | |
self.stride = stride | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.k = 16 | |
self.sp_pool = ME.MinkowskiMaxPooling( | |
kernel_size=stride, stride=stride, dimension=3 | |
) | |
self.coords_pool = GridCoordsDown(stride=stride) | |
self.norm = nn.LayerNorm(in_channels) | |
self.linear = nn.Linear(in_channels, out_channels, bias=False) | |
self.pool = nn.MaxPool1d(self.k) | |
def forward(self, sp, coords_sp): | |
# calculate the voxel | |
sp_down = self.sp_pool(sp) | |
# for downsampled cRSE | |
coords_sp_down = self.coords_pool(coords_sp, sp_down) | |
offset = get_offset(sp.C[:, 0]) | |
n_offset = get_offset(sp_down.C[:, 0]) | |
xyz = coords_sp.F[:, 1:4].detach().contiguous() | |
n_xyz = coords_sp_down.F[:, 1:4].detach().contiguous() | |
feats = query_knn_feature(self.k, xyz, n_xyz, sp.F, offset, n_offset) | |
m, k, c = feats.shape | |
feats = ( | |
self.linear(self.norm(feats.view(m * k, c)).view(m, k, c)) | |
.transpose(1, 2) | |
.contiguous() | |
) | |
feats = self.pool(feats).squeeze(-1) | |
sp = assign_feats(sp_down, feats.float()) | |
coords_sp = coords_sp_down | |
return sp, coords_sp | |
def extra_repr(self) -> str: | |
return f"kernel_size={self.k}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" | |
class Upsample(nn.Module): | |
""" | |
upsample using trilinear interpolation | |
follower by attn block according to self.attn | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
num_heads, | |
window_size, | |
quant_size, | |
attn=True, | |
up_k=3, | |
cRSE="XYZ_RGB", | |
fp16_mode=0, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.linear1 = nn.Sequential( | |
nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels) | |
) | |
self.linear2 = nn.Sequential( | |
nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels) | |
) | |
self.up_k = up_k | |
self.attn = attn and window_size > 0 | |
if self.attn: | |
self.block = BasicLayer( | |
dim=out_channels, | |
depth=1, | |
num_heads=num_heads, | |
window_size=window_size, | |
quant_size=quant_size, | |
drop_path=0.1, | |
downsample=None, | |
out_channels=None, | |
cRSE=cRSE, | |
fp16_mode=fp16_mode, | |
) | |
def forward(self, sp, coords_sp, sp_up, coords_sp_up): | |
feats = sp.F | |
support_feats = sp_up.F | |
xyz = coords_sp.F[:, 1:4].detach().contiguous() | |
support_xyz = coords_sp_up.F[:, 1:4].detach().contiguous() | |
offset = get_offset(sp.C[:, 0]) | |
support_offset = get_offset(sp_up.C[:, 0]) | |
feats = self.linear1(support_feats) + knn_linear_interpolation( | |
xyz, support_xyz, self.linear2(feats), offset, support_offset, K=self.up_k | |
) | |
sp_up = assign_feats(sp_up, feats) | |
if self.attn: | |
sp_up, _, _ = self.block(sp_up, coords_sp_up) | |
return sp_up | |
def extra_repr(self) -> str: | |
return f"up_k={self.up_k}, in_channels={self.in_channels}, out_channels={self.out_channels}, attn={self.attn}" | |
class WindowAttention(nn.Module): | |
""" | |
Window based multi-head self attention (W-MSA) module with cRSE. | |
Designed for sparse structure | |
It supports both of shifted and non-shifted window. | |
Args: | |
dim (int): Number of input channels. | |
window_size (tuple[int]): The height and width of the window. | |
quant_size (int): quant_size for for finer cRSE table | |
num_heads (int): Number of attention heads. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set | |
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 | |
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 | |
cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' | |
fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 | |
0: fp32 forward and fp32 backward | |
1: fp16 forward and fp32 backward | |
2: fp16 forward and fp16 backward | |
""" | |
def __init__( | |
self, | |
dim, | |
window_size, | |
quant_size, | |
num_heads, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
cRSE="XYZ_RGB", | |
fp16_mode=0, | |
): | |
super().__init__() | |
self.dim = dim | |
self.window_size = window_size | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
# color in [-1, 1], color_windowsize = 2 | |
# normal in [-1, 1], normal_windowsize = 2 | |
self.color_windowsize = 2 | |
self.normal_windowsize = 2 | |
self.fp16_mode = fp16_mode | |
table_offsets = [] | |
self.cRSE = cRSE | |
if "XYZ" in cRSE: | |
self.xyz_quant_size = quant_size | |
quant_grid_length_xyz = window_size * self.xyz_quant_size | |
table_shape_xyz = (3, 2 * quant_grid_length_xyz, num_heads, head_dim) | |
self.query_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) | |
trunc_normal_(self.query_xyz_table, std=0.02) | |
self.key_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) | |
trunc_normal_(self.key_xyz_table, std=0.02) | |
self.value_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) | |
trunc_normal_(self.value_xyz_table, std=0.02) | |
table_offsets += [np.prod(table_shape_xyz[1:])] * 3 | |
if "RGB" in cRSE: | |
self.color_quant_size = quant_size * 2 | |
quant_grid_length_rgb = self.color_windowsize * self.color_quant_size | |
table_shape_rgb = (3, 2 * quant_grid_length_rgb, num_heads, head_dim) | |
self.query_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) | |
trunc_normal_(self.query_rgb_table, std=0.02) | |
self.key_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) | |
trunc_normal_(self.key_rgb_table, std=0.02) | |
self.value_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) | |
trunc_normal_(self.value_rgb_table, std=0.02) | |
table_offsets += [np.prod(table_shape_rgb[1:])] * 3 | |
if "NORM" in cRSE: | |
self.normal_quant_size = quant_size * 2 | |
quant_grid_length_norm = self.normal_windowsize * self.normal_quant_size | |
table_shape_norm = (3, 2 * quant_grid_length_norm, num_heads, head_dim) | |
self.query_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) | |
trunc_normal_(self.query_norm_table, std=0.02) | |
self.key_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) | |
trunc_normal_(self.key_norm_table, std=0.02) | |
self.value_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) | |
trunc_normal_(self.value_norm_table, std=0.02) | |
table_offsets += [np.prod(table_shape_norm[1:])] * 3 | |
self.table_offsets = table_offsets | |
self.quant_size = quant_size | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop, inplace=True) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop, inplace=True) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, feats: torch.Tensor, attn_args): | |
"""Forward function. | |
Args: | |
feats: N, C | |
attn_args: arguments for computing attention | |
""" | |
num_v, _ = feats.shape | |
num_sc = self.dim // self.num_heads | |
( | |
x_offset, | |
y_offset, | |
m2w_indices, | |
w_sizes, | |
w2n_indices, | |
n2n_indices, | |
w2m_indices, | |
n_coords, | |
) = attn_args | |
# Query, Key, Value | |
qkv = self.qkv(feats) | |
qkv = ( | |
qkv.reshape(num_v, 3, self.num_heads, num_sc) | |
.permute(1, 0, 2, 3) | |
.contiguous() | |
) | |
query, key, value = qkv[0], qkv[1], qkv[2] # [N, num_heads, C//num_heads] | |
query = query * self.scale | |
table_offsets = torch.IntTensor(self.table_offsets).cuda() | |
query_table, key_table, value_table = [], [], [] | |
n_cRSE = [] | |
if "XYZ" in self.cRSE: | |
n_xyz = n_coords[:, 0:3] | |
n_xyz = n_xyz * self.quant_size | |
n_cRSE.append(n_xyz) | |
query_table.append(self.query_xyz_table.view(-1)) | |
key_table.append(self.key_xyz_table.view(-1)) | |
value_table.append(self.value_xyz_table.view(-1)) | |
if "RGB" in self.cRSE: | |
n_rgb = n_coords[:, 3:6] | |
n_rgb = n_rgb * self.color_quant_size | |
n_cRSE.append(n_rgb) | |
query_table.append(self.query_rgb_table.view(-1)) | |
key_table.append(self.key_rgb_table.view(-1)) | |
value_table.append(self.value_rgb_table.view(-1)) | |
if "NORM" in self.cRSE: | |
n_norm = n_coords[:, 6:9] | |
n_norm = n_norm * self.normal_quant_size | |
n_cRSE.append(n_norm) | |
query_table.append(self.query_norm_table.view(-1)) | |
key_table.append(self.key_norm_table.view(-1)) | |
value_table.append(self.value_norm_table.view(-1)) | |
n_cRSE = torch.cat(n_cRSE, dim=1) | |
indices = [m2w_indices, w_sizes, w2m_indices, w2n_indices, n2n_indices, n_cRSE] | |
query_table = torch.cat(query_table) | |
key_table = torch.cat(key_table) | |
value_table = torch.cat(value_table) | |
if self.fp16_mode == 0: | |
# do not use fp16 | |
# cast q,k,v to fp32 in forward and backward | |
fp16_mode = PrecisionMode.HALF_NONE | |
elif self.fp16_mode == 1: | |
# use fp16 only in forward | |
fp16_mode = PrecisionMode.HALF_FORWARD | |
elif self.fp16_mode == 2: | |
# use fp16 both in forward and backward | |
fp16_mode = PrecisionMode.HALF_ALL | |
updated_values = SelfAttnAIOFunction.apply( | |
query, | |
key, | |
value, | |
query_table, | |
key_table, | |
value_table, | |
table_offsets, | |
indices, | |
PosEmb.SEPARATE, | |
TableDims.D0, | |
IndexMode.INDIRECT, | |
fp16_mode, | |
) | |
updated_values = updated_values.flatten(1) | |
updated_feats = updated_values.view(num_v, self.dim) | |
updated_feats = self.proj(updated_feats) | |
updated_feats = self.proj_drop(updated_feats) # [N, C] | |
return updated_feats | |
class SwinTransformerBlock(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
window_size, | |
quant_size, | |
drop_path=0.0, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
qk_scale=None, | |
act_layer=nn.GELU, | |
norm_layer=nn.LayerNorm, | |
cRSE="XYZ_RGB", | |
fp16_mode=0, | |
): | |
super().__init__() | |
self.window_size = window_size | |
self.norm1 = norm_layer(dim) | |
self.attn = WindowAttention( | |
dim, | |
window_size=self.window_size, | |
quant_size=quant_size, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
cRSE=cRSE, | |
fp16_mode=fp16_mode, | |
) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp( | |
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer | |
) | |
def forward(self, feats, attn_args): | |
# feats: [N, c] | |
short_cut = feats | |
feats = self.norm1(feats) | |
feats = self.attn(feats, attn_args) # [N, c] | |
feats = short_cut + self.drop_path(feats) | |
feats = feats + self.drop_path(self.mlp(self.norm2(feats))) | |
return feats | |
class BasicLayer(nn.Module): | |
"""A basic Swin3D layer for one stage. | |
Args: | |
dim (int): Number of input channels. | |
depth (int): Number of blocks. | |
num_heads (int): Number of attention heads. | |
window_size (int): Local window size. | |
quant_size (int): quant_size for for finer cRSE table | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True | |
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. | |
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 | |
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm | |
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None | |
cRSE (str | 'XYZ', 'XYZ_RGB', 'XYZ_RGB_NORM'): cRSE mode. Default: 'XYZ_RGB' | |
fp16_mode (int | 0, 1, 2): fp16 mode for attention module, Default: 0 | |
0: fp32 forward and fp32 backward | |
1: fp16 forward and fp32 backward | |
2: fp16 forward and fp16 backward | |
""" | |
def __init__( | |
self, | |
dim, | |
depth, | |
num_heads, | |
window_size, | |
quant_size, | |
out_channels=None, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
qk_scale=None, | |
drop_path=0.0, | |
norm_layer=nn.LayerNorm, | |
downsample=None, | |
down_stride=2, | |
cRSE="XYZ_RGB", | |
fp16_mode=0, | |
): | |
super().__init__() | |
self.window_size = window_size | |
self.depth = depth | |
self.dim = dim | |
self.num_heads = num_heads | |
self.quant_size = quant_size | |
self.cRSE = cRSE | |
self.fp16_mode = fp16_mode | |
self.shift_size = window_size // 2 | |
# build blocks | |
self.blocks = nn.ModuleList( | |
[ | |
SwinTransformerBlock( | |
dim, | |
num_heads, | |
window_size, | |
quant_size, | |
drop_path=( | |
drop_path[i] if isinstance(drop_path, list) else drop_path | |
), | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
norm_layer=norm_layer, | |
cRSE=cRSE, | |
fp16_mode=fp16_mode, | |
) | |
for i in range(depth) | |
] | |
) | |
self.pool = ME.MinkowskiMaxPooling( | |
kernel_size=self.window_size, stride=self.window_size, dimension=3 | |
) | |
if downsample is not None: | |
if out_channels is None: | |
out_channels = dim * 2 | |
self.downsample = downsample( | |
dim, out_channels, kernel_size=down_stride, stride=down_stride | |
) | |
else: | |
self.downsample = None | |
def get_map_pair(self, sp): | |
""" | |
use minkowski pool to calculate windows | |
get the mapping from voxel to window | |
""" | |
window_size = [self.window_size] * 3 | |
pool_sp = self.pool(sp) | |
windows = pool_sp.C | |
window_N = windows.shape[0] | |
stride_in = sp.coordinate_map_key.get_tensor_stride() | |
x, y, z = [ | |
torch.arange(window_size[i], device=self.device) * stride_in[i] | |
for i in range(3) | |
] | |
x, y, z = torch.meshgrid(x, y, z) | |
i = torch.zeros_like(x, device=self.device) | |
local_window = torch.stack([i, x, y, z], dim=-1).flatten(0, -2) | |
all_windows = windows.unsqueeze(1) + local_window.unsqueeze(0) | |
all_windows = all_windows.flatten(0, -2).int() | |
cm = sp.coordinate_manager | |
query_key, (map, inverse_map) = cm.insert_and_map( | |
all_windows, tensor_stride=stride_in | |
) | |
map_pair = cm.kernel_map(query_key, sp.coordinate_map_key, kernel_size=1)[0] | |
return map_pair, window_N | |
def get_window_mapping(self, sp): | |
""" | |
calculate the relationshape in the window: | |
w_w_id: non-empty idx inside the window(sorted by window) | |
w_w_xyz: xyz inside the window(sorted by window) | |
nempty_num: non-empty voxel number in each window | |
sort_idx: sort voxel according to window_id, to gather the point inside the same window | |
inv_sort_idx: inverse sort index | |
""" | |
map_pair, window_N = self.get_map_pair(sp) | |
window_size = self.window_size | |
nW = window_size**3 | |
in_map, out_map = map_pair | |
in_map, sort_idx = torch.sort(in_map) | |
# assert out_map == arange(out_map.shape[0]) | |
out_map = out_map[sort_idx] | |
sort_idx = out_map.long() | |
inv_sort_idx = torch.zeros_like(sort_idx) | |
inv_sort_idx[sort_idx] = torch.arange( | |
sort_idx.shape[0], dtype=sort_idx.dtype, device=self.device | |
) | |
N = window_N * nW | |
v2w_mask = torch.zeros(N, dtype=torch.bool, device=self.device) | |
w_id = ( | |
torch.arange(window_N, dtype=torch.long, device=self.device) | |
.unsqueeze(1) | |
.repeat(1, nW) | |
.view(-1) | |
) | |
w_w_id = ( | |
torch.arange(nW, dtype=torch.long, device=self.device) | |
.unsqueeze(0) | |
.repeat(window_N, 1) | |
.view(-1) | |
) | |
v2w_mask[in_map.long()] = True | |
nempty_num = v2w_mask.view(-1, nW).sum(dim=-1) | |
w_id = w_id[in_map.long()] | |
w_w_id = w_w_id[in_map.long()] | |
w_w_xyz = torch.stack( | |
[ | |
w_w_id // window_size // window_size, | |
w_w_id // window_size % window_size, | |
w_w_id % window_size, | |
], | |
dim=-1, | |
) | |
return w_w_id, w_w_xyz, nempty_num, sort_idx, inv_sort_idx | |
def get_index01(self, sp, local_xyz, colors): | |
""" | |
calculate the arguments for sparse attention | |
""" | |
( | |
w_w_id, | |
w_w_xyz, | |
nempty_num, | |
n2n_indices, | |
inv_sort_idx, | |
) = self.get_window_mapping(sp) | |
local_xyz = local_xyz[n2n_indices] | |
colors = colors[n2n_indices] | |
# recover the relative pos in the voxel | |
n_coords = w_w_xyz + local_xyz | |
n_coords = torch.cat([n_coords, colors], dim=1) | |
( | |
x_offset, | |
y_offset, | |
m2w_indices, | |
w_sizes, | |
w2n_indices, | |
w2m_indices, | |
) = sparse_self_attention(w_w_id, nempty_num, protocol="v2") | |
return ( | |
x_offset, | |
y_offset, | |
m2w_indices, | |
w_sizes, | |
w2n_indices, | |
n2n_indices, | |
w2m_indices, | |
n_coords, | |
) | |
def get_shifted_sp(self, sp): | |
""" | |
get the shifted sparse tensor for shift-window | |
""" | |
stride_in = sp.coordinate_map_key.get_tensor_stride() | |
shift_size = self.shift_size * stride_in[0] | |
shifted_C = sp.C.clone() | |
shifted_C[:, 1:] += shift_size | |
shifted_sp = SparseTensor( | |
features=sp.F, | |
coordinates=shifted_C, | |
device=self.device, | |
tensor_stride=stride_in, | |
) | |
return shifted_sp | |
def get_window_pos(self, sp): | |
stride_in = sp.coordinate_map_key.get_tensor_stride() | |
return (sp.C[:, 1:] / stride_in[0]) % self.window_size | |
def forward(self, sp, coords_sp): | |
""" | |
xyz: position of point inside voxel | |
colors: other signal for cRSE, include colors and normals | |
local_xyz: relative position of point indide voxel(using for finer cRSE table) | |
""" | |
colors = coords_sp.F[:, 4:] | |
xyz = coords_sp.F[:, :4] | |
local_xyz = (xyz - coords_sp.C)[ | |
:, 1: | |
] / coords_sp.coordinate_map_key.get_tensor_stride()[0] | |
self.device = sp.device | |
sp_shift = self.get_shifted_sp(sp) | |
attn_args = self.get_index01(sp, local_xyz, colors) | |
attn_args_shift = self.get_index01(sp_shift, local_xyz, colors) | |
feats = sp.F | |
for i, blk in enumerate(self.blocks): | |
attn_args_blk = attn_args if i % 2 == 0 else attn_args_shift | |
feats = blk(feats, attn_args_blk) # [N, C] | |
sp = assign_feats(sp, feats) | |
if self.downsample is not None: | |
sp_down, coords_sp = self.downsample(sp, coords_sp) | |
return sp, sp_down, coords_sp | |
else: | |
return sp, sp, coords_sp | |
def extra_repr(self) -> str: | |
return f"window_size={self.window_size}, depth={self.depth}, channel={self.dim}, num_heads={self.num_heads}, quant_size={self.quant_size}, cRSE={self.cRSE}, fp16_mode={self.fp16_mode}" | |