Spaces:
Runtime error
Runtime error
from torch import nn | |
import segmentation_models_pytorch as smp | |
from segmentation_models_pytorch.losses import DiceLoss | |
ENCODER = 'timm-efficientnet-b0' | |
WEIGHTS = 'imagenet' | |
class SegmentationModel(nn.Module): | |
def __init__(self): | |
super(SegmentationModel, self).__init__() | |
self.arc = smp.Unet( | |
encoder_name = ENCODER, | |
encoder_weights = WEIGHTS, | |
in_channels = 3, | |
classes = 1, | |
activation = None | |
) | |
def forward(self, images, masks = None): | |
logits = self.arc(images) | |
if masks != None: | |
loss1 = DiceLoss(mode='binary')(logits, masks) | |
loss2 = nn.BCEWithLogitsLoss()(logits, masks) | |
return logits, loss1 + loss2 | |
return logits | |