import torch.nn as nn import torch_scatter from pointcept.models.losses import build_criteria from pointcept.models.utils.structure import Point from .builder import MODELS, build_model @MODELS.register_module() class DefaultSegmentor(nn.Module): def __init__(self, backbone=None, criteria=None): super().__init__() self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) def forward(self, input_dict): if "condition" in input_dict.keys(): # PPT (https://arxiv.org/abs/2308.09718) # currently, only support one batch one condition input_dict["condition"] = input_dict["condition"][0] seg_logits = self.backbone(input_dict) # train if self.training: loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss) # eval elif "segment" in input_dict.keys(): loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss, seg_logits=seg_logits) # test else: return dict(seg_logits=seg_logits) @MODELS.register_module() class DefaultSegmentorV2(nn.Module): def __init__( self, num_classes, backbone_out_channels, backbone=None, criteria=None, ): super().__init__() self.seg_head = ( nn.Linear(backbone_out_channels, num_classes) if num_classes > 0 else nn.Identity() ) self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) def forward(self, input_dict): point = Point(input_dict) point = self.backbone(point) # Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2 # TODO: remove this part after make all backbone return Point only. if isinstance(point, Point): feat = point.feat else: feat = point seg_logits = self.seg_head(feat) # train if self.training: loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss) # eval elif "segment" in input_dict.keys(): loss = self.criteria(seg_logits, input_dict["segment"]) return dict(loss=loss, seg_logits=seg_logits) # test else: return dict(seg_logits=seg_logits) @MODELS.register_module() class DefaultClassifier(nn.Module): def __init__( self, backbone=None, criteria=None, num_classes=40, backbone_embed_dim=256, ): super().__init__() self.backbone = build_model(backbone) self.criteria = build_criteria(criteria) self.num_classes = num_classes self.backbone_embed_dim = backbone_embed_dim self.cls_head = nn.Sequential( nn.Linear(backbone_embed_dim, 256), nn.BatchNorm1d(256), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Dropout(p=0.5), nn.Linear(128, num_classes), ) def forward(self, input_dict): point = Point(input_dict) point = self.backbone(point) # Backbone added after v1.5.0 return Point instead of feat # And after v1.5.0 feature aggregation for classification operated in classifier # TODO: remove this part after make all backbone return Point only. if isinstance(point, Point): point.feat = torch_scatter.segment_csr( src=point.feat, indptr=nn.functional.pad(point.offset, (1, 0)), reduce="mean", ) feat = point.feat else: feat = point cls_logits = self.cls_head(feat) if self.training: loss = self.criteria(cls_logits, input_dict["category"]) return dict(loss=loss) elif "category" in input_dict.keys(): loss = self.criteria(cls_logits, input_dict["category"]) return dict(loss=loss, cls_logits=cls_logits) else: return dict(cls_logits=cls_logits)