Spaces:
Runtime error
Runtime error
from torch import nn | |
from layers.normalization import AdaIN | |
class DestyleResBlock(nn.Module): | |
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False): | |
super(DestyleResBlock, self).__init__() | |
# uses 1x1 convolutions for downsampling | |
if not channels_in or channels_in == channels_out: | |
channels_in = channels_out | |
self.projection = None | |
else: | |
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1) | |
self.use_dropout = use_dropout | |
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) | |
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation) | |
self.adain = AdaIN() | |
if self.use_dropout: | |
self.dropout = nn.Dropout() | |
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x, feat): | |
residual = x | |
out = self.conv1(x) | |
out = self.lrelu1(out) | |
out = self.conv2(out) | |
_, _, h, w = out.size() | |
out = self.adain(out, feat) | |
if self.use_dropout: | |
out = self.dropout(out) | |
if self.projection: | |
residual = self.projection(x) | |
out = out + residual | |
out = self.lrelu2(out) | |
return out | |
class ResBlock(nn.Module): | |
def __init__(self, channels_out, kernel_size, channels_in=None, stride=1, dilation=1, padding=1, use_dropout=False): | |
super(ResBlock, self).__init__() | |
# uses 1x1 convolutions for downsampling | |
if not channels_in or channels_in == channels_out: | |
channels_in = channels_out | |
self.projection = None | |
else: | |
self.projection = nn.Conv2d(channels_in, channels_out, kernel_size=1, stride=stride, dilation=1) | |
self.use_dropout = use_dropout | |
self.conv1 = nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) | |
self.lrelu1 = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
self.conv2 = nn.Conv2d(channels_out, channels_out, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation) | |
self.n2 = nn.BatchNorm2d(channels_out) | |
if self.use_dropout: | |
self.dropout = nn.Dropout() | |
self.lrelu2 = nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
def forward(self, x): | |
residual = x | |
out = self.conv1(x) | |
out = self.lrelu1(out) | |
out = self.conv2(out) | |
# out = self.n2(out) | |
if self.use_dropout: | |
out = self.dropout(out) | |
if self.projection: | |
residual = self.projection(x) | |
out = out + residual | |
out = self.lrelu2(out) | |
return out | |
class Destyler(nn.Module): | |
def __init__(self, in_features, num_features): | |
super(Destyler, self).__init__() | |
self.fc1 = nn.Linear(in_features, num_features) | |
self.fc2 = nn.Linear(num_features, num_features) | |
self.fc3 = nn.Linear(num_features, num_features) | |
self.fc4 = nn.Linear(num_features, num_features) | |
self.fc5 = nn.Linear(num_features, num_features) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.fc2(x) | |
x = self.fc3(x) | |
x = self.fc4(x) | |
x = self.fc5(x) | |
return x | |