soumickmj commited on
Commit
3f260ac
·
verified ·
1 Parent(s): 59c10a4

Upload UNetMSS3D

Browse files
Files changed (5) hide show
  1. UNetConfigs.py +30 -0
  2. UNets.py +26 -0
  3. config.json +16 -0
  4. model.safetensors +3 -0
  5. unet3d.py +310 -0
UNetConfigs.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class UNet3DConfig(PretrainedConfig):
5
+ model_type = "UNet"
6
+ def __init__(
7
+ self,
8
+ in_ch=1,
9
+ out_ch=1,
10
+ init_features=64,
11
+ **kwargs):
12
+ self.in_ch = in_ch
13
+ self.out_ch = out_ch
14
+ self.init_features = init_features
15
+ super().__init__(**kwargs)
16
+
17
+ class UNetMSS3DConfig(PretrainedConfig):
18
+ model_type = "UNetMSS"
19
+ def __init__(
20
+ self,
21
+ in_ch=1,
22
+ out_ch=1,
23
+ output_dir=None,
24
+ init_features=64,
25
+ **kwargs):
26
+ self.in_ch = in_ch
27
+ self.out_ch = out_ch
28
+ self.output_dir = output_dir
29
+ self.init_features = init_features
30
+ super().__init__(**kwargs)
UNets.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .unet3d import U_Net, U_Net_DeepSup
3
+ from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig
4
+
5
+ class UNet3D(PreTrainedModel):
6
+ config_class = UNet3DConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = U_Net(
10
+ in_ch=config.in_ch,
11
+ out_ch=config.out_ch,
12
+ init_features=config.init_features)
13
+ def forward(self, x):
14
+ return self.model(x)
15
+
16
+ class UNetMSS3D(PreTrainedModel):
17
+ config_class = UNetMSS3DConfig
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.model = U_Net_DeepSup(
21
+ in_ch=config.in_ch,
22
+ out_ch=config.out_ch,
23
+ output_dir=config.output_dir,
24
+ init_features=config.init_features)
25
+ def forward(self, x):
26
+ return self.model(x)
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UNetMSS3D"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "UNetConfigs.UNetMSS3DConfig",
7
+ "AutoModel": "UNets.UNetMSS3D"
8
+ },
9
+ "in_ch": 1,
10
+ "init_features": 64,
11
+ "model_type": "UNetMSS",
12
+ "out_ch": 1,
13
+ "output_dir": null,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.44.2"
16
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:283da7d3057bb1c50b293af02cdb4591a06b3274e715124c7cd38a5395e4f2a6
3
+ size 414260220
unet3d.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # from __future__ import print_function, division
4
+ '''
5
+
6
+ This script is from the DS6 (https://github.com/soumickmj/DS6/blob/main/Models/unet3d.py),
7
+ and then the SPOCKMIP repository (https://github.com/soumickmj/SPOCKMIP/blob/master/Models/unet3d.py)
8
+
9
+ Part of the DS6 paper:
10
+ "DS6, Deformation-Aware Semi-Supervised Learning: Application to Small Vessel Segmentation with Noisy Training Data"
11
+ (https://doi.org/10.3390/jimaging8100259)
12
+
13
+ and the SPOCKMIP paper:
14
+ "SPOCKMIP: Segmentation of Vessels in MRAs with Enhanced Continuity using Maximum Intensity Projection as Loss"
15
+ (https://doi.org/10.48550/arXiv.2407.08655)
16
+
17
+ '''
18
+
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.utils.data
23
+ import os
24
+
25
+ __author__ = "Kartik Prabhu, Mahantesh Pattadkal, and Soumick Chatterjee"
26
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
27
+ __credits__ = ["Kartik Prabhu", "Mahantesh Pattadkal", "Soumick Chatterjee"]
28
+ __license__ = "GPL"
29
+ __version__ = "1.0.0"
30
+ __maintainer__ = "Soumick Chatterjee"
31
+ __email__ = "[email protected]"
32
+ __status__ = "Production"
33
+
34
+ class conv_block(nn.Module):
35
+ """
36
+ Convolution Block
37
+ """
38
+
39
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
40
+ super(conv_block, self).__init__()
41
+ self.conv = nn.Sequential(
42
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
43
+ stride=stride, padding=padding, bias=bias),
44
+ nn.BatchNorm3d(num_features=out_channels),
45
+ nn.LeakyReLU(inplace=True),
46
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
47
+ stride=stride, padding=padding, bias=bias),
48
+ nn.BatchNorm3d(num_features=out_channels),
49
+ nn.LeakyReLU(inplace=True)
50
+ )
51
+
52
+ def forward(self, x):
53
+ x = self.conv(x)
54
+ return x
55
+
56
+
57
+ class up_conv(nn.Module):
58
+ """
59
+ Up Convolution Block
60
+ """
61
+
62
+ # def __init__(self, in_ch, out_ch):
63
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
64
+ super(up_conv, self).__init__()
65
+ self.up = nn.Sequential(
66
+ nn.Upsample(scale_factor=2),
67
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
68
+ stride=stride, padding=padding, bias=bias),
69
+ nn.BatchNorm3d(num_features=out_channels),
70
+ nn.LeakyReLU(inplace=True))
71
+
72
+ def forward(self, x):
73
+ x = self.up(x)
74
+ return x
75
+
76
+
77
+ class U_Net(nn.Module):
78
+ """
79
+ UNet - Basic Implementation
80
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
81
+ Paper : https://arxiv.org/abs/1505.04597
82
+ """
83
+
84
+ def __init__(self, in_ch=1, out_ch=1, init_features=64):
85
+ super(U_Net, self).__init__()
86
+
87
+ n1 = init_features
88
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
89
+
90
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
91
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
92
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
93
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
94
+
95
+ self.Conv1 = conv_block(in_ch, filters[0])
96
+ self.Conv2 = conv_block(filters[0], filters[1])
97
+ self.Conv3 = conv_block(filters[1], filters[2])
98
+ self.Conv4 = conv_block(filters[2], filters[3])
99
+ self.Conv5 = conv_block(filters[3], filters[4])
100
+
101
+ self.Up5 = up_conv(filters[4], filters[3])
102
+ self.Up_conv5 = conv_block(filters[4], filters[3])
103
+
104
+ self.Up4 = up_conv(filters[3], filters[2])
105
+ self.Up_conv4 = conv_block(filters[3], filters[2])
106
+
107
+ self.Up3 = up_conv(filters[2], filters[1])
108
+ self.Up_conv3 = conv_block(filters[2], filters[1])
109
+
110
+ self.Up2 = up_conv(filters[1], filters[0])
111
+ self.Up_conv2 = conv_block(filters[1], filters[0])
112
+
113
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
114
+
115
+ # self.active = torch.nn.Sigmoid()
116
+
117
+ def forward(self, x):
118
+ # print("unet")
119
+ # print(x.shape)
120
+ # print(padded.shape)
121
+
122
+ e1 = self.Conv1(x)
123
+ # print("conv1:")
124
+ # print(e1.shape)
125
+
126
+ e2 = self.Maxpool1(e1)
127
+ e2 = self.Conv2(e2)
128
+ # print("conv2:")
129
+ # print(e2.shape)
130
+
131
+ e3 = self.Maxpool2(e2)
132
+ e3 = self.Conv3(e3)
133
+ # print("conv3:")
134
+ # print(e3.shape)
135
+
136
+ e4 = self.Maxpool3(e3)
137
+ e4 = self.Conv4(e4)
138
+ # print("conv4:")
139
+ # print(e4.shape)
140
+
141
+ e5 = self.Maxpool4(e4)
142
+ e5 = self.Conv5(e5)
143
+ # print("conv5:")
144
+ # print(e5.shape)
145
+
146
+ d5 = self.Up5(e5)
147
+ # print("d5:")
148
+ # print(d5.shape)
149
+ # print("e4:")
150
+ # print(e4.shape)
151
+ d5 = torch.cat((e4, d5), dim=1)
152
+ d5 = self.Up_conv5(d5)
153
+ # print("upconv5:")
154
+ # print(d5.size)
155
+
156
+ d4 = self.Up4(d5)
157
+ # print("d4:")
158
+ # print(d4.shape)
159
+ d4 = torch.cat((e3, d4), dim=1)
160
+ d4 = self.Up_conv4(d4)
161
+ # print("upconv4:")
162
+ # print(d4.shape)
163
+ d3 = self.Up3(d4)
164
+ d3 = torch.cat((e2, d3), dim=1)
165
+ d3 = self.Up_conv3(d3)
166
+ # print("upconv3:")
167
+ # print(d3.shape)
168
+ d2 = self.Up2(d3)
169
+ d2 = torch.cat((e1, d2), dim=1)
170
+ d2 = self.Up_conv2(d2)
171
+ # print("upconv2:")
172
+ # print(d2.shape)
173
+ out = self.Conv(d2)
174
+ # print("out:")
175
+ # print(out.shape)
176
+ # d1 = self.active(out)
177
+
178
+ return [out]
179
+
180
+ class U_Net_DeepSup(nn.Module):
181
+ """
182
+ UNet - Basic Implementation
183
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
184
+ Paper : https://arxiv.org/abs/1505.04597
185
+ """
186
+
187
+ def __init__(self, in_ch=1, out_ch=1, output_dir=None, init_features=64):
188
+ super(U_Net_DeepSup, self).__init__()
189
+
190
+ self.output_dir = output_dir
191
+ n1 = init_features
192
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
193
+
194
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
195
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
196
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
197
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
198
+
199
+ self.Conv1 = conv_block(in_ch, filters[0])
200
+ self.Conv2 = conv_block(filters[0], filters[1])
201
+ self.Conv3 = conv_block(filters[1], filters[2])
202
+ self.Conv4 = conv_block(filters[2], filters[3])
203
+ self.Conv5 = conv_block(filters[3], filters[4])
204
+
205
+ #1x1x1 Convolution for Deep Supervision
206
+ self.Conv_d3 = conv_block(filters[1], 1)
207
+ self.Conv_d4 = conv_block(filters[2], 1)
208
+
209
+
210
+
211
+ self.Up5 = up_conv(filters[4], filters[3])
212
+ self.Up_conv5 = conv_block(filters[4], filters[3])
213
+
214
+ self.Up4 = up_conv(filters[3], filters[2])
215
+ self.Up_conv4 = conv_block(filters[3], filters[2])
216
+
217
+ self.Up3 = up_conv(filters[2], filters[1])
218
+ self.Up_conv3 = conv_block(filters[2], filters[1])
219
+
220
+ self.Up2 = up_conv(filters[1], filters[0])
221
+ self.Up_conv2 = conv_block(filters[1], filters[0])
222
+
223
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
224
+
225
+ for submodule in self.modules():
226
+ submodule.register_forward_hook(self.nan_hook)
227
+
228
+ # self.active = torch.nn.Sigmoid()
229
+
230
+ def nan_hook(self, module, inp, output):
231
+ for i, out in enumerate(output):
232
+ nan_mask = torch.isnan(out)
233
+ if nan_mask.any():
234
+ print("In", self.__class__.__name__)
235
+ torch.save(inp, os.path.join(self.output_dir, 'nan_values_ip.pt'))
236
+ module_params = module.named_parameters()
237
+ for name, param in module_params:
238
+ torch.save(param, os.path.join(self.output_dir, 'nan_{}_param.pt'.format(name)))
239
+ torch.save(self.input_to_net, os.path.join(self.output_dir, 'nan_ip_batch.pt'))
240
+ raise RuntimeError(" classname "+self.__class__.__name__+"i "+str(i)+f" module: {module} classname {self.__class__.__name__} Found NAN in output {i} at indices: ", nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
241
+
242
+ def forward(self, x):
243
+ # print("unet")
244
+ # print(x.shape)
245
+ # print(padded.shape)
246
+ self.input_to_net = x
247
+ e1 = self.Conv1(x)
248
+ # print("conv1:")
249
+ # print(e1.shape)
250
+
251
+ e2 = self.Maxpool1(e1)
252
+ e2 = self.Conv2(e2)
253
+ # print("conv2:")
254
+ # print(e2.shape)
255
+
256
+ e3 = self.Maxpool2(e2)
257
+ e3 = self.Conv3(e3)
258
+ # print("conv3:")
259
+ # print(e3.shape)
260
+
261
+ e4 = self.Maxpool3(e3)
262
+ e4 = self.Conv4(e4)
263
+ # print("conv4:")
264
+ # print(e4.shape)
265
+
266
+ e5 = self.Maxpool4(e4)
267
+ e5 = self.Conv5(e5)
268
+ # print("conv5:")
269
+ # print(e5.shape)
270
+
271
+ d5 = self.Up5(e5)
272
+ # print("d5:")
273
+ # print(d5.shape)
274
+ # print("e4:")
275
+ # print(e4.shape)
276
+ d5 = torch.cat((e4, d5), dim=1)
277
+ d5 = self.Up_conv5(d5)
278
+ # print("upconv5:")
279
+ # print(d5.size)
280
+
281
+ d4 = self.Up4(d5)
282
+ # print("d4:")
283
+ # print(d4.shape)
284
+ d4 = torch.cat((e3, d4), dim=1)
285
+ d4 = self.Up_conv4(d4)
286
+ d4_out = self.Conv_d4(d4)
287
+
288
+
289
+ # print("upconv4:")
290
+ # print(d4.shape)
291
+ d3 = self.Up3(d4)
292
+ d3 = torch.cat((e2, d3), dim=1)
293
+ d3 = self.Up_conv3(d3)
294
+ d3_out = self.Conv_d3(d3)
295
+
296
+ # print("upconv3:")
297
+ # print(d3.shape)
298
+ d2 = self.Up2(d3)
299
+ d2 = torch.cat((e1, d2), dim=1)
300
+ d2 = self.Up_conv2(d2)
301
+ # print("upconv2:")
302
+ # print(d2.shape)
303
+ out = self.Conv(d2)
304
+ # print("out:")
305
+ # print(out.shape)
306
+ # d1 = self.active(out)
307
+
308
+ return [out, d3_out , d4_out]
309
+
310
+