soumickmj commited on
Commit
75e7505
·
verified ·
1 Parent(s): f306d18

Upload ProbUNet

Browse files
Files changed (6) hide show
  1. PULASki.py +48 -0
  2. PULASkiConfigs.py +26 -0
  3. ProbUNet_model.py +731 -0
  4. ProbUNet_utils.py +224 -0
  5. config.json +21 -0
  6. model.safetensors +3 -0
PULASki.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from transformers import PreTrainedModel
8
+
9
+ from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet
10
+ from .PULASkiConfigs import ProbUNetConfig
11
+
12
+ class ProbUNet(PreTrainedModel):
13
+ config_class = ProbUNetConfig
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+
17
+ if config.dim == 2:
18
+ task_op = InjectionUNet2D
19
+ prior_op = InjectionConvEncoder2D
20
+ posterior_op = InjectionConvEncoder2D
21
+ elif config.dim == 3:
22
+ task_op = InjectionUNet3D
23
+ prior_op = InjectionConvEncoder3D
24
+ posterior_op = InjectionConvEncoder3D
25
+ else:
26
+ sys.exit("Invalid dim! Only configured for dim 2 and 3.")
27
+
28
+ if config.latent_distribution == "normal":
29
+ latent_distribution = torch.distributions.Normal
30
+ else:
31
+ sys.exit("Invalid latent_distribution. Only normal has been implemented.")
32
+
33
+ self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels,
34
+ out_channels=config.out_channels,
35
+ num_feature_maps=config.num_feature_maps,
36
+ latent_size=config.latent_size,
37
+ depth=config.depth,
38
+ latent_distribution=latent_distribution,
39
+ task_op=task_op,
40
+ task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid,
41
+ "activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at},
42
+ prior_op=prior_op,
43
+ prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2},
44
+ posterior_op=posterior_op,
45
+ posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2},
46
+ )
47
+ def forward(self, x):
48
+ return self.model(x)
PULASkiConfigs.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class ProbUNetConfig(PretrainedConfig):
4
+ model_type = "ProbUNet"
5
+ def __init__(
6
+ self,
7
+ dim=2,
8
+ in_channels=1,
9
+ out_channels=1,
10
+ num_feature_maps=24,
11
+ latent_size=3,
12
+ depth=5,
13
+ latent_distribution="normal",
14
+ no_outact_op=False,
15
+ prob_injection_at="end",
16
+ **kwargs):
17
+ self.dim = dim
18
+ self.in_channels = in_channels
19
+ self.out_channels = out_channels
20
+ self.num_feature_maps = num_feature_maps
21
+ self.latent_size = latent_size
22
+ self.depth = depth
23
+ self.latent_distribution = latent_distribution
24
+ self.no_outact_op = no_outact_op
25
+ self.prob_injection_at = prob_injection_at
26
+ super().__init__(**kwargs)
ProbUNet_model.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .ProbUNet_utils import make_onehot as make_onehot_segmentation, make_slices, match_to
5
+
6
+
7
+ def is_conv(op):
8
+ conv_types = (nn.Conv1d,
9
+ nn.Conv2d,
10
+ nn.Conv3d,
11
+ nn.ConvTranspose1d,
12
+ nn.ConvTranspose2d,
13
+ nn.ConvTranspose3d)
14
+ if type(op) == type and issubclass(op, conv_types):
15
+ return True
16
+ elif type(op) in conv_types:
17
+ return True
18
+ else:
19
+ return False
20
+
21
+
22
+
23
+ class ConvModule(nn.Module):
24
+
25
+ def __init__(self, *args, **kwargs):
26
+
27
+ super(ConvModule, self).__init__()
28
+
29
+ def init_weights(self, init_fn, *args, **kwargs):
30
+
31
+ class init_(object):
32
+
33
+ def __init__(self):
34
+ self.fn = init_fn
35
+ self.args = args
36
+ self.kwargs = kwargs
37
+
38
+ def __call__(self, module):
39
+ if is_conv(type(module)):
40
+ module.weight = self.fn(module.weight, *self.args, **self.kwargs)
41
+
42
+ _init_ = init_()
43
+ self.apply(_init_)
44
+
45
+ def init_bias(self, init_fn, *args, **kwargs):
46
+
47
+ class init_(object):
48
+
49
+ def __init__(self):
50
+ self.fn = init_fn
51
+ self.args = args
52
+ self.kwargs = kwargs
53
+
54
+ def __call__(self, module):
55
+ if is_conv(type(module)) and module.bias is not None:
56
+ module.bias = self.fn(module.bias, *self.args, **self.kwargs)
57
+
58
+ _init_ = init_()
59
+ self.apply(_init_)
60
+
61
+
62
+
63
+ class ConcatCoords(nn.Module):
64
+
65
+ def forward(self, input_):
66
+
67
+ dim = input_.dim() - 2
68
+ coord_channels = []
69
+ for i in range(dim):
70
+ view = [1, ] * dim
71
+ view[i] = -1
72
+ repeat = list(input_.shape[2:])
73
+ repeat[i] = 1
74
+ coord_channels.append(
75
+ torch.linspace(-0.5, 0.5, input_.shape[i+2])
76
+ .view(*view)
77
+ .repeat(*repeat)
78
+ .to(device=input_.device, dtype=input_.dtype))
79
+ coord_channels = torch.stack(coord_channels).unsqueeze(0)
80
+ repeat = [1, ] * input_.dim()
81
+ repeat[0] = input_.shape[0]
82
+ coord_channels = coord_channels.repeat(*repeat).contiguous()
83
+
84
+ return torch.cat([input_, coord_channels], 1)
85
+
86
+
87
+
88
+ class InjectionConvEncoder(ConvModule):
89
+
90
+ _default_activation_kwargs = dict(inplace=True)
91
+ _default_norm_kwargs = dict()
92
+ _default_conv_kwargs = dict(kernel_size=3, padding=1)
93
+ _default_pool_kwargs = dict(kernel_size=2)
94
+ _default_dropout_kwargs = dict()
95
+ _default_global_pool_kwargs = dict()
96
+
97
+ def __init__(self,
98
+ in_channels=1,
99
+ out_channels=6,
100
+ depth=4,
101
+ injection_depth="last",
102
+ injection_channels=0,
103
+ block_depth=2,
104
+ num_feature_maps=24,
105
+ feature_map_multiplier=2,
106
+ activation_op=nn.LeakyReLU,
107
+ activation_kwargs=None,
108
+ norm_op=nn.InstanceNorm2d,
109
+ norm_kwargs=None,
110
+ norm_depth=0,
111
+ conv_op=nn.Conv2d,
112
+ conv_kwargs=None,
113
+ pool_op=nn.AvgPool2d,
114
+ pool_kwargs=None,
115
+ dropout_op=None,
116
+ dropout_kwargs=None,
117
+ global_pool_op=nn.AdaptiveAvgPool2d,
118
+ global_pool_kwargs=None,
119
+ **kwargs):
120
+
121
+ super(InjectionConvEncoder, self).__init__(**kwargs)
122
+
123
+ self.in_channels = in_channels
124
+ self.out_channels = out_channels
125
+ self.depth = depth
126
+ self.injection_depth = depth - 1 if injection_depth == "last" else injection_depth
127
+ self.injection_channels = injection_channels
128
+ self.block_depth = block_depth
129
+ self.num_feature_maps = num_feature_maps
130
+ self.feature_map_multiplier = feature_map_multiplier
131
+
132
+ self.activation_op = activation_op
133
+ self.activation_kwargs = self._default_activation_kwargs
134
+ if activation_kwargs is not None:
135
+ self.activation_kwargs.update(activation_kwargs)
136
+
137
+ self.norm_op = norm_op
138
+ self.norm_kwargs = self._default_norm_kwargs
139
+ if norm_kwargs is not None:
140
+ self.norm_kwargs.update(norm_kwargs)
141
+ self.norm_depth = depth if norm_depth == "full" else norm_depth
142
+
143
+ self.conv_op = conv_op
144
+ self.conv_kwargs = self._default_conv_kwargs
145
+ if conv_kwargs is not None:
146
+ self.conv_kwargs.update(conv_kwargs)
147
+
148
+ self.pool_op = pool_op
149
+ self.pool_kwargs = self._default_pool_kwargs
150
+ if pool_kwargs is not None:
151
+ self.pool_kwargs.update(pool_kwargs)
152
+
153
+ self.dropout_op = dropout_op
154
+ self.dropout_kwargs = self._default_dropout_kwargs
155
+ if dropout_kwargs is not None:
156
+ self.dropout_kwargs.update(dropout_kwargs)
157
+
158
+ self.global_pool_op = global_pool_op
159
+ self.global_pool_kwargs = self._default_global_pool_kwargs
160
+ if global_pool_kwargs is not None:
161
+ self.global_pool_kwargs.update(global_pool_kwargs)
162
+
163
+ for d in range(self.depth):
164
+
165
+ in_ = self.in_channels if d == 0 else self.num_feature_maps * (self.feature_map_multiplier**(d-1))
166
+ out_ = self.num_feature_maps * (self.feature_map_multiplier**d)
167
+
168
+ if d == self.injection_depth + 1:
169
+ in_ += self.injection_channels
170
+
171
+ layers = []
172
+ if d > 0:
173
+ layers.append(self.pool_op(**self.pool_kwargs))
174
+ for b in range(self.block_depth):
175
+ current_in = in_ if b == 0 else out_
176
+ layers.append(self.conv_op(current_in, out_, **self.conv_kwargs))
177
+ if self.norm_op is not None and d < self.norm_depth:
178
+ layers.append(self.norm_op(out_, **self.norm_kwargs))
179
+ if self.activation_op is not None:
180
+ layers.append(self.activation_op(**self.activation_kwargs))
181
+ if self.dropout_op is not None:
182
+ layers.append(self.dropout_op(**self.dropout_kwargs))
183
+ if d == self.depth - 1:
184
+ current_conv_kwargs = self.conv_kwargs.copy()
185
+ current_conv_kwargs["kernel_size"] = 1
186
+ current_conv_kwargs["padding"] = 0
187
+ current_conv_kwargs["bias"] = False
188
+ layers.append(self.conv_op(out_, out_channels, **current_conv_kwargs))
189
+
190
+ self.add_module("encode_{}".format(d), nn.Sequential(*layers))
191
+
192
+ if self.global_pool_op is not None:
193
+ self.add_module("global_pool", self.global_pool_op(1, **self.global_pool_kwargs))
194
+
195
+ def forward(self, x, injection=None):
196
+
197
+ for d in range(self.depth):
198
+ x = self._modules["encode_{}".format(d)](x)
199
+ if d == self.injection_depth and self.injection_channels > 0:
200
+ injection = match_to(injection, x, self.injection_channels)
201
+ x = torch.cat([x, injection], 1)
202
+ if hasattr(self, "global_pool"):
203
+ x = self.global_pool(x)
204
+
205
+ return x
206
+
207
+
208
+ class InjectionConvEncoder3D(InjectionConvEncoder):
209
+
210
+ def __init__(self, *args, **kwargs):
211
+
212
+ update_kwargs = dict(
213
+ norm_op=nn.InstanceNorm3d,
214
+ conv_op=nn.Conv3d,
215
+ pool_op=nn.AvgPool3d,
216
+ global_pool_op=nn.AdaptiveAvgPool3d
217
+ )
218
+
219
+ for (arg, val) in update_kwargs.items():
220
+ if arg not in kwargs: kwargs[arg] = val
221
+
222
+ super(InjectionConvEncoder3D, self).__init__(*args, **kwargs)
223
+
224
+ class InjectionConvEncoder2D(InjectionConvEncoder): #Created by Soumick
225
+
226
+ def __init__(self, *args, **kwargs):
227
+
228
+ update_kwargs = dict(
229
+ norm_op=nn.InstanceNorm2d,
230
+ conv_op=nn.Conv2d,
231
+ pool_op=nn.AvgPool2d,
232
+ global_pool_op=nn.AdaptiveAvgPool2d
233
+ )
234
+
235
+ for (arg, val) in update_kwargs.items():
236
+ if arg not in kwargs: kwargs[arg] = val
237
+
238
+ super(InjectionConvEncoder2D, self).__init__(*args, **kwargs)
239
+
240
+ class InjectionUNet(ConvModule):
241
+
242
+ def __init__(
243
+ self,
244
+ depth=5,
245
+ in_channels=4,
246
+ out_channels=4,
247
+ kernel_size=3,
248
+ dilation=1,
249
+ num_feature_maps=24,
250
+ block_depth=2,
251
+ num_1x1_at_end=3,
252
+ injection_channels=3,
253
+ injection_at="end",
254
+ activation_op=nn.LeakyReLU,
255
+ activation_kwargs=None,
256
+ pool_op=nn.AvgPool2d,
257
+ pool_kwargs=dict(kernel_size=2),
258
+ dropout_op=None,
259
+ dropout_kwargs=None,
260
+ norm_op=nn.InstanceNorm2d,
261
+ norm_kwargs=None,
262
+ conv_op=nn.Conv2d,
263
+ conv_kwargs=None,
264
+ upconv_op=nn.ConvTranspose2d,
265
+ upconv_kwargs=None,
266
+ output_activation_op=None,
267
+ output_activation_kwargs=None,
268
+ return_bottom=False,
269
+ coords=False,
270
+ coords_dim=2,
271
+ **kwargs
272
+ ):
273
+
274
+ super(InjectionUNet, self).__init__(**kwargs)
275
+
276
+ self.depth = depth
277
+ self.in_channels = in_channels
278
+ self.out_channels = out_channels
279
+ self.kernel_size = kernel_size
280
+ self.dilation = dilation
281
+ self.padding = (self.kernel_size + (self.kernel_size-1) * (self.dilation-1)) // 2
282
+ self.num_feature_maps = num_feature_maps
283
+ self.block_depth = block_depth
284
+ self.num_1x1_at_end = num_1x1_at_end
285
+ self.injection_channels = injection_channels
286
+ self.injection_at = injection_at
287
+ self.activation_op = activation_op
288
+ self.activation_kwargs = {} if activation_kwargs is None else activation_kwargs
289
+ self.pool_op = pool_op
290
+ self.pool_kwargs = {} if pool_kwargs is None else pool_kwargs
291
+ self.dropout_op = dropout_op
292
+ self.dropout_kwargs = {} if dropout_kwargs is None else dropout_kwargs
293
+ self.norm_op = norm_op
294
+ self.norm_kwargs = {} if norm_kwargs is None else norm_kwargs
295
+ self.conv_op = conv_op
296
+ self.conv_kwargs = {} if conv_kwargs is None else conv_kwargs
297
+ self.upconv_op = upconv_op
298
+ self.upconv_kwargs = {} if upconv_kwargs is None else upconv_kwargs
299
+ self.output_activation_op = output_activation_op
300
+ self.output_activation_kwargs = {} if output_activation_kwargs is None else output_activation_kwargs
301
+ self.return_bottom = return_bottom
302
+ if not coords:
303
+ self.coords = [[], []]
304
+ elif coords is True:
305
+ self.coords = [list(range(depth)), []]
306
+ else:
307
+ self.coords = coords
308
+ self.coords_dim = coords_dim
309
+
310
+ self.last_activations = None
311
+
312
+ # BUILD ENCODER
313
+ for d in range(self.depth):
314
+
315
+ block = []
316
+ if d > 0:
317
+ block.append(self.pool_op(**self.pool_kwargs))
318
+
319
+ for i in range(self.block_depth):
320
+
321
+ # bottom block fixed to have depth 1
322
+ if d == self.depth - 1 and i > 0:
323
+ continue
324
+
325
+ out_size = self.num_feature_maps * 2**d
326
+ if d == 0 and i == 0:
327
+ in_size = self.in_channels
328
+ elif i == 0:
329
+ in_size = self.num_feature_maps * 2**(d - 1)
330
+ else:
331
+ in_size = out_size
332
+
333
+ # check for coord appending at this depth
334
+ if d in self.coords[0] and i == 0:
335
+ block.append(ConcatCoords())
336
+ in_size += self.coords_dim
337
+
338
+ block.append(self.conv_op(in_size,
339
+ out_size,
340
+ self.kernel_size,
341
+ padding=self.padding,
342
+ dilation=self.dilation,
343
+ **self.conv_kwargs))
344
+ if self.dropout_op is not None:
345
+ block.append(self.dropout_op(**self.dropout_kwargs))
346
+ if self.norm_op is not None:
347
+ block.append(self.norm_op(out_size, **self.norm_kwargs))
348
+ block.append(self.activation_op(**self.activation_kwargs))
349
+
350
+ self.add_module("encode-{}".format(d), nn.Sequential(*block))
351
+
352
+ # BUILD DECODER
353
+ for d in reversed(range(self.depth)):
354
+
355
+ block = []
356
+
357
+ for i in range(self.block_depth):
358
+
359
+ # bottom block fixed to have depth 1
360
+ if d == self.depth - 1 and i > 0:
361
+ continue
362
+
363
+ out_size = self.num_feature_maps * 2**(d)
364
+ if i == 0 and d < self.depth - 1:
365
+ in_size = self.num_feature_maps * 2**(d+1)
366
+ elif i == 0 and self.injection_at == "bottom":
367
+ in_size = out_size + self.injection_channels
368
+ else:
369
+ in_size = out_size
370
+
371
+ # check for coord appending at this depth
372
+ if d in self.coords[0] and i == 0 and d < self.depth - 1:
373
+ block.append(ConcatCoords())
374
+ in_size += self.coords_dim
375
+
376
+ block.append(self.conv_op(in_size,
377
+ out_size,
378
+ self.kernel_size,
379
+ padding=self.padding,
380
+ dilation=self.dilation,
381
+ **self.conv_kwargs))
382
+ if self.dropout_op is not None:
383
+ block.append(self.dropout_op(**self.dropout_kwargs))
384
+ if self.norm_op is not None:
385
+ block.append(self.norm_op(out_size, **self.norm_kwargs))
386
+ block.append(self.activation_op(**self.activation_kwargs))
387
+
388
+ if d > 0:
389
+ block.append(self.upconv_op(out_size,
390
+ out_size // 2,
391
+ self.kernel_size,
392
+ 2,
393
+ padding=self.padding,
394
+ dilation=self.dilation,
395
+ output_padding=1,
396
+ **self.upconv_kwargs))
397
+
398
+ self.add_module("decode-{}".format(d), nn.Sequential(*block))
399
+
400
+ if self.injection_at == "end":
401
+ out_size += self.injection_channels
402
+ in_size = out_size
403
+ for i in range(self.num_1x1_at_end):
404
+ if i == self.num_1x1_at_end - 1:
405
+ out_size = self.out_channels
406
+ current_conv_kwargs = self.conv_kwargs.copy()
407
+ current_conv_kwargs["bias"] = True
408
+ self.add_module("reduce-{}".format(i), self.conv_op(in_size, out_size, 1, **current_conv_kwargs))
409
+ if i != self.num_1x1_at_end - 1:
410
+ self.add_module("reduce-{}-nonlin".format(i), self.activation_op(**self.activation_kwargs))
411
+ if self.output_activation_op is not None:
412
+ self.add_module("output-activation", self.output_activation_op(**self.output_activation_kwargs))
413
+
414
+ def reset(self):
415
+
416
+ self.last_activations = None
417
+
418
+ def forward(self, x, injection=None, reuse_last_activations=False, store_activations=False):
419
+
420
+ if self.injection_at == "bottom": # not worth it for now
421
+ reuse_last_activations = False
422
+ store_activations = False
423
+
424
+ if self.last_activations is None or reuse_last_activations is False:
425
+
426
+ enc = [x]
427
+
428
+ for i in range(self.depth - 1):
429
+ enc.append(self._modules["encode-{}".format(i)](enc[-1]))
430
+
431
+ bottom_rep = self._modules["encode-{}".format(self.depth - 1)](enc[-1])
432
+
433
+ if self.injection_at == "bottom" and self.injection_channels > 0:
434
+ injection = match_to(injection, bottom_rep, (0, 1))
435
+ bottom_rep = torch.cat((bottom_rep, injection), 1)
436
+
437
+ x = self._modules["decode-{}".format(self.depth - 1)](bottom_rep)
438
+
439
+ for i in reversed(range(self.depth - 1)):
440
+ x = self._modules["decode-{}".format(i)](torch.cat((enc[-(self.depth - 1 - i)], x), 1))
441
+
442
+ if store_activations:
443
+ self.last_activations = x.detach()
444
+
445
+ else:
446
+
447
+ x = self.last_activations
448
+
449
+ if self.injection_at == "end" and self.injection_channels > 0:
450
+ injection = match_to(injection, x, (0, 1))
451
+ x = torch.cat((x, injection), 1)
452
+
453
+ for i in range(self.num_1x1_at_end):
454
+ x = self._modules["reduce-{}".format(i)](x)
455
+ if self.output_activation_op is not None:
456
+ x = self._modules["output-activation"](x)
457
+
458
+ if self.return_bottom and not reuse_last_activations:
459
+ return x, bottom_rep
460
+ else:
461
+ return x
462
+
463
+
464
+
465
+ class InjectionUNet3D(InjectionUNet):
466
+
467
+ def __init__(self, *args, **kwargs):
468
+
469
+ update_kwargs = dict(
470
+ pool_op=nn.AvgPool3d,
471
+ norm_op=nn.InstanceNorm3d,
472
+ conv_op=nn.Conv3d,
473
+ upconv_op=nn.ConvTranspose3d,
474
+ coords_dim=3
475
+ )
476
+
477
+ for (arg, val) in update_kwargs.items():
478
+ if arg not in kwargs: kwargs[arg] = val
479
+
480
+ super(InjectionUNet3D, self).__init__(*args, **kwargs)
481
+
482
+ class InjectionUNet2D(InjectionUNet): #Created by Soumick
483
+
484
+ def __init__(self, *args, **kwargs):
485
+
486
+ update_kwargs = dict(
487
+ pool_op=nn.AvgPool2d,
488
+ norm_op=nn.InstanceNorm2d,
489
+ conv_op=nn.Conv2d,
490
+ upconv_op=nn.ConvTranspose2d,
491
+ coords_dim=2
492
+ )
493
+
494
+ for (arg, val) in update_kwargs.items():
495
+ if arg not in kwargs: kwargs[arg] = val
496
+
497
+ super(InjectionUNet2D, self).__init__(*args, **kwargs)
498
+
499
+ class ProbabilisticSegmentationNet(ConvModule):
500
+
501
+ def __init__(self,
502
+ in_channels=4,
503
+ out_channels=4,
504
+ num_feature_maps=24,
505
+ latent_size=3,
506
+ depth=5,
507
+ latent_distribution=torch.distributions.Normal,
508
+ task_op=InjectionUNet3D,
509
+ task_kwargs=None,
510
+ prior_op=InjectionConvEncoder3D,
511
+ prior_kwargs=None,
512
+ posterior_op=InjectionConvEncoder3D,
513
+ posterior_kwargs=None,
514
+ **kwargs):
515
+
516
+ super(ProbabilisticSegmentationNet, self).__init__(**kwargs)
517
+
518
+ self.task_op = task_op
519
+ self.task_kwargs = {} if task_kwargs is None else task_kwargs
520
+ self.prior_op = prior_op
521
+ self.prior_kwargs = {} if prior_kwargs is None else prior_kwargs
522
+ self.posterior_op = posterior_op
523
+ self.posterior_kwargs = {} if posterior_kwargs is None else posterior_kwargs
524
+
525
+ default_task_kwargs = dict(
526
+ in_channels=in_channels,
527
+ out_channels=out_channels,
528
+ num_feature_maps=num_feature_maps,
529
+ injection_size=latent_size,
530
+ depth=depth
531
+ )
532
+
533
+ default_prior_kwargs = dict(
534
+ in_channels=in_channels,
535
+ out_channels=latent_size*2, #Soumick
536
+ num_feature_maps=num_feature_maps,
537
+ z_dim=latent_size,
538
+ depth=depth
539
+ )
540
+
541
+ default_posterior_kwargs = dict(
542
+ in_channels=in_channels+out_channels,
543
+ out_channels=latent_size*2, #Soumick
544
+ num_feature_maps=num_feature_maps,
545
+ z_dim=latent_size,
546
+ depth=depth
547
+ )
548
+
549
+ default_task_kwargs.update(self.task_kwargs)
550
+ self.task_kwargs = default_task_kwargs
551
+ default_prior_kwargs.update(self.prior_kwargs)
552
+ self.prior_kwargs = default_prior_kwargs
553
+ default_posterior_kwargs.update(self.posterior_kwargs)
554
+ self.posterior_kwargs = default_posterior_kwargs
555
+
556
+ self.latent_distribution = latent_distribution
557
+ self._prior = None
558
+ self._posterior = None
559
+
560
+ self.make_modules()
561
+
562
+ def make_modules(self):
563
+
564
+ if type(self.task_op) == type:
565
+ self.add_module("task_net", self.task_op(**self.task_kwargs))
566
+ else:
567
+ self.add_module("task_net", self.task_op)
568
+ if type(self.prior_op) == type:
569
+ self.add_module("prior_net", self.prior_op(**self.prior_kwargs))
570
+ else:
571
+ self.add_module("prior_net", self.prior_op)
572
+ if type(self.posterior_op) == type:
573
+ self.add_module("posterior_net", self.posterior_op(**self.posterior_kwargs))
574
+ else:
575
+ self.add_module("posterior_net", self.posterior_op)
576
+
577
+ @property
578
+ def prior(self):
579
+ return self._prior
580
+
581
+ @property
582
+ def posterior(self):
583
+ return self._posterior
584
+
585
+ @property
586
+ def last_activations(self):
587
+ return self.task_net.last_activations
588
+
589
+ def train(self, mode=True):
590
+
591
+ super(ProbabilisticSegmentationNet, self).train(mode)
592
+ self.reset()
593
+
594
+ def reset(self):
595
+
596
+ self.task_net.reset()
597
+ self._prior = None
598
+ self._posterior = None
599
+
600
+ def forward(self, input_, seg=None, make_onehot=True, make_onehot_classes=None, newaxis=False, distlossN=0):
601
+ """Forward pass includes reparametrization sampling during training, otherwise it'll just take the prior mean."""
602
+
603
+ self.encode_prior(input_)
604
+
605
+ if distlossN == 0:
606
+ if self.training:
607
+ self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis)
608
+ sample = self.posterior.rsample()
609
+ else:
610
+ sample = self.prior.loc
611
+ return self.task_net(input_, sample, store_activations=not self.training)
612
+ else:
613
+ if self.training:
614
+ self.encode_posterior(input_, seg, make_onehot, make_onehot_classes, newaxis)
615
+ segs = []
616
+ for i in range(distlossN):
617
+ sample = self.posterior.rsample()
618
+ segs.append(self.task_net(input_, sample, store_activations=not self.training))
619
+ return segs #torch.concat(segs, dim=0)
620
+ else: #I'm not totally sure about this!!
621
+ sample = self.prior.loc
622
+ return self.task_net(input_, sample, store_activations=not self.training)
623
+
624
+
625
+ def encode_prior(self, input_):
626
+
627
+ rep = self.prior_net(input_)
628
+ if isinstance(rep, tuple):
629
+ mean, logvar = rep
630
+ elif torch.is_tensor(rep):
631
+ mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1)
632
+ self._prior = self.latent_distribution(mean, logvar.mul(0.5).exp())
633
+ return self._prior
634
+
635
+ def encode_posterior(self, input_, seg, make_onehot=True, make_onehot_classes=None, newaxis=False):
636
+
637
+ if make_onehot:
638
+ if make_onehot_classes is None:
639
+ make_onehot_classes = tuple(range(self.posterior_net.in_channels - input_.shape[1]))
640
+ seg = make_onehot_segmentation(seg, make_onehot_classes, newaxis=newaxis)
641
+ rep = self.posterior_net(torch.cat((input_, seg.float()), 1))
642
+ if isinstance(rep, tuple):
643
+ mean, logvar = rep
644
+ elif torch.is_tensor(rep):
645
+ mean, logvar = torch.split(rep, rep.shape[1] // 2, dim=1)
646
+ self._posterior = self.latent_distribution(mean, logvar.mul(0.5).exp())
647
+ return self._posterior
648
+
649
+ def sample_prior(self, N=1, out_device=None, input_=None, pred_with_mean=False):
650
+ """Draw multiple samples from the current prior.
651
+
652
+ * input_ is required if no activations are stored in task_net.
653
+ * If input_ is given, prior will automatically be encoded again.
654
+ * Returns either a single sample or a list of samples.
655
+
656
+ """
657
+
658
+ if out_device is None:
659
+ if self.last_activations is not None:
660
+ out_device = self.last_activations.device
661
+ elif input_ is not None:
662
+ out_device = input_.device
663
+ else:
664
+ out_device = next(self.task_net.parameters()).device
665
+ with torch.no_grad():
666
+ if self.prior is None or input_ is not None:
667
+ self.encode_prior(input_)
668
+ result = []
669
+
670
+ if input_ is not None:
671
+ result.append(self.task_net(input_, self.prior.sample(), reuse_last_activations=False, store_activations=True).to(device=out_device))
672
+ while len(result) < N:
673
+ result.append(self.task_net(input_,
674
+ self.prior.sample(),
675
+ reuse_last_activations=self.last_activations is not None,
676
+ store_activations=False).to(device=out_device))
677
+ if pred_with_mean:
678
+ result.append(self.task_net(input_, self.prior.mean, reuse_last_activations=False, store_activations=True).to(device=out_device))
679
+
680
+ if len(result) == 1:
681
+ return result[0]
682
+ else:
683
+ return result
684
+
685
+ def reconstruct(self, sample=None, use_posterior_mean=True, out_device=None, input_=None):
686
+ """Reconstruct a sample or the current posterior mean. Will not compute gradients!"""
687
+
688
+ if self.posterior is None and sample is None:
689
+ raise ValueError("'posterior' is currently None. Please pass an input and a segmentation first.")
690
+ if out_device is None:
691
+ out_device = next(self.task_net.parameters()).device
692
+ if sample is None:
693
+ if use_posterior_mean:
694
+ sample = self.posterior.loc
695
+ else:
696
+ sample = self.posterior.sample()
697
+ else:
698
+ sample = sample.to(next(self.task_net.parameters()).device)
699
+ with torch.no_grad():
700
+ return self.task_net(input_, sample, reuse_last_activations=True).to(device=out_device)
701
+
702
+ def kl_divergence(self):
703
+ """Compute current KL, requires existing prior and posterior."""
704
+
705
+ if self.posterior is None or self.prior is None:
706
+ raise ValueError("'prior' and 'posterior' must not be None, but prior={} and posterior={}".format(self.prior, self.posterior))
707
+ return torch.distributions.kl_divergence(self.posterior, self.prior).sum()
708
+
709
+ def elbo(self, seg, input_=None, nll_reduction="sum", beta=1.0, make_onehot=True, make_onehot_classes=None, newaxis=False):
710
+ """Compute the ELBO with seg as ground truth.
711
+
712
+ * Prior is expected and will not be encoded.
713
+ * If input_ is given, posterior will automatically be encoded.
714
+ * Either input_ or stored activations must be available.
715
+
716
+ """
717
+
718
+ if self.last_activations is None:
719
+ raise ValueError("'last_activations' is currently None. Please pass an input first.")
720
+ if input_ is not None:
721
+ with torch.no_grad():
722
+ self.encode_posterior(input_, seg, make_onehot=make_onehot, make_onehot_classes=make_onehot_classes, newaxis=newaxis)
723
+ if make_onehot and newaxis:
724
+ pass # seg will already be (B x SPACE)
725
+ elif make_onehot and not newaxis:
726
+ seg = seg[:, 0] # in this case seg will hopefully be (B x 1 x SPACE)
727
+ else:
728
+ seg = torch.argmax(seg, 1, keepdim=False) # seg is already onehot
729
+ kl = self.kl_divergence()
730
+ nll = nn.NLLLoss(reduction=nll_reduction)(self.reconstruct(sample=None, use_posterior_mean=True, out_device=None), seg.long())
731
+ return - (beta * nll + kl)
ProbUNet_utils.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import shutil
4
+ import numpy as np
5
+ import torch
6
+ # from trixi.util import Config, GridSearch
7
+
8
+
9
+ def check_attributes(object_, attributes):
10
+
11
+ missing = []
12
+ for attr in attributes:
13
+ if not hasattr(object_, attr):
14
+ missing.append(attr)
15
+ if len(missing) > 0:
16
+ return False
17
+ else:
18
+ return True
19
+
20
+
21
+ def set_seeds(seed, cuda=True):
22
+ if not hasattr(seed, "__iter__"):
23
+ seed = (seed, seed, seed)
24
+ np.random.seed(seed[0])
25
+ torch.manual_seed(seed[1])
26
+ if cuda: torch.cuda.manual_seed_all(seed[2])
27
+
28
+
29
+ def make_onehot(array, labels=None, axis=1, newaxis=False):
30
+
31
+ # get labels if necessary
32
+ if labels is None:
33
+ labels = np.unique(array)
34
+ labels = list(map(lambda x: x.item(), labels))
35
+
36
+ # get target shape
37
+ new_shape = list(array.shape)
38
+ if newaxis:
39
+ new_shape.insert(axis, len(labels))
40
+ else:
41
+ new_shape[axis] = new_shape[axis] * len(labels)
42
+
43
+ # make zero array
44
+ if type(array) == np.ndarray:
45
+ new_array = np.zeros(new_shape, dtype=array.dtype)
46
+ elif torch.is_tensor(array):
47
+ new_array = torch.zeros(new_shape, dtype=array.dtype, device=array.device)
48
+ else:
49
+ raise TypeError("Onehot conversion undefined for object of type {}".format(type(array)))
50
+
51
+ # fill new array
52
+ n_seg_channels = 1 if newaxis else array.shape[axis]
53
+ for seg_channel in range(n_seg_channels):
54
+ for l, label in enumerate(labels):
55
+ new_slc = [slice(None), ] * len(new_shape)
56
+ slc = [slice(None), ] * len(array.shape)
57
+ new_slc[axis] = seg_channel * len(labels) + l
58
+ if not newaxis:
59
+ slc[axis] = seg_channel
60
+ new_array[tuple(new_slc)] = array[tuple(slc)] == label
61
+
62
+ return new_array
63
+
64
+
65
+ def match_to(x, ref, keep_axes=(1,)):
66
+
67
+ target_shape = list(ref.shape)
68
+ for i in keep_axes:
69
+ target_shape[i] = x.shape[i]
70
+ target_shape = tuple(target_shape)
71
+ if x.shape == target_shape:
72
+ pass
73
+ if x.dim() == 1:
74
+ x = x.unsqueeze(0)
75
+ if x.dim() == 2:
76
+ while x.dim() < len(target_shape):
77
+ x = x.unsqueeze(-1)
78
+
79
+ x = x.expand(*target_shape)
80
+ x = x.to(device=ref.device, dtype=ref.dtype)
81
+
82
+ return x
83
+
84
+
85
+ def make_slices(original_shape, patch_shape):
86
+
87
+ working_shape = original_shape[-len(patch_shape):]
88
+ splits = []
89
+ for i in range(len(working_shape)):
90
+ splits.append([])
91
+ for j in range(working_shape[i] // patch_shape[i]):
92
+ splits[i].append(slice(j*patch_shape[i], (j+1)*patch_shape[i]))
93
+ rest = working_shape[i] % patch_shape[i]
94
+ if rest > 0:
95
+ splits[i].append(slice((j+1)*patch_shape[i], (j+1)*patch_shape[i] + rest))
96
+
97
+ # now we have all slices for the individual dimensions
98
+ # we need their combinatorial combinations
99
+ slices = list(itertools.product(*splits))
100
+ for i in range(len(slices)):
101
+ slices[i] = [slice(None), ] * (len(original_shape) - len(patch_shape)) + list(slices[i])
102
+
103
+ return slices
104
+
105
+
106
+ def coordinate_grid_samples(mean, std, factor_std=5, scale_std=1.):
107
+
108
+ relative = np.linspace(-scale_std*factor_std, scale_std*factor_std, 2*factor_std+1)
109
+ positions = np.array([mean + i * std for i in relative]).T
110
+ axes = np.meshgrid(*positions)
111
+ axes = map(lambda x: list(x.ravel()), axes)
112
+ samples = list(zip(*axes))
113
+ samples = list(map(np.array, samples))
114
+
115
+ return samples
116
+
117
+
118
+ def get_default_experiment_parser():
119
+
120
+ parser = argparse.ArgumentParser()
121
+ parser.add_argument("base_dir", type=str, help="Working directory for experiment.")
122
+ parser.add_argument("-c", "--config", type=str, default=None, help="Path to a config file.")
123
+ parser.add_argument("-v", "--visdomlogger", action="store_true", help="Use visdomlogger.")
124
+ parser.add_argument("-tx", "--tensorboardxlogger", type=str, default=None)
125
+ parser.add_argument("-tl", "--telegramlogger", action="store_true")
126
+ parser.add_argument("-dc", "--default_config", type=str, default="DEFAULTS", help="Select a default Config")
127
+ parser.add_argument("-ad", "--automatic_description", action="store_true")
128
+ parser.add_argument("-r", "--resume", type=str, default=None, help="Path to resume from")
129
+ parser.add_argument("-irc", "--ignore_resume_config", action="store_true", help="Ignore Config in experiment we resume from.")
130
+ parser.add_argument("-test", "--test", action="store_true", help="Run test instead of training")
131
+ parser.add_argument("-g", "--grid", type=str, help="Path to a config for grid search")
132
+ parser.add_argument("-s", "--skip_existing", action="store_true", help="Skip configs for which an experiment exists, only for grid search")
133
+ parser.add_argument("-m", "--mods", type=str, nargs="+", default=None, help="Mods are Config stubs to update only relevant parts for a certain setup.")
134
+ parser.add_argument("-ct", "--copy_test", action="store_true", help="Copy test files to original experiment.")
135
+
136
+ return parser
137
+
138
+
139
+ def run_experiment(experiment, configs, args, mods=None, **kwargs):
140
+
141
+ # set a few defaults
142
+ if "explogger_kwargs" not in kwargs:
143
+ kwargs["explogger_kwargs"] = dict(folder_format="{experiment_name}_%Y%m%d-%H%M%S")
144
+ if "explogger_freq" not in kwargs:
145
+ kwargs["explogger_freq"] = 1
146
+ if "resume_save_types" not in kwargs:
147
+ kwargs["resume_save_types"] = ("model", "simple", "th_vars", "results")
148
+
149
+ config = Config(file_=args.config) if args.config is not None else Config()
150
+ config.update_missing(configs[args.default_config].deepcopy())
151
+ if args.mods is not None and mods is not None:
152
+ for mod in args.mods:
153
+ config.update(mods[mod])
154
+ config = Config(config=config, update_from_argv=True)
155
+
156
+ # GET EXISTING EXPERIMENTS TO BE ABLE TO SKIP CERTAIN CONFIGS
157
+ if args.skip_existing:
158
+ existing_configs = []
159
+ for exp in os.listdir(args.base_dir):
160
+ try:
161
+ existing_configs.append(Config(file_=os.path.join(args.base_dir, exp, "config", "config.json")))
162
+ except Exception as e:
163
+ pass
164
+
165
+ if args.grid is not None:
166
+ grid = GridSearch().read(args.grid)
167
+ else:
168
+ grid = [{}]
169
+
170
+ for combi in grid:
171
+
172
+ config.update(combi)
173
+
174
+ if args.skip_existing:
175
+ skip_this = False
176
+ for existing_config in existing_configs:
177
+ if existing_config.contains(config):
178
+ skip_this = True
179
+ break
180
+ if skip_this:
181
+ continue
182
+
183
+ if "backup_every" in config:
184
+ kwargs["save_checkpoint_every_epoch"] = config["backup_every"]
185
+
186
+ loggers = {}
187
+ if args.visdomlogger:
188
+ loggers["v"] = ("visdom", {}, 1)
189
+ if args.tensorboardxlogger is not None:
190
+ if args.tensorboardxlogger == "same":
191
+ loggers["tx"] = ("tensorboard", {}, 1)
192
+ else:
193
+ loggers["tx"] = ("tensorboard", {"target_dir": args.tensorboardxlogger}, 1)
194
+
195
+ if args.telegramlogger:
196
+ kwargs["use_telegram"] = True
197
+
198
+ if args.automatic_description:
199
+ difference_to_default = Config.difference_config_static(config, configs["DEFAULTS"]).flat(keep_lists=True, max_split_size=0, flatten_int=True)
200
+ description_str = ""
201
+ for key, val in difference_to_default.items():
202
+ val = val[0]
203
+ description_str = "{} = {}\n{}".format(key, val, description_str)
204
+ config.description = description_str
205
+
206
+ exp = experiment(config=config,
207
+ base_dir=args.base_dir,
208
+ resume=args.resume,
209
+ ignore_resume_config=args.ignore_resume_config,
210
+ loggers=loggers,
211
+ **kwargs)
212
+
213
+ trained = False
214
+ if args.resume is None or args.test is False:
215
+ exp.run()
216
+ trained = True
217
+ if args.test:
218
+ exp.run_test(setup=not trained)
219
+ if isinstance(args.resume, str) and exp.elog is not None and args.copy_test:
220
+ for f in glob.glob(os.path.join(exp.elog.save_dir, "test*")):
221
+ if os.path.isdir(f):
222
+ shutil.copytree(f, os.path.join(args.resume, "save", os.path.basename(f)))
223
+ else:
224
+ shutil.copy(f, os.path.join(args.resume, "save"))
config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProbUNet"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "PULASkiConfigs.ProbUNetConfig",
7
+ "AutoModel": "PULASki.ProbUNet"
8
+ },
9
+ "depth": 5,
10
+ "dim": 3,
11
+ "in_channels": 1,
12
+ "latent_distribution": "normal",
13
+ "latent_size": 3,
14
+ "model_type": "ProbUNet",
15
+ "no_outact_op": false,
16
+ "num_feature_maps": 24,
17
+ "out_channels": 1,
18
+ "prob_injection_at": "end",
19
+ "torch_dtype": "float32",
20
+ "transformers_version": "4.44.2"
21
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b0be92591502420a34cf1163afde3843199892b43750c711166a5e5beb92973
3
+ size 121870000