Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Point Transformer - V3 Mode1 | |
Pointcept detached version | |
Author: Xiaoyang Wu ([email protected]) | |
Please cite our work if the code is helpful to you. | |
""" | |
import sys | |
from functools import partial | |
from addict import Dict | |
import math | |
import torch | |
import torch.nn as nn | |
import spconv.pytorch as spconv | |
import torch_scatter | |
from timm.models.layers import DropPath | |
from collections import OrderedDict | |
import numpy as np | |
import torch.nn.functional as F | |
try: | |
import flash_attn | |
except ImportError: | |
flash_attn = None | |
from model.serialization import encode | |
from huggingface_hub import PyTorchModelHubMixin | |
def offset2bincount(offset): | |
return torch.diff( | |
offset, prepend=torch.tensor([0], device=offset.device, dtype=torch.long) | |
) | |
def offset2batch(offset): | |
bincount = offset2bincount(offset) | |
return torch.arange( | |
len(bincount), device=offset.device, dtype=torch.long | |
).repeat_interleave(bincount) | |
def batch2offset(batch): | |
return torch.cumsum(batch.bincount(), dim=0).long() | |
class Point(Dict): | |
""" | |
Point Structure of Pointcept | |
A Point (point cloud) in Pointcept is a dictionary that contains various properties of | |
a batched point cloud. The property with the following names have a specific definition | |
as follows: | |
- "coord": original coordinate of point cloud; | |
- "grid_coord": grid coordinate for specific grid size (related to GridSampling); | |
Point also support the following optional attributes: | |
- "offset": if not exist, initialized as batch size is 1; | |
- "batch": if not exist, initialized as batch size is 1; | |
- "feat": feature of point cloud, default input of model; | |
- "grid_size": Grid size of point cloud (related to GridSampling); | |
(related to Serialization) | |
- "serialized_depth": depth of serialization, 2 ** depth * grid_size describe the maximum of point cloud range; | |
- "serialized_code": a list of serialization codes; | |
- "serialized_order": a list of serialization order determined by code; | |
- "serialized_inverse": a list of inverse mapping determined by code; | |
(related to Sparsify: SpConv) | |
- "sparse_shape": Sparse shape for Sparse Conv Tensor; | |
- "sparse_conv_feat": SparseConvTensor init with information provide by Point; | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# If one of "offset" or "batch" do not exist, generate by the existing one | |
if "batch" not in self.keys() and "offset" in self.keys(): | |
self["batch"] = offset2batch(self.offset) | |
elif "offset" not in self.keys() and "batch" in self.keys(): | |
self["offset"] = batch2offset(self.batch) | |
def serialization(self, order="z", depth=None, shuffle_orders=False): | |
""" | |
Point Cloud Serialization | |
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] | |
""" | |
assert "batch" in self.keys() | |
if "grid_coord" not in self.keys(): | |
# if you don't want to operate GridSampling in data augmentation, | |
# please add the following augmentation into your pipline: | |
# dict(type="Copy", keys_dict={"grid_size": 0.01}), | |
# (adjust `grid_size` to what your want) | |
assert {"grid_size", "coord"}.issubset(self.keys()) | |
self["grid_coord"] = torch.div( | |
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" | |
).int() | |
if depth is None: | |
# Adaptive measure the depth of serialization cube (length = 2 ^ depth) | |
depth = int(self.grid_coord.max()).bit_length() | |
self["serialized_depth"] = depth | |
# Maximum bit length for serialization code is 63 (int64) | |
assert depth * 3 + len(self.offset).bit_length() <= 63 | |
# Here we follow OCNN and set the depth limitation to 16 (48bit) for the point position. | |
# Although depth is limited to less than 16, we can encode a 655.36^3 (2^16 * 0.01) meter^3 | |
# cube with a grid size of 0.01 meter. We consider it is enough for the current stage. | |
# We can unlock the limitation by optimizing the z-order encoding function if necessary. | |
assert depth <= 16 | |
# The serialization codes are arranged as following structures: | |
# [Order1 ([n]), | |
# Order2 ([n]), | |
# ... | |
# OrderN ([n])] (k, n) | |
code = [ | |
encode(self.grid_coord, self.batch, depth, order=order_) for order_ in order | |
] | |
code = torch.stack(code) | |
order = torch.argsort(code) | |
inverse = torch.zeros_like(order).scatter_( | |
dim=1, | |
index=order, | |
src=torch.arange(0, code.shape[1], device=order.device).repeat( | |
code.shape[0], 1 | |
), | |
) | |
if shuffle_orders: | |
perm = torch.randperm(code.shape[0]) | |
code = code[perm] | |
order = order[perm] | |
inverse = inverse[perm] | |
self["serialized_code"] = code | |
self["serialized_order"] = order | |
self["serialized_inverse"] = inverse | |
def sparsify(self, pad=96): | |
""" | |
Point Cloud Serialization | |
Point cloud is sparse, here we use "sparsify" to specifically refer to | |
preparing "spconv.SparseConvTensor" for SpConv. | |
relay on ["grid_coord" or "coord" + "grid_size", "batch", "feat"] | |
pad: padding sparse for sparse shape. | |
""" | |
assert {"feat", "batch"}.issubset(self.keys()) | |
if "grid_coord" not in self.keys(): | |
# if you don't want to operate GridSampling in data augmentation, | |
# please add the following augmentation into your pipline: | |
# dict(type="Copy", keys_dict={"grid_size": 0.01}), | |
# (adjust `grid_size` to what your want) | |
assert {"grid_size", "coord"}.issubset(self.keys()) | |
self["grid_coord"] = torch.div( | |
self.coord - self.coord.min(0)[0], self.grid_size, rounding_mode="trunc" | |
).int() | |
if "sparse_shape" in self.keys(): | |
sparse_shape = self.sparse_shape | |
else: | |
sparse_shape = torch.add( | |
torch.max(self.grid_coord, dim=0).values, pad | |
).tolist() | |
sparse_conv_feat = spconv.SparseConvTensor( | |
features=self.feat, | |
indices=torch.cat( | |
[self.batch.unsqueeze(-1).int(), self.grid_coord.int()], dim=1 | |
).contiguous(), | |
spatial_shape=sparse_shape, | |
batch_size=self.batch[-1].tolist() + 1, | |
) | |
self["sparse_shape"] = sparse_shape | |
self["sparse_conv_feat"] = sparse_conv_feat | |
class PointModule(nn.Module): | |
r"""PointModule | |
placeholder, all module subclass from this will take Point in PointSequential. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
class PointSequential(PointModule): | |
r"""A sequential container. | |
Modules will be added to it in the order they are passed in the constructor. | |
Alternatively, an ordered dict of modules can also be passed in. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
if len(args) == 1 and isinstance(args[0], OrderedDict): | |
for key, module in args[0].items(): | |
self.add_module(key, module) | |
else: | |
for idx, module in enumerate(args): | |
self.add_module(str(idx), module) | |
for name, module in kwargs.items(): | |
if sys.version_info < (3, 6): | |
raise ValueError("kwargs only supported in py36+") | |
if name in self._modules: | |
raise ValueError("name exists.") | |
self.add_module(name, module) | |
def __getitem__(self, idx): | |
if not (-len(self) <= idx < len(self)): | |
raise IndexError("index {} is out of range".format(idx)) | |
if idx < 0: | |
idx += len(self) | |
it = iter(self._modules.values()) | |
for i in range(idx): | |
next(it) | |
return next(it) | |
def __len__(self): | |
return len(self._modules) | |
def add(self, module, name=None): | |
if name is None: | |
name = str(len(self._modules)) | |
if name in self._modules: | |
raise KeyError("name exists") | |
self.add_module(name, module) | |
def forward(self, input): | |
for k, module in self._modules.items(): | |
# Point module | |
if isinstance(module, PointModule): | |
input = module(input) | |
# Spconv module | |
elif spconv.modules.is_spconv_module(module): | |
if isinstance(input, Point): | |
input.sparse_conv_feat = module(input.sparse_conv_feat) | |
input.feat = input.sparse_conv_feat.features | |
else: | |
input = module(input) | |
# PyTorch module | |
else: | |
if isinstance(input, Point): | |
input.feat = module(input.feat) | |
if "sparse_conv_feat" in input.keys(): | |
input.sparse_conv_feat = input.sparse_conv_feat.replace_feature( | |
input.feat | |
) | |
elif isinstance(input, spconv.SparseConvTensor): | |
if input.indices.shape[0] != 0: | |
input = input.replace_feature(module(input.features)) | |
else: | |
input = module(input) | |
return input | |
class PDNorm(PointModule): | |
def __init__( | |
self, | |
num_features, | |
norm_layer, | |
context_channels=256, | |
conditions=("ScanNet", "S3DIS", "Structured3D"), | |
decouple=True, | |
adaptive=False, | |
): | |
super().__init__() | |
self.conditions = conditions | |
self.decouple = decouple | |
self.adaptive = adaptive | |
if self.decouple: | |
self.norm = nn.ModuleList([norm_layer(num_features) for _ in conditions]) | |
else: | |
self.norm = norm_layer | |
if self.adaptive: | |
self.modulation = nn.Sequential( | |
nn.SiLU(), nn.Linear(context_channels, 2 * num_features, bias=True) | |
) | |
def forward(self, point): | |
assert {"feat", "condition"}.issubset(point.keys()) | |
if isinstance(point.condition, str): | |
condition = point.condition | |
else: | |
condition = point.condition[0] | |
if self.decouple: | |
assert condition in self.conditions | |
norm = self.norm[self.conditions.index(condition)] | |
else: | |
norm = self.norm | |
point.feat = norm(point.feat) | |
if self.adaptive: | |
assert "context" in point.keys() | |
shift, scale = self.modulation(point.context).chunk(2, dim=1) | |
point.feat = point.feat * (1.0 + scale) + shift | |
return point | |
class RPE(torch.nn.Module): | |
def __init__(self, patch_size, num_heads): | |
super().__init__() | |
self.patch_size = patch_size | |
self.num_heads = num_heads | |
self.pos_bnd = int((4 * patch_size) ** (1 / 3) * 2) | |
self.rpe_num = 2 * self.pos_bnd + 1 | |
self.rpe_table = torch.nn.Parameter(torch.zeros(3 * self.rpe_num, num_heads)) | |
torch.nn.init.trunc_normal_(self.rpe_table, std=0.02) | |
def forward(self, coord): | |
idx = ( | |
coord.clamp(-self.pos_bnd, self.pos_bnd) # clamp into bnd | |
+ self.pos_bnd # relative position to positive index | |
+ torch.arange(3, device=coord.device) * self.rpe_num # x, y, z stride | |
) | |
out = self.rpe_table.index_select(0, idx.reshape(-1)) | |
out = out.view(idx.shape + (-1,)).sum(3) | |
out = out.permute(0, 3, 1, 2) # (N, K, K, H) -> (N, H, K, K) | |
return out | |
class SerializedAttention(PointModule): | |
def __init__( | |
self, | |
channels, | |
num_heads, | |
patch_size, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
order_index=0, | |
enable_rpe=False, | |
enable_flash=True, | |
upcast_attention=True, | |
upcast_softmax=True, | |
): | |
super().__init__() | |
assert channels % num_heads == 0 | |
self.channels = channels | |
self.num_heads = num_heads | |
self.scale = qk_scale or (channels // num_heads) ** -0.5 | |
self.order_index = order_index | |
self.upcast_attention = upcast_attention | |
self.upcast_softmax = upcast_softmax | |
self.enable_rpe = enable_rpe | |
self.enable_flash = enable_flash | |
if enable_flash: | |
assert ( | |
enable_rpe is False | |
), "Set enable_rpe to False when enable Flash Attention" | |
assert ( | |
upcast_attention is False | |
), "Set upcast_attention to False when enable Flash Attention" | |
assert ( | |
upcast_softmax is False | |
), "Set upcast_softmax to False when enable Flash Attention" | |
#assert flash_attn is not None, "Make sure flash_attn is installed." | |
self.patch_size = patch_size | |
self.attn_drop = attn_drop | |
else: | |
# when disable flash attention, we still don't want to use mask | |
# consequently, patch size will auto set to the | |
# min number of patch_size_max and number of points | |
self.patch_size_max = patch_size | |
self.patch_size = 0 | |
self.attn_drop = torch.nn.Dropout(attn_drop) | |
self.qkv = torch.nn.Linear(channels, channels * 3, bias=qkv_bias) | |
self.proj = torch.nn.Linear(channels, channels) | |
self.proj_drop = torch.nn.Dropout(proj_drop) | |
self.softmax = torch.nn.Softmax(dim=-1) | |
self.rpe = RPE(patch_size, num_heads) if self.enable_rpe else None | |
def get_rel_pos(self, point, order): | |
K = self.patch_size | |
rel_pos_key = f"rel_pos_{self.order_index}" | |
if rel_pos_key not in point.keys(): | |
grid_coord = point.grid_coord[order] | |
grid_coord = grid_coord.reshape(-1, K, 3) | |
point[rel_pos_key] = grid_coord.unsqueeze(2) - grid_coord.unsqueeze(1) | |
return point[rel_pos_key] | |
def get_padding_and_inverse(self, point): | |
pad_key = "pad" | |
unpad_key = "unpad" | |
cu_seqlens_key = "cu_seqlens_key" | |
if ( | |
pad_key not in point.keys() | |
or unpad_key not in point.keys() | |
or cu_seqlens_key not in point.keys() | |
): | |
offset = point.offset | |
bincount = offset2bincount(offset) | |
bincount_pad = ( | |
torch.div( | |
bincount + self.patch_size - 1, | |
self.patch_size, | |
rounding_mode="trunc", | |
) | |
* self.patch_size | |
) | |
# only pad point when num of points larger than patch_size | |
mask_pad = bincount > self.patch_size | |
bincount_pad = ~mask_pad * bincount + mask_pad * bincount_pad | |
_offset = nn.functional.pad(offset, (1, 0)) | |
_offset_pad = nn.functional.pad(torch.cumsum(bincount_pad, dim=0), (1, 0)) | |
pad = torch.arange(_offset_pad[-1], device=offset.device) | |
unpad = torch.arange(_offset[-1], device=offset.device) | |
cu_seqlens = [] | |
for i in range(len(offset)): | |
unpad[_offset[i] : _offset[i + 1]] += _offset_pad[i] - _offset[i] | |
if bincount[i] != bincount_pad[i]: | |
pad[ | |
_offset_pad[i + 1] | |
- self.patch_size | |
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
] = pad[ | |
_offset_pad[i + 1] | |
- 2 * self.patch_size | |
+ (bincount[i] % self.patch_size) : _offset_pad[i + 1] | |
- self.patch_size | |
] | |
pad[_offset_pad[i] : _offset_pad[i + 1]] -= _offset_pad[i] - _offset[i] | |
cu_seqlens.append( | |
torch.arange( | |
_offset_pad[i], | |
_offset_pad[i + 1], | |
step=self.patch_size, | |
dtype=torch.int32, | |
device=offset.device, | |
) | |
) | |
point[pad_key] = pad | |
point[unpad_key] = unpad | |
point[cu_seqlens_key] = nn.functional.pad( | |
torch.concat(cu_seqlens), (0, 1), value=_offset_pad[-1] | |
) | |
return point[pad_key], point[unpad_key], point[cu_seqlens_key] | |
def forward(self, point): | |
if not self.enable_flash: | |
self.patch_size = min( | |
offset2bincount(point.offset).min().tolist(), self.patch_size_max | |
) | |
H = self.num_heads | |
K = self.patch_size | |
C = self.channels | |
pad, unpad, cu_seqlens = self.get_padding_and_inverse(point) | |
order = point.serialized_order[self.order_index][pad] | |
inverse = unpad[point.serialized_inverse[self.order_index]] | |
# padding and reshape feat and batch for serialized point patch | |
qkv = self.qkv(point.feat)[order] | |
if not self.enable_flash: | |
# encode and reshape qkv: (N', K, 3, H, C') => (3, N', H, K, C') | |
q, k, v = ( | |
qkv.reshape(-1, K, 3, H, C // H).permute(2, 0, 3, 1, 4).unbind(dim=0) | |
) | |
# attn | |
if self.upcast_attention: | |
q = q.float() | |
k = k.float() | |
attn = (q * self.scale) @ k.transpose(-2, -1) # (N', H, K, K) | |
if self.enable_rpe: | |
attn = attn + self.rpe(self.get_rel_pos(point, order)) | |
if self.upcast_softmax: | |
attn = attn.float() | |
attn = self.softmax(attn) | |
attn = self.attn_drop(attn).to(qkv.dtype) | |
feat = (attn @ v).transpose(1, 2).reshape(-1, C) | |
else: | |
feat = flash_attn.flash_attn_varlen_qkvpacked_func( | |
qkv.half().reshape(-1, 3, H, C // H), | |
cu_seqlens, | |
max_seqlen=self.patch_size, | |
dropout_p=self.attn_drop if self.training else 0, | |
softmax_scale=self.scale, | |
).reshape(-1, C) | |
feat = feat.to(qkv.dtype) | |
feat = feat[inverse] | |
# ffn | |
feat = self.proj(feat) | |
feat = self.proj_drop(feat) | |
point.feat = feat | |
return point | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels=None, | |
out_channels=None, | |
act_layer=nn.GELU, | |
drop=0.0, | |
): | |
super().__init__() | |
out_channels = out_channels or in_channels | |
hidden_channels = hidden_channels or in_channels | |
self.fc1 = nn.Linear(in_channels, hidden_channels) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(hidden_channels, out_channels) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Block(PointModule): | |
def __init__( | |
self, | |
channels, | |
num_heads, | |
patch_size=48, | |
mlp_ratio=4.0, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
drop_path=0.0, | |
norm_layer=nn.LayerNorm, | |
act_layer=nn.GELU, | |
pre_norm=True, | |
order_index=0, | |
cpe_indice_key=None, | |
enable_rpe=False, | |
enable_flash=True, | |
upcast_attention=True, | |
upcast_softmax=True, | |
): | |
super().__init__() | |
self.channels = channels | |
self.pre_norm = pre_norm | |
self.cpe = PointSequential( | |
spconv.SubMConv3d( | |
channels, | |
channels, | |
kernel_size=3, | |
bias=True, | |
indice_key=cpe_indice_key, | |
), | |
nn.Linear(channels, channels), | |
norm_layer(channels), | |
) | |
self.norm1 = PointSequential(norm_layer(channels)) | |
self.attn = SerializedAttention( | |
channels=channels, | |
patch_size=patch_size, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
order_index=order_index, | |
enable_rpe=enable_rpe, | |
enable_flash=enable_flash, | |
upcast_attention=upcast_attention, | |
upcast_softmax=upcast_softmax, | |
) | |
self.norm2 = PointSequential(norm_layer(channels)) | |
self.mlp = PointSequential( | |
MLP( | |
in_channels=channels, | |
hidden_channels=int(channels * mlp_ratio), | |
out_channels=channels, | |
act_layer=act_layer, | |
drop=proj_drop, | |
) | |
) | |
self.drop_path = PointSequential( | |
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
) | |
def forward(self, point: Point): | |
shortcut = point.feat | |
point = self.cpe(point) | |
point.feat = shortcut + point.feat | |
shortcut = point.feat | |
if self.pre_norm: | |
point = self.norm1(point) | |
point = self.drop_path(self.attn(point)) | |
point.feat = shortcut + point.feat | |
if not self.pre_norm: | |
point = self.norm1(point) | |
shortcut = point.feat | |
if self.pre_norm: | |
point = self.norm2(point) | |
point = self.drop_path(self.mlp(point)) | |
point.feat = shortcut + point.feat | |
if not self.pre_norm: | |
point = self.norm2(point) | |
point.sparse_conv_feat = point.sparse_conv_feat.replace_feature(point.feat) | |
#point.sparse_conv_feat.replace_feature(point.feat) old version | |
return point | |
class SerializedPooling(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
stride=2, | |
norm_layer=None, | |
act_layer=None, | |
reduce="max", | |
shuffle_orders=True, | |
traceable=True, # record parent and cluster | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
assert stride == 2 ** (math.ceil(stride) - 1).bit_length() # 2, 4, 8 | |
# TODO: add support to grid pool (any stride) | |
self.stride = stride | |
assert reduce in ["sum", "mean", "min", "max"] | |
self.reduce = reduce | |
self.shuffle_orders = shuffle_orders | |
self.traceable = traceable | |
self.proj = nn.Linear(in_channels, out_channels) | |
if norm_layer is not None: | |
self.norm = PointSequential(norm_layer(out_channels)) | |
if act_layer is not None: | |
self.act = PointSequential(act_layer()) | |
def forward(self, point: Point): | |
pooling_depth = (math.ceil(self.stride) - 1).bit_length() | |
if pooling_depth > point.serialized_depth: | |
pooling_depth = 0 | |
assert { | |
"serialized_code", | |
"serialized_order", | |
"serialized_inverse", | |
"serialized_depth", | |
}.issubset( | |
point.keys() | |
), "Run point.serialization() point cloud before SerializedPooling" | |
code = point.serialized_code >> pooling_depth * 3 # if pooling depth=1, right shift 3 i.e. divide by 8 | |
# this is divide by 2^(pooling_depth+2) i.e. 4*stride | |
# this is because it's 3d, shift index by 8 means half | |
code_, cluster, counts = torch.unique( | |
code[0], | |
sorted=True, | |
return_inverse=True, | |
return_counts=True, | |
) | |
# indices of point sorted by cluster, for torch_scatter.segment_csr | |
_, indices = torch.sort(cluster) | |
# index pointer for sorted point, for torch_scatter.segment_csr | |
idx_ptr = torch.cat([counts.new_zeros(1), torch.cumsum(counts, dim=0)]) | |
# head_indices of each cluster, for reduce attr e.g. code, batch | |
head_indices = indices[idx_ptr[:-1]] | |
# generate down code, order, inverse | |
code = code[:, head_indices] # these are the unique entries | |
order = torch.argsort(code) | |
inverse = torch.zeros_like(order).scatter_( | |
dim=1, | |
index=order, | |
src=torch.arange(0, code.shape[1], device=order.device).repeat( | |
code.shape[0], 1 | |
), | |
) | |
if self.shuffle_orders: | |
perm = torch.randperm(code.shape[0]) | |
code = code[perm] | |
order = order[perm] | |
inverse = inverse[perm] | |
# coordinate is also halved - the space is sparser | |
# collect information | |
point_dict = Dict( | |
feat=torch_scatter.segment_csr( | |
self.proj(point.feat)[indices], idx_ptr, reduce=self.reduce | |
), | |
coord=torch_scatter.segment_csr( | |
point.coord[indices], idx_ptr, reduce="mean" | |
), | |
grid_coord=point.grid_coord[head_indices] >> pooling_depth, | |
serialized_code=code, | |
serialized_order=order, | |
serialized_inverse=inverse, | |
serialized_depth=point.serialized_depth - pooling_depth, | |
batch=point.batch[head_indices], | |
) | |
if "condition" in point.keys(): | |
point_dict["condition"] = point.condition | |
if "context" in point.keys(): | |
point_dict["context"] = point.context | |
if self.traceable: | |
point_dict["pooling_inverse"] = cluster | |
point_dict["pooling_parent"] = point | |
point = Point(point_dict) | |
if self.norm is not None: | |
point = self.norm(point) | |
if self.act is not None: | |
point = self.act(point) | |
point.sparsify() | |
return point | |
class SerializedUnpooling(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
skip_channels, | |
out_channels, | |
norm_layer=None, | |
act_layer=None, | |
traceable=False, # record parent and cluster | |
): | |
super().__init__() | |
self.proj = PointSequential(nn.Linear(in_channels, out_channels)) | |
self.proj_skip = PointSequential(nn.Linear(skip_channels, out_channels)) | |
if norm_layer is not None: | |
self.proj.add(norm_layer(out_channels)) | |
self.proj_skip.add(norm_layer(out_channels)) | |
if act_layer is not None: | |
self.proj.add(act_layer()) | |
self.proj_skip.add(act_layer()) | |
self.traceable = traceable | |
def forward(self, point): | |
assert "pooling_parent" in point.keys() | |
assert "pooling_inverse" in point.keys() | |
parent = point.pop("pooling_parent") | |
inverse = point.pop("pooling_inverse") | |
point = self.proj(point) | |
parent = self.proj_skip(parent) | |
parent.feat = parent.feat + point.feat[inverse] | |
if self.traceable: | |
parent["unpooling_parent"] = point | |
return parent | |
class Embedding(PointModule): | |
def __init__( | |
self, | |
in_channels, | |
embed_channels, | |
norm_layer=None, | |
act_layer=None, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.embed_channels = embed_channels | |
# TODO: check remove spconv | |
self.stem = PointSequential( | |
conv=spconv.SubMConv3d( | |
in_channels, | |
embed_channels, | |
kernel_size=5, | |
padding=1, | |
bias=False, | |
indice_key="stem", | |
) | |
) | |
if norm_layer is not None: | |
self.stem.add(norm_layer(embed_channels), name="norm") | |
if act_layer is not None: | |
self.stem.add(act_layer(), name="act") | |
def forward(self, point: Point): | |
point = self.stem(point) | |
return point | |
class PointTransformerV3(PointModule): | |
def __init__( | |
self, | |
in_channels=6, | |
order=("z", "z-trans", "hilbert", "hilbert-trans"), | |
stride=(2, 2, 2, 2), | |
enc_depths=(2, 2, 2, 6, 2), | |
enc_channels=(32, 64, 128, 256, 512), | |
enc_num_head=(2, 4, 8, 16, 32), | |
enc_patch_size=(1024, 1024, 1024, 1024, 1024), | |
dec_depths=(2, 2, 2, 2), | |
dec_channels=(64, 64, 128, 256), | |
dec_num_head=(4, 4, 8, 16), | |
dec_patch_size=(1024, 1024, 1024, 1024), | |
mlp_ratio=4, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_drop=0.0, | |
proj_drop=0.0, | |
drop_path=0.3, | |
pre_norm=True, | |
shuffle_orders=True, | |
enable_rpe=False, | |
enable_flash=False,#True, | |
upcast_attention=False, | |
upcast_softmax=False, | |
cls_mode=False, | |
pdnorm_bn=False, | |
pdnorm_ln=False, | |
pdnorm_decouple=True, | |
pdnorm_adaptive=False, | |
pdnorm_affine=True, | |
pdnorm_conditions=("ScanNet", "S3DIS", "Structured3D"), | |
): | |
super().__init__() | |
self.num_stages = len(enc_depths) | |
self.order = [order] if isinstance(order, str) else order | |
self.cls_mode = cls_mode | |
self.shuffle_orders = shuffle_orders | |
assert self.num_stages == len(stride) + 1 | |
assert self.num_stages == len(enc_depths) | |
assert self.num_stages == len(enc_channels) | |
assert self.num_stages == len(enc_num_head) | |
assert self.num_stages == len(enc_patch_size) | |
assert self.cls_mode or self.num_stages == len(dec_depths) + 1 | |
assert self.cls_mode or self.num_stages == len(dec_channels) + 1 | |
assert self.cls_mode or self.num_stages == len(dec_num_head) + 1 | |
assert self.cls_mode or self.num_stages == len(dec_patch_size) + 1 | |
# norm layers | |
if pdnorm_bn: | |
bn_layer = partial( | |
PDNorm, | |
norm_layer=partial( | |
nn.BatchNorm1d, eps=1e-3, momentum=0.01, affine=pdnorm_affine | |
), | |
conditions=pdnorm_conditions, | |
decouple=pdnorm_decouple, | |
adaptive=pdnorm_adaptive, | |
) | |
else: | |
bn_layer = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) | |
if pdnorm_ln: | |
ln_layer = partial( | |
PDNorm, | |
norm_layer=partial(nn.LayerNorm, elementwise_affine=pdnorm_affine), | |
conditions=pdnorm_conditions, | |
decouple=pdnorm_decouple, | |
adaptive=pdnorm_adaptive, | |
) | |
else: | |
ln_layer = nn.LayerNorm | |
# activation layers | |
act_layer = nn.GELU | |
self.embedding = Embedding( | |
in_channels=in_channels, | |
embed_channels=enc_channels[0], | |
norm_layer=bn_layer, | |
act_layer=act_layer, | |
) | |
# encoder | |
enc_drop_path = [ | |
x.item() for x in torch.linspace(0, drop_path, sum(enc_depths)) | |
] | |
self.enc = PointSequential() | |
for s in range(self.num_stages): | |
enc_drop_path_ = enc_drop_path[ | |
sum(enc_depths[:s]) : sum(enc_depths[: s + 1]) | |
] | |
enc = PointSequential() | |
if s > 0: | |
enc.add( | |
SerializedPooling( | |
in_channels=enc_channels[s - 1], | |
out_channels=enc_channels[s], | |
stride=stride[s - 1], | |
norm_layer=bn_layer, | |
act_layer=act_layer, | |
), | |
name="down", | |
) | |
for i in range(enc_depths[s]): | |
enc.add( | |
Block( | |
channels=enc_channels[s], | |
num_heads=enc_num_head[s], | |
patch_size=enc_patch_size[s], | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
drop_path=enc_drop_path_[i], | |
norm_layer=ln_layer, | |
act_layer=act_layer, | |
pre_norm=pre_norm, | |
order_index=i % len(self.order), | |
cpe_indice_key=f"stage{s}", | |
enable_rpe=enable_rpe, | |
enable_flash=enable_flash, | |
upcast_attention=upcast_attention, | |
upcast_softmax=upcast_softmax, | |
), | |
name=f"block{i}", | |
) | |
if len(enc) != 0: | |
self.enc.add(module=enc, name=f"enc{s}") | |
# decoder | |
if not self.cls_mode: | |
dec_drop_path = [ | |
x.item() for x in torch.linspace(0, drop_path, sum(dec_depths)) | |
] | |
self.dec = PointSequential() | |
dec_channels = list(dec_channels) + [enc_channels[-1]] | |
for s in reversed(range(self.num_stages - 1)): | |
dec_drop_path_ = dec_drop_path[ | |
sum(dec_depths[:s]) : sum(dec_depths[: s + 1]) | |
] | |
dec_drop_path_.reverse() | |
dec = PointSequential() | |
dec.add( | |
SerializedUnpooling( | |
in_channels=dec_channels[s + 1], | |
skip_channels=enc_channels[s], | |
out_channels=dec_channels[s], | |
norm_layer=bn_layer, | |
act_layer=act_layer, | |
), | |
name="up", | |
) | |
for i in range(dec_depths[s]): | |
dec.add( | |
Block( | |
channels=dec_channels[s], | |
num_heads=dec_num_head[s], | |
patch_size=dec_patch_size[s], | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
drop_path=dec_drop_path_[i], | |
norm_layer=ln_layer, | |
act_layer=act_layer, | |
pre_norm=pre_norm, | |
order_index=i % len(self.order), | |
cpe_indice_key=f"stage{s}", | |
enable_rpe=enable_rpe, | |
enable_flash=enable_flash, | |
upcast_attention=upcast_attention, | |
upcast_softmax=upcast_softmax, | |
), | |
name=f"block{i}", | |
) | |
self.dec.add(module=dec, name=f"dec{s}") | |
def forward(self, data_dict): | |
""" | |
A data_dict is a dictionary containing properties of a batched point cloud. | |
It should contain the following properties for PTv3: | |
1. "feat": feature of point cloud | |
2. "grid_coord": discrete coordinate after grid sampling (voxelization) or "coord" + "grid_size" | |
3. "offset" or "batch": https://github.com/Pointcept/Pointcept?tab=readme-ov-file#offset | |
""" | |
point = Point(data_dict) | |
point.serialization(order=self.order, shuffle_orders=self.shuffle_orders) | |
point.sparsify() | |
point = self.embedding(point) | |
point = self.enc(point) #23,512 | |
if not self.cls_mode: | |
point = self.dec(point) #n_pts, 64 | |
return point | |
class PointSemSeg(nn.Module): | |
def __init__(self, args, dim_output, emb=64, init_logit_scale=np.log(1 / 0.07)): | |
super().__init__() | |
self.dim_output = dim_output | |
# define the extractor | |
self.extractor = PointTransformerV3() # this outputs a 64-dim feature per point | |
# define logit scale | |
self.ln_logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) | |
self.fc1 = nn.Linear(emb, emb) | |
self.fc2 = nn.Linear(emb, emb) | |
self.fc3 = nn.Linear(emb, emb) | |
self.fc4 = nn.Linear(emb, dim_output) | |
def distillation_head(self, x): | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = F.relu(self.fc3(x)) | |
x = self.fc4(x) | |
return x | |
def freeze_extractor(self): | |
for param in self.extractor.parameters(): | |
param.requires_grad = False | |
def forward(self, x, return_pts_feat=False): | |
pointall = self.extractor(x) | |
feature = pointall["feat"] #[n_pts_cur_batch, 64] | |
x = self.distillation_head(feature) #[n_pts_cur_batch, dim_out] | |
if return_pts_feat: | |
return x, feature | |
else: | |
return x | |
class Find3D(nn.Module, PyTorchModelHubMixin): | |
def __init__(self, dim_output, emb=64, init_logit_scale=np.log(1 / 0.07)): | |
super().__init__() | |
self.dim_output = dim_output | |
# define the extractor | |
self.extractor = PointTransformerV3() # this outputs a 64-dim feature per point | |
# define logit scale | |
self.ln_logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) | |
self.fc1 = nn.Linear(emb, emb) | |
self.fc2 = nn.Linear(emb, emb) | |
self.fc3 = nn.Linear(emb, emb) | |
self.fc4 = nn.Linear(emb, dim_output) | |
def distillation_head(self, x): | |
x = F.relu(self.fc1(x)) | |
x = F.relu(self.fc2(x)) | |
x = F.relu(self.fc3(x)) | |
x = self.fc4(x) | |
return x | |
def freeze_extractor(self): | |
for param in self.extractor.parameters(): | |
param.requires_grad = False | |
def forward(self, x, return_pts_feat=False): | |
pointall = self.extractor(x) | |
feature = pointall["feat"] #[n_pts_cur_batch, 64] | |
x = self.distillation_head(feature) #[n_pts_cur_batch, dim_out] | |
if return_pts_feat: | |
return x, feature | |
else: | |
return x |