English
File size: 1,048 Bytes
63a9590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
# @Time    : 2024/7/21 下午5:11
# @Author  : xiaoshun
# @Email   : [email protected]
# @File    : scnn.py
# @Software: PyCharm

# 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0924271624000352?via%3Dihub#fn1

import torch
import torch.nn as nn
import torch.nn.functional as F


class SCNN(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, dropout_p=0.5):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1)
        self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1)
        self.conv3 = nn.Conv2d(num_classes, num_classes, kernel_size=3, padding=1)
        self.dropout = nn.Dropout2d(p=dropout_p)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return x


if __name__ == '__main__':
    model = SCNN(num_classes=7)
    fake_img = torch.randn((2, 3, 224, 224))
    out = model(fake_img)
    print(out.shape)
    # torch.Size([2, 7, 224, 224])