Spaces:
Running
Running
init
Browse files- .gitignore +5 -0
- README.md +1 -12
- app.py +73 -5
- models/__init__.py +0 -0
- models/basic_layer.py +429 -0
- models/c2pDis.py +313 -0
- models/c2pGen.py +266 -0
- models/networks.py +244 -0
- models/p2cGen.py +76 -0
- pixelization.py +151 -0
- reference.png +0 -0
- requirements.txt +5 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
myvenv
|
2 |
+
myvenv/**/*
|
3 |
+
__pycache__
|
4 |
+
flagged
|
5 |
+
*.pth
|
README.md
CHANGED
@@ -1,12 +1 @@
|
|
1 |
-
|
2 |
-
title: Pixelization
|
3 |
-
emoji: 🚀
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.16.2
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,8 +1,76 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import functools
|
3 |
+
from pixelization import Model
|
4 |
+
import torch
|
5 |
+
import argparse
|
6 |
+
import huggingface_hub
|
7 |
+
import os
|
8 |
|
9 |
+
TOKEN = "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"
|
|
|
10 |
|
11 |
+
def parse_args() -> argparse.Namespace:
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--theme', type=str, default='default')
|
14 |
+
parser.add_argument('--live', action='store_true')
|
15 |
+
parser.add_argument('--share', action='store_true')
|
16 |
+
parser.add_argument('--port', type=int)
|
17 |
+
parser.add_argument('--disable-queue',
|
18 |
+
dest='enable_queue',
|
19 |
+
action='store_false')
|
20 |
+
parser.add_argument('--allow-flagging', type=str, default='never')
|
21 |
+
return parser.parse_args()
|
22 |
+
|
23 |
+
def main():
|
24 |
+
args = parse_args()
|
25 |
+
|
26 |
+
|
27 |
+
# DL MODEL
|
28 |
+
# PIX_MODEL
|
29 |
+
os.environ['PIX_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "pixelart_vgg19.pth", token=TOKEN);
|
30 |
+
# NET_MODEL
|
31 |
+
os.environ['NET_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "160_net_G_A.pth", token=TOKEN);
|
32 |
+
# ALIAS_MODEL
|
33 |
+
os.environ['ALIAS_MODEL'] = huggingface_hub.hf_hub_download("NoCrypt/pixelization_models", "alias_net.pth", token=TOKEN);
|
34 |
+
|
35 |
+
# # For local testing
|
36 |
+
# # PIX_MODEL
|
37 |
+
# os.environ['PIX_MODEL'] = "pixelart_vgg19.pth"
|
38 |
+
# # NET_MODEL
|
39 |
+
# os.environ['NET_MODEL'] = "160_net_G_A.pth"
|
40 |
+
# # ALIAS_MODEL
|
41 |
+
# os.environ['ALIAS_MODEL'] = "alias_net.pth"
|
42 |
+
|
43 |
+
|
44 |
+
use_cpu = True
|
45 |
+
m = Model(device = "cpu" if use_cpu else "cuda")
|
46 |
+
m.load()
|
47 |
+
|
48 |
+
# To use GPU: Change use_cpu to false, and checkout my comment on networks.py at line 107 & 108
|
49 |
+
# + Use torch with cuda support (Change in requirements.txt)
|
50 |
+
|
51 |
+
gr.Interface(m.pixelize_modified,
|
52 |
+
[
|
53 |
+
gr.components.Image(type='pil', label='Input'),
|
54 |
+
gr.components.Slider(minimum=1, maximum=16, value=4, step=1, label='Pixel Size'),
|
55 |
+
gr.components.Checkbox(True, label="Upscale after")
|
56 |
+
],
|
57 |
+
gr.components.Image(type='pil', label='Output'),
|
58 |
+
title="Pixelization",
|
59 |
+
description='''
|
60 |
+
Demo for [WuZongWei6/Pixelization](https://github.com/WuZongWei6/Pixelization)
|
61 |
+
|
62 |
+
Models that are used is private to comply with License.
|
63 |
+
|
64 |
+
|
65 |
+
''',
|
66 |
+
theme=args.theme,
|
67 |
+
allow_flagging=args.allow_flagging,
|
68 |
+
live=args.live,
|
69 |
+
).launch(
|
70 |
+
enable_queue=args.enable_queue,
|
71 |
+
server_port=args.port,
|
72 |
+
share=args.share,
|
73 |
+
)
|
74 |
+
|
75 |
+
if __name__ == '__main__':
|
76 |
+
main()
|
models/__init__.py
ADDED
File without changes
|
models/basic_layer.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class ModulationConvBlock(nn.Module):
|
7 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride=1,
|
8 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
9 |
+
super(ModulationConvBlock, self).__init__()
|
10 |
+
self.in_c = input_dim
|
11 |
+
self.out_c = output_dim
|
12 |
+
self.ksize = kernel_size
|
13 |
+
self.stride = 1
|
14 |
+
self.padding = kernel_size // 2
|
15 |
+
|
16 |
+
self.eps = 1e-8
|
17 |
+
weight_shape = (output_dim, input_dim, kernel_size, kernel_size)
|
18 |
+
fan_in = kernel_size * kernel_size *input_dim
|
19 |
+
wscale = 1.0/np.sqrt(fan_in)
|
20 |
+
|
21 |
+
self.weight = nn.Parameter(torch.randn(*weight_shape))
|
22 |
+
self.wscale = wscale
|
23 |
+
|
24 |
+
self.bias = nn.Parameter(torch.zeros(output_dim))
|
25 |
+
|
26 |
+
self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
27 |
+
self.activate_scale = np.sqrt(2.0)
|
28 |
+
|
29 |
+
def forward(self, x, code):
|
30 |
+
batch,in_channel,height,width = x.shape
|
31 |
+
weight = self.weight * self.wscale
|
32 |
+
_weight = weight.view(1, self.ksize, self.ksize, self.in_c, self.out_c)
|
33 |
+
_weight = _weight * code.view(batch, 1, 1, self.in_c, 1)
|
34 |
+
# demodulation
|
35 |
+
_weight_norm = torch.sqrt(torch.sum(_weight ** 2, dim=[1, 2, 3]) + self.eps)
|
36 |
+
_weight = _weight / _weight_norm.view(batch, 1, 1, 1, self.out_c)
|
37 |
+
# fused_modulate
|
38 |
+
x = x.view(1, batch * self.in_c, x.shape[2], x.shape[3])
|
39 |
+
weight = _weight.permute(1, 2, 3, 0, 4).reshape(
|
40 |
+
self.ksize, self.ksize, self.in_c, batch * self.out_c)
|
41 |
+
# not use_conv2d_transpose
|
42 |
+
weight = weight.permute(3, 2, 0, 1)
|
43 |
+
x = F.conv2d(x,
|
44 |
+
weight=weight,
|
45 |
+
bias=None,
|
46 |
+
stride=self.stride,
|
47 |
+
padding=self.padding,
|
48 |
+
groups=(batch if True else 1))
|
49 |
+
|
50 |
+
if True:#self.fused_modulate:
|
51 |
+
x = x.view(batch, self.out_c, height, width)
|
52 |
+
x = x+self.bias.view(1,-1,1,1)
|
53 |
+
x = self.activate(x)*self.activate_scale
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class AliasConvBlock(nn.Module):
|
58 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
59 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
60 |
+
super(AliasConvBlock, self).__init__()
|
61 |
+
self.use_bias = True
|
62 |
+
# initialize padding
|
63 |
+
if pad_type == 'reflect':
|
64 |
+
self.pad = nn.ReflectionPad2d(padding)
|
65 |
+
elif pad_type == 'replicate':
|
66 |
+
self.pad = nn.ReplicationPad2d(padding)
|
67 |
+
elif pad_type == 'zero':
|
68 |
+
self.pad = nn.ZeroPad2d(padding)
|
69 |
+
else:
|
70 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
71 |
+
|
72 |
+
# initialize normalization
|
73 |
+
norm_dim = output_dim
|
74 |
+
if norm == 'bn':
|
75 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
76 |
+
elif norm == 'in':
|
77 |
+
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
78 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
79 |
+
elif norm == 'ln':
|
80 |
+
self.norm = LayerNorm(norm_dim)
|
81 |
+
elif norm == 'adain':
|
82 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
83 |
+
elif norm == 'none' or norm == 'sn':
|
84 |
+
self.norm = None
|
85 |
+
else:
|
86 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
87 |
+
|
88 |
+
# initialize activation
|
89 |
+
if activation == 'relu':
|
90 |
+
self.activation = nn.ReLU(inplace=True)
|
91 |
+
elif activation == 'lrelu':
|
92 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
93 |
+
elif activation == 'prelu':
|
94 |
+
self.activation = nn.PReLU()
|
95 |
+
elif activation == 'selu':
|
96 |
+
self.activation = nn.SELU(inplace=True)
|
97 |
+
elif activation == 'tanh':
|
98 |
+
self.activation = nn.Tanh()
|
99 |
+
elif activation == 'none':
|
100 |
+
self.activation = None
|
101 |
+
else:
|
102 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
103 |
+
|
104 |
+
# initialize convolution
|
105 |
+
if norm == 'sn':
|
106 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
107 |
+
|
108 |
+
else:
|
109 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
x = self.conv(self.pad(x))
|
113 |
+
if self.norm:
|
114 |
+
x = self.norm(x)
|
115 |
+
if self.activation:
|
116 |
+
x = self.activation(x)
|
117 |
+
return x
|
118 |
+
|
119 |
+
class AliasResBlocks(nn.Module):
|
120 |
+
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
|
121 |
+
super(AliasResBlocks, self).__init__()
|
122 |
+
self.model = []
|
123 |
+
for i in range(num_blocks):
|
124 |
+
self.model += [AliasResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
|
125 |
+
self.model = nn.Sequential(*self.model)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
return self.model(x)
|
129 |
+
class AliasResBlock(nn.Module):
|
130 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
131 |
+
super(AliasResBlock, self).__init__()
|
132 |
+
|
133 |
+
model = []
|
134 |
+
model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
135 |
+
model += [AliasConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
136 |
+
self.model = nn.Sequential(*model)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
residual = x
|
140 |
+
out = self.model(x)
|
141 |
+
out += residual
|
142 |
+
return out
|
143 |
+
##################################################################################
|
144 |
+
# Sequential Models
|
145 |
+
##################################################################################
|
146 |
+
class ResBlocks(nn.Module):
|
147 |
+
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
|
148 |
+
super(ResBlocks, self).__init__()
|
149 |
+
self.model = []
|
150 |
+
for i in range(num_blocks):
|
151 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
|
152 |
+
self.model = nn.Sequential(*self.model)
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
return self.model(x)
|
156 |
+
|
157 |
+
|
158 |
+
class MLP(nn.Module):
|
159 |
+
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
160 |
+
super(MLP, self).__init__()
|
161 |
+
self.model = []
|
162 |
+
self.model += [linearBlock(input_dim, input_dim, norm=norm, activation=activ)]
|
163 |
+
self.model += [linearBlock(input_dim, dim, norm=norm, activation=activ)]
|
164 |
+
for i in range(n_blk - 2):
|
165 |
+
self.model += [linearBlock(dim, dim, norm=norm, activation=activ)]
|
166 |
+
self.model += [linearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
167 |
+
self.model = nn.Sequential(*self.model)
|
168 |
+
|
169 |
+
# def forward(self, style0, style1, a=0):
|
170 |
+
# return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
|
171 |
+
# style1.view(style1.size(0), -1)))
|
172 |
+
def forward(self, style0, style1=None, a=0):
|
173 |
+
style1 = style0
|
174 |
+
return self.model[3]((1 - a) * self.model[0:3](style0.view(style0.size(0), -1)) + a * self.model[0:3](
|
175 |
+
style1.view(style1.size(0), -1)))
|
176 |
+
##################################################################################
|
177 |
+
# Basic Blocks
|
178 |
+
##################################################################################
|
179 |
+
class ResBlock(nn.Module):
|
180 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
181 |
+
super(ResBlock, self).__init__()
|
182 |
+
|
183 |
+
model = []
|
184 |
+
model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
185 |
+
model += [ConvBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
186 |
+
self.model = nn.Sequential(*model)
|
187 |
+
|
188 |
+
def forward(self, x):
|
189 |
+
residual = x
|
190 |
+
out = self.model(x)
|
191 |
+
out += residual
|
192 |
+
return out
|
193 |
+
|
194 |
+
|
195 |
+
class ConvBlock(nn.Module):
|
196 |
+
def __init__(self, input_dim, output_dim, kernel_size, stride,
|
197 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
198 |
+
super(ConvBlock, self).__init__()
|
199 |
+
self.use_bias = True
|
200 |
+
# initialize padding
|
201 |
+
if pad_type == 'reflect':
|
202 |
+
self.pad = nn.ReflectionPad2d(padding)
|
203 |
+
elif pad_type == 'replicate':
|
204 |
+
self.pad = nn.ReplicationPad2d(padding)
|
205 |
+
elif pad_type == 'zero':
|
206 |
+
self.pad = nn.ZeroPad2d(padding)
|
207 |
+
else:
|
208 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
209 |
+
|
210 |
+
# initialize normalization
|
211 |
+
norm_dim = output_dim
|
212 |
+
if norm == 'bn':
|
213 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
214 |
+
elif norm == 'in':
|
215 |
+
# self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
216 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
217 |
+
elif norm == 'ln':
|
218 |
+
self.norm = LayerNorm(norm_dim)
|
219 |
+
elif norm == 'adain':
|
220 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
221 |
+
elif norm == 'none' or norm == 'sn':
|
222 |
+
self.norm = None
|
223 |
+
else:
|
224 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
225 |
+
|
226 |
+
# initialize activation
|
227 |
+
if activation == 'relu':
|
228 |
+
self.activation = nn.ReLU(inplace=True)
|
229 |
+
elif activation == 'lrelu':
|
230 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
231 |
+
elif activation == 'prelu':
|
232 |
+
self.activation = nn.PReLU()
|
233 |
+
elif activation == 'selu':
|
234 |
+
self.activation = nn.SELU(inplace=True)
|
235 |
+
elif activation == 'tanh':
|
236 |
+
self.activation = nn.Tanh()
|
237 |
+
elif activation == 'none':
|
238 |
+
self.activation = None
|
239 |
+
else:
|
240 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
241 |
+
|
242 |
+
# initialize convolution
|
243 |
+
if norm == 'sn':
|
244 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
245 |
+
|
246 |
+
else:
|
247 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
248 |
+
|
249 |
+
def forward(self, x):
|
250 |
+
x = self.conv(self.pad(x))
|
251 |
+
if self.norm:
|
252 |
+
x = self.norm(x)
|
253 |
+
if self.activation:
|
254 |
+
x = self.activation(x)
|
255 |
+
return x
|
256 |
+
|
257 |
+
class linearBlock(nn.Module):
|
258 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
259 |
+
super(linearBlock, self).__init__()
|
260 |
+
use_bias = True
|
261 |
+
# initialize fully connected layer
|
262 |
+
if norm == 'sn':
|
263 |
+
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
264 |
+
else:
|
265 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
266 |
+
|
267 |
+
# initialize normalization
|
268 |
+
norm_dim = output_dim
|
269 |
+
if norm == 'bn':
|
270 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
271 |
+
elif norm == 'in':
|
272 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
273 |
+
elif norm == 'ln':
|
274 |
+
self.norm = LayerNorm(norm_dim)
|
275 |
+
elif norm == 'none' or norm == 'sn':
|
276 |
+
self.norm = None
|
277 |
+
else:
|
278 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
279 |
+
|
280 |
+
# initialize activation
|
281 |
+
if activation == 'relu':
|
282 |
+
self.activation = nn.ReLU(inplace=True)
|
283 |
+
elif activation == 'lrelu':
|
284 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
285 |
+
elif activation == 'prelu':
|
286 |
+
self.activation = nn.PReLU()
|
287 |
+
elif activation == 'selu':
|
288 |
+
self.activation = nn.SELU(inplace=True)
|
289 |
+
elif activation == 'tanh':
|
290 |
+
self.activation = nn.Tanh()
|
291 |
+
elif activation == 'none':
|
292 |
+
self.activation = None
|
293 |
+
else:
|
294 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
295 |
+
|
296 |
+
def forward(self, x):
|
297 |
+
out = self.fc(x)
|
298 |
+
if self.norm:
|
299 |
+
out = self.norm(out)
|
300 |
+
if self.activation:
|
301 |
+
out = self.activation(out)
|
302 |
+
return out
|
303 |
+
##################################################################################
|
304 |
+
# Normalization layers
|
305 |
+
##################################################################################
|
306 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
307 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
308 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
309 |
+
self.num_features = num_features
|
310 |
+
self.eps = eps
|
311 |
+
self.momentum = momentum
|
312 |
+
# weight and bias are dynamically assigned
|
313 |
+
self.weight = None
|
314 |
+
self.bias = None
|
315 |
+
# just dummy buffers, not used
|
316 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
317 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
318 |
+
|
319 |
+
def forward(self, x):
|
320 |
+
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
|
321 |
+
b, c = x.size(0), x.size(1)
|
322 |
+
running_mean = self.running_mean.repeat(b)
|
323 |
+
running_var = self.running_var.repeat(b)
|
324 |
+
|
325 |
+
# Apply instance norm
|
326 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
327 |
+
|
328 |
+
out = F.batch_norm(
|
329 |
+
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
330 |
+
True, self.momentum, self.eps)
|
331 |
+
|
332 |
+
return out.view(b, c, *x.size()[2:])
|
333 |
+
|
334 |
+
def __repr__(self):
|
335 |
+
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
336 |
+
|
337 |
+
|
338 |
+
class LayerNorm(nn.Module):
|
339 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
340 |
+
super(LayerNorm, self).__init__()
|
341 |
+
self.num_features = num_features
|
342 |
+
self.affine = affine
|
343 |
+
self.eps = eps
|
344 |
+
|
345 |
+
if self.affine:
|
346 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
347 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
348 |
+
|
349 |
+
def forward(self, x):
|
350 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
351 |
+
# print(x.size())
|
352 |
+
if x.size(0) == 1:
|
353 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
354 |
+
mean = x.view(-1).mean().view(*shape)
|
355 |
+
std = x.view(-1).std().view(*shape)
|
356 |
+
else:
|
357 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
358 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
359 |
+
|
360 |
+
x = (x - mean) / (std + self.eps)
|
361 |
+
|
362 |
+
if self.affine:
|
363 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
364 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
365 |
+
return x
|
366 |
+
|
367 |
+
|
368 |
+
def l2normalize(v, eps=1e-12):
|
369 |
+
return v / (v.norm() + eps)
|
370 |
+
|
371 |
+
|
372 |
+
class SpectralNorm(nn.Module):
|
373 |
+
"""
|
374 |
+
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
|
375 |
+
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
376 |
+
"""
|
377 |
+
|
378 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
379 |
+
super(SpectralNorm, self).__init__()
|
380 |
+
self.module = module
|
381 |
+
self.name = name
|
382 |
+
self.power_iterations = power_iterations
|
383 |
+
if not self._made_params():
|
384 |
+
self._make_params()
|
385 |
+
|
386 |
+
def _update_u_v(self):
|
387 |
+
u = getattr(self.module, self.name + "_u")
|
388 |
+
v = getattr(self.module, self.name + "_v")
|
389 |
+
w = getattr(self.module, self.name + "_bar")
|
390 |
+
|
391 |
+
height = w.data.shape[0]
|
392 |
+
for _ in range(self.power_iterations):
|
393 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
394 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
395 |
+
|
396 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
397 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
398 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
399 |
+
|
400 |
+
def _made_params(self):
|
401 |
+
try:
|
402 |
+
u = getattr(self.module, self.name + "_u")
|
403 |
+
v = getattr(self.module, self.name + "_v")
|
404 |
+
w = getattr(self.module, self.name + "_bar")
|
405 |
+
return True
|
406 |
+
except AttributeError:
|
407 |
+
return False
|
408 |
+
|
409 |
+
def _make_params(self):
|
410 |
+
w = getattr(self.module, self.name)
|
411 |
+
|
412 |
+
height = w.data.shape[0]
|
413 |
+
width = w.view(height, -1).data.shape[1]
|
414 |
+
|
415 |
+
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
416 |
+
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
417 |
+
u.data = l2normalize(u.data)
|
418 |
+
v.data = l2normalize(v.data)
|
419 |
+
w_bar = nn.Parameter(w.data)
|
420 |
+
|
421 |
+
del self.module._parameters[self.name]
|
422 |
+
|
423 |
+
self.module.register_parameter(self.name + "_u", u)
|
424 |
+
self.module.register_parameter(self.name + "_v", v)
|
425 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
426 |
+
|
427 |
+
def forward(self, *args):
|
428 |
+
self._update_u_v()
|
429 |
+
return self.module.forward(*args)
|
models/c2pDis.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basic_layer import *
|
2 |
+
import math
|
3 |
+
from torch.nn import Parameter
|
4 |
+
#from pytorch_metric_learning import losses
|
5 |
+
|
6 |
+
'''
|
7 |
+
Margin code is borrowed from https://github.com/MuggleWang/CosFace_pytorch and https://github.com/wujiyang/Face_Pytorch.
|
8 |
+
'''
|
9 |
+
def cosine_sim(x1, x2, dim=1, eps=1e-8):
|
10 |
+
ip = torch.mm(x1, x2.t()) # w 7*512
|
11 |
+
w1 = torch.norm(x1, 2, dim)
|
12 |
+
w2 = torch.norm(x2, 2, dim)
|
13 |
+
return ip / torch.ger(w1,w2).clamp(min=eps)
|
14 |
+
|
15 |
+
class MarginCosineProduct(nn.Module):
|
16 |
+
r"""Implement of large margin cosine distance: :
|
17 |
+
Args:
|
18 |
+
in_features: size of each input sample
|
19 |
+
out_features: size of each output sample
|
20 |
+
s: norm of input feature
|
21 |
+
m: margin
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, in_features, out_features, s=30.0, m=0.40):
|
25 |
+
super(MarginCosineProduct, self).__init__()
|
26 |
+
self.in_features = in_features
|
27 |
+
self.out_features = out_features
|
28 |
+
self.s = s
|
29 |
+
self.m = m
|
30 |
+
self.weight = Parameter(torch.Tensor(out_features, in_features)) # 7 512
|
31 |
+
nn.init.xavier_uniform_(self.weight)
|
32 |
+
#stdv = 1. / math.sqrt(self.weight.size(1))
|
33 |
+
#self.weight.data.uniform_(-stdv, stdv)
|
34 |
+
|
35 |
+
def forward(self, input, label):
|
36 |
+
cosine = cosine_sim(input, self.weight) # 1*512 7*512
|
37 |
+
# cosine = F.linear(F.normalize(input), F.normalize(self.weight))
|
38 |
+
# --------------------------- convert label to one-hot ---------------------------
|
39 |
+
# https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
|
40 |
+
one_hot = torch.zeros_like(cosine)
|
41 |
+
one_hot.scatter_(1, label.view(-1, 1), 1.0)
|
42 |
+
# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
|
43 |
+
output = self.s * (cosine - one_hot * self.m)
|
44 |
+
|
45 |
+
return output
|
46 |
+
|
47 |
+
def __repr__(self):
|
48 |
+
return self.__class__.__name__ + '(' \
|
49 |
+
+ 'in_features=' + str(self.in_features) \
|
50 |
+
+ ', out_features=' + str(self.out_features) \
|
51 |
+
+ ', s=' + str(self.s) \
|
52 |
+
+ ', m=' + str(self.m) + ')'
|
53 |
+
|
54 |
+
class ArcMarginProduct(nn.Module):
|
55 |
+
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m=0.50, easy_margin=False):
|
56 |
+
super(ArcMarginProduct, self).__init__()
|
57 |
+
self.in_feature = in_feature
|
58 |
+
self.out_feature = out_feature
|
59 |
+
self.s = s
|
60 |
+
self.m = m
|
61 |
+
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
|
62 |
+
nn.init.xavier_uniform_(self.weight)
|
63 |
+
|
64 |
+
self.easy_margin = easy_margin
|
65 |
+
self.cos_m = math.cos(m)
|
66 |
+
self.sin_m = math.sin(m)
|
67 |
+
|
68 |
+
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
|
69 |
+
self.th = math.cos(math.pi - m)
|
70 |
+
self.mm = math.sin(math.pi - m) * m
|
71 |
+
|
72 |
+
def forward(self, x, label):
|
73 |
+
# cos(theta)
|
74 |
+
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
|
75 |
+
# cos(theta + m)
|
76 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
77 |
+
phi = cosine * self.cos_m - sine * self.sin_m
|
78 |
+
|
79 |
+
if self.easy_margin:
|
80 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
81 |
+
else:
|
82 |
+
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
|
83 |
+
|
84 |
+
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
|
85 |
+
one_hot = torch.zeros_like(cosine)
|
86 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
87 |
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
|
88 |
+
output = output * self.s
|
89 |
+
|
90 |
+
return output
|
91 |
+
|
92 |
+
|
93 |
+
class MultiMarginProduct(nn.Module):
|
94 |
+
def __init__(self, in_feature=128, out_feature=10575, s=32.0, m1=0.20, m2=0.35, easy_margin=False):
|
95 |
+
super(MultiMarginProduct, self).__init__()
|
96 |
+
self.in_feature = in_feature
|
97 |
+
self.out_feature = out_feature
|
98 |
+
self.s = s
|
99 |
+
self.m1 = m1
|
100 |
+
self.m2 = m2
|
101 |
+
self.weight = Parameter(torch.Tensor(out_feature, in_feature))
|
102 |
+
nn.init.xavier_uniform_(self.weight)
|
103 |
+
|
104 |
+
self.easy_margin = easy_margin
|
105 |
+
self.cos_m1 = math.cos(m1)
|
106 |
+
self.sin_m1 = math.sin(m1)
|
107 |
+
|
108 |
+
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
|
109 |
+
self.th = math.cos(math.pi - m1)
|
110 |
+
self.mm = math.sin(math.pi - m1) * m1
|
111 |
+
|
112 |
+
def forward(self, x, label):
|
113 |
+
# cos(theta)
|
114 |
+
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
|
115 |
+
# cos(theta + m1)
|
116 |
+
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
|
117 |
+
phi = cosine * self.cos_m1 - sine * self.sin_m1
|
118 |
+
|
119 |
+
if self.easy_margin:
|
120 |
+
phi = torch.where(cosine > 0, phi, cosine)
|
121 |
+
else:
|
122 |
+
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
|
123 |
+
|
124 |
+
|
125 |
+
one_hot = torch.zeros_like(cosine)
|
126 |
+
one_hot.scatter_(1, label.view(-1, 1), 1)
|
127 |
+
output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # additive angular margin
|
128 |
+
output = output - one_hot * self.m2 # additive cosine margin
|
129 |
+
output = output * self.s
|
130 |
+
|
131 |
+
return output
|
132 |
+
|
133 |
+
|
134 |
+
class CPDis(nn.Module):
|
135 |
+
"""PatchGAN."""
|
136 |
+
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
|
137 |
+
super(CPDis, self).__init__()
|
138 |
+
|
139 |
+
layers = []
|
140 |
+
if norm == 'SN':
|
141 |
+
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
|
142 |
+
else:
|
143 |
+
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
144 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
145 |
+
|
146 |
+
curr_dim = conv_dim
|
147 |
+
for i in range(1, repeat_num):
|
148 |
+
if norm == 'SN':
|
149 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
|
150 |
+
else:
|
151 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
|
152 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
153 |
+
curr_dim = curr_dim * 2
|
154 |
+
|
155 |
+
# k_size = int(image_size / np.power(2, repeat_num))
|
156 |
+
if norm == 'SN':
|
157 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
|
158 |
+
else:
|
159 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
|
160 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
161 |
+
curr_dim = curr_dim * 2
|
162 |
+
|
163 |
+
self.main = nn.Sequential(*layers)
|
164 |
+
if norm == 'SN':
|
165 |
+
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
|
166 |
+
else:
|
167 |
+
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
if x.ndim == 5:
|
171 |
+
x = x.squeeze(0)
|
172 |
+
assert x.ndim == 4, x.ndim
|
173 |
+
h = self.main(x)
|
174 |
+
# out_real = self.conv1(h)
|
175 |
+
out_makeup = self.conv1(h)
|
176 |
+
# return out_real.squeeze(), out_makeup.squeeze()
|
177 |
+
return out_makeup
|
178 |
+
|
179 |
+
|
180 |
+
class CPDis_cls(nn.Module):
|
181 |
+
"""PatchGAN."""
|
182 |
+
def __init__(self, image_size=256, conv_dim=64, repeat_num=3, norm='SN'):
|
183 |
+
super(CPDis_cls, self).__init__()
|
184 |
+
|
185 |
+
layers = []
|
186 |
+
if norm == 'SN':
|
187 |
+
layers.append(spectral_norm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
|
188 |
+
else:
|
189 |
+
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
|
190 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
191 |
+
|
192 |
+
curr_dim = conv_dim
|
193 |
+
for i in range(1, repeat_num):
|
194 |
+
if norm == 'SN':
|
195 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
|
196 |
+
else:
|
197 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
|
198 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
199 |
+
curr_dim = curr_dim * 2
|
200 |
+
|
201 |
+
# k_size = int(image_size / np.power(2, repeat_num))
|
202 |
+
if norm == 'SN':
|
203 |
+
layers.append(spectral_norm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
|
204 |
+
else:
|
205 |
+
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
|
206 |
+
layers.append(nn.LeakyReLU(0.01, inplace=True))
|
207 |
+
curr_dim = curr_dim * 2
|
208 |
+
|
209 |
+
self.main = nn.Sequential(*layers)
|
210 |
+
if norm == 'SN':
|
211 |
+
self.conv1 = spectral_norm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
|
212 |
+
self.classifier_pool = nn.AdaptiveAvgPool2d(1)
|
213 |
+
self.classifier_conv = nn.Conv2d(512, 512, 1, 1, 0)
|
214 |
+
self.classifier = MarginCosineProduct(512,7)#ArcMarginProduct(512, 7)
|
215 |
+
print("Using Large Margin Cosine Loss.")
|
216 |
+
|
217 |
+
else:
|
218 |
+
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
|
219 |
+
|
220 |
+
def forward(self, x, label):
|
221 |
+
if x.ndim == 5:
|
222 |
+
x = x.squeeze(0)
|
223 |
+
assert x.ndim == 4, x.ndim
|
224 |
+
h = self.main(x) # ([1, 512, 31, 31])
|
225 |
+
#print(out_cls.shape)
|
226 |
+
out_cls = self.classifier_pool(h)
|
227 |
+
#print(out_cls.shape)
|
228 |
+
out_cls = self.classifier_conv(out_cls)
|
229 |
+
#print(out_cls.shape)
|
230 |
+
out_cls = torch.squeeze(out_cls, -1)
|
231 |
+
out_cls = torch.squeeze(out_cls, -1)
|
232 |
+
out_cls = self.classifier(out_cls, label)
|
233 |
+
out_makeup = self.conv1(h) # torch.Size([1, 1, 30, 30])
|
234 |
+
# return out_real.squeeze(), out_makeup.squeeze()
|
235 |
+
return out_makeup, out_cls
|
236 |
+
|
237 |
+
class SpectralNorm(object):
|
238 |
+
def __init__(self):
|
239 |
+
self.name = "weight"
|
240 |
+
# print(self.name)
|
241 |
+
self.power_iterations = 1
|
242 |
+
|
243 |
+
def compute_weight(self, module):
|
244 |
+
u = getattr(module, self.name + "_u")
|
245 |
+
v = getattr(module, self.name + "_v")
|
246 |
+
w = getattr(module, self.name + "_bar")
|
247 |
+
|
248 |
+
height = w.data.shape[0]
|
249 |
+
for _ in range(self.power_iterations):
|
250 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
|
251 |
+
u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
|
252 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
253 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
254 |
+
return w / sigma.expand_as(w)
|
255 |
+
|
256 |
+
@staticmethod
|
257 |
+
def apply(module):
|
258 |
+
name = "weight"
|
259 |
+
fn = SpectralNorm()
|
260 |
+
|
261 |
+
try:
|
262 |
+
u = getattr(module, name + "_u")
|
263 |
+
v = getattr(module, name + "_v")
|
264 |
+
w = getattr(module, name + "_bar")
|
265 |
+
except AttributeError:
|
266 |
+
w = getattr(module, name)
|
267 |
+
height = w.data.shape[0]
|
268 |
+
width = w.view(height, -1).data.shape[1]
|
269 |
+
u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
270 |
+
v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
271 |
+
w_bar = Parameter(w.data)
|
272 |
+
|
273 |
+
# del module._parameters[name]
|
274 |
+
|
275 |
+
module.register_parameter(name + "_u", u)
|
276 |
+
module.register_parameter(name + "_v", v)
|
277 |
+
module.register_parameter(name + "_bar", w_bar)
|
278 |
+
|
279 |
+
# remove w from parameter list
|
280 |
+
del module._parameters[name]
|
281 |
+
|
282 |
+
setattr(module, name, fn.compute_weight(module))
|
283 |
+
|
284 |
+
# recompute weight before every forward()
|
285 |
+
module.register_forward_pre_hook(fn)
|
286 |
+
|
287 |
+
return fn
|
288 |
+
|
289 |
+
def remove(self, module):
|
290 |
+
weight = self.compute_weight(module)
|
291 |
+
delattr(module, self.name)
|
292 |
+
del module._parameters[self.name + '_u']
|
293 |
+
del module._parameters[self.name + '_v']
|
294 |
+
del module._parameters[self.name + '_bar']
|
295 |
+
module.register_parameter(self.name, Parameter(weight.data))
|
296 |
+
|
297 |
+
def __call__(self, module, inputs):
|
298 |
+
setattr(module, self.name, self.compute_weight(module))
|
299 |
+
|
300 |
+
def spectral_norm(module):
|
301 |
+
SpectralNorm.apply(module)
|
302 |
+
return module
|
303 |
+
|
304 |
+
def remove_spectral_norm(module):
|
305 |
+
name = 'weight'
|
306 |
+
for k, hook in module._forward_pre_hooks.items():
|
307 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
308 |
+
hook.remove(module)
|
309 |
+
del module._forward_pre_hooks[k]
|
310 |
+
return module
|
311 |
+
|
312 |
+
raise ValueError("spectral_norm of '{}' not found in {}"
|
313 |
+
.format(name, module))
|
models/c2pGen.py
ADDED
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basic_layer import *
|
2 |
+
import torchvision.models as models
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class AliasNet(nn.Module):
|
8 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
|
9 |
+
super(AliasNet, self).__init__()
|
10 |
+
self.RGBEnc = AliasRGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
11 |
+
self.RGBDec = AliasRGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
|
12 |
+
activ=activ, pad_type=pad_type)
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
x = self.RGBEnc(x)
|
16 |
+
x = self.RGBDec(x)
|
17 |
+
return x
|
18 |
+
|
19 |
+
|
20 |
+
class AliasRGBEncoder(nn.Module):
|
21 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
22 |
+
super(AliasRGBEncoder, self).__init__()
|
23 |
+
self.model = []
|
24 |
+
self.model += [AliasConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
25 |
+
# downsampling blocks
|
26 |
+
for i in range(n_downsample):
|
27 |
+
self.model += [AliasConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
28 |
+
dim *= 2
|
29 |
+
# residual blocks
|
30 |
+
self.model += [AliasResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
31 |
+
self.model = nn.Sequential(*self.model)
|
32 |
+
self.output_dim = dim
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.model(x)
|
36 |
+
|
37 |
+
|
38 |
+
class AliasRGBDecoder(nn.Module):
|
39 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
40 |
+
super(AliasRGBDecoder, self).__init__()
|
41 |
+
# self.model = []
|
42 |
+
# # AdaIN residual blocks
|
43 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
44 |
+
# # upsampling blocks
|
45 |
+
# for i in range(n_upsample):
|
46 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
47 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
48 |
+
# dim //= 2
|
49 |
+
# # use reflection padding in the last conv layer
|
50 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
51 |
+
# self.model = nn.Sequential(*self.model)
|
52 |
+
self.Res_Blocks = AliasResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
53 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
54 |
+
self.conv_1 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
55 |
+
dim //= 2
|
56 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
57 |
+
self.conv_2 = AliasConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
58 |
+
dim //= 2
|
59 |
+
self.conv_3 = AliasConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
60 |
+
|
61 |
+
def forward(self, x):
|
62 |
+
x = self.Res_Blocks(x)
|
63 |
+
# print(x.shape)
|
64 |
+
x = self.upsample_block1(x)
|
65 |
+
# print(x.shape)
|
66 |
+
x = self.conv_1(x)
|
67 |
+
# print(x_small.shape)
|
68 |
+
x = self.upsample_block2(x)
|
69 |
+
# print(x.shape)
|
70 |
+
x = self.conv_2(x)
|
71 |
+
# print(x_middle.shape)
|
72 |
+
x = self.conv_3(x)
|
73 |
+
# print(x_big.shape)
|
74 |
+
return x
|
75 |
+
|
76 |
+
|
77 |
+
class C2PGen(nn.Module):
|
78 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, style_dim, mlp_dim, activ='relu', pad_type='reflect'):
|
79 |
+
super(C2PGen, self).__init__()
|
80 |
+
self.PBEnc = PixelBlockEncoder(input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
|
81 |
+
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
82 |
+
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='adain',
|
83 |
+
activ=activ, pad_type=pad_type)
|
84 |
+
self.MLP = MLP(style_dim, 2048, mlp_dim, 3, norm='none', activ=activ)
|
85 |
+
|
86 |
+
def forward(self, clipart, pixelart, s=1):
|
87 |
+
feature = self.RGBEnc(clipart)
|
88 |
+
code = self.PBEnc(pixelart)
|
89 |
+
result, cellcode = self.fuse(feature, code, s)
|
90 |
+
return result#, cellcode #return cellcode when visualizing the cell size code
|
91 |
+
|
92 |
+
def fuse(self, content, style_code, s=1):
|
93 |
+
#print("MLP input:code's shape:", style_code.shape)
|
94 |
+
adain_params = self.MLP(style_code) * s # [batch,2048]
|
95 |
+
#print("MLP output:adain_params's shape", adain_params.shape)
|
96 |
+
#self.assign_adain_params(adain_params, self.RGBDec)
|
97 |
+
images = self.RGBDec(content, adain_params)
|
98 |
+
return images, adain_params
|
99 |
+
|
100 |
+
def assign_adain_params(self, adain_params, model):
|
101 |
+
# assign the adain_params to the AdaIN layers in model
|
102 |
+
for m in model.modules():
|
103 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
104 |
+
mean = adain_params[:, :m.num_features]
|
105 |
+
std = adain_params[:, m.num_features:2 * m.num_features]
|
106 |
+
m.bias = mean.contiguous().view(-1)
|
107 |
+
m.weight = std.contiguous().view(-1)
|
108 |
+
if adain_params.size(1) > 2 * m.num_features:
|
109 |
+
adain_params = adain_params[:, 2 * m.num_features:]
|
110 |
+
|
111 |
+
def get_num_adain_params(self, model):
|
112 |
+
# return the number of AdaIN parameters needed by the model
|
113 |
+
num_adain_params = 0
|
114 |
+
for m in model.modules():
|
115 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
116 |
+
num_adain_params += 2 * m.num_features
|
117 |
+
return num_adain_params
|
118 |
+
|
119 |
+
|
120 |
+
class PixelBlockEncoder(nn.Module):
|
121 |
+
def __init__(self, input_dim, dim, style_dim, norm, activ, pad_type):
|
122 |
+
super(PixelBlockEncoder, self).__init__()
|
123 |
+
vgg19 = models.vgg.vgg19()
|
124 |
+
vgg19.classifier._modules['6'] = nn.Linear(4096, 7, bias=True)
|
125 |
+
vgg19.load_state_dict(torch.load('./pixelart_vgg19.pth' if not os.environ['PIX_MODEL'] else os.environ['PIX_MODEL'], map_location=torch.device('cpu')))
|
126 |
+
self.vgg = vgg19.features
|
127 |
+
for p in self.vgg.parameters():
|
128 |
+
p.requires_grad = False
|
129 |
+
# vgg19 = models.vgg.vgg19(pretrained=False)
|
130 |
+
# vgg19.load_state_dict(torch.load('./vgg.pth'))
|
131 |
+
# self.vgg = vgg19.features
|
132 |
+
# for p in self.vgg.parameters():
|
133 |
+
# p.requires_grad = False
|
134 |
+
|
135 |
+
|
136 |
+
self.conv1 = ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64,concat
|
137 |
+
dim = dim * 2
|
138 |
+
self.conv2 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128
|
139 |
+
dim = dim * 2
|
140 |
+
self.conv3 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256
|
141 |
+
dim = dim * 2
|
142 |
+
self.conv4 = ConvBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512
|
143 |
+
dim = dim * 2
|
144 |
+
|
145 |
+
self.model = []
|
146 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
147 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
148 |
+
self.model = nn.Sequential(*self.model)
|
149 |
+
self.output_dim = dim
|
150 |
+
|
151 |
+
def get_features(self, image, model, layers=None):
|
152 |
+
if layers is None:
|
153 |
+
layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'}
|
154 |
+
features = {}
|
155 |
+
x = image
|
156 |
+
# model._modules is a dictionary holding each module in the model
|
157 |
+
for name, layer in model._modules.items():
|
158 |
+
x = layer(x)
|
159 |
+
if name in layers:
|
160 |
+
features[layers[name]] = x
|
161 |
+
return features
|
162 |
+
|
163 |
+
def componet_enc(self, x):
|
164 |
+
# x [16,3,256,256]
|
165 |
+
# factor_img [16,7,256,256]
|
166 |
+
vgg_aux = self.get_features(x, self.vgg) # x是3通道灰度图
|
167 |
+
#x = torch.cat([x, factor_img], dim=1) # [16,3+7,256,256]
|
168 |
+
x = self.conv1(x) # 64 256 256
|
169 |
+
x = torch.cat([x, vgg_aux['conv1_1']], dim=1) # 128 256 256
|
170 |
+
x = self.conv2(x) # 128 128 128
|
171 |
+
x = torch.cat([x, vgg_aux['conv2_1']], dim=1) # 256 128 128
|
172 |
+
x = self.conv3(x) # 256 64 64
|
173 |
+
x = torch.cat([x, vgg_aux['conv3_1']], dim=1) # 512 64 64
|
174 |
+
x = self.conv4(x) # 512 32 32
|
175 |
+
x = torch.cat([x, vgg_aux['conv4_1']], dim=1) # 1024 32 32
|
176 |
+
x = self.model(x)
|
177 |
+
return x
|
178 |
+
|
179 |
+
def forward(self, x):
|
180 |
+
code = self.componet_enc(x)
|
181 |
+
return code
|
182 |
+
|
183 |
+
class RGBEncoder(nn.Module):
|
184 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
185 |
+
super(RGBEncoder, self).__init__()
|
186 |
+
self.model = []
|
187 |
+
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
188 |
+
# downsampling blocks
|
189 |
+
for i in range(n_downsample):
|
190 |
+
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
191 |
+
dim *= 2
|
192 |
+
# residual blocks
|
193 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
194 |
+
self.model = nn.Sequential(*self.model)
|
195 |
+
self.output_dim = dim
|
196 |
+
|
197 |
+
def forward(self, x):
|
198 |
+
return self.model(x)
|
199 |
+
|
200 |
+
|
201 |
+
class RGBDecoder(nn.Module):
|
202 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
203 |
+
super(RGBDecoder, self).__init__()
|
204 |
+
# self.model = []
|
205 |
+
# # AdaIN residual blocks
|
206 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
207 |
+
# # upsampling blocks
|
208 |
+
# for i in range(n_upsample):
|
209 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
210 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
211 |
+
# dim //= 2
|
212 |
+
# # use reflection padding in the last conv layer
|
213 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
214 |
+
# self.model = nn.Sequential(*self.model)
|
215 |
+
#self.Res_Blocks = ModulationResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
216 |
+
self.mod_conv_1 = ModulationConvBlock(256,256,3)
|
217 |
+
self.mod_conv_2 = ModulationConvBlock(256,256,3)
|
218 |
+
self.mod_conv_3 = ModulationConvBlock(256,256,3)
|
219 |
+
self.mod_conv_4 = ModulationConvBlock(256,256,3)
|
220 |
+
self.mod_conv_5 = ModulationConvBlock(256,256,3)
|
221 |
+
self.mod_conv_6 = ModulationConvBlock(256,256,3)
|
222 |
+
self.mod_conv_7 = ModulationConvBlock(256,256,3)
|
223 |
+
self.mod_conv_8 = ModulationConvBlock(256,256,3)
|
224 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
225 |
+
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
226 |
+
dim //= 2
|
227 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
228 |
+
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
229 |
+
dim //= 2
|
230 |
+
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
231 |
+
|
232 |
+
# def forward(self, x):
|
233 |
+
# residual = x
|
234 |
+
# out = self.model(x)
|
235 |
+
# out += residual
|
236 |
+
# return out
|
237 |
+
def forward(self, x, code):
|
238 |
+
residual = x
|
239 |
+
x = self.mod_conv_1(x, code[:, :256])
|
240 |
+
x = self.mod_conv_2(x, code[:, 256*1:256*2])
|
241 |
+
x += residual
|
242 |
+
residual = x
|
243 |
+
x = self.mod_conv_2(x, code[:, 256*2:256 * 3])
|
244 |
+
x = self.mod_conv_2(x, code[:, 256*3:256 * 4])
|
245 |
+
x += residual
|
246 |
+
residual =x
|
247 |
+
x = self.mod_conv_2(x, code[:, 256*4:256 * 5])
|
248 |
+
x = self.mod_conv_2(x, code[:, 256*5:256 * 6])
|
249 |
+
x += residual
|
250 |
+
residual = x
|
251 |
+
x = self.mod_conv_2(x, code[:, 256*6:256 * 7])
|
252 |
+
x = self.mod_conv_2(x, code[:, 256*7:256 * 8])
|
253 |
+
x += residual
|
254 |
+
# print(x.shape)
|
255 |
+
x = self.upsample_block1(x)
|
256 |
+
# print(x.shape)
|
257 |
+
x = self.conv_1(x)
|
258 |
+
# print(x_small.shape)
|
259 |
+
x = self.upsample_block2(x)
|
260 |
+
# print(x.shape)
|
261 |
+
x = self.conv_2(x)
|
262 |
+
# print(x_middle.shape)
|
263 |
+
x = self.conv_3(x)
|
264 |
+
# print(x_big.shape)
|
265 |
+
return x
|
266 |
+
|
models/networks.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init
|
4 |
+
import functools
|
5 |
+
from torch.optim import lr_scheduler
|
6 |
+
from .c2pGen import *
|
7 |
+
from .p2cGen import *
|
8 |
+
from .c2pDis import *
|
9 |
+
|
10 |
+
class Identity(nn.Module):
|
11 |
+
def forward(self, x):
|
12 |
+
return x
|
13 |
+
|
14 |
+
def get_norm_layer(norm_type='instance'):
|
15 |
+
"""Return a normalization layer
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
norm_type (str) -- the name of the normalization layer: batch | instance | none
|
19 |
+
|
20 |
+
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
|
21 |
+
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
|
22 |
+
"""
|
23 |
+
if norm_type == 'batch':
|
24 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
25 |
+
elif norm_type == 'instance':
|
26 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
27 |
+
elif norm_type == 'none':
|
28 |
+
def norm_layer(x): return Identity()
|
29 |
+
else:
|
30 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
31 |
+
return norm_layer
|
32 |
+
|
33 |
+
|
34 |
+
def get_scheduler(optimizer, opt):
|
35 |
+
"""Return a learning rate scheduler
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
optimizer -- the optimizer of the network
|
39 |
+
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
|
40 |
+
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
|
41 |
+
|
42 |
+
For 'linear', we keep the same learning rate for the first <opt.n_epochs> epochs
|
43 |
+
and linearly decay the rate to zero over the next <opt.n_epochs_decay> epochs.
|
44 |
+
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
|
45 |
+
See https://pytorch.org/docs/stable/optim.html for more details.
|
46 |
+
"""
|
47 |
+
if opt.lr_policy == 'linear':
|
48 |
+
def lambda_rule(epoch):
|
49 |
+
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
|
50 |
+
return lr_l
|
51 |
+
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
|
52 |
+
elif opt.lr_policy == 'step':
|
53 |
+
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
|
54 |
+
elif opt.lr_policy == 'plateau':
|
55 |
+
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
|
56 |
+
elif opt.lr_policy == 'cosine':
|
57 |
+
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
|
58 |
+
else:
|
59 |
+
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
|
60 |
+
return scheduler
|
61 |
+
|
62 |
+
|
63 |
+
def init_weights(net, init_type='normal', init_gain=0.02):
|
64 |
+
"""Initialize network weights.
|
65 |
+
|
66 |
+
Parameters:
|
67 |
+
net (network) -- network to be initialized
|
68 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
69 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
70 |
+
|
71 |
+
"""
|
72 |
+
def init_func(m): # define the initialization function
|
73 |
+
classname = m.__class__.__name__
|
74 |
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
75 |
+
if init_type == 'normal':
|
76 |
+
init.normal_(m.weight.data, 0.0, init_gain)
|
77 |
+
elif init_type == 'xavier':
|
78 |
+
init.xavier_normal_(m.weight.data, gain=init_gain)
|
79 |
+
elif init_type == 'kaiming':
|
80 |
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
81 |
+
elif init_type == 'orthogonal':
|
82 |
+
init.orthogonal_(m.weight.data, gain=init_gain)
|
83 |
+
else:
|
84 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
85 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
86 |
+
init.constant_(m.bias.data, 0.0)
|
87 |
+
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
|
88 |
+
init.normal_(m.weight.data, 1.0, init_gain)
|
89 |
+
init.constant_(m.bias.data, 0.0)
|
90 |
+
|
91 |
+
#print('initialize network with %s' % init_type)
|
92 |
+
net.apply(init_func) # apply the initialization function <init_func>
|
93 |
+
|
94 |
+
|
95 |
+
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
96 |
+
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
|
97 |
+
Parameters:
|
98 |
+
net (network) -- the network to be initialized
|
99 |
+
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
|
100 |
+
gain (float) -- scaling factor for normal, xavier and orthogonal.
|
101 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
102 |
+
|
103 |
+
Return an initialized network.
|
104 |
+
"""
|
105 |
+
gpu_ids = [0]
|
106 |
+
if len(gpu_ids) > 0:
|
107 |
+
# assert(torch.cuda.is_available()) #uncomment this for using gpu
|
108 |
+
net.to(torch.device("cpu")) #change this for using gpu to gpu_ids[0]
|
109 |
+
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
|
110 |
+
init_weights(net, init_type, init_gain=init_gain)
|
111 |
+
return net
|
112 |
+
|
113 |
+
|
114 |
+
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
115 |
+
"""Create a generator
|
116 |
+
|
117 |
+
Parameters:
|
118 |
+
input_nc (int) -- the number of channels in input images
|
119 |
+
output_nc (int) -- the number of channels in output images
|
120 |
+
ngf (int) -- the number of filters in the last conv layer
|
121 |
+
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
|
122 |
+
norm (str) -- the name of normalization layers used in the network: batch | instance | none
|
123 |
+
use_dropout (bool) -- if use dropout layers.
|
124 |
+
init_type (str) -- the name of our initialization method.
|
125 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
126 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
127 |
+
|
128 |
+
Returns a generator
|
129 |
+
"""
|
130 |
+
net = None
|
131 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
132 |
+
|
133 |
+
if netG == 'c2pGen': # style_dim mlp_dim
|
134 |
+
net = C2PGen(input_nc, output_nc, ngf, 2, 4, 256, 256, activ='relu', pad_type='reflect')
|
135 |
+
#print('c2pgen resblock is 8')
|
136 |
+
elif netG == 'p2cGen':
|
137 |
+
net = P2CGen(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
|
138 |
+
elif netG == 'antialias':
|
139 |
+
net = AliasNet(input_nc, output_nc, ngf, 2, 3, activ='relu', pad_type='reflect')
|
140 |
+
else:
|
141 |
+
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
142 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
|
147 |
+
"""Create a discriminator
|
148 |
+
|
149 |
+
Parameters:
|
150 |
+
input_nc (int) -- the number of channels in input images
|
151 |
+
ndf (int) -- the number of filters in the first conv layer
|
152 |
+
netD (str) -- the architecture's name: basic | n_layers | pixel
|
153 |
+
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
|
154 |
+
norm (str) -- the type of normalization layers used in the network.
|
155 |
+
init_type (str) -- the name of the initialization method.
|
156 |
+
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
|
157 |
+
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
|
158 |
+
|
159 |
+
Returns a discriminator
|
160 |
+
"""
|
161 |
+
net = None
|
162 |
+
norm_layer = get_norm_layer(norm_type=norm)
|
163 |
+
|
164 |
+
|
165 |
+
if netD == 'CPDis':
|
166 |
+
net = CPDis(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
|
167 |
+
elif netD == 'CPDis_cls':
|
168 |
+
net = CPDis_cls(image_size=256, conv_dim=64, repeat_num=3, norm='SN')
|
169 |
+
else:
|
170 |
+
raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
|
171 |
+
return init_net(net, init_type, init_gain, gpu_ids)
|
172 |
+
|
173 |
+
|
174 |
+
class GANLoss(nn.Module):
|
175 |
+
"""Define different GAN objectives.
|
176 |
+
|
177 |
+
The GANLoss class abstracts away the need to create the target label tensor
|
178 |
+
that has the same size as the input.
|
179 |
+
"""
|
180 |
+
|
181 |
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
182 |
+
""" Initialize the GANLoss class.
|
183 |
+
|
184 |
+
Parameters:
|
185 |
+
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
|
186 |
+
target_real_label (bool) - - label for a real image
|
187 |
+
target_fake_label (bool) - - label of a fake image
|
188 |
+
|
189 |
+
Note: Do not use sigmoid as the last layer of Discriminator.
|
190 |
+
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
|
191 |
+
"""
|
192 |
+
super(GANLoss, self).__init__()
|
193 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
194 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
195 |
+
self.gan_mode = gan_mode
|
196 |
+
if gan_mode == 'lsgan':
|
197 |
+
self.loss = nn.MSELoss()
|
198 |
+
elif gan_mode == 'vanilla':
|
199 |
+
self.loss = nn.BCEWithLogitsLoss()
|
200 |
+
elif gan_mode in ['wgangp']:
|
201 |
+
self.loss = None
|
202 |
+
else:
|
203 |
+
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
|
204 |
+
|
205 |
+
def get_target_tensor(self, prediction, target_is_real):
|
206 |
+
"""Create label tensors with the same size as the input.
|
207 |
+
|
208 |
+
Parameters:
|
209 |
+
prediction (tensor) - - tpyically the prediction from a discriminator
|
210 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
211 |
+
|
212 |
+
Returns:
|
213 |
+
A label tensor filled with ground truth label, and with the size of the input
|
214 |
+
"""
|
215 |
+
|
216 |
+
if target_is_real:
|
217 |
+
target_tensor = self.real_label
|
218 |
+
else:
|
219 |
+
target_tensor = self.fake_label
|
220 |
+
return target_tensor.expand_as(prediction)
|
221 |
+
|
222 |
+
def __call__(self, prediction, target_is_real):
|
223 |
+
"""Calculate loss given Discriminator's output and grount truth labels.
|
224 |
+
|
225 |
+
Parameters:
|
226 |
+
prediction (tensor) - - tpyically the prediction output from a discriminator
|
227 |
+
target_is_real (bool) - - if the ground truth label is for real images or fake images
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
the calculated loss.
|
231 |
+
"""
|
232 |
+
if self.gan_mode in ['lsgan', 'vanilla']:
|
233 |
+
target_tensor = self.get_target_tensor(prediction, target_is_real)
|
234 |
+
loss = self.loss(prediction, target_tensor)
|
235 |
+
elif self.gan_mode == 'wgangp':
|
236 |
+
if target_is_real:
|
237 |
+
loss = -prediction.mean()
|
238 |
+
else:
|
239 |
+
loss = prediction.mean()
|
240 |
+
return loss
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
models/p2cGen.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .basic_layer import *
|
2 |
+
|
3 |
+
|
4 |
+
class P2CGen(nn.Module):
|
5 |
+
def __init__(self, input_dim, output_dim, dim, n_downsample, n_res, activ='relu', pad_type='reflect'):
|
6 |
+
super(P2CGen, self).__init__()
|
7 |
+
self.RGBEnc = RGBEncoder(input_dim, dim, n_downsample, n_res, "in", activ, pad_type=pad_type)
|
8 |
+
self.RGBDec = RGBDecoder(self.RGBEnc.output_dim, output_dim, n_downsample, n_res, res_norm='in',
|
9 |
+
activ=activ, pad_type=pad_type)
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
x = self.RGBEnc(x)
|
13 |
+
# print("encoder->>", x.shape)
|
14 |
+
x = self.RGBDec(x)
|
15 |
+
# print(x_small.shape)
|
16 |
+
# print(x_middle.shape)
|
17 |
+
# print(x_big.shape)
|
18 |
+
#return y_small, y_middle, y_big
|
19 |
+
return x
|
20 |
+
|
21 |
+
|
22 |
+
class RGBEncoder(nn.Module):
|
23 |
+
def __init__(self, input_dim, dim, n_downsample, n_res, norm, activ, pad_type):
|
24 |
+
super(RGBEncoder, self).__init__()
|
25 |
+
self.model = []
|
26 |
+
self.model += [ConvBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
27 |
+
# downsampling blocks
|
28 |
+
for i in range(n_downsample):
|
29 |
+
self.model += [ConvBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
30 |
+
dim *= 2
|
31 |
+
# residual blocks
|
32 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
33 |
+
self.model = nn.Sequential(*self.model)
|
34 |
+
self.output_dim = dim
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.model(x)
|
38 |
+
|
39 |
+
|
40 |
+
class RGBDecoder(nn.Module):
|
41 |
+
def __init__(self, dim, output_dim, n_upsample, n_res, res_norm, activ='relu', pad_type='zero'):
|
42 |
+
super(RGBDecoder, self).__init__()
|
43 |
+
# self.model = []
|
44 |
+
# # AdaIN residual blocks
|
45 |
+
# self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
46 |
+
# # upsampling blocks
|
47 |
+
# for i in range(n_upsample):
|
48 |
+
# self.model += [nn.Upsample(scale_factor=2, mode='nearest'),
|
49 |
+
# ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
50 |
+
# dim //= 2
|
51 |
+
# # use reflection padding in the last conv layer
|
52 |
+
# self.model += [ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
53 |
+
# self.model = nn.Sequential(*self.model)
|
54 |
+
self.Res_Blocks = ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)
|
55 |
+
self.upsample_block1 = nn.Upsample(scale_factor=2, mode='nearest')
|
56 |
+
self.conv_1 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
57 |
+
dim //= 2
|
58 |
+
self.upsample_block2 = nn.Upsample(scale_factor=2, mode='nearest')
|
59 |
+
self.conv_2 = ConvBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)
|
60 |
+
dim //= 2
|
61 |
+
self.conv_3 = ConvBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x = self.Res_Blocks(x)
|
65 |
+
# print(x.shape)
|
66 |
+
x = self.upsample_block1(x)
|
67 |
+
# print(x.shape)
|
68 |
+
x = self.conv_1(x)
|
69 |
+
# print(x_small.shape)
|
70 |
+
x = self.upsample_block2(x)
|
71 |
+
# print(x.shape)
|
72 |
+
x = self.conv_2(x)
|
73 |
+
# print(x_middle.shape)
|
74 |
+
x = self.conv_3(x)
|
75 |
+
# print(x_big.shape)
|
76 |
+
return x
|
pixelization.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from models.networks import define_G
|
7 |
+
import glob
|
8 |
+
|
9 |
+
|
10 |
+
class Model():
|
11 |
+
def __init__(self, device="cpu"):
|
12 |
+
self.device = torch.device(device)
|
13 |
+
self.G_A_net = None
|
14 |
+
self.alias_net = None
|
15 |
+
self.ref_t = None
|
16 |
+
|
17 |
+
def load(self):
|
18 |
+
with torch.no_grad():
|
19 |
+
self.G_A_net = define_G(3, 3, 64, "c2pGen", "instance", False, "normal", 0.02, [0])
|
20 |
+
self.alias_net = define_G(3, 3, 64, "antialias", "instance", False, "normal", 0.02, [0])
|
21 |
+
|
22 |
+
G_A_state = torch.load("160_net_G_A.pth" if not os.environ['NET_MODEL'] else os.environ['NET_MODEL'], map_location=str(self.device))
|
23 |
+
for p in list(G_A_state.keys()):
|
24 |
+
G_A_state["module."+str(p)] = G_A_state.pop(p)
|
25 |
+
self.G_A_net.load_state_dict(G_A_state)
|
26 |
+
|
27 |
+
alias_state = torch.load("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL'], map_location=str(self.device))
|
28 |
+
for p in list(alias_state.keys()):
|
29 |
+
alias_state["module."+str(p)] = alias_state.pop(p)
|
30 |
+
self.alias_net.load_state_dict(alias_state)
|
31 |
+
|
32 |
+
ref_img = Image.open("reference.png").convert('L')
|
33 |
+
self.ref_t = process(greyscale(ref_img)).to(self.device)
|
34 |
+
|
35 |
+
def pixelize(self, in_img, out_img):
|
36 |
+
with torch.no_grad():
|
37 |
+
in_img = Image.open(in_img).convert('RGB')
|
38 |
+
in_t = process(in_img).to(self.device)
|
39 |
+
|
40 |
+
out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
|
41 |
+
|
42 |
+
save(out_t, out_img)
|
43 |
+
|
44 |
+
def pixelize_modified(self, in_img, pixel_size, upscale_after) -> Image.Image:
|
45 |
+
with torch.no_grad():
|
46 |
+
in_img = in_img.convert('RGB')
|
47 |
+
|
48 |
+
# limit in_img size to 1024x1024 so it didn't destroyed by large image
|
49 |
+
if in_img.size[0] > 1024 or in_img.size[1] > 1024:
|
50 |
+
in_img.thumbnail((1024, 1024), Image.NEAREST)
|
51 |
+
|
52 |
+
in_img.resize((in_img.size[0] * 4 // pixel_size, in_img.size[1] * 4 // pixel_size))
|
53 |
+
|
54 |
+
in_t = process(in_img).to(self.device)
|
55 |
+
|
56 |
+
out_t = self.alias_net(self.G_A_net(in_t, self.ref_t))
|
57 |
+
img = to_image(out_t, pixel_size, upscale_after)
|
58 |
+
return img
|
59 |
+
|
60 |
+
def to_image(tensor, pixel_size, upscale_after):
|
61 |
+
img = tensor.data[0].cpu().float().numpy()
|
62 |
+
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
|
63 |
+
img = img.astype(np.uint8)
|
64 |
+
img = Image.fromarray(img)
|
65 |
+
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
|
66 |
+
if upscale_after:
|
67 |
+
img = img.resize((img.size[0]*pixel_size, img.size[1]*pixel_size), resample=Image.Resampling.NEAREST)
|
68 |
+
|
69 |
+
return img
|
70 |
+
|
71 |
+
|
72 |
+
def greyscale(img):
|
73 |
+
gray = np.array(img.convert('L'))
|
74 |
+
tmp = np.expand_dims(gray, axis=2)
|
75 |
+
tmp = np.concatenate((tmp, tmp, tmp), axis=-1)
|
76 |
+
return Image.fromarray(tmp)
|
77 |
+
|
78 |
+
def process(img):
|
79 |
+
ow,oh = img.size
|
80 |
+
|
81 |
+
nw = int(round(ow / 4) * 4)
|
82 |
+
nh = int(round(oh / 4) * 4)
|
83 |
+
|
84 |
+
left = (ow - nw)//2
|
85 |
+
top = (oh - nh)//2
|
86 |
+
right = left + nw
|
87 |
+
bottom = top + nh
|
88 |
+
|
89 |
+
img = img.crop((left, top, right, bottom))
|
90 |
+
|
91 |
+
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
92 |
+
|
93 |
+
return trans(img)[None, :, :, :]
|
94 |
+
|
95 |
+
def save(tensor, file):
|
96 |
+
img = tensor.data[0].cpu().float().numpy()
|
97 |
+
img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0
|
98 |
+
img = img.astype(np.uint8)
|
99 |
+
img = Image.fromarray(img)
|
100 |
+
img = img.resize((img.size[0]//4, img.size[1]//4), resample=Image.Resampling.NEAREST)
|
101 |
+
img = img.resize((img.size[0]*4, img.size[1]*4), resample=Image.Resampling.NEAREST)
|
102 |
+
img.save(file)
|
103 |
+
|
104 |
+
def pixelize_cli():
|
105 |
+
import argparse
|
106 |
+
import os
|
107 |
+
parser = argparse.ArgumentParser(description='Pixelization')
|
108 |
+
parser.add_argument('--input', type=str, default=None, required=True, help='path to image or directory')
|
109 |
+
parser.add_argument('--output', type=str, default=None, required=False, help='path to save image/images')
|
110 |
+
parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU')
|
111 |
+
|
112 |
+
args = parser.parse_args()
|
113 |
+
in_path = args.input
|
114 |
+
out_path = args.output
|
115 |
+
use_cpu = args.cpu
|
116 |
+
|
117 |
+
if not os.path.exists("alias_net.pth" if not os.environ['ALIAS_MODEL'] else os.environ['ALIAS_MODEL']):
|
118 |
+
print("missing models")
|
119 |
+
|
120 |
+
pairs = []
|
121 |
+
|
122 |
+
if os.path.isdir(in_path):
|
123 |
+
in_images = glob.glob(in_path + "/*.png") + glob.glob(in_path + "/*.jpg")
|
124 |
+
if not out_path:
|
125 |
+
out_path = os.path.join(in_path, "outputs")
|
126 |
+
if not os.path.exists(out_path):
|
127 |
+
os.makedirs(out_path)
|
128 |
+
elif os.path.isfile(out_path):
|
129 |
+
print("output cant be a file if input is a directory")
|
130 |
+
return
|
131 |
+
for i in in_images:
|
132 |
+
pairs += [(i, i.replace(in_path, out_path))]
|
133 |
+
elif os.path.isfile(in_path):
|
134 |
+
if not out_path:
|
135 |
+
base, ext = os.path.splitext(in_path)
|
136 |
+
out_path = base+"_pixelized"+ext
|
137 |
+
else:
|
138 |
+
if os.path.isdir(out_path):
|
139 |
+
_, file = os.path.split(in_path)
|
140 |
+
out_path = os.path.join(out_path, file)
|
141 |
+
pairs = [(in_path, out_path)]
|
142 |
+
|
143 |
+
m = Model(device = "cpu" if use_cpu else "cuda")
|
144 |
+
m.load()
|
145 |
+
|
146 |
+
for in_file, out_file in pairs:
|
147 |
+
print("PROCESSING", in_file, "TO", out_file)
|
148 |
+
m.pixelize(in_file, out_file)
|
149 |
+
|
150 |
+
if __name__ == "__main__":
|
151 |
+
pixelize_cli()
|
reference.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
transforms
|
4 |
+
numpy==1.24.1
|
5 |
+
pillow
|