NoCrypt commited on
Commit
2c9c37b
·
1 Parent(s): 12c866f
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ myvenv
2
+ myvenv/**/*
3
+ __pycache__
4
+ flagged
5
+ *.pth
README.md CHANGED
@@ -1,12 +1 @@
1
- ---
2
- title: Pixelization
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 3.16.2
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ...
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,8 +1,76 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
8
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import functools
3
+ from pixelization import Model
4
+ import torch
5
+ import argparse
6
+ import huggingface_hub
7
+ import os
8
 
9
+ TOKEN = "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"
 
10
 
11
+ def parse_args() -> argparse.Namespace:
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--theme', type=str, default='default')
14
+ parser.add_argument('--live', action='store_true')
15
+ parser.add_argument('--share', action='store_true')
16
+ parser.add_argument('--port', type=int)
17
+ parser.add_argument('--disable-queue',
18
+ dest='enable_queue',
19
+ action='store_false')
20
+ parser.add_argument('--allow-flagging', type=str, default='never')
21
+ return parser.parse_args()
22
+
23
+ def main():
24
+ args = parse_args()
25
+
26
+
27
+ # DL MODEL
28
+ # PIX_MODEL
29
+ os.environ['PIX_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "pixelart_vgg19.pth", token=TOKEN);
30
+ # NET_MODEL
31
+ os.environ['NET_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "160_net_G_A.pth", token=TOKEN);
32
+ # ALIAS_MODEL
33
+ os.environ['ALIAS_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "alias_net.pth", token=TOKEN);
34
+
35
+ # # For local testing
36
+ # # PIX_MODEL
37
+ # os.environ['PIX_MODEL'] = "pixelart_vgg19.pth"
38
+ # # NET_MODEL
39
+ # os.environ['NET_MODEL'] = "160_net_G_A.pth"
40
+ # # ALIAS_MODEL
41
+ # os.environ['ALIAS_MODEL'] = "alias_net.pth"
42
+
43
+
44
+ use_cpu = True
45
+ m = Model(device = "cpu" if use_cpu else "cuda")
46
+ m.load()
47
+
48
+ # To use GPU: Change use_cpu to false, and checkout my comment on networks.py at line 107 & 108
49
+ # + Use torch with cuda support (Change in requirements.txt)
50
+
51
+ gr.Interface(m.pixelize_modified,
52
+ [
53
+ gr.components.Image(type='pil', label='Input'),
54
+ gr.components.Slider(minimum=1, maximum=16, value=4, step=1, label='Pixel Size'),
55
+ gr.components.Checkbox(True, label="Upscale after")
56
+ ],
57
+ gr.components.Image(type='pil', label='Output'),
58
+ title="Pixelization",
59
+ description='''
60
+ Demo for [WuZongWei6/Pixelization](https://github.com/WuZongWei6/Pixelization)
61
+
62
+ Models that are used is private to comply with License.
63
+
64
+
65
+ ''',
66
+ theme=args.theme,
67
+ allow_flagging=args.allow_flagging,
68
+ live=args.live,
69
+ ).launch(
70
+ enable_queue=args.enable_queue,
71
+ server_port=args.port,
72
+ share=args.share,
73
+ )
74
+
75
+ if __name__ == '__main__':
76
+ main()
models/__init__.py ADDED
File without changes
models/basic_layer.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ class ModulationConvBlock(nn.Module):
7
+ def __init__(self, input_dim, output_dim, kernel_size, stride=1,
8
+ padding=0, norm='none', activation='relu', pad_type='zero'):
9
+ super(ModulationConvBlock, self).__init__()
10
+ self.in_c = input_dim
11
+ self.out_c = output_dim
12
+ self.ksize = kernel_size
13
+ self.stride = 1
14
+ self.padding = kernel_size // 2
15
+
16
+ self.eps = 1e-8
17
+ weight_shape = (output_dim, input_dim, kernel_size, kernel_size)
18
+ fan_in = kernel_size * kernel_size *input_dim
19
+ wscale = 1.0/np.sqrt(fan_in)
20
+
21
+ self.weight = nn.Parameter(torch.randn(*weight_shape))
22
+ self.wscale = wscale
23
+
24
+ self.bias = nn.Parameter(torch.zeros(output_dim))
25
+
26
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+ self.activate_scale = np.sqrt(2.0)
28
+
29
+ def forward(self, x, code):
30
+ batch,in_channel,height,width = x.shape
31
+ weight = self.weight * self.wscale
32
+ _weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
33
+ _weight = _weight * code.view(batch, 1, 1, self.in_c, 1)
34
+ # demodulation
35
+ _weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
36
+ _weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
37
+ # fused_modulate
38
+ x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
39
+ weight = _weight.permute(1, 2, 3, 0, 4).reshape(
40
+ self.ksize, self.ksize, self.in_c, batch * self.out_c)
41
+ # not use_conv2d_transpose
42
+ weight = weight.permute(3, 2, 0, 1)
43
+ x = F.conv2d(x,
44
+ weight=weight,
45
+ bias=None,
46
+ stride=self.stride,
47
+ padding=self.padding,
48
+ groups=(batch if True else 1))
49
+
50
+ if True:#self.fused_modulate:
51
+ x = x.view(batch, self.out_c, height, width)
52
+ x = x+self.bias.view(1,-1,1,1)
53
+ x = self.activate(x)*self.activate_scale
54
+ return x
55
+
56
+
57
+ class AliasConvBlock(nn.Module):
58
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
59
+ padding=0, norm='none', activation='relu', pad_type='zero'):
60
+ super(AliasConvBlock, self).__init__()
61
+ self.use_bias = True
62
+ # initialize padding
63
+ if pad_type == 'reflect':
64
+ self.pad = nn.ReflectionPad2d(padding)
65
+ elif pad_type == 'replicate':
66
+ self.pad = nn.ReplicationPad2d(padding)
67
+ elif pad_type == 'zero':
68
+ self.pad = nn.ZeroPad2d(padding)
69
+ else:
70
+ assert 0, "Unsupported padding type: {}".format(pad_type)
71
+
72
+ # initialize normalization
73
+ norm_dim = output_dim
74
+ if norm == 'bn':
75
+ self.norm = nn.BatchNorm2d(norm_dim)
76
+ elif norm == 'in':
77
+ # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
78
+ self.norm = nn.InstanceNorm2d(norm_dim)
79
+ elif norm == 'ln':
80
+ self.norm = LayerNorm(norm_dim)
81
+ elif norm == 'adain':
82
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
83
+ elif norm == 'none' or norm == 'sn':
84
+ self.norm = None
85
+ else:
86
+ assert 0, "Unsupported normalization: {}".format(norm)
87
+
88
+ # initialize activation
89
+ if activation == 'relu':
90
+ self.activation = nn.ReLU(inplace=True)
91
+ elif activation == 'lrelu':
92
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
93
+ elif activation == 'prelu':
94
+ self.activation = nn.PReLU()
95
+ elif activation == 'selu':
96
+ self.activation = nn.SELU(inplace=True)
97
+ elif activation == 'tanh':
98
+ self.activation = nn.Tanh()
99
+ elif activation == 'none':
100
+ self.activation = None
101
+ else:
102
+ assert 0, "Unsupported activation: {}".format(activation)
103
+
104
+ # initialize convolution
105
+ if norm == 'sn':
106
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
107
+
108
+ else:
109
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
110
+
111
+ def forward(self, x):
112
+ x = self.conv(self.pad(x))
113
+ if self.norm:
114
+ x = self.norm(x)
115
+ if self.activation:
116
+ x = self.activation(x)
117
+ return x
118
+
119
+ class AliasResBlocks(nn.Module):
120
+ def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
121
+ super(AliasResBlocks, self).__init__()
122
+ self.model = []
123
+ for i in range(num_blocks):
124
+ self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
125
+ self.model = nn.Sequential(*self.model)
126
+
127
+ def forward(self, x):
128
+ return self.model(x)
129
+ class AliasResBlock(nn.Module):
130
+ def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
131
+ super(AliasResBlock, self).__init__()
132
+
133
+ model = []
134
+ model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
135
+ model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
136
+ self.model = nn.Sequential(*model)
137
+
138
+ def forward(self, x):
139
+ residual = x
140
+ out = self.model(x)
141
+ out += residual
142
+ return out
143
+ ##################################################################################
144
+ # Sequential Models
145
+ ##################################################################################
146
+ class ResBlocks(nn.Module):
147
+ def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
148
+ super(ResBlocks, self).__init__()
149
+ self.model = []
150
+ for i in range(num_blocks):
151
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
152
+ self.model = nn.Sequential(*self.model)
153
+
154
+ def forward(self, x):
155
+ return self.model(x)
156
+
157
+
158
+ class MLP(nn.Module):
159
+ def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
160
+ super(MLP, self).__init__()
161
+ self.model = []
162
+ self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)]
163
+ self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)]
164
+ for i in range(n_blk - 2):
165
+ self.model += [linearBlock(dim, dim, norm=norm, activation=activ)]
166
+ self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
167
+ self.model = nn.Sequential(*self.model)
168
+
169
+ # def forward(self, style0, style1, a=0):
170
+ # return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
171
+ # style1.view(style1.size(0), -1)))
172
+ def forward(self, style0, style1=None, a=0):
173
+ style1 = style0
174
+ return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
175
+ style1.view(style1.size(0), -1)))
176
+ ##################################################################################
177
+ # Basic Blocks
178
+ ##################################################################################
179
+ class ResBlock(nn.Module):
180
+ def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
181
+ super(ResBlock, self).__init__()
182
+
183
+ model = []
184
+ model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
185
+ model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
186
+ self.model = nn.Sequential(*model)
187
+
188
+ def forward(self, x):
189
+ residual = x
190
+ out = self.model(x)
191
+ out += residual
192
+ return out
193
+
194
+
195
+ class ConvBlock(nn.Module):
196
+ def __init__(self, input_dim, output_dim, kernel_size, stride,
197
+ padding=0, norm='none', activation='relu', pad_type='zero'):
198
+ super(ConvBlock, self).__init__()
199
+ self.use_bias = True
200
+ # initialize padding
201
+ if pad_type == 'reflect':
202
+ self.pad = nn.ReflectionPad2d(padding)
203
+ elif pad_type == 'replicate':
204
+ self.pad = nn.ReplicationPad2d(padding)
205
+ elif pad_type == 'zero':
206
+ self.pad = nn.ZeroPad2d(padding)
207
+ else:
208
+ assert 0, "Unsupported padding type: {}".format(pad_type)
209
+
210
+ # initialize normalization
211
+ norm_dim = output_dim
212
+ if norm == 'bn':
213
+ self.norm = nn.BatchNorm2d(norm_dim)
214
+ elif norm == 'in':
215
+ # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
216
+ self.norm = nn.InstanceNorm2d(norm_dim)
217
+ elif norm == 'ln':
218
+ self.norm = LayerNorm(norm_dim)
219
+ elif norm == 'adain':
220
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
221
+ elif norm == 'none' or norm == 'sn':
222
+ self.norm = None
223
+ else:
224
+ assert 0, "Unsupported normalization: {}".format(norm)
225
+
226
+ # initialize activation
227
+ if activation == 'relu':
228
+ self.activation = nn.ReLU(inplace=True)
229
+ elif activation == 'lrelu':
230
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
231
+ elif activation == 'prelu':
232
+ self.activation = nn.PReLU()
233
+ elif activation == 'selu':
234
+ self.activation = nn.SELU(inplace=True)
235
+ elif activation == 'tanh':
236
+ self.activation = nn.Tanh()
237
+ elif activation == 'none':
238
+ self.activation = None
239
+ else:
240
+ assert 0, "Unsupported activation: {}".format(activation)
241
+
242
+ # initialize convolution
243
+ if norm == 'sn':
244
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
245
+
246
+ else:
247
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
248
+
249
+ def forward(self, x):
250
+ x = self.conv(self.pad(x))
251
+ if self.norm:
252
+ x = self.norm(x)
253
+ if self.activation:
254
+ x = self.activation(x)
255
+ return x
256
+
257
+ class linearBlock(nn.Module):
258
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
259
+ super(linearBlock, self).__init__()
260
+ use_bias = True
261
+ # initialize fully connected layer
262
+ if norm == 'sn':
263
+ self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
264
+ else:
265
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
266
+
267
+ # initialize normalization
268
+ norm_dim = output_dim
269
+ if norm == 'bn':
270
+ self.norm = nn.BatchNorm1d(norm_dim)
271
+ elif norm == 'in':
272
+ self.norm = nn.InstanceNorm1d(norm_dim)
273
+ elif norm == 'ln':
274
+ self.norm = LayerNorm(norm_dim)
275
+ elif norm == 'none' or norm == 'sn':
276
+ self.norm = None
277
+ else:
278
+ assert 0, "Unsupported normalization: {}".format(norm)
279
+
280
+ # initialize activation
281
+ if activation == 'relu':
282
+ self.activation = nn.ReLU(inplace=True)
283
+ elif activation == 'lrelu':
284
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
285
+ elif activation == 'prelu':
286
+ self.activation = nn.PReLU()
287
+ elif activation == 'selu':
288
+ self.activation = nn.SELU(inplace=True)
289
+ elif activation == 'tanh':
290
+ self.activation = nn.Tanh()
291
+ elif activation == 'none':
292
+ self.activation = None
293
+ else:
294
+ assert 0, "Unsupported activation: {}".format(activation)
295
+
296
+ def forward(self, x):
297
+ out = self.fc(x)
298
+ if self.norm:
299
+ out = self.norm(out)
300
+ if self.activation:
301
+ out = self.activation(out)
302
+ return out
303
+ ##################################################################################
304
+ # Normalization layers
305
+ ##################################################################################
306
+ class AdaptiveInstanceNorm2d(nn.Module):
307
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
308
+ super(AdaptiveInstanceNorm2d, self).__init__()
309
+ self.num_features = num_features
310
+ self.eps = eps
311
+ self.momentum = momentum
312
+ # weight and bias are dynamically assigned
313
+ self.weight = None
314
+ self.bias = None
315
+ # just dummy buffers, not used
316
+ self.register_buffer('running_mean', torch.zeros(num_features))
317
+ self.register_buffer('running_var', torch.ones(num_features))
318
+
319
+ def forward(self, x):
320
+ assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
321
+ b, c = x.size(0), x.size(1)
322
+ running_mean = self.running_mean.repeat(b)
323
+ running_var = self.running_var.repeat(b)
324
+
325
+ # Apply instance norm
326
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
327
+
328
+ out = F.batch_norm(
329
+ x_reshaped, running_mean, running_var, self.weight, self.bias,
330
+ True, self.momentum, self.eps)
331
+
332
+ return out.view(b, c, *x.size()[2:])
333
+
334
+ def __repr__(self):
335
+ return self.__class__.__name__ + '(' + str(self.num_features) + ')'
336
+
337
+
338
+ class LayerNorm(nn.Module):
339
+ def __init__(self, num_features, eps=1e-5, affine=True):
340
+ super(LayerNorm, self).__init__()
341
+ self.num_features = num_features
342
+ self.affine = affine
343
+ self.eps = eps
344
+
345
+ if self.affine:
346
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
347
+ self.beta = nn.Parameter(torch.zeros(num_features))
348
+
349
+ def forward(self, x):
350
+ shape = [-1] + [1] * (x.dim() - 1)
351
+ # print(x.size())
352
+ if x.size(0) == 1:
353
+ # These two lines run much faster in pytorch 0.4 than the two lines listed below.
354
+ mean = x.view(-1).mean().view(*shape)
355
+ std = x.view(-1).std().view(*shape)
356
+ else:
357
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
358
+ std = x.view(x.size(0), -1).std(1).view(*shape)
359
+
360
+ x = (x - mean) / (std + self.eps)
361
+
362
+ if self.affine:
363
+ shape = [1, -1] + [1] * (x.dim() - 2)
364
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
365
+ return x
366
+
367
+
368
+ def l2normalize(v, eps=1e-12):
369
+ return v / (v.norm() + eps)
370
+
371
+
372
+ class SpectralNorm(nn.Module):
373
+ """
374
+ Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
375
+ and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
376
+ """
377
+
378
+ def __init__(self, module, name='weight', power_iterations=1):
379
+ super(SpectralNorm, self).__init__()
380
+ self.module = module
381
+ self.name = name
382
+ self.power_iterations = power_iterations
383
+ if not self._made_params():
384
+ self._make_params()
385
+
386
+ def _update_u_v(self):
387
+ u = getattr(self.module, self.name + "_u")
388
+ v = getattr(self.module, self.name + "_v")
389
+ w = getattr(self.module, self.name + "_bar")
390
+
391
+ height = w.data.shape[0]
392
+ for _ in range(self.power_iterations):
393
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
394
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
395
+
396
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
397
+ sigma = u.dot(w.view(height, -1).mv(v))
398
+ setattr(self.module, self.name, w / sigma.expand_as(w))
399
+
400
+ def _made_params(self):
401
+ try:
402
+ u = getattr(self.module, self.name + "_u")
403
+ v = getattr(self.module, self.name + "_v")
404
+ w = getattr(self.module, self.name + "_bar")
405
+ return True
406
+ except AttributeError:
407
+ return False
408
+
409
+ def _make_params(self):
410
+ w = getattr(self.module, self.name)
411
+
412
+ height = w.data.shape[0]
413
+ width = w.view(height, -1).data.shape[1]
414
+
415
+ u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
416
+ v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
417
+ u.data = l2normalize(u.data)
418
+ v.data = l2normalize(v.data)
419
+ w_bar = nn.Parameter(w.data)
420
+
421
+ del self.module._parameters[self.name]
422
+
423
+ self.module.register_parameter(self.name + "_u", u)
424
+ self.module.register_parameter(self.name + "_v", v)
425
+ self.module.register_parameter(self.name + "_bar", w_bar)
426
+
427
+ def forward(self, *args):
428
+ self._update_u_v()
429
+ return self.module.forward(*args)
models/c2pDis.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .basic_layer import *
2
+ import math
3
+ from torch.nn import Parameter
4
+ #from pytorch_metric_learning import losses
5
+
6
+ '''
7
+ Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch.
8
+ '''
9
+ def cosine_sim(x1, x2, dim=1, eps=1e-8):
10
+ ip = torch.mm(x1, x2.t()) # w 7*512
11
+ w1 = torch.norm(x1, 2, dim)
12
+ w2 = torch.norm(x2, 2, dim)
13
+ return ip / torch.ger(w1,w2).clamp(min=eps)
14
+
15
+ class MarginCosineProduct(nn.Module):
16
+ r"""Implement of large margin cosine distance: :
17
+ Args:
18
+ in_features: size of each input sample
19
+ out_features: size of each output sample
20
+ s: norm of input feature
21
+ m: margin
22
+ """
23
+
24
+ def __init__(self, in_features, out_features, s=30.0, m=0.40):
25
+ super(MarginCosineProduct, self).__init__()
26
+ self.in_features = in_features
27
+ self.out_features = out_features
28
+ self.s = s
29
+ self.m = m
30
+ self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512
31
+ nn.init.xavier_uniform_(self.weight)
32
+ #stdv = 1. / math.sqrt(self.weight.size(1))
33
+ #self.weight.data.uniform_(-stdv, stdv)
34
+
35
+ def forward(self, input, label):
36
+ cosine = cosine_sim(input, self.weight) # 1*512 7*512
37
+ # cosine = F.linear(F.normalize(input), F.normalize(self.weight))
38
+ # --------------------------- convert label to one-hot ---------------------------
39
+ # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
40
+ one_hot = torch.zeros_like(cosine)
41
+ one_hot.scatter_(1, label.view(-1, 1), 1.0)
42
+ # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
43
+ output = self.s * (cosine - one_hot * self.m)
44
+
45
+ return output
46
+
47
+ def __repr__(self):
48
+ return self.__class__.__name__ + '(' \
49
+ + 'in_features=' + str(self.in_features) \
50
+ + ', out_features=' + str(self.out_features) \
51
+ + ', s=' + str(self.s) \
52
+ + ', m=' + str(self.m) + ')'
53
+
54
+ class ArcMarginProduct(nn.Module):
55
+ def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False):
56
+ super(ArcMarginProduct, self).__init__()
57
+ self.in_feature = in_feature
58
+ self.out_feature = out_feature
59
+ self.s = s
60
+ self.m = m
61
+ self.weight = Parameter(torch.Tensor(out_feature, in_feature))
62
+ nn.init.xavier_uniform_(self.weight)
63
+
64
+ self.easy_margin = easy_margin
65
+ self.cos_m = math.cos(m)
66
+ self.sin_m = math.sin(m)
67
+
68
+ # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
69
+ self.th = math.cos(math.pi - m)
70
+ self.mm = math.sin(math.pi - m) * m
71
+
72
+ def forward(self, x, label):
73
+ # cos(theta)
74
+ cosine = F.linear(F.normalize(x), F.normalize(self.weight))
75
+ # cos(theta + m)
76
+ sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
77
+ phi = cosine * self.cos_m - sine * self.sin_m
78
+
79
+ if self.easy_margin:
80
+ phi = torch.where(cosine > 0, phi, cosine)
81
+ else:
82
+ phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
83
+
84
+ #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
85
+ one_hot = torch.zeros_like(cosine)
86
+ one_hot.scatter_(1, label.view(-1, 1), 1)
87
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
88
+ output = output * self.s
89
+
90
+ return output
91
+
92
+
93
+ class MultiMarginProduct(nn.Module):
94
+ def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False):
95
+ super(MultiMarginProduct, self).__init__()
96
+ self.in_feature = in_feature
97
+ self.out_feature = out_feature
98
+ self.s = s
99
+ self.m1 = m1
100
+ self.m2 = m2
101
+ self.weight = Parameter(torch.Tensor(out_feature, in_feature))
102
+ nn.init.xavier_uniform_(self.weight)
103
+
104
+ self.easy_margin = easy_margin
105
+ self.cos_m1 = math.cos(m1)
106
+ self.sin_m1 = math.sin(m1)
107
+
108
+ # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
109
+ self.th = math.cos(math.pi - m1)
110
+ self.mm = math.sin(math.pi - m1) * m1
111
+
112
+ def forward(self, x, label):
113
+ # cos(theta)
114
+ cosine = F.linear(F.normalize(x), F.normalize(self.weight))
115
+ # cos(theta + m1)
116
+ sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
117
+ phi = cosine * self.cos_m1 - sine * self.sin_m1
118
+
119
+ if self.easy_margin:
120
+ phi = torch.where(cosine > 0, phi, cosine)
121
+ else:
122
+ phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
123
+
124
+
125
+ one_hot = torch.zeros_like(cosine)
126
+ one_hot.scatter_(1, label.view(-1, 1), 1)
127
+ output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin
128
+ output = output - one_hot * self.m2 # additive cosine margin
129
+ output = output * self.s
130
+
131
+ return output
132
+
133
+
134
+ class CPDis(nn.Module):
135
+ """PatchGAN."""
136
+ def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
137
+ super(CPDis, self).__init__()
138
+
139
+ layers = []
140
+ if norm == 'SN':
141
+ layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
142
+ else:
143
+ layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
144
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
145
+
146
+ curr_dim = conv_dim
147
+ for i in range(1, repeat_num):
148
+ if norm == 'SN':
149
+ layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
150
+ else:
151
+ layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
152
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
153
+ curr_dim = curr_dim * 2
154
+
155
+ # k_size = int(image_size / np.power(2, repeat_num))
156
+ if norm == 'SN':
157
+ layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
158
+ else:
159
+ layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
160
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
161
+ curr_dim = curr_dim * 2
162
+
163
+ self.main = nn.Sequential(*layers)
164
+ if norm == 'SN':
165
+ self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
166
+ else:
167
+ self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
168
+
169
+ def forward(self, x):
170
+ if x.ndim == 5:
171
+ x = x.squeeze(0)
172
+ assert x.ndim == 4, x.ndim
173
+ h = self.main(x)
174
+ # out_real = self.conv1(h)
175
+ out_makeup = self.conv1(h)
176
+ # return out_real.squeeze(), out_makeup.squeeze()
177
+ return out_makeup
178
+
179
+
180
+ class CPDis_cls(nn.Module):
181
+ """PatchGAN."""
182
+ def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
183
+ super(CPDis_cls, self).__init__()
184
+
185
+ layers = []
186
+ if norm == 'SN':
187
+ layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
188
+ else:
189
+ layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
190
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
191
+
192
+ curr_dim = conv_dim
193
+ for i in range(1, repeat_num):
194
+ if norm == 'SN':
195
+ layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
196
+ else:
197
+ layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
198
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
199
+ curr_dim = curr_dim * 2
200
+
201
+ # k_size = int(image_size / np.power(2, repeat_num))
202
+ if norm == 'SN':
203
+ layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
204
+ else:
205
+ layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
206
+ layers.append(nn.LeakyReLU(0.01, inplace=True))
207
+ curr_dim = curr_dim * 2
208
+
209
+ self.main = nn.Sequential(*layers)
210
+ if norm == 'SN':
211
+ self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
212
+ self.classifier_pool = nn.AdaptiveAvgPool2d(1)
213
+ self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0)
214
+ self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7)
215
+ print("Using Large Margin Cosine Loss.")
216
+
217
+ else:
218
+ self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
219
+
220
+ def forward(self, x, label):
221
+ if x.ndim == 5:
222
+ x = x.squeeze(0)
223
+ assert x.ndim == 4, x.ndim
224
+ h = self.main(x) # ([1, 512, 31, 31])
225
+ #print(out_cls.shape)
226
+ out_cls = self.classifier_pool(h)
227
+ #print(out_cls.shape)
228
+ out_cls = self.classifier_conv(out_cls)
229
+ #print(out_cls.shape)
230
+ out_cls = torch.squeeze(out_cls, -1)
231
+ out_cls = torch.squeeze(out_cls, -1)
232
+ out_cls = self.classifier(out_cls, label)
233
+ out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30])
234
+ # return out_real.squeeze(), out_makeup.squeeze()
235
+ return out_makeup, out_cls
236
+
237
+ class SpectralNorm(object):
238
+ def __init__(self):
239
+ self.name = "weight"
240
+ # print(self.name)
241
+ self.power_iterations = 1
242
+
243
+ def compute_weight(self, module):
244
+ u = getattr(module, self.name + "_u")
245
+ v = getattr(module, self.name + "_v")
246
+ w = getattr(module, self.name + "_bar")
247
+
248
+ height = w.data.shape[0]
249
+ for _ in range(self.power_iterations):
250
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
251
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
252
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
253
+ sigma = u.dot(w.view(height, -1).mv(v))
254
+ return w / sigma.expand_as(w)
255
+
256
+ @staticmethod
257
+ def apply(module):
258
+ name = "weight"
259
+ fn = SpectralNorm()
260
+
261
+ try:
262
+ u = getattr(module, name + "_u")
263
+ v = getattr(module, name + "_v")
264
+ w = getattr(module, name + "_bar")
265
+ except AttributeError:
266
+ w = getattr(module, name)
267
+ height = w.data.shape[0]
268
+ width = w.view(height, -1).data.shape[1]
269
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
270
+ v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
271
+ w_bar = Parameter(w.data)
272
+
273
+ # del module._parameters[name]
274
+
275
+ module.register_parameter(name + "_u", u)
276
+ module.register_parameter(name + "_v", v)
277
+ module.register_parameter(name + "_bar", w_bar)
278
+
279
+ # remove w from parameter list
280
+ del module._parameters[name]
281
+
282
+ setattr(module, name, fn.compute_weight(module))
283
+
284
+ # recompute weight before every forward()
285
+ module.register_forward_pre_hook(fn)
286
+
287
+ return fn
288
+
289
+ def remove(self, module):
290
+ weight = self.compute_weight(module)
291
+ delattr(module, self.name)
292
+ del module._parameters[self.name + '_u']
293
+ del module._parameters[self.name + '_v']
294
+ del module._parameters[self.name + '_bar']
295
+ module.register_parameter(self.name, Parameter(weight.data))
296
+
297
+ def __call__(self, module, inputs):
298
+ setattr(module, self.name, self.compute_weight(module))
299
+
300
+ def spectral_norm(module):
301
+ SpectralNorm.apply(module)
302
+ return module
303
+
304
+ def remove_spectral_norm(module):
305
+ name = 'weight'
306
+ for k, hook in module._forward_pre_hooks.items():
307
+ if isinstance(hook, SpectralNorm) and hook.name == name:
308
+ hook.remove(module)
309
+ del module._forward_pre_hooks[k]
310
+ return module
311
+
312
+ raise ValueError("spectral_norm of '{}' not found in {}"
313
+ .format(name, module))
models/c2pGen.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .basic_layer import *
2
+ import torchvision.models as models
3
+ import os
4
+
5
+
6
+
7
+ class AliasNet(nn.Module):
8
+ def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
9
+ super(AliasNet, self).__init__()
10
+ self.RGBEnc = AliasRGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
11
+ self.RGBDec = AliasRGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
12
+ activ=activ, pad_type=pad_type)
13
+
14
+ def forward(self, x):
15
+ x = self.RGBEnc(x)
16
+ x = self.RGBDec(x)
17
+ return x
18
+
19
+
20
+ class AliasRGBEncoder(nn.Module):
21
+ def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
22
+ super(AliasRGBEncoder, self).__init__()
23
+ self.model = []
24
+ self.model += [AliasConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
25
+ # downsampling blocks
26
+ for i in range(n_downsample):
27
+ self.model += [AliasConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
28
+ dim *= 2
29
+ # residual blocks
30
+ self.model += [AliasResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
31
+ self.model = nn.Sequential(*self.model)
32
+ self.output_dim = dim
33
+
34
+ def forward(self, x):
35
+ return self.model(x)
36
+
37
+
38
+ class AliasRGBDecoder(nn.Module):
39
+ def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
40
+ super(AliasRGBDecoder, self).__init__()
41
+ # self.model = []
42
+ # # AdaIN residual blocks
43
+ # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
44
+ # # upsampling blocks
45
+ # for i in range(n_upsample):
46
+ # self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
47
+ # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
48
+ # dim //= 2
49
+ # # use reflection padding in the last conv layer
50
+ # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
51
+ # self.model = nn.Sequential(*self.model)
52
+ self.Res_Blocks = AliasResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
53
+ self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
54
+ self.conv_1 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
55
+ dim //= 2
56
+ self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
57
+ self.conv_2 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
58
+ dim //= 2
59
+ self.conv_3 = AliasConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
60
+
61
+ def forward(self, x):
62
+ x = self.Res_Blocks(x)
63
+ # print(x.shape)
64
+ x = self.upsample_block1(x)
65
+ # print(x.shape)
66
+ x = self.conv_1(x)
67
+ # print(x_small.shape)
68
+ x = self.upsample_block2(x)
69
+ # print(x.shape)
70
+ x = self.conv_2(x)
71
+ # print(x_middle.shape)
72
+ x = self.conv_3(x)
73
+ # print(x_big.shape)
74
+ return x
75
+
76
+
77
+ class C2PGen(nn.Module):
78
+ def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, style_dim, mlp_dim, activ='relu', pad_type='reflect'):
79
+ super(C2PGen, self).__init__()
80
+ self.PBEnc = PixelBlockEncoder(input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
81
+ self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
82
+ self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='adain',
83
+ activ=activ, pad_type=pad_type)
84
+ self.MLP = MLP(style_dim, 2048, mlp_dim, 3, norm='none', activ=activ)
85
+
86
+ def forward(self, clipart, pixelart, s=1):
87
+ feature = self.RGBEnc(clipart)
88
+ code = self.PBEnc(pixelart)
89
+ result, cellcode = self.fuse(feature, code, s)
90
+ return result#, cellcode #return cellcode when visualizing the cell size code
91
+
92
+ def fuse(self, content, style_code, s=1):
93
+ #print("MLP input:code's shape:", style_code.shape)
94
+ adain_params = self.MLP(style_code) * s # [batch,2048]
95
+ #print("MLP output:adain_params's shape", adain_params.shape)
96
+ #self.assign_adain_params(adain_params, self.RGBDec)
97
+ images = self.RGBDec(content, adain_params)
98
+ return images, adain_params
99
+
100
+ def assign_adain_params(self, adain_params, model):
101
+ # assign the adain_params to the AdaIN layers in model
102
+ for m in model.modules():
103
+ if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
104
+ mean = adain_params[:, :m.num_features]
105
+ std = adain_params[:, m.num_features:2 * m.num_features]
106
+ m.bias = mean.contiguous().view(-1)
107
+ m.weight = std.contiguous().view(-1)
108
+ if adain_params.size(1) > 2 * m.num_features:
109
+ adain_params = adain_params[:, 2 * m.num_features:]
110
+
111
+ def get_num_adain_params(self, model):
112
+ # return the number of AdaIN parameters needed by the model
113
+ num_adain_params = 0
114
+ for m in model.modules():
115
+ if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
116
+ num_adain_params += 2 * m.num_features
117
+ return num_adain_params
118
+
119
+
120
+ class PixelBlockEncoder(nn.Module):
121
+ def __init__(self, input_dim, dim, style_dim, norm, activ, pad_type):
122
+ super(PixelBlockEncoder, self).__init__()
123
+ vgg19 = models.vgg.vgg19()
124
+ vgg19.classifier._modules['6'] = nn.Linear(4096, 7, bias=True)
125
+ vgg19.load_state_dict(torch.load('./pixelart_vgg19.pth' if not os.environ['PIX_MODEL'] else os.environ['PIX_MODEL'], map_location=torch.device('cpu')))
126
+ self.vgg = vgg19.features
127
+ for p in self.vgg.parameters():
128
+ p.requires_grad = False
129
+ # vgg19 = models.vgg.vgg19(pretrained=False)
130
+ # vgg19.load_state_dict(torch.load('./vgg.pth'))
131
+ # self.vgg = vgg19.features
132
+ # for p in self.vgg.parameters():
133
+ # p.requires_grad = False
134
+
135
+
136
+ self.conv1 = ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64,concat
137
+ dim = dim * 2
138
+ self.conv2 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128
139
+ dim = dim * 2
140
+ self.conv3 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256
141
+ dim = dim * 2
142
+ self.conv4 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512
143
+ dim = dim * 2
144
+
145
+ self.model = []
146
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
147
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
148
+ self.model = nn.Sequential(*self.model)
149
+ self.output_dim = dim
150
+
151
+ def get_features(self, image, model, layers=None):
152
+ if layers is None:
153
+ layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'}
154
+ features = {}
155
+ x = image
156
+ # model._modules is a dictionary holding each module in the model
157
+ for name, layer in model._modules.items():
158
+ x = layer(x)
159
+ if name in layers:
160
+ features[layers[name]] = x
161
+ return features
162
+
163
+ def componet_enc(self, x):
164
+ # x [16,3,256,256]
165
+ # factor_img [16,7,256,256]
166
+ vgg_aux = self.get_features(x, self.vgg) # x是3通道灰度图
167
+ #x = torch.cat([x, factor_img], dim=1) # [16,3+7,256,256]
168
+ x = self.conv1(x) # 64 256 256
169
+ x = torch.cat([x, vgg_aux['conv1_1']], dim=1) # 128 256 256
170
+ x = self.conv2(x) # 128 128 128
171
+ x = torch.cat([x, vgg_aux['conv2_1']], dim=1) # 256 128 128
172
+ x = self.conv3(x) # 256 64 64
173
+ x = torch.cat([x, vgg_aux['conv3_1']], dim=1) # 512 64 64
174
+ x = self.conv4(x) # 512 32 32
175
+ x = torch.cat([x, vgg_aux['conv4_1']], dim=1) # 1024 32 32
176
+ x = self.model(x)
177
+ return x
178
+
179
+ def forward(self, x):
180
+ code = self.componet_enc(x)
181
+ return code
182
+
183
+ class RGBEncoder(nn.Module):
184
+ def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
185
+ super(RGBEncoder, self).__init__()
186
+ self.model = []
187
+ self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
188
+ # downsampling blocks
189
+ for i in range(n_downsample):
190
+ self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
191
+ dim *= 2
192
+ # residual blocks
193
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
194
+ self.model = nn.Sequential(*self.model)
195
+ self.output_dim = dim
196
+
197
+ def forward(self, x):
198
+ return self.model(x)
199
+
200
+
201
+ class RGBDecoder(nn.Module):
202
+ def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
203
+ super(RGBDecoder, self).__init__()
204
+ # self.model = []
205
+ # # AdaIN residual blocks
206
+ # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
207
+ # # upsampling blocks
208
+ # for i in range(n_upsample):
209
+ # self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
210
+ # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
211
+ # dim //= 2
212
+ # # use reflection padding in the last conv layer
213
+ # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
214
+ # self.model = nn.Sequential(*self.model)
215
+ #self.Res_Blocks = ModulationResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
216
+ self.mod_conv_1 = ModulationConvBlock(256,256,3)
217
+ self.mod_conv_2 = ModulationConvBlock(256,256,3)
218
+ self.mod_conv_3 = ModulationConvBlock(256,256,3)
219
+ self.mod_conv_4 = ModulationConvBlock(256,256,3)
220
+ self.mod_conv_5 = ModulationConvBlock(256,256,3)
221
+ self.mod_conv_6 = ModulationConvBlock(256,256,3)
222
+ self.mod_conv_7 = ModulationConvBlock(256,256,3)
223
+ self.mod_conv_8 = ModulationConvBlock(256,256,3)
224
+ self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
225
+ self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
226
+ dim //= 2
227
+ self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
228
+ self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
229
+ dim //= 2
230
+ self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
231
+
232
+ # def forward(self, x):
233
+ # residual = x
234
+ # out = self.model(x)
235
+ # out += residual
236
+ # return out
237
+ def forward(self, x, code):
238
+ residual = x
239
+ x = self.mod_conv_1(x, code[:, :256])
240
+ x = self.mod_conv_2(x, code[:, 256*1:256*2])
241
+ x += residual
242
+ residual = x
243
+ x = self.mod_conv_2(x, code[:, 256*2:256 * 3])
244
+ x = self.mod_conv_2(x, code[:, 256*3:256 * 4])
245
+ x += residual
246
+ residual =x
247
+ x = self.mod_conv_2(x, code[:, 256*4:256 * 5])
248
+ x = self.mod_conv_2(x, code[:, 256*5:256 * 6])
249
+ x += residual
250
+ residual = x
251
+ x = self.mod_conv_2(x, code[:, 256*6:256 * 7])
252
+ x = self.mod_conv_2(x, code[:, 256*7:256 * 8])
253
+ x += residual
254
+ # print(x.shape)
255
+ x = self.upsample_block1(x)
256
+ # print(x.shape)
257
+ x = self.conv_1(x)
258
+ # print(x_small.shape)
259
+ x = self.upsample_block2(x)
260
+ # print(x.shape)
261
+ x = self.conv_2(x)
262
+ # print(x_middle.shape)
263
+ x = self.conv_3(x)
264
+ # print(x_big.shape)
265
+ return x
266
+
models/networks.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import init
4
+ import functools
5
+ from torch.optim import lr_scheduler
6
+ from .c2pGen import *
7
+ from .p2cGen import *
8
+ from .c2pDis import *
9
+
10
+ class Identity(nn.Module):
11
+ def forward(self, x):
12
+ return x
13
+
14
+ def get_norm_layer(norm_type='instance'):
15
+ """Return a normalization layer
16
+
17
+ Parameters:
18
+ norm_type (str) -- the name of the normalization layer: batch | instance | none
19
+
20
+ For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
21
+ For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
22
+ """
23
+ if norm_type == 'batch':
24
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
25
+ elif norm_type == 'instance':
26
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
27
+ elif norm_type == 'none':
28
+ def norm_layer(x): return Identity()
29
+ else:
30
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
31
+ return norm_layer
32
+
33
+
34
+ def get_scheduler(optimizer, opt):
35
+ """Return a learning rate scheduler
36
+
37
+ Parameters:
38
+ optimizer -- the optimizer of the network
39
+ opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
40
+ opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
41
+
42
+ For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
43
+ and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
44
+ For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
45
+ See https://pytorch.org/docs/stable/optim.html for more details.
46
+ """
47
+ if opt.lr_policy == 'linear':
48
+ def lambda_rule(epoch):
49
+ lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
50
+ return lr_l
51
+ scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
52
+ elif opt.lr_policy == 'step':
53
+ scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
54
+ elif opt.lr_policy == 'plateau':
55
+ scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
56
+ elif opt.lr_policy == 'cosine':
57
+ scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
58
+ else:
59
+ return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
60
+ return scheduler
61
+
62
+
63
+ def init_weights(net, init_type='normal', init_gain=0.02):
64
+ """Initialize network weights.
65
+
66
+ Parameters:
67
+ net (network) -- network to be initialized
68
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
69
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
70
+
71
+ """
72
+ def init_func(m): # define the initialization function
73
+ classname = m.__class__.__name__
74
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
75
+ if init_type == 'normal':
76
+ init.normal_(m.weight.data, 0.0, init_gain)
77
+ elif init_type == 'xavier':
78
+ init.xavier_normal_(m.weight.data, gain=init_gain)
79
+ elif init_type == 'kaiming':
80
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
81
+ elif init_type == 'orthogonal':
82
+ init.orthogonal_(m.weight.data, gain=init_gain)
83
+ else:
84
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
85
+ if hasattr(m, 'bias') and m.bias is not None:
86
+ init.constant_(m.bias.data, 0.0)
87
+ elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
88
+ init.normal_(m.weight.data, 1.0, init_gain)
89
+ init.constant_(m.bias.data, 0.0)
90
+
91
+ #print('initialize network with %s' % init_type)
92
+ net.apply(init_func) # apply the initialization function <init_func>
93
+
94
+
95
+ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
96
+ """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
97
+ Parameters:
98
+ net (network) -- the network to be initialized
99
+ init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
100
+ gain (float) -- scaling factor for normal, xavier and orthogonal.
101
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
102
+
103
+ Return an initialized network.
104
+ """
105
+ gpu_ids = [0]
106
+ if len(gpu_ids) > 0:
107
+ # assert(torch.cuda.is_available()) #uncomment this for using gpu
108
+ net.to(torch.device("cpu")) #change this for using gpu to gpu_ids[0]
109
+ net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
110
+ init_weights(net, init_type, init_gain=init_gain)
111
+ return net
112
+
113
+
114
+ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
115
+ """Create a generator
116
+
117
+ Parameters:
118
+ input_nc (int) -- the number of channels in input images
119
+ output_nc (int) -- the number of channels in output images
120
+ ngf (int) -- the number of filters in the last conv layer
121
+ netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
122
+ norm (str) -- the name of normalization layers used in the network: batch | instance | none
123
+ use_dropout (bool) -- if use dropout layers.
124
+ init_type (str) -- the name of our initialization method.
125
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
126
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
127
+
128
+ Returns a generator
129
+ """
130
+ net = None
131
+ norm_layer = get_norm_layer(norm_type=norm)
132
+
133
+ if netG == 'c2pGen': # style_dim mlp_dim
134
+ net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect')
135
+ #print('c2pgen resblock is 8')
136
+ elif netG == 'p2cGen':
137
+ net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
138
+ elif netG == 'antialias':
139
+ net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
140
+ else:
141
+ raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
142
+ return init_net(net, init_type, init_gain, gpu_ids)
143
+
144
+
145
+
146
+ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
147
+ """Create a discriminator
148
+
149
+ Parameters:
150
+ input_nc (int) -- the number of channels in input images
151
+ ndf (int) -- the number of filters in the first conv layer
152
+ netD (str) -- the architecture's name: basic | n_layers | pixel
153
+ n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
154
+ norm (str) -- the type of normalization layers used in the network.
155
+ init_type (str) -- the name of the initialization method.
156
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
157
+ gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
158
+
159
+ Returns a discriminator
160
+ """
161
+ net = None
162
+ norm_layer = get_norm_layer(norm_type=norm)
163
+
164
+
165
+ if netD == 'CPDis':
166
+ net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
167
+ elif netD == 'CPDis_cls':
168
+ net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
169
+ else:
170
+ raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
171
+ return init_net(net, init_type, init_gain, gpu_ids)
172
+
173
+
174
+ class GANLoss(nn.Module):
175
+ """Define different GAN objectives.
176
+
177
+ The GANLoss class abstracts away the need to create the target label tensor
178
+ that has the same size as the input.
179
+ """
180
+
181
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
182
+ """ Initialize the GANLoss class.
183
+
184
+ Parameters:
185
+ gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
186
+ target_real_label (bool) - - label for a real image
187
+ target_fake_label (bool) - - label of a fake image
188
+
189
+ Note: Do not use sigmoid as the last layer of Discriminator.
190
+ LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
191
+ """
192
+ super(GANLoss, self).__init__()
193
+ self.register_buffer('real_label', torch.tensor(target_real_label))
194
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
195
+ self.gan_mode = gan_mode
196
+ if gan_mode == 'lsgan':
197
+ self.loss = nn.MSELoss()
198
+ elif gan_mode == 'vanilla':
199
+ self.loss = nn.BCEWithLogitsLoss()
200
+ elif gan_mode in ['wgangp']:
201
+ self.loss = None
202
+ else:
203
+ raise NotImplementedError('gan mode %s not implemented' % gan_mode)
204
+
205
+ def get_target_tensor(self, prediction, target_is_real):
206
+ """Create label tensors with the same size as the input.
207
+
208
+ Parameters:
209
+ prediction (tensor) - - tpyically the prediction from a discriminator
210
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
211
+
212
+ Returns:
213
+ A label tensor filled with ground truth label, and with the size of the input
214
+ """
215
+
216
+ if target_is_real:
217
+ target_tensor = self.real_label
218
+ else:
219
+ target_tensor = self.fake_label
220
+ return target_tensor.expand_as(prediction)
221
+
222
+ def __call__(self, prediction, target_is_real):
223
+ """Calculate loss given Discriminator's output and grount truth labels.
224
+
225
+ Parameters:
226
+ prediction (tensor) - - tpyically the prediction output from a discriminator
227
+ target_is_real (bool) - - if the ground truth label is for real images or fake images
228
+
229
+ Returns:
230
+ the calculated loss.
231
+ """
232
+ if self.gan_mode in ['lsgan', 'vanilla']:
233
+ target_tensor = self.get_target_tensor(prediction, target_is_real)
234
+ loss = self.loss(prediction, target_tensor)
235
+ elif self.gan_mode == 'wgangp':
236
+ if target_is_real:
237
+ loss = -prediction.mean()
238
+ else:
239
+ loss = prediction.mean()
240
+ return loss
241
+
242
+
243
+
244
+
models/p2cGen.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .basic_layer import *
2
+
3
+
4
+ class P2CGen(nn.Module):
5
+ def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
6
+ super(P2CGen, self).__init__()
7
+ self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
8
+ self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
9
+ activ=activ, pad_type=pad_type)
10
+
11
+ def forward(self, x):
12
+ x = self.RGBEnc(x)
13
+ # print("encoder->>", x.shape)
14
+ x = self.RGBDec(x)
15
+ # print(x_small.shape)
16
+ # print(x_middle.shape)
17
+ # print(x_big.shape)
18
+ #return y_small, y_middle, y_big
19
+ return x
20
+
21
+
22
+ class RGBEncoder(nn.Module):
23
+ def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
24
+ super(RGBEncoder, self).__init__()
25
+ self.model = []
26
+ self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
27
+ # downsampling blocks
28
+ for i in range(n_downsample):
29
+ self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
30
+ dim *= 2
31
+ # residual blocks
32
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
33
+ self.model = nn.Sequential(*self.model)
34
+ self.output_dim = dim
35
+
36
+ def forward(self, x):
37
+ return self.model(x)
38
+
39
+
40
+ class RGBDecoder(nn.Module):
41
+ def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
42
+ super(RGBDecoder, self).__init__()
43
+ # self.model = []
44
+ # # AdaIN residual blocks
45
+ # self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
46
+ # # upsampling blocks
47
+ # for i in range(n_upsample):
48
+ # self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
49
+ # ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
50
+ # dim //= 2
51
+ # # use reflection padding in the last conv layer
52
+ # self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
53
+ # self.model = nn.Sequential(*self.model)
54
+ self.Res_Blocks = ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
55
+ self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
56
+ self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
57
+ dim //= 2
58
+ self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
59
+ self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
60
+ dim //= 2
61
+ self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
62
+
63
+ def forward(self, x):
64
+ x = self.Res_Blocks(x)
65
+ # print(x.shape)
66
+ x = self.upsample_block1(x)
67
+ # print(x.shape)
68
+ x = self.conv_1(x)
69
+ # print(x_small.shape)
70
+ x = self.upsample_block2(x)
71
+ # print(x.shape)
72
+ x = self.conv_2(x)
73
+ # print(x_middle.shape)
74
+ x = self.conv_3(x)
75
+ # print(x_big.shape)
76
+ return x
pixelization.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import numpy as np
6
+ from models.networks import define_G
7
+ import glob
8
+
9
+
10
+ class Model():
11
+ def __init__(self, device="cpu"):
12
+ self.device = torch.device(device)
13
+ self.G_A_net = None
14
+ self.alias_net = None
15
+ self.ref_t = None
16
+
17
+ def load(self):
18
+ with torch.no_grad():
19
+ self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
20
+ self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0])
21
+
22
+ G_A_state = torch.load("160_net_G_A.pth" if not os.environ['NET_MODEL'] else os.environ['NET_MODEL'], map_location=str(self.device))
23
+ for p in list(G_A_state.keys()):
24
+ G_A_state["module."+str(p)] = G_A_state.pop(p)
25
+ self.G_A_net.load_state_dict(G_A_state)
26
+
27
+ alias_state = torch.load("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL'], map_location=str(self.device))
28
+ for p in list(alias_state.keys()):
29
+ alias_state["module."+str(p)] = alias_state.pop(p)
30
+ self.alias_net.load_state_dict(alias_state)
31
+
32
+ ref_img = Image.open("reference.png").convert('L')
33
+ self.ref_t = process(greyscale(ref_img)).to(self.device)
34
+
35
+ def pixelize(self, in_img, out_img):
36
+ with torch.no_grad():
37
+ in_img = Image.open(in_img).convert('RGB')
38
+ in_t = process(in_img).to(self.device)
39
+
40
+ out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
41
+
42
+ save(out_t, out_img)
43
+
44
+ def pixelize_modified(self, in_img, pixel_size, upscale_after) -> Image.Image:
45
+ with torch.no_grad():
46
+ in_img = in_img.convert('RGB')
47
+
48
+ # limit in_img size to 1024x1024 so it didn't destroyed by large image
49
+ if in_img.size[0] > 1024 or in_img.size[1] > 1024:
50
+ in_img.thumbnail((1024, 1024), Image.NEAREST)
51
+
52
+ in_img.resize((in_img.size[0] * 4 // pixel_size, in_img.size[1] * 4 // pixel_size))
53
+
54
+ in_t = process(in_img).to(self.device)
55
+
56
+ out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
57
+ img = to_image(out_t, pixel_size, upscale_after)
58
+ return img
59
+
60
+ def to_image(tensor, pixel_size, upscale_after):
61
+ img = tensor.data[0].cpu().float().numpy()
62
+ img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
63
+ img = img.astype(np.uint8)
64
+ img = Image.fromarray(img)
65
+ img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
66
+ if upscale_after:
67
+ img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)
68
+
69
+ return img
70
+
71
+
72
+ def greyscale(img):
73
+ gray = np.array(img.convert('L'))
74
+ tmp = np.expand_dims(gray, axis=2)
75
+ tmp = np.concatenate((tmp, tmp, tmp), axis=-1)
76
+ return Image.fromarray(tmp)
77
+
78
+ def process(img):
79
+ ow,oh = img.size
80
+
81
+ nw = int(round(ow / 4) * 4)
82
+ nh = int(round(oh / 4) * 4)
83
+
84
+ left = (ow - nw)//2
85
+ top = (oh - nh)//2
86
+ right = left + nw
87
+ bottom = top + nh
88
+
89
+ img = img.crop((left, top, right, bottom))
90
+
91
+ trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
92
+
93
+ return trans(img)[None, :, :, :]
94
+
95
+ def save(tensor, file):
96
+ img = tensor.data[0].cpu().float().numpy()
97
+ img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
98
+ img = img.astype(np.uint8)
99
+ img = Image.fromarray(img)
100
+ img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
101
+ img = img.resize((img.size[0]*4, img.size[1]*4), resample=Image.Resampling.NEAREST)
102
+ img.save(file)
103
+
104
+ def pixelize_cli():
105
+ import argparse
106
+ import os
107
+ parser = argparse.ArgumentParser(description='Pixelization')
108
+ parser.add_argument('--input', type=str, default=None, required=True, help='path to image or directory')
109
+ parser.add_argument('--output', type=str, default=None, required=False, help='path to save image/images')
110
+ parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU')
111
+
112
+ args = parser.parse_args()
113
+ in_path = args.input
114
+ out_path = args.output
115
+ use_cpu = args.cpu
116
+
117
+ if not os.path.exists("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL']):
118
+ print("missing models")
119
+
120
+ pairs = []
121
+
122
+ if os.path.isdir(in_path):
123
+ in_images = glob.glob(in_path + "/*.png") + glob.glob(in_path + "/*.jpg")
124
+ if not out_path:
125
+ out_path = os.path.join(in_path, "outputs")
126
+ if not os.path.exists(out_path):
127
+ os.makedirs(out_path)
128
+ elif os.path.isfile(out_path):
129
+ print("output cant be a file if input is a directory")
130
+ return
131
+ for i in in_images:
132
+ pairs += [(i, i.replace(in_path, out_path))]
133
+ elif os.path.isfile(in_path):
134
+ if not out_path:
135
+ base, ext = os.path.splitext(in_path)
136
+ out_path = base+"_pixelized"+ext
137
+ else:
138
+ if os.path.isdir(out_path):
139
+ _, file = os.path.split(in_path)
140
+ out_path = os.path.join(out_path, file)
141
+ pairs = [(in_path, out_path)]
142
+
143
+ m = Model(device = "cpu" if use_cpu else "cuda")
144
+ m.load()
145
+
146
+ for in_file, out_file in pairs:
147
+ print("PROCESSING", in_file, "TO", out_file)
148
+ m.pixelize(in_file, out_file)
149
+
150
+ if __name__ == "__main__":
151
+ pixelize_cli()
reference.png ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ transforms
4
+ numpy==1.24.1
5
+ pillow