import torch.nn as nn class DCGAN3D_G(nn.Module): def __init__(self, isize, nz, nc, ngf, ngpu, n_extra_layers=0): super(DCGAN3D_G, self).__init__() self.ngpu = ngpu assert isize % 16 == 0, "isize has to be a multiple of 16" cngf, tisize = ngf // 2, 4 while tisize != isize: cngf = cngf * 2 tisize = tisize * 2 main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose3d(nz, cngf, 4, 1, 0, bias=False), nn.BatchNorm3d(cngf), nn.ReLU(True), ) i, csize, cndf = 3, 4, cngf while csize < isize // 2: main.add_module(str(i), nn.ConvTranspose3d(cngf, cngf // 2, 4, 2, 1, bias=False)) main.add_module(str(i + 1), nn.BatchNorm3d(cngf // 2)) main.add_module(str(i + 2), nn.ReLU(True)) i += 3 cngf = cngf // 2 csize = csize * 2 # Extra layers for t in range(n_extra_layers): main.add_module(str(i), nn.Conv3d(cngf, cngf, 3, 1, 1, bias=False)) main.add_module(str(i + 1), nn.BatchNorm3d(cngf)) main.add_module(str(i + 2), nn.ReLU(True)) i += 3 main.add_module(str(i), nn.ConvTranspose3d(cngf, nc, 4, 2, 1, bias=False)) main.add_module(str(i + 1), nn.Tanh()) self.main = main def forward(self, input): return self.main(input)