Spaces:
Running
on
Zero
Running
on
Zero
import torch.nn as nn | |
import torch_scatter | |
from pointcept.models.losses import build_criteria | |
from pointcept.models.utils.structure import Point | |
from .builder import MODELS, build_model | |
class DefaultSegmentor(nn.Module): | |
def __init__(self, backbone=None, criteria=None): | |
super().__init__() | |
self.backbone = build_model(backbone) | |
self.criteria = build_criteria(criteria) | |
def forward(self, input_dict): | |
if "condition" in input_dict.keys(): | |
# PPT (https://arxiv.org/abs/2308.09718) | |
# currently, only support one batch one condition | |
input_dict["condition"] = input_dict["condition"][0] | |
seg_logits = self.backbone(input_dict) | |
# train | |
if self.training: | |
loss = self.criteria(seg_logits, input_dict["segment"]) | |
return dict(loss=loss) | |
# eval | |
elif "segment" in input_dict.keys(): | |
loss = self.criteria(seg_logits, input_dict["segment"]) | |
return dict(loss=loss, seg_logits=seg_logits) | |
# test | |
else: | |
return dict(seg_logits=seg_logits) | |
class DefaultSegmentorV2(nn.Module): | |
def __init__( | |
self, | |
num_classes, | |
backbone_out_channels, | |
backbone=None, | |
criteria=None, | |
): | |
super().__init__() | |
self.seg_head = ( | |
nn.Linear(backbone_out_channels, num_classes) | |
if num_classes > 0 | |
else nn.Identity() | |
) | |
self.backbone = build_model(backbone) | |
self.criteria = build_criteria(criteria) | |
def forward(self, input_dict): | |
point = Point(input_dict) | |
point = self.backbone(point) | |
# Backbone added after v1.5.0 return Point instead of feat and use DefaultSegmentorV2 | |
# TODO: remove this part after make all backbone return Point only. | |
if isinstance(point, Point): | |
feat = point.feat | |
else: | |
feat = point | |
seg_logits = self.seg_head(feat) | |
# train | |
if self.training: | |
loss = self.criteria(seg_logits, input_dict["segment"]) | |
return dict(loss=loss) | |
# eval | |
elif "segment" in input_dict.keys(): | |
loss = self.criteria(seg_logits, input_dict["segment"]) | |
return dict(loss=loss, seg_logits=seg_logits) | |
# test | |
else: | |
return dict(seg_logits=seg_logits) | |
class DefaultClassifier(nn.Module): | |
def __init__( | |
self, | |
backbone=None, | |
criteria=None, | |
num_classes=40, | |
backbone_embed_dim=256, | |
): | |
super().__init__() | |
self.backbone = build_model(backbone) | |
self.criteria = build_criteria(criteria) | |
self.num_classes = num_classes | |
self.backbone_embed_dim = backbone_embed_dim | |
self.cls_head = nn.Sequential( | |
nn.Linear(backbone_embed_dim, 256), | |
nn.BatchNorm1d(256), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=0.5), | |
nn.Linear(256, 128), | |
nn.BatchNorm1d(128), | |
nn.ReLU(inplace=True), | |
nn.Dropout(p=0.5), | |
nn.Linear(128, num_classes), | |
) | |
def forward(self, input_dict): | |
point = Point(input_dict) | |
point = self.backbone(point) | |
# Backbone added after v1.5.0 return Point instead of feat | |
# And after v1.5.0 feature aggregation for classification operated in classifier | |
# TODO: remove this part after make all backbone return Point only. | |
if isinstance(point, Point): | |
point.feat = torch_scatter.segment_csr( | |
src=point.feat, | |
indptr=nn.functional.pad(point.offset, (1, 0)), | |
reduce="mean", | |
) | |
feat = point.feat | |
else: | |
feat = point | |
cls_logits = self.cls_head(feat) | |
if self.training: | |
loss = self.criteria(cls_logits, input_dict["category"]) | |
return dict(loss=loss) | |
elif "category" in input_dict.keys(): | |
loss = self.criteria(cls_logits, input_dict["category"]) | |
return dict(loss=loss, cls_logits=cls_logits) | |
else: | |
return dict(cls_logits=cls_logits) | |