import torch import torch.nn as nn import MinkowskiEngine as ME from MinkowskiEngine import SparseTensor from timm.models.layers import trunc_normal_ from .mink_layers import MinkConvBNRelu, MinkResBlock from .swin3d_layers import GridDownsample, GridKNNDownsample, BasicLayer, Upsample from pointcept.models.builder import MODELS from pointcept.models.utils import offset2batch, batch2offset @MODELS.register_module("Swin3D-v1m1") class Swin3DUNet(nn.Module): def __init__( self, in_channels, num_classes, base_grid_size, depths, channels, num_heads, window_sizes, quant_size, drop_path_rate=0.2, up_k=3, num_layers=5, stem_transformer=True, down_stride=2, upsample="linear", knn_down=True, cRSE="XYZ_RGB", fp16_mode=0, ): super().__init__() dpr = [ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) ] # stochastic depth decay rule if knn_down: downsample = GridKNNDownsample else: downsample = GridDownsample self.cRSE = cRSE if stem_transformer: self.stem_layer = MinkConvBNRelu( in_channels=in_channels, out_channels=channels[0], kernel_size=3, stride=1, ) self.layer_start = 0 else: self.stem_layer = nn.Sequential( MinkConvBNRelu( in_channels=in_channels, out_channels=channels[0], kernel_size=3, stride=1, ), MinkResBlock(in_channels=channels[0], out_channels=channels[0]), ) self.downsample = downsample( channels[0], channels[1], kernel_size=down_stride, stride=down_stride ) self.layer_start = 1 self.layers = nn.ModuleList( [ BasicLayer( dim=channels[i], depth=depths[i], num_heads=num_heads[i], window_size=window_sizes[i], quant_size=quant_size, drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], downsample=downsample if i < num_layers - 1 else None, down_stride=down_stride if i == 0 else 2, out_channels=channels[i + 1] if i < num_layers - 1 else None, cRSE=cRSE, fp16_mode=fp16_mode, ) for i in range(self.layer_start, num_layers) ] ) if "attn" in upsample: up_attn = True else: up_attn = False self.upsamples = nn.ModuleList( [ Upsample( channels[i], channels[i - 1], num_heads[i - 1], window_sizes[i - 1], quant_size, attn=up_attn, up_k=up_k, cRSE=cRSE, fp16_mode=fp16_mode, ) for i in range(num_layers - 1, 0, -1) ] ) self.classifier = nn.Sequential( nn.Linear(channels[0], channels[0]), nn.BatchNorm1d(channels[0]), nn.ReLU(inplace=True), nn.Linear(channels[0], num_classes), ) self.num_classes = num_classes self.base_grid_size = base_grid_size self.init_weights() def forward(self, data_dict): grid_coord = data_dict["grid_coord"] feat = data_dict["feat"] coord_feat = data_dict["coord_feat"] coord = data_dict["coord"] offset = data_dict["offset"] batch = offset2batch(offset) in_field = ME.TensorField( features=torch.cat( [ batch.unsqueeze(-1), coord / self.base_grid_size, coord_feat / 1.001, feat, ], dim=1, ), coordinates=torch.cat([batch.unsqueeze(-1).int(), grid_coord.int()], dim=1), quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE, minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED, device=feat.device, ) sp = in_field.sparse() coords_sp = SparseTensor( features=sp.F[:, : coord_feat.shape[-1] + 4], coordinate_map_key=sp.coordinate_map_key, coordinate_manager=sp.coordinate_manager, ) sp = SparseTensor( features=sp.F[:, coord_feat.shape[-1] + 4 :], coordinate_map_key=sp.coordinate_map_key, coordinate_manager=sp.coordinate_manager, ) sp_stack = [] coords_sp_stack = [] sp = self.stem_layer(sp) if self.layer_start > 0: sp_stack.append(sp) coords_sp_stack.append(coords_sp) sp, coords_sp = self.downsample(sp, coords_sp) for i, layer in enumerate(self.layers): coords_sp_stack.append(coords_sp) sp, sp_down, coords_sp = layer(sp, coords_sp) sp_stack.append(sp) assert (coords_sp.C == sp_down.C).all() sp = sp_down sp = sp_stack.pop() coords_sp = coords_sp_stack.pop() for i, upsample in enumerate(self.upsamples): sp_i = sp_stack.pop() coords_sp_i = coords_sp_stack.pop() sp = upsample(sp, coords_sp, sp_i, coords_sp_i) coords_sp = coords_sp_i output = self.classifier(sp.slice(in_field).F) return output def init_weights(self): """Initialize the weights in backbone.""" def _init_weights(m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm) or isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) self.apply(_init_weights)