Spaces:
Build error
Build error
Files for the app
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- FashionGen.py +265 -0
- LICENSE +21 -0
- README.md +2 -13
- TkTorchWindow.py +208 -0
- bpe_simple_vocab_16e6.txt.gz +3 -0
- cache/components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz +3 -0
- clip.py +237 -0
- config.py +72 -0
- decomposition.py +402 -0
- deps/windows/PyOpenGL-3.1.4-cp37-cp37m-win_amd64.whl +3 -0
- deps/windows/glumpy-1.1.0-cp37-cp37m-win_amd64.whl +3 -0
- deps/windows/pycuda-2019.1.2+cuda101-cp37-cp37m-win_amd64.whl +0 -0
- deps/windows/triangle-20190115.3-cp37-cp37m-win_amd64.whl +3 -0
- environment.yml +25 -0
- estimators.py +218 -0
- interactive.py +655 -0
- model_clip.py +436 -0
- models/__init__.py +11 -0
- models/__pycache__/__init__.cpython-310.pyc +0 -0
- models/__pycache__/wrappers.cpython-310.pyc +0 -0
- models/biggan/__init__.py +8 -0
- models/biggan/__pycache__/__init__.cpython-310.pyc +0 -0
- models/biggan/pytorch_biggan/.gitignore +110 -0
- models/biggan/pytorch_biggan/LICENSE +21 -0
- models/biggan/pytorch_biggan/MANIFEST.in +1 -0
- models/biggan/pytorch_biggan/README.md +227 -0
- models/biggan/pytorch_biggan/assets/output_0.png +0 -0
- models/biggan/pytorch_biggan/assets/output_1.png +0 -0
- models/biggan/pytorch_biggan/assets/output_2.png +0 -0
- models/biggan/pytorch_biggan/full_requirements.txt +5 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py +6 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py +70 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py +312 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py +249 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py +345 -0
- models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py +216 -0
- models/biggan/pytorch_biggan/requirements.txt +8 -0
- models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh +21 -0
- models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh +21 -0
- models/biggan/pytorch_biggan/setup.py +69 -0
- models/stylegan/__init__.py +17 -0
- models/stylegan/__pycache__/__init__.cpython-310.pyc +0 -0
- models/stylegan/__pycache__/model.cpython-310.pyc +0 -0
- models/stylegan/model.py +456 -0
- models/stylegan/stylegan_tf/LICENSE.txt +410 -0
- models/stylegan/stylegan_tf/README.md +232 -0
- models/stylegan/stylegan_tf/config.py +18 -0
- models/stylegan/stylegan_tf/dataset_tool.py +645 -0
- models/stylegan/stylegan_tf/dnnlib/__init__.py +20 -0
.gitattributes
CHANGED
@@ -32,3 +32,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
deps/windows/glumpy-1.1.0-cp37-cp37m-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
36 |
+
deps/windows/PyOpenGL-3.1.4-cp37-cp37m-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
37 |
+
deps/windows/triangle-20190115.3-cp37-cp37m-win_amd64.whl filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models/stylegan/stylegan_tf/stylegan-teaser.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
models/stylegan2/stylegan2-pytorch/doc/sample.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
models/stylegan2/stylegan2-pytorch/doc/stylegan2-church-config-f.png filter=lfs diff=lfs merge=lfs -text
|
FashionGen.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
import PIL
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import imageio
|
8 |
+
from models import get_instrumented_model
|
9 |
+
from decomposition import get_or_compute
|
10 |
+
from config import Config
|
11 |
+
from skimage import img_as_ubyte
|
12 |
+
import clip
|
13 |
+
from torchvision.transforms import Resize, Normalize, Compose, CenterCrop
|
14 |
+
from torch.optim import Adam
|
15 |
+
from stqdm import stqdm
|
16 |
+
|
17 |
+
torch.set_num_threads(8)
|
18 |
+
|
19 |
+
# Speed up computation
|
20 |
+
torch.autograd.set_grad_enabled(True)
|
21 |
+
#torch.backends.cudnn.benchmark = True
|
22 |
+
|
23 |
+
# Specify model to use
|
24 |
+
config = Config(
|
25 |
+
model='StyleGAN2',
|
26 |
+
layer='style',
|
27 |
+
output_class= 'lookbook',
|
28 |
+
components=80,
|
29 |
+
use_w=True,
|
30 |
+
batch_size=5_000, # style layer quite small
|
31 |
+
)
|
32 |
+
|
33 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
34 |
+
|
35 |
+
preprocess = Compose([
|
36 |
+
Resize(224),
|
37 |
+
CenterCrop(224),
|
38 |
+
Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
39 |
+
])
|
40 |
+
|
41 |
+
@st.cache_data
|
42 |
+
def clip_optimized_latent(text, seed, iterations=25, lr=1e-2):
|
43 |
+
seed = int(seed)
|
44 |
+
text_input = clip.tokenize([text]).to(device)
|
45 |
+
|
46 |
+
# Initialize a random latent vector
|
47 |
+
latent_vector = model.sample_latent(1,seed=seed).detach()
|
48 |
+
latent_vector.requires_grad = True
|
49 |
+
latent_vector = [latent_vector]*model.get_max_latents()
|
50 |
+
params = [torch.nn.Parameter(latent_vector[i], requires_grad=True) for i in range(len(latent_vector))]
|
51 |
+
optimizer = Adam(params, lr=lr)
|
52 |
+
|
53 |
+
with torch.no_grad():
|
54 |
+
text_features = clip_model.encode_text(text_input)
|
55 |
+
|
56 |
+
#pbar = tqdm(range(iterations), dynamic_ncols=True)
|
57 |
+
|
58 |
+
for iteration in stqdm(range(iterations)):
|
59 |
+
optimizer.zero_grad()
|
60 |
+
|
61 |
+
# Generate an image from the latent vector
|
62 |
+
image = model.sample(params)
|
63 |
+
image = image.to(device)
|
64 |
+
|
65 |
+
# Preprocess the image for the CLIP model
|
66 |
+
image = preprocess(image)
|
67 |
+
#image = clip_preprocess(Image.fromarray((image_np * 255).astype(np.uint8))).unsqueeze(0).to(device)
|
68 |
+
|
69 |
+
# Extract features from the image
|
70 |
+
image_features = clip_model.encode_image(image)
|
71 |
+
|
72 |
+
# Calculate the loss and backpropagate
|
73 |
+
loss = -torch.cosine_similarity(text_features, image_features).mean()
|
74 |
+
loss.backward()
|
75 |
+
optimizer.step()
|
76 |
+
|
77 |
+
#pbar.set_description(f"Loss: {loss.item()}") # Update the progress bar to show the current loss
|
78 |
+
w = [param.detach().cpu().numpy() for param in params]
|
79 |
+
|
80 |
+
return w
|
81 |
+
|
82 |
+
def mix_w(w1, w2, content, style):
|
83 |
+
for i in range(0,5):
|
84 |
+
w2[i] = w1[i] * (1 - content) + w2[i] * content
|
85 |
+
|
86 |
+
for i in range(5, 16):
|
87 |
+
w2[i] = w1[i] * (1 - style) + w2[i] * style
|
88 |
+
|
89 |
+
return w2
|
90 |
+
|
91 |
+
def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None):
|
92 |
+
# blockPrint()
|
93 |
+
model.truncation = truncation
|
94 |
+
if w is None:
|
95 |
+
w = model.sample_latent(1, seed=seed).detach().cpu().numpy()
|
96 |
+
w = [w]*model.get_max_latents() # one per layer
|
97 |
+
else:
|
98 |
+
w_numpy = [x.detach().numpy() for x in w]
|
99 |
+
w = [np.expand_dims(x, 0) for x in w_numpy]
|
100 |
+
#w = [x.unsqueeze(0) for x in w]
|
101 |
+
|
102 |
+
|
103 |
+
for l in range(start, end):
|
104 |
+
for i in range(len(directions)):
|
105 |
+
w[l] = w[l] + directions[i] * distances[i] * scale
|
106 |
+
|
107 |
+
torch.cuda.empty_cache()
|
108 |
+
#save image and display
|
109 |
+
out = model.sample(w)
|
110 |
+
out = out.permute(0, 2, 3, 1).cpu().detach().numpy()
|
111 |
+
out = np.clip(out, 0.0, 1.0).squeeze()
|
112 |
+
|
113 |
+
final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
|
114 |
+
|
115 |
+
|
116 |
+
if save is not None:
|
117 |
+
if disp == False:
|
118 |
+
print(save)
|
119 |
+
final_im.save(f'out/{seed}_{save:05}.png')
|
120 |
+
if disp:
|
121 |
+
display(final_im)
|
122 |
+
|
123 |
+
return final_im
|
124 |
+
|
125 |
+
## Generate image for app
|
126 |
+
def generate_image(content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer,w1,w2):
|
127 |
+
|
128 |
+
scale = 1
|
129 |
+
params = {'c0': c0,
|
130 |
+
'c1': c1,
|
131 |
+
'c2': c2,
|
132 |
+
'c3': c3,
|
133 |
+
'c4': c4,
|
134 |
+
'c5': c5,
|
135 |
+
'c6': c6}
|
136 |
+
|
137 |
+
param_indexes = {'c0': 0,
|
138 |
+
'c1': 1,
|
139 |
+
'c2': 2,
|
140 |
+
'c3': 3,
|
141 |
+
'c4': 4,
|
142 |
+
'c5': 5,
|
143 |
+
'c6': 6}
|
144 |
+
|
145 |
+
directions = []
|
146 |
+
distances = []
|
147 |
+
for k, v in params.items():
|
148 |
+
directions.append(latent_dirs[param_indexes[k]])
|
149 |
+
distances.append(v)
|
150 |
+
|
151 |
+
if w1 is not None and w2 is not None:
|
152 |
+
w1 = [torch.from_numpy(x).to(device) for x in w1]
|
153 |
+
w2 = [torch.from_numpy(x).to(device) for x in w2]
|
154 |
+
|
155 |
+
|
156 |
+
#w1 = clip_optimized_latent(text1, seed1, iters)
|
157 |
+
im1 = model.sample(w1)
|
158 |
+
im1_np = im1.permute(0, 2, 3, 1).cpu().detach().numpy()
|
159 |
+
im1_np = np.clip(im1_np, 0.0, 1.0).squeeze()
|
160 |
+
|
161 |
+
#w2 = clip_optimized_latent(text2, seed2, iters)
|
162 |
+
im2 = model.sample(w2)
|
163 |
+
im2_np = im2.permute(0, 2, 3, 1).cpu().detach().numpy()
|
164 |
+
im2_np = np.clip(im2_np, 0.0, 1.0).squeeze()
|
165 |
+
|
166 |
+
combined_im = np.concatenate([im1_np, im2_np], axis=1)
|
167 |
+
input_im = Image.fromarray((combined_im * 255).astype(np.uint8))
|
168 |
+
|
169 |
+
|
170 |
+
mixed_w = mix_w(w1, w2, content, style)
|
171 |
+
return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False)
|
172 |
+
|
173 |
+
|
174 |
+
# Streamlit app title
|
175 |
+
st.title("FashionGen Demo - AI assisted fashion design")
|
176 |
+
"""This application employs the StyleGAN framework, CLIP and GANSpace exploration techniques to synthesize images of garments from textual inputs. With training based on the comprehensive LookBook dataset, it supports an efficient fashion design process by transforming text into visual concepts, showcasing the practical application of Generative Adversarial Networks (GANs) in the realm of creative design."""
|
177 |
+
|
178 |
+
@st.cache_resource
|
179 |
+
def load_model():
|
180 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
181 |
+
# Load the pre-trained CLIP model
|
182 |
+
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
|
183 |
+
inst = get_instrumented_model(config.model, config.output_class,
|
184 |
+
config.layer, torch.device('cpu'), use_w=config.use_w)
|
185 |
+
return clip_model, inst
|
186 |
+
|
187 |
+
# Then, to load your models, call this function:
|
188 |
+
clip_model, inst = load_model()
|
189 |
+
model = inst.model
|
190 |
+
|
191 |
+
|
192 |
+
path_to_components = get_or_compute(config, inst)
|
193 |
+
comps = np.load(path_to_components)
|
194 |
+
lst = comps.files
|
195 |
+
latent_dirs = []
|
196 |
+
latent_stdevs = []
|
197 |
+
|
198 |
+
load_activations = False
|
199 |
+
|
200 |
+
for item in lst:
|
201 |
+
if load_activations:
|
202 |
+
if item == 'act_comp':
|
203 |
+
for i in range(comps[item].shape[0]):
|
204 |
+
latent_dirs.append(comps[item][i])
|
205 |
+
if item == 'act_stdev':
|
206 |
+
for i in range(comps[item].shape[0]):
|
207 |
+
latent_stdevs.append(comps[item][i])
|
208 |
+
else:
|
209 |
+
if item == 'lat_comp':
|
210 |
+
for i in range(comps[item].shape[0]):
|
211 |
+
latent_dirs.append(comps[item][i])
|
212 |
+
if item == 'lat_stdev':
|
213 |
+
for i in range(comps[item].shape[0]):
|
214 |
+
latent_stdevs.append(comps[item][i])
|
215 |
+
|
216 |
+
## Side bar texts
|
217 |
+
st.sidebar.title('Tuning Parameters')
|
218 |
+
st.sidebar.subheader('(CLIP + GANSpace)')
|
219 |
+
|
220 |
+
|
221 |
+
# Create UI widgets
|
222 |
+
|
223 |
+
if 'seed1' not in st.session_state and 'seed2' not in st.session_state:
|
224 |
+
st.session_state['seed1'] = random.randint(1, 1000)
|
225 |
+
st.session_state['seed2'] = random.randint(1, 1000)
|
226 |
+
seed1 = st.sidebar.number_input("Seed 1", value= st.session_state['seed1'])
|
227 |
+
seed2 = st.sidebar.number_input("Seed 2", value= st.session_state['seed2'])
|
228 |
+
text1 = st.sidebar.text_input("Text Description 1")
|
229 |
+
text2 = st.sidebar.text_input("Text Description 2")
|
230 |
+
iters = st.sidebar.number_input("Iterations for CLIP Optimization", value = 25)
|
231 |
+
submit_button = st.sidebar.button("Submit")
|
232 |
+
content = st.sidebar.slider("Structural Composition", min_value=0.0, max_value=1.0, value=0.5)
|
233 |
+
style = st.sidebar.slider("Style", min_value=0.0, max_value=1.0, value=0.5)
|
234 |
+
truncation = st.sidebar.slider("Dimensional Scaling", min_value=0.0, max_value=1.0, value=0.5)
|
235 |
+
|
236 |
+
slider_min_val = -20
|
237 |
+
slider_max_val = 20
|
238 |
+
slider_step = 1
|
239 |
+
|
240 |
+
c0 = st.sidebar.slider("Sleeve Size Scaling", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
241 |
+
c1 = st.sidebar.slider("Jacket Features", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
242 |
+
c2 = st.sidebar.slider("Women's Overcoat", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
243 |
+
c3 = st.sidebar.slider("Coat", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
244 |
+
c4 = st.sidebar.slider("Graphic Elements", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
245 |
+
c5 = st.sidebar.slider("Darker Color", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
246 |
+
c6 = st.sidebar.slider("Modest Neckline", min_value=slider_min_val, max_value=slider_max_val, value=0)
|
247 |
+
start_layer = st.sidebar.number_input("Start Layer", value=0)
|
248 |
+
end_layer = st.sidebar.number_input("End Layer", value=14)
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
if submit_button: # Execute when the submit button is pressed
|
253 |
+
w1 = clip_optimized_latent(text1, seed1, iters)
|
254 |
+
st.session_state['w1-np'] = w1
|
255 |
+
w2 = clip_optimized_latent(text2, seed2, iters)
|
256 |
+
st.session_state['w2-np'] = w2
|
257 |
+
|
258 |
+
try:
|
259 |
+
input_im, output_im = generate_image(content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer,st.session_state['w1-np'],st.session_state['w2-np'])
|
260 |
+
st.image(input_im, caption="Input Image")
|
261 |
+
st.image(output_im, caption="Output Image")
|
262 |
+
except:
|
263 |
+
pass
|
264 |
+
|
265 |
+
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 prathmeshdahikar
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,13 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.19.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: afl-3.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# FashionGen
|
2 |
+
AI assisted fashion design
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TkTorchWindow.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
import tkinter as tk
|
12 |
+
import numpy as np
|
13 |
+
import time
|
14 |
+
from contextlib import contextmanager
|
15 |
+
import pycuda.driver
|
16 |
+
from pycuda.gl import graphics_map_flags
|
17 |
+
from glumpy import gloo, gl
|
18 |
+
from pyopengltk import OpenGLFrame
|
19 |
+
import torch
|
20 |
+
from torch.autograd import Variable
|
21 |
+
|
22 |
+
# TkInter widget that can draw torch tensors directly from GPU memory
|
23 |
+
|
24 |
+
@contextmanager
|
25 |
+
def cuda_activate(img):
|
26 |
+
"""Context manager simplifying use of pycuda.gl.RegisteredImage"""
|
27 |
+
mapping = img.map()
|
28 |
+
yield mapping.array(0,0)
|
29 |
+
mapping.unmap()
|
30 |
+
|
31 |
+
def create_shared_texture(w, h, c=4,
|
32 |
+
map_flags=graphics_map_flags.WRITE_DISCARD,
|
33 |
+
dtype=np.uint8):
|
34 |
+
"""Create and return a Texture2D with gloo and pycuda views."""
|
35 |
+
tex = np.zeros((h,w,c), dtype).view(gloo.Texture2D)
|
36 |
+
tex.activate() # force gloo to create on GPU
|
37 |
+
tex.deactivate()
|
38 |
+
cuda_buffer = pycuda.gl.RegisteredImage(
|
39 |
+
int(tex.handle), tex.target, map_flags)
|
40 |
+
return tex, cuda_buffer
|
41 |
+
|
42 |
+
# Shape batch as square if possible
|
43 |
+
def get_grid_dims(B):
|
44 |
+
S = int(B**0.5 + 0.5)
|
45 |
+
while B % S != 0:
|
46 |
+
S -= 1
|
47 |
+
return (B // S, S)
|
48 |
+
|
49 |
+
def create_gl_texture(tensor_shape):
|
50 |
+
if len(tensor_shape) != 4:
|
51 |
+
raise RuntimeError('Please provide a tensor of shape NCHW')
|
52 |
+
|
53 |
+
N, C, H, W = tensor_shape
|
54 |
+
|
55 |
+
cols, rows = get_grid_dims(N)
|
56 |
+
tex, cuda_buffer = create_shared_texture(W*cols, H*rows, 4)
|
57 |
+
|
58 |
+
return tex, cuda_buffer
|
59 |
+
|
60 |
+
# Create window with OpenGL context
|
61 |
+
class TorchImageView(OpenGLFrame):
|
62 |
+
def __init__(self, root = None, show_fps=True, **kwargs):
|
63 |
+
self.root = root or tk.Tk()
|
64 |
+
self.width = kwargs.get('width', 512)
|
65 |
+
self.height = kwargs.get('height', 512)
|
66 |
+
self.show_fps = show_fps
|
67 |
+
self.pycuda_initialized = False
|
68 |
+
self.animate = 0 # disable internal main loop
|
69 |
+
OpenGLFrame.__init__(self, root, **kwargs)
|
70 |
+
|
71 |
+
# Called by pyopengltk.BaseOpenGLFrame
|
72 |
+
# when the frame goes onto the screen
|
73 |
+
def initgl(self):
|
74 |
+
if not self.pycuda_initialized:
|
75 |
+
self.setup_gl(self.width, self.height)
|
76 |
+
self.pycuda_initialized = True
|
77 |
+
|
78 |
+
"""Initalize gl states when the frame is created"""
|
79 |
+
gl.glViewport(0, 0, self.width, self.height)
|
80 |
+
gl.glClearColor(0.0, 0.0, 0.0, 0.0)
|
81 |
+
self.dt_history = [1000/60]
|
82 |
+
self.t0 = time.time()
|
83 |
+
self.t_last = self.t0
|
84 |
+
self.nframes = 0
|
85 |
+
|
86 |
+
def setup_gl(self, width, height):
|
87 |
+
# setup pycuda and torch
|
88 |
+
import pycuda.gl.autoinit
|
89 |
+
import pycuda.gl
|
90 |
+
|
91 |
+
assert torch.cuda.is_available(), "PyTorch: CUDA is not available"
|
92 |
+
print('Using GPU {}'.format(torch.cuda.current_device()))
|
93 |
+
|
94 |
+
# Create tensor to be shared between GL and CUDA
|
95 |
+
# Always overwritten so no sharing is necessary
|
96 |
+
dummy = torch.cuda.FloatTensor((1))
|
97 |
+
dummy.uniform_()
|
98 |
+
dummy = Variable(dummy)
|
99 |
+
|
100 |
+
# Create a buffer with pycuda and gloo views, using tensor created above
|
101 |
+
self.tex, self.cuda_buffer = create_gl_texture((1, 3, width, height))
|
102 |
+
|
103 |
+
# create a shader to program to draw to the screen
|
104 |
+
vertex = """
|
105 |
+
uniform float scale;
|
106 |
+
attribute vec2 position;
|
107 |
+
attribute vec2 texcoord;
|
108 |
+
varying vec2 v_texcoord;
|
109 |
+
void main()
|
110 |
+
{
|
111 |
+
v_texcoord = texcoord;
|
112 |
+
gl_Position = vec4(scale*position, 0.0, 1.0);
|
113 |
+
} """
|
114 |
+
fragment = """
|
115 |
+
uniform sampler2D tex;
|
116 |
+
varying vec2 v_texcoord;
|
117 |
+
void main()
|
118 |
+
{
|
119 |
+
gl_FragColor = texture2D(tex, v_texcoord);
|
120 |
+
} """
|
121 |
+
# Build the program and corresponding buffers (with 4 vertices)
|
122 |
+
self.screen = gloo.Program(vertex, fragment, count=4)
|
123 |
+
|
124 |
+
# NDC coordinates: Texcoords: Vertex order,
|
125 |
+
# (-1, +1) (+1, +1) (0,0) (1,0) triangle strip:
|
126 |
+
# +-------+ +----+ 1----3
|
127 |
+
# | NDC | | | | / |
|
128 |
+
# | SPACE | | | | / |
|
129 |
+
# +-------+ +----+ 2----4
|
130 |
+
# (-1, -1) (+1, -1) (0,1) (1,1)
|
131 |
+
|
132 |
+
# Upload data to GPU
|
133 |
+
self.screen['position'] = [(-1,+1), (-1,-1), (+1,+1), (+1,-1)]
|
134 |
+
self.screen['texcoord'] = [(0,0), (0,1), (1,0), (1,1)]
|
135 |
+
self.screen['scale'] = 1.0
|
136 |
+
self.screen['tex'] = self.tex
|
137 |
+
|
138 |
+
# Don't call directly, use update() instead
|
139 |
+
def redraw(self):
|
140 |
+
t_now = time.time()
|
141 |
+
dt = t_now - self.t_last
|
142 |
+
self.t_last = t_now
|
143 |
+
|
144 |
+
self.dt_history = ([dt] + self.dt_history)[:50]
|
145 |
+
dt_mean = sum(self.dt_history) / len(self.dt_history)
|
146 |
+
|
147 |
+
if self.show_fps and self.nframes % 60 == 0:
|
148 |
+
self.master.title('FPS: {:.0f}'.format(1 / dt_mean))
|
149 |
+
|
150 |
+
def draw(self, img):
|
151 |
+
assert len(img.shape) == 4, "Please provide an NCHW image tensor"
|
152 |
+
assert img.device.type == "cuda", "Please provide a CUDA tensor"
|
153 |
+
|
154 |
+
if img.dtype.is_floating_point:
|
155 |
+
img = (255*img).byte()
|
156 |
+
|
157 |
+
# Tile images
|
158 |
+
N, C, H, W = img.shape
|
159 |
+
|
160 |
+
if N > 1:
|
161 |
+
cols, rows = get_grid_dims(N)
|
162 |
+
img = img.reshape(cols, rows, C, H, W)
|
163 |
+
img = img.permute(2, 1, 3, 0, 4) # [C, rows, H, cols, W]
|
164 |
+
img = img.reshape(1, C, rows*H, cols*W)
|
165 |
+
|
166 |
+
tensor = img.squeeze().permute(1, 2, 0).data # CHW => HWC
|
167 |
+
if C == 3:
|
168 |
+
tensor = torch.cat((tensor, tensor[:,:,0:1]),2) # add the alpha channel
|
169 |
+
tensor[:,:,3] = 1 # set alpha
|
170 |
+
|
171 |
+
tensor = tensor.contiguous()
|
172 |
+
|
173 |
+
tex_h, tex_w, _ = self.tex.shape
|
174 |
+
tensor_h, tensor_w, _ = tensor.shape
|
175 |
+
|
176 |
+
if (tex_h, tex_w) != (tensor_h, tensor_w):
|
177 |
+
print(f'Resizing texture to {tensor_w}*{tensor_h}')
|
178 |
+
self.tex, self.cuda_buffer = create_gl_texture((N, C, H, W)) # original shape
|
179 |
+
self.screen['tex'] = self.tex
|
180 |
+
|
181 |
+
# copy from torch into buffer
|
182 |
+
assert self.tex.nbytes == tensor.numel()*tensor.element_size(), "Tensor and texture shape mismatch!"
|
183 |
+
with cuda_activate(self.cuda_buffer) as ary:
|
184 |
+
cpy = pycuda.driver.Memcpy2D()
|
185 |
+
cpy.set_src_device(tensor.data_ptr())
|
186 |
+
cpy.set_dst_array(ary)
|
187 |
+
cpy.width_in_bytes = cpy.src_pitch = cpy.dst_pitch = self.tex.nbytes//tensor_h
|
188 |
+
cpy.height = tensor_h
|
189 |
+
cpy(aligned=False)
|
190 |
+
torch.cuda.synchronize()
|
191 |
+
|
192 |
+
# draw to screen
|
193 |
+
self.screen.draw(gl.GL_TRIANGLE_STRIP)
|
194 |
+
|
195 |
+
def update(self):
|
196 |
+
self.update_idletasks()
|
197 |
+
self.tkMakeCurrent()
|
198 |
+
self.redraw()
|
199 |
+
self.tkSwapBuffers()
|
200 |
+
|
201 |
+
# USAGE:
|
202 |
+
# root = tk.Tk()
|
203 |
+
# iv = TorchImageView(root, width=512, height=512)
|
204 |
+
# iv.pack(fill='both', expand=True)
|
205 |
+
# while True:
|
206 |
+
# iv.draw(nchw_tensor)
|
207 |
+
# root.update()
|
208 |
+
# iv.update()
|
bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
cache/components/stylegan2-lookbook_style_ipca_c80_n300000_w.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80cde5f3476909d69649ebcb1f9872d0fd95cb1632770db1b7fb962608f905b8
|
3 |
+
size 312351
|
clip.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from model_clip import build_model
|
14 |
+
from simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
+
|
16 |
+
try:
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
19 |
+
except ImportError:
|
20 |
+
BICUBIC = Image.BICUBIC
|
21 |
+
|
22 |
+
|
23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
25 |
+
|
26 |
+
|
27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
28 |
+
_tokenizer = _Tokenizer()
|
29 |
+
|
30 |
+
_MODELS = {
|
31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def _download(url: str, root: str):
|
44 |
+
os.makedirs(root, exist_ok=True)
|
45 |
+
filename = os.path.basename(url)
|
46 |
+
|
47 |
+
expected_sha256 = url.split("/")[-2]
|
48 |
+
download_target = os.path.join(root, filename)
|
49 |
+
|
50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
52 |
+
|
53 |
+
if os.path.isfile(download_target):
|
54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
55 |
+
return download_target
|
56 |
+
else:
|
57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
58 |
+
|
59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
|
69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
71 |
+
|
72 |
+
return download_target
|
73 |
+
|
74 |
+
|
75 |
+
def _convert_image_to_rgb(image):
|
76 |
+
return image.convert("RGB")
|
77 |
+
|
78 |
+
|
79 |
+
def _transform(n_px):
|
80 |
+
return Compose([
|
81 |
+
Resize(n_px, interpolation=BICUBIC),
|
82 |
+
CenterCrop(n_px),
|
83 |
+
_convert_image_to_rgb,
|
84 |
+
ToTensor(),
|
85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
86 |
+
])
|
87 |
+
|
88 |
+
|
89 |
+
def available_models() -> List[str]:
|
90 |
+
"""Returns the names of available CLIP models"""
|
91 |
+
return list(_MODELS.keys())
|
92 |
+
|
93 |
+
|
94 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
95 |
+
"""Load a CLIP model
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
name : str
|
100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
101 |
+
|
102 |
+
device : Union[str, torch.device]
|
103 |
+
The device to put the loaded model
|
104 |
+
|
105 |
+
jit : bool
|
106 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
107 |
+
|
108 |
+
download_root: str
|
109 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
110 |
+
|
111 |
+
Returns
|
112 |
+
-------
|
113 |
+
model : torch.nn.Module
|
114 |
+
The CLIP model
|
115 |
+
|
116 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
117 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
118 |
+
"""
|
119 |
+
if name in _MODELS:
|
120 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
121 |
+
elif os.path.isfile(name):
|
122 |
+
model_path = name
|
123 |
+
else:
|
124 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
125 |
+
|
126 |
+
with open(model_path, 'rb') as opened_file:
|
127 |
+
try:
|
128 |
+
# loading JIT archive
|
129 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
130 |
+
state_dict = None
|
131 |
+
except RuntimeError:
|
132 |
+
# loading saved state dict
|
133 |
+
if jit:
|
134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
135 |
+
jit = False
|
136 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
137 |
+
|
138 |
+
if not jit:
|
139 |
+
model = build_model(state_dict or model.state_dict()).to(device)
|
140 |
+
if str(device) == "cpu":
|
141 |
+
model.float()
|
142 |
+
return model, _transform(model.visual.input_resolution)
|
143 |
+
|
144 |
+
# patch the device names
|
145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
147 |
+
|
148 |
+
def patch_device(module):
|
149 |
+
try:
|
150 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
151 |
+
except RuntimeError:
|
152 |
+
graphs = []
|
153 |
+
|
154 |
+
if hasattr(module, "forward1"):
|
155 |
+
graphs.append(module.forward1.graph)
|
156 |
+
|
157 |
+
for graph in graphs:
|
158 |
+
for node in graph.findAllNodes("prim::Constant"):
|
159 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
160 |
+
node.copyAttributes(device_node)
|
161 |
+
|
162 |
+
model.apply(patch_device)
|
163 |
+
patch_device(model.encode_image)
|
164 |
+
patch_device(model.encode_text)
|
165 |
+
|
166 |
+
# patch dtype to float32 on CPU
|
167 |
+
if str(device) == "cpu":
|
168 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
169 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
170 |
+
float_node = float_input.node()
|
171 |
+
|
172 |
+
def patch_float(module):
|
173 |
+
try:
|
174 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
175 |
+
except RuntimeError:
|
176 |
+
graphs = []
|
177 |
+
|
178 |
+
if hasattr(module, "forward1"):
|
179 |
+
graphs.append(module.forward1.graph)
|
180 |
+
|
181 |
+
for graph in graphs:
|
182 |
+
for node in graph.findAllNodes("aten::to"):
|
183 |
+
inputs = list(node.inputs())
|
184 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
185 |
+
if inputs[i].node()["value"] == 5:
|
186 |
+
inputs[i].node().copyAttributes(float_node)
|
187 |
+
|
188 |
+
model.apply(patch_float)
|
189 |
+
patch_float(model.encode_image)
|
190 |
+
patch_float(model.encode_text)
|
191 |
+
|
192 |
+
model.float()
|
193 |
+
|
194 |
+
return model, _transform(model.input_resolution.item())
|
195 |
+
|
196 |
+
|
197 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
198 |
+
"""
|
199 |
+
Returns the tokenized representation of given input string(s)
|
200 |
+
|
201 |
+
Parameters
|
202 |
+
----------
|
203 |
+
texts : Union[str, List[str]]
|
204 |
+
An input string or a list of input strings to tokenize
|
205 |
+
|
206 |
+
context_length : int
|
207 |
+
The context length to use; all CLIP models use 77 as the context length
|
208 |
+
|
209 |
+
truncate: bool
|
210 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
211 |
+
|
212 |
+
Returns
|
213 |
+
-------
|
214 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
215 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
216 |
+
"""
|
217 |
+
if isinstance(texts, str):
|
218 |
+
texts = [texts]
|
219 |
+
|
220 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
221 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
222 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
223 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
224 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
225 |
+
else:
|
226 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
227 |
+
|
228 |
+
for i, tokens in enumerate(all_tokens):
|
229 |
+
if len(tokens) > context_length:
|
230 |
+
if truncate:
|
231 |
+
tokens = tokens[:context_length]
|
232 |
+
tokens[-1] = eot_token
|
233 |
+
else:
|
234 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
235 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
236 |
+
|
237 |
+
return result
|
config.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
import sys
|
12 |
+
import argparse
|
13 |
+
import json
|
14 |
+
from copy import deepcopy
|
15 |
+
|
16 |
+
class Config:
|
17 |
+
def __init__(self, **kwargs):
|
18 |
+
self.from_args([]) # set all defaults
|
19 |
+
self.default_args = deepcopy(self.__dict__)
|
20 |
+
self.from_dict(kwargs) # override
|
21 |
+
|
22 |
+
def __str__(self):
|
23 |
+
custom = {}
|
24 |
+
default = {}
|
25 |
+
|
26 |
+
# Find non-default arguments
|
27 |
+
for k, v in self.__dict__.items():
|
28 |
+
if k == 'default_args':
|
29 |
+
continue
|
30 |
+
|
31 |
+
in_default = k in self.default_args
|
32 |
+
same_value = self.default_args.get(k) == v
|
33 |
+
|
34 |
+
if in_default and same_value:
|
35 |
+
default[k] = v
|
36 |
+
else:
|
37 |
+
custom[k] = v
|
38 |
+
|
39 |
+
config = {
|
40 |
+
'custom': custom,
|
41 |
+
'default': default
|
42 |
+
}
|
43 |
+
|
44 |
+
return json.dumps(config, indent=4)
|
45 |
+
|
46 |
+
def __repr__(self):
|
47 |
+
return self.__str__()
|
48 |
+
|
49 |
+
def from_dict(self, dictionary):
|
50 |
+
for k, v in dictionary.items():
|
51 |
+
setattr(self, k, v)
|
52 |
+
return self
|
53 |
+
|
54 |
+
def from_args(self, args=sys.argv[1:]):
|
55 |
+
parser = argparse.ArgumentParser(description='GAN component analysis config')
|
56 |
+
parser.add_argument('--model', dest='model', type=str, default='StyleGAN', help='The network to analyze') # StyleGAN, DCGAN, ProGAN, BigGAN-XYZ
|
57 |
+
parser.add_argument('--layer', dest='layer', type=str, default='g_mapping', help='The layer to analyze')
|
58 |
+
parser.add_argument('--class', dest='output_class', type=str, default=None, help='Output class to generate (BigGAN: Imagenet, ProGAN: LSUN)')
|
59 |
+
parser.add_argument('--est', dest='estimator', type=str, default='ipca', help='The algorithm to use [pca, fbpca, cupca, spca, ica]')
|
60 |
+
parser.add_argument('--sparsity', type=float, default=1.0, help='Sparsity parameter of SPCA')
|
61 |
+
parser.add_argument('--video', dest='make_video', action='store_true', help='Generate output videos (MP4s)')
|
62 |
+
parser.add_argument('--batch', dest='batch_mode', action='store_true', help="Don't open windows, instead save results to file")
|
63 |
+
parser.add_argument('-b', dest='batch_size', type=int, default=None, help='Minibatch size, leave empty for automatic detection')
|
64 |
+
parser.add_argument('-c', dest='components', type=int, default=80, help='Number of components to keep')
|
65 |
+
parser.add_argument('-n', type=int, default=300_000, help='Number of examples to use in decomposition')
|
66 |
+
parser.add_argument('--use_w', action='store_true', help='Use W latent space (StyleGAN(2))')
|
67 |
+
parser.add_argument('--sigma', type=float, default=2.0, help='Number of stdevs to walk in visualize.py')
|
68 |
+
parser.add_argument('--inputs', type=str, default=None, help='Path to directory with named components')
|
69 |
+
parser.add_argument('--seed', type=int, default=None, help='Seed used in decomposition')
|
70 |
+
args = parser.parse_args(args)
|
71 |
+
|
72 |
+
return self.from_dict(args.__dict__)
|
decomposition.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
# Patch for broken CTRL+C handler
|
12 |
+
# https://github.com/ContinuumIO/anaconda-issues/issues/905
|
13 |
+
import os
|
14 |
+
os.environ['FOR_DISABLE_CONSOLE_CTRL_HANDLER'] = '1'
|
15 |
+
|
16 |
+
import numpy as np
|
17 |
+
import os
|
18 |
+
from pathlib import Path
|
19 |
+
import re
|
20 |
+
import sys
|
21 |
+
import datetime
|
22 |
+
import argparse
|
23 |
+
import torch
|
24 |
+
import json
|
25 |
+
from types import SimpleNamespace
|
26 |
+
import scipy
|
27 |
+
from scipy.cluster.vq import kmeans
|
28 |
+
from tqdm import trange
|
29 |
+
from netdissect.nethook import InstrumentedModel
|
30 |
+
from config import Config
|
31 |
+
from estimators import get_estimator
|
32 |
+
from models import get_instrumented_model
|
33 |
+
|
34 |
+
SEED_SAMPLING = 1
|
35 |
+
SEED_RANDOM_DIRS = 2
|
36 |
+
SEED_LINREG = 3
|
37 |
+
SEED_VISUALIZATION = 5
|
38 |
+
|
39 |
+
B = 20
|
40 |
+
n_clusters = 500
|
41 |
+
|
42 |
+
def get_random_dirs(components, dimensions):
|
43 |
+
gen = np.random.RandomState(seed=SEED_RANDOM_DIRS)
|
44 |
+
dirs = gen.normal(size=(components, dimensions))
|
45 |
+
dirs /= np.sqrt(np.sum(dirs**2, axis=1, keepdims=True))
|
46 |
+
return dirs.astype(np.float32)
|
47 |
+
|
48 |
+
# Compute maximum batch size for given VRAM and network
|
49 |
+
def get_max_batch_size(inst, device, layer_name=None):
|
50 |
+
inst.remove_edits()
|
51 |
+
|
52 |
+
# Reset statistics
|
53 |
+
torch.cuda.reset_max_memory_cached(device)
|
54 |
+
torch.cuda.reset_max_memory_allocated(device)
|
55 |
+
total_mem = torch.cuda.get_device_properties(device).total_memory
|
56 |
+
|
57 |
+
B_max = 20
|
58 |
+
|
59 |
+
# Measure actual usage
|
60 |
+
for i in range(2, B_max, 2):
|
61 |
+
z = inst.model.sample_latent(n_samples=i)
|
62 |
+
if layer_name:
|
63 |
+
inst.model.partial_forward(z, layer_name)
|
64 |
+
else:
|
65 |
+
inst.model.forward(z)
|
66 |
+
|
67 |
+
maxmem = torch.cuda.max_memory_allocated(device)
|
68 |
+
del z
|
69 |
+
|
70 |
+
if maxmem > 0.5*total_mem:
|
71 |
+
print('Batch size {:d}: memory usage {:.0f}MB'.format(i, maxmem / 1e6))
|
72 |
+
return i
|
73 |
+
|
74 |
+
return B_max
|
75 |
+
|
76 |
+
# Solve for directions in latent space that match PCs in activaiton space
|
77 |
+
def linreg_lstsq(comp_np, mean_np, stdev_np, inst, config):
|
78 |
+
print('Performing least squares regression', flush=True)
|
79 |
+
|
80 |
+
torch.manual_seed(SEED_LINREG)
|
81 |
+
np.random.seed(SEED_LINREG)
|
82 |
+
|
83 |
+
comp = torch.from_numpy(comp_np).float().to(inst.model.device)
|
84 |
+
mean = torch.from_numpy(mean_np).float().to(inst.model.device)
|
85 |
+
stdev = torch.from_numpy(stdev_np).float().to(inst.model.device)
|
86 |
+
|
87 |
+
n_samp = max(10_000, config.n) // B * B # make divisible
|
88 |
+
n_comp = comp.shape[0]
|
89 |
+
latent_dims = inst.model.get_latent_dims()
|
90 |
+
|
91 |
+
# We're looking for M s.t. M*P*G'(Z) = Z => M*A = Z
|
92 |
+
# Z = batch of latent vectors (n_samples x latent_dims)
|
93 |
+
# G'(Z) = batch of activations at intermediate layer
|
94 |
+
# A = P*G'(Z) = projected activations (n_samples x pca_coords)
|
95 |
+
# M = linear mapping (pca_coords x latent_dims)
|
96 |
+
|
97 |
+
# Minimization min_M ||MA - Z||_l2 rewritten as min_M.T ||A.T*M.T - Z.T||_l2
|
98 |
+
# to match format expected by pytorch.lstsq
|
99 |
+
|
100 |
+
# TODO: regression on pixel-space outputs? (using nonlinear optimizer)
|
101 |
+
# min_M lpips(G_full(MA), G_full(Z))
|
102 |
+
|
103 |
+
# Tensors to fill with data
|
104 |
+
# Dimensions other way around, so these are actually the transposes
|
105 |
+
A = np.zeros((n_samp, n_comp), dtype=np.float32)
|
106 |
+
Z = np.zeros((n_samp, latent_dims), dtype=np.float32)
|
107 |
+
|
108 |
+
# Project tensor X onto PCs, return coordinates
|
109 |
+
def project(X, comp):
|
110 |
+
N = X.shape[0]
|
111 |
+
K = comp.shape[0]
|
112 |
+
coords = torch.bmm(comp.expand([N]+[-1]*comp.ndim), X.view(N, -1, 1))
|
113 |
+
return coords.reshape(N, K)
|
114 |
+
|
115 |
+
for i in trange(n_samp // B, desc='Collecting samples', ascii=True):
|
116 |
+
z = inst.model.sample_latent(B)
|
117 |
+
inst.model.partial_forward(z, config.layer)
|
118 |
+
act = inst.retained_features()[config.layer].reshape(B, -1)
|
119 |
+
|
120 |
+
# Project onto basis
|
121 |
+
act = act - mean
|
122 |
+
coords = project(act, comp)
|
123 |
+
coords_scaled = coords / stdev
|
124 |
+
|
125 |
+
A[i*B:(i+1)*B] = coords_scaled.detach().cpu().numpy()
|
126 |
+
Z[i*B:(i+1)*B] = z.detach().cpu().numpy().reshape(B, -1)
|
127 |
+
|
128 |
+
# Solve least squares fit
|
129 |
+
|
130 |
+
# gelsd = divide-and-conquer SVD; good default
|
131 |
+
# gelsy = complete orthogonal factorization; sometimes faster
|
132 |
+
# gelss = SVD; slow but less memory hungry
|
133 |
+
M_t = scipy.linalg.lstsq(A, Z, lapack_driver='gelsd')[0] # torch.lstsq(Z, A)[0][:n_comp, :]
|
134 |
+
|
135 |
+
# Solution given by rows of M_t
|
136 |
+
Z_comp = M_t[:n_comp, :]
|
137 |
+
Z_mean = np.mean(Z, axis=0, keepdims=True)
|
138 |
+
|
139 |
+
return Z_comp, Z_mean
|
140 |
+
|
141 |
+
def regression(comp, mean, stdev, inst, config):
|
142 |
+
# Sanity check: verify orthonormality
|
143 |
+
M = np.dot(comp, comp.T)
|
144 |
+
if not np.allclose(M, np.identity(M.shape[0])):
|
145 |
+
det = np.linalg.det(M)
|
146 |
+
print(f'WARNING: Computed basis is not orthonormal (determinant={det})')
|
147 |
+
|
148 |
+
return linreg_lstsq(comp, mean, stdev, inst, config)
|
149 |
+
|
150 |
+
def compute(config, dump_name, instrumented_model):
|
151 |
+
global B
|
152 |
+
|
153 |
+
timestamp = lambda : datetime.datetime.now().strftime("%d.%m %H:%M")
|
154 |
+
print(f'[{timestamp()}] Computing', dump_name.name)
|
155 |
+
|
156 |
+
# Ensure reproducibility
|
157 |
+
torch.manual_seed(0) # also sets cuda seeds
|
158 |
+
np.random.seed(0)
|
159 |
+
|
160 |
+
# Speed up backend
|
161 |
+
torch.backends.cudnn.benchmark = True
|
162 |
+
|
163 |
+
has_gpu = torch.cuda.is_available()
|
164 |
+
device = torch.device('cuda' if has_gpu else 'cpu')
|
165 |
+
layer_key = config.layer
|
166 |
+
|
167 |
+
if instrumented_model is None:
|
168 |
+
inst = get_instrumented_model(config.model, config.output_class, layer_key, device)
|
169 |
+
model = inst.model
|
170 |
+
else:
|
171 |
+
print('Reusing InstrumentedModel instance')
|
172 |
+
inst = instrumented_model
|
173 |
+
model = inst.model
|
174 |
+
inst.remove_edits()
|
175 |
+
model.set_output_class(config.output_class)
|
176 |
+
|
177 |
+
# Regress back to w space
|
178 |
+
if config.use_w:
|
179 |
+
print('Using W latent space')
|
180 |
+
model.use_w()
|
181 |
+
|
182 |
+
inst.retain_layer(layer_key)
|
183 |
+
model.partial_forward(model.sample_latent(1), layer_key)
|
184 |
+
sample_shape = inst.retained_features()[layer_key].shape
|
185 |
+
sample_dims = np.prod(sample_shape)
|
186 |
+
print('Feature shape:', sample_shape)
|
187 |
+
|
188 |
+
input_shape = inst.model.get_latent_shape()
|
189 |
+
input_dims = inst.model.get_latent_dims()
|
190 |
+
|
191 |
+
config.components = min(config.components, sample_dims)
|
192 |
+
transformer = get_estimator(config.estimator, config.components, config.sparsity)
|
193 |
+
|
194 |
+
X = None
|
195 |
+
X_global_mean = None
|
196 |
+
|
197 |
+
# Figure out batch size if not provided
|
198 |
+
B = config.batch_size or get_max_batch_size(inst, device, layer_key)
|
199 |
+
|
200 |
+
# Divisible by B (ignored in output name)
|
201 |
+
N = config.n // B * B
|
202 |
+
|
203 |
+
# Compute maximum batch size based on RAM + pagefile budget
|
204 |
+
target_bytes = 20 * 1_000_000_000 # GB
|
205 |
+
feat_size_bytes = sample_dims * np.dtype('float64').itemsize
|
206 |
+
N_limit_RAM = np.floor_divide(target_bytes, feat_size_bytes)
|
207 |
+
if not transformer.batch_support and N > N_limit_RAM:
|
208 |
+
print('WARNING: estimator does not support batching, ' \
|
209 |
+
'given config will use {:.1f} GB memory.'.format(feat_size_bytes / 1_000_000_000 * N))
|
210 |
+
|
211 |
+
# 32-bit LAPACK gets very unhappy about huge matrices (in linalg.svd)
|
212 |
+
if config.estimator == 'ica':
|
213 |
+
lapack_max_N = np.floor_divide(np.iinfo(np.int32).max // 4, sample_dims) # 4x extra buffer
|
214 |
+
if N > lapack_max_N:
|
215 |
+
raise RuntimeError(f'Matrices too large for ICA, please use N <= {lapack_max_N}')
|
216 |
+
|
217 |
+
print('B={}, N={}, dims={}, N/dims={:.1f}'.format(B, N, sample_dims, N/sample_dims), flush=True)
|
218 |
+
|
219 |
+
# Must not depend on chosen batch size (reproducibility)
|
220 |
+
NB = max(B, max(2_000, 3*config.components)) # ipca: as large as possible!
|
221 |
+
|
222 |
+
samples = None
|
223 |
+
if not transformer.batch_support:
|
224 |
+
samples = np.zeros((N + NB, sample_dims), dtype=np.float32)
|
225 |
+
|
226 |
+
torch.manual_seed(config.seed or SEED_SAMPLING)
|
227 |
+
np.random.seed(config.seed or SEED_SAMPLING)
|
228 |
+
|
229 |
+
# Use exactly the same latents regardless of batch size
|
230 |
+
# Store in main memory, since N might be huge (1M+)
|
231 |
+
# Run in batches, since sample_latent() might perform Z -> W mapping
|
232 |
+
n_lat = ((N + NB - 1) // B + 1) * B
|
233 |
+
latents = np.zeros((n_lat, *input_shape[1:]), dtype=np.float32)
|
234 |
+
with torch.no_grad():
|
235 |
+
for i in trange(n_lat // B, desc='Sampling latents'):
|
236 |
+
latents[i*B:(i+1)*B] = model.sample_latent(n_samples=B).cpu().numpy()
|
237 |
+
|
238 |
+
# Decomposition on non-Gaussian latent space
|
239 |
+
samples_are_latents = layer_key in ['g_mapping', 'style'] and inst.model.latent_space_name() == 'W'
|
240 |
+
|
241 |
+
canceled = False
|
242 |
+
try:
|
243 |
+
X = np.ones((NB, sample_dims), dtype=np.float32)
|
244 |
+
action = 'Fitting' if transformer.batch_support else 'Collecting'
|
245 |
+
for gi in trange(0, N, NB, desc=f'{action} batches (NB={NB})', ascii=True):
|
246 |
+
for mb in range(0, NB, B):
|
247 |
+
z = torch.from_numpy(latents[gi+mb:gi+mb+B]).to(device)
|
248 |
+
|
249 |
+
if samples_are_latents:
|
250 |
+
# Decomposition on latents directly (e.g. StyleGAN W)
|
251 |
+
batch = z.reshape((B, -1))
|
252 |
+
else:
|
253 |
+
# Decomposition on intermediate layer
|
254 |
+
with torch.no_grad():
|
255 |
+
model.partial_forward(z, layer_key)
|
256 |
+
|
257 |
+
# Permuted to place PCA dimensions last
|
258 |
+
batch = inst.retained_features()[layer_key].reshape((B, -1))
|
259 |
+
|
260 |
+
space_left = min(B, NB - mb)
|
261 |
+
X[mb:mb+space_left] = batch.cpu().numpy()[:space_left]
|
262 |
+
|
263 |
+
if transformer.batch_support:
|
264 |
+
if not transformer.fit_partial(X.reshape(-1, sample_dims)):
|
265 |
+
break
|
266 |
+
else:
|
267 |
+
samples[gi:gi+NB, :] = X.copy()
|
268 |
+
except KeyboardInterrupt:
|
269 |
+
if not transformer.batch_support:
|
270 |
+
sys.exit(1) # no progress yet
|
271 |
+
|
272 |
+
dump_name = dump_name.parent / dump_name.name.replace(f'n{N}', f'n{gi}')
|
273 |
+
print(f'Saving current state to "{dump_name.name}" before exiting')
|
274 |
+
canceled = True
|
275 |
+
|
276 |
+
if not transformer.batch_support:
|
277 |
+
X = samples # Use all samples
|
278 |
+
X_global_mean = X.mean(axis=0, keepdims=True, dtype=np.float32) # TODO: activations surely multi-modal...!
|
279 |
+
X -= X_global_mean
|
280 |
+
|
281 |
+
print(f'[{timestamp()}] Fitting whole batch')
|
282 |
+
t_start_fit = datetime.datetime.now()
|
283 |
+
|
284 |
+
transformer.fit(X)
|
285 |
+
|
286 |
+
print(f'[{timestamp()}] Done in {datetime.datetime.now() - t_start_fit}')
|
287 |
+
assert np.all(transformer.transformer.mean_ < 1e-3), 'Mean of normalized data should be zero'
|
288 |
+
else:
|
289 |
+
X_global_mean = transformer.transformer.mean_.reshape((1, sample_dims))
|
290 |
+
X = X.reshape(-1, sample_dims)
|
291 |
+
X -= X_global_mean
|
292 |
+
|
293 |
+
X_comp, X_stdev, X_var_ratio = transformer.get_components()
|
294 |
+
|
295 |
+
assert X_comp.shape[1] == sample_dims \
|
296 |
+
and X_comp.shape[0] == config.components \
|
297 |
+
and X_global_mean.shape[1] == sample_dims \
|
298 |
+
and X_stdev.shape[0] == config.components, 'Invalid shape'
|
299 |
+
|
300 |
+
# 'Activations' are really latents in a secondary latent space
|
301 |
+
if samples_are_latents:
|
302 |
+
Z_comp = X_comp
|
303 |
+
Z_global_mean = X_global_mean
|
304 |
+
else:
|
305 |
+
Z_comp, Z_global_mean = regression(X_comp, X_global_mean, X_stdev, inst, config)
|
306 |
+
|
307 |
+
# Normalize
|
308 |
+
Z_comp /= np.linalg.norm(Z_comp, axis=-1, keepdims=True)
|
309 |
+
|
310 |
+
# Random projections
|
311 |
+
# We expect these to explain much less of the variance
|
312 |
+
random_dirs = get_random_dirs(config.components, np.prod(sample_shape))
|
313 |
+
n_rand_samples = min(5000, X.shape[0])
|
314 |
+
X_view = X[:n_rand_samples, :].T
|
315 |
+
assert np.shares_memory(X_view, X), "Error: slice produced copy"
|
316 |
+
X_stdev_random = np.dot(random_dirs, X_view).std(axis=1)
|
317 |
+
|
318 |
+
# Inflate back to proper shapes (for easier broadcasting)
|
319 |
+
X_comp = X_comp.reshape(-1, *sample_shape)
|
320 |
+
X_global_mean = X_global_mean.reshape(sample_shape)
|
321 |
+
Z_comp = Z_comp.reshape(-1, *input_shape)
|
322 |
+
Z_global_mean = Z_global_mean.reshape(input_shape)
|
323 |
+
|
324 |
+
# Compute stdev in latent space if non-Gaussian
|
325 |
+
lat_stdev = np.ones_like(X_stdev)
|
326 |
+
if config.use_w:
|
327 |
+
samples = model.sample_latent(5000).reshape(5000, input_dims).detach().cpu().numpy()
|
328 |
+
coords = np.dot(Z_comp.reshape(-1, input_dims), samples.T)
|
329 |
+
lat_stdev = coords.std(axis=1)
|
330 |
+
|
331 |
+
os.makedirs(dump_name.parent, exist_ok=True)
|
332 |
+
np.savez_compressed(dump_name, **{
|
333 |
+
'act_comp': X_comp.astype(np.float32),
|
334 |
+
'act_mean': X_global_mean.astype(np.float32),
|
335 |
+
'act_stdev': X_stdev.astype(np.float32),
|
336 |
+
'lat_comp': Z_comp.astype(np.float32),
|
337 |
+
'lat_mean': Z_global_mean.astype(np.float32),
|
338 |
+
'lat_stdev': lat_stdev.astype(np.float32),
|
339 |
+
'var_ratio': X_var_ratio.astype(np.float32),
|
340 |
+
'random_stdevs': X_stdev_random.astype(np.float32),
|
341 |
+
})
|
342 |
+
|
343 |
+
if canceled:
|
344 |
+
sys.exit(1)
|
345 |
+
|
346 |
+
# Don't shutdown if passed as param
|
347 |
+
if instrumented_model is None:
|
348 |
+
inst.close()
|
349 |
+
del inst
|
350 |
+
del model
|
351 |
+
|
352 |
+
del X
|
353 |
+
del X_comp
|
354 |
+
del random_dirs
|
355 |
+
del batch
|
356 |
+
del samples
|
357 |
+
del latents
|
358 |
+
torch.cuda.empty_cache()
|
359 |
+
|
360 |
+
# Return cached results or commpute if needed
|
361 |
+
# Pass existing InstrumentedModel instance to reuse it
|
362 |
+
def get_or_compute(config, model=None, submit_config=None, force_recompute=False):
|
363 |
+
if submit_config is None:
|
364 |
+
wrkdir = str(Path(__file__).parent.resolve())
|
365 |
+
submit_config = SimpleNamespace(run_dir_root = wrkdir, run_dir = wrkdir)
|
366 |
+
|
367 |
+
# Called directly by run.py
|
368 |
+
return _compute(submit_config, config, model, force_recompute)
|
369 |
+
|
370 |
+
def _compute(submit_config, config, model=None, force_recompute=False):
|
371 |
+
basedir = Path(submit_config.run_dir)
|
372 |
+
outdir = basedir / 'out'
|
373 |
+
|
374 |
+
if config.n is None:
|
375 |
+
raise RuntimeError('Must specify number of samples with -n=XXX')
|
376 |
+
|
377 |
+
if model and not isinstance(model, InstrumentedModel):
|
378 |
+
raise RuntimeError('Passed model has to be wrapped in "InstrumentedModel"')
|
379 |
+
|
380 |
+
if config.use_w and not 'StyleGAN' in config.model:
|
381 |
+
raise RuntimeError(f'Cannot change latent space of non-StyleGAN model {config.model}')
|
382 |
+
|
383 |
+
transformer = get_estimator(config.estimator, config.components, config.sparsity)
|
384 |
+
dump_name = "{}-{}_{}_{}_n{}{}{}.npz".format(
|
385 |
+
config.model.lower(),
|
386 |
+
config.output_class.replace(' ', '_'),
|
387 |
+
config.layer.lower(),
|
388 |
+
transformer.get_param_str(),
|
389 |
+
config.n,
|
390 |
+
'_w' if config.use_w else '',
|
391 |
+
f'_seed{config.seed}' if config.seed else ''
|
392 |
+
)
|
393 |
+
|
394 |
+
dump_path = basedir / 'cache' / 'components' / dump_name
|
395 |
+
|
396 |
+
if not dump_path.is_file() or force_recompute:
|
397 |
+
print('Not cached')
|
398 |
+
t_start = datetime.datetime.now()
|
399 |
+
compute(config, dump_path, model)
|
400 |
+
print('Total time:', datetime.datetime.now() - t_start)
|
401 |
+
|
402 |
+
return dump_path
|
deps/windows/PyOpenGL-3.1.4-cp37-cp37m-win_amd64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6ff6658d48c4c941bc230e9d46c8b6fe593de1c4c523f7b0b678a6a4f920a1e
|
3 |
+
size 2849264
|
deps/windows/glumpy-1.1.0-cp37-cp37m-win_amd64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6e8984115f12b78ea29196d5d34bdddbf080674c3b6c6673b3daa037b61812cb
|
3 |
+
size 1061208
|
deps/windows/pycuda-2019.1.2+cuda101-cp37-cp37m-win_amd64.whl
ADDED
Binary file (361 kB). View file
|
|
deps/windows/triangle-20190115.3-cp37-cp37m-win_amd64.whl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d86a42322673b599a930384b2272b3c3bc666e1c75c994c12f4dd7065e5d44bc
|
3 |
+
size 1431810
|
environment.yml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: ganspace
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
- conda-forge
|
5 |
+
- pytorch
|
6 |
+
dependencies:
|
7 |
+
- python=3.7
|
8 |
+
- pytorch::pytorch=1.3
|
9 |
+
- pytorch::torchvision
|
10 |
+
- cudatoolkit=10.1
|
11 |
+
- pillow=6.2
|
12 |
+
- ffmpeg
|
13 |
+
- tqdm
|
14 |
+
- scipy
|
15 |
+
- scikit-learn
|
16 |
+
- scikit-image
|
17 |
+
- boto3
|
18 |
+
- requests
|
19 |
+
- nltk
|
20 |
+
- pip
|
21 |
+
- pip:
|
22 |
+
- fbpca
|
23 |
+
- pyopengltk
|
24 |
+
|
25 |
+
# conda env update -f environment.yml --prune
|
estimators.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
from sklearn.decomposition import FastICA, PCA, IncrementalPCA, MiniBatchSparsePCA, SparsePCA, KernelPCA
|
12 |
+
import fbpca
|
13 |
+
import numpy as np
|
14 |
+
import itertools
|
15 |
+
from types import SimpleNamespace
|
16 |
+
|
17 |
+
# ICA
|
18 |
+
class ICAEstimator():
|
19 |
+
def __init__(self, n_components):
|
20 |
+
self.n_components = n_components
|
21 |
+
self.maxiter = 10000
|
22 |
+
self.whiten = True # ICA: whitening is essential, should not be skipped
|
23 |
+
self.transformer = FastICA(n_components, random_state=0, whiten=self.whiten, max_iter=self.maxiter)
|
24 |
+
self.batch_support = False
|
25 |
+
self.stdev = np.zeros((n_components,))
|
26 |
+
self.total_var = 0.0
|
27 |
+
|
28 |
+
def get_param_str(self):
|
29 |
+
return "ica_c{}{}".format(self.n_components, '_w' if self.whiten else '')
|
30 |
+
|
31 |
+
def fit(self, X):
|
32 |
+
self.transformer.fit(X)
|
33 |
+
if self.transformer.n_iter_ >= self.maxiter:
|
34 |
+
raise RuntimeError(f'FastICA did not converge (N={X.shape[0]}, it={self.maxiter})')
|
35 |
+
|
36 |
+
# Normalize components
|
37 |
+
self.transformer.components_ /= np.sqrt(np.sum(self.transformer.components_**2, axis=-1, keepdims=True))
|
38 |
+
|
39 |
+
# Save variance for later
|
40 |
+
self.total_var = X.var(axis=0).sum()
|
41 |
+
|
42 |
+
# Compute projected standard deviations
|
43 |
+
self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
|
44 |
+
|
45 |
+
# Sort components based on explained variance
|
46 |
+
idx = np.argsort(self.stdev)[::-1]
|
47 |
+
self.stdev = self.stdev[idx]
|
48 |
+
self.transformer.components_[:] = self.transformer.components_[idx]
|
49 |
+
|
50 |
+
def get_components(self):
|
51 |
+
var_ratio = self.stdev**2 / self.total_var
|
52 |
+
return self.transformer.components_, self.stdev, var_ratio # ICA outputs are not normalized
|
53 |
+
|
54 |
+
# Incremental PCA
|
55 |
+
class IPCAEstimator():
|
56 |
+
def __init__(self, n_components):
|
57 |
+
self.n_components = n_components
|
58 |
+
self.whiten = False
|
59 |
+
self.transformer = IncrementalPCA(n_components, whiten=self.whiten, batch_size=max(100, 2*n_components))
|
60 |
+
self.batch_support = True
|
61 |
+
|
62 |
+
def get_param_str(self):
|
63 |
+
return "ipca_c{}{}".format(self.n_components, '_w' if self.whiten else '')
|
64 |
+
|
65 |
+
def fit(self, X):
|
66 |
+
self.transformer.fit(X)
|
67 |
+
|
68 |
+
def fit_partial(self, X):
|
69 |
+
try:
|
70 |
+
self.transformer.partial_fit(X)
|
71 |
+
self.transformer.n_samples_seen_ = \
|
72 |
+
self.transformer.n_samples_seen_.astype(np.int64) # avoid overflow
|
73 |
+
return True
|
74 |
+
except ValueError as e:
|
75 |
+
print(f'\nIPCA error:', e)
|
76 |
+
return False
|
77 |
+
|
78 |
+
def get_components(self):
|
79 |
+
stdev = np.sqrt(self.transformer.explained_variance_) # already sorted
|
80 |
+
var_ratio = self.transformer.explained_variance_ratio_
|
81 |
+
return self.transformer.components_, stdev, var_ratio # PCA outputs are normalized
|
82 |
+
|
83 |
+
# Standard PCA
|
84 |
+
class PCAEstimator():
|
85 |
+
def __init__(self, n_components):
|
86 |
+
self.n_components = n_components
|
87 |
+
self.solver = 'full'
|
88 |
+
self.transformer = PCA(n_components, svd_solver=self.solver)
|
89 |
+
self.batch_support = False
|
90 |
+
|
91 |
+
def get_param_str(self):
|
92 |
+
return f"pca-{self.solver}_c{self.n_components}"
|
93 |
+
|
94 |
+
def fit(self, X):
|
95 |
+
self.transformer.fit(X)
|
96 |
+
|
97 |
+
# Save variance for later
|
98 |
+
self.total_var = X.var(axis=0).sum()
|
99 |
+
|
100 |
+
# Compute projected standard deviations
|
101 |
+
self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
|
102 |
+
|
103 |
+
# Sort components based on explained variance
|
104 |
+
idx = np.argsort(self.stdev)[::-1]
|
105 |
+
self.stdev = self.stdev[idx]
|
106 |
+
self.transformer.components_[:] = self.transformer.components_[idx]
|
107 |
+
|
108 |
+
# Check orthogonality
|
109 |
+
dotps = [np.dot(*self.transformer.components_[[i, j]])
|
110 |
+
for (i, j) in itertools.combinations(range(self.n_components), 2)]
|
111 |
+
if not np.allclose(dotps, 0, atol=1e-4):
|
112 |
+
print('IPCA components not orghogonal, max dot', np.abs(dotps).max())
|
113 |
+
|
114 |
+
self.transformer.mean_ = X.mean(axis=0, keepdims=True)
|
115 |
+
|
116 |
+
def get_components(self):
|
117 |
+
var_ratio = self.stdev**2 / self.total_var
|
118 |
+
return self.transformer.components_, self.stdev, var_ratio
|
119 |
+
|
120 |
+
# Facebook's PCA
|
121 |
+
# Good default choice: very fast and accurate.
|
122 |
+
# Very high sample counts won't fit into RAM,
|
123 |
+
# in which case IncrementalPCA must be used.
|
124 |
+
class FacebookPCAEstimator():
|
125 |
+
def __init__(self, n_components):
|
126 |
+
self.n_components = n_components
|
127 |
+
self.transformer = SimpleNamespace()
|
128 |
+
self.batch_support = False
|
129 |
+
self.n_iter = 2
|
130 |
+
self.l = 2*self.n_components
|
131 |
+
|
132 |
+
def get_param_str(self):
|
133 |
+
return "fbpca_c{}_it{}_l{}".format(self.n_components, self.n_iter, self.l)
|
134 |
+
|
135 |
+
def fit(self, X):
|
136 |
+
U, s, Va = fbpca.pca(X, k=self.n_components, n_iter=self.n_iter, raw=True, l=self.l)
|
137 |
+
self.transformer.components_ = Va
|
138 |
+
|
139 |
+
# Save variance for later
|
140 |
+
self.total_var = X.var(axis=0).sum()
|
141 |
+
|
142 |
+
# Compute projected standard deviations
|
143 |
+
self.stdev = np.dot(self.transformer.components_, X.T).std(axis=1)
|
144 |
+
|
145 |
+
# Sort components based on explained variance
|
146 |
+
idx = np.argsort(self.stdev)[::-1]
|
147 |
+
self.stdev = self.stdev[idx]
|
148 |
+
self.transformer.components_[:] = self.transformer.components_[idx]
|
149 |
+
|
150 |
+
# Check orthogonality
|
151 |
+
dotps = [np.dot(*self.transformer.components_[[i, j]])
|
152 |
+
for (i, j) in itertools.combinations(range(self.n_components), 2)]
|
153 |
+
if not np.allclose(dotps, 0, atol=1e-4):
|
154 |
+
print('FBPCA components not orghogonal, max dot', np.abs(dotps).max())
|
155 |
+
|
156 |
+
self.transformer.mean_ = X.mean(axis=0, keepdims=True)
|
157 |
+
|
158 |
+
def get_components(self):
|
159 |
+
var_ratio = self.stdev**2 / self.total_var
|
160 |
+
return self.transformer.components_, self.stdev, var_ratio
|
161 |
+
|
162 |
+
# Sparse PCA
|
163 |
+
# The algorithm is online along the features direction, not the samples direction
|
164 |
+
# => no partial_fit
|
165 |
+
class SPCAEstimator():
|
166 |
+
def __init__(self, n_components, alpha=10.0):
|
167 |
+
self.n_components = n_components
|
168 |
+
self.whiten = False
|
169 |
+
self.alpha = alpha # higher alpha => sparser components
|
170 |
+
#self.transformer = MiniBatchSparsePCA(n_components, alpha=alpha, n_iter=100,
|
171 |
+
# batch_size=max(20, n_components//5), random_state=0, normalize_components=True)
|
172 |
+
self.transformer = SparsePCA(n_components, alpha=alpha, ridge_alpha=0.01,
|
173 |
+
max_iter=100, random_state=0, n_jobs=-1, normalize_components=True) # TODO: warm start using PCA result?
|
174 |
+
self.batch_support = False # maybe through memmap and HDD-stored tensor
|
175 |
+
self.stdev = np.zeros((n_components,))
|
176 |
+
self.total_var = 0.0
|
177 |
+
|
178 |
+
def get_param_str(self):
|
179 |
+
return "spca_c{}_a{}{}".format(self.n_components, self.alpha, '_w' if self.whiten else '')
|
180 |
+
|
181 |
+
def fit(self, X):
|
182 |
+
self.transformer.fit(X)
|
183 |
+
|
184 |
+
# Save variance for later
|
185 |
+
self.total_var = X.var(axis=0).sum()
|
186 |
+
|
187 |
+
# Compute projected standard deviations
|
188 |
+
# NB: cannot simply project with dot product!
|
189 |
+
self.stdev = self.transformer.transform(X).std(axis=0) # X = (n_samples, n_features)
|
190 |
+
|
191 |
+
# Sort components based on explained variance
|
192 |
+
idx = np.argsort(self.stdev)[::-1]
|
193 |
+
self.stdev = self.stdev[idx]
|
194 |
+
self.transformer.components_[:] = self.transformer.components_[idx]
|
195 |
+
|
196 |
+
# Check orthogonality
|
197 |
+
dotps = [np.dot(*self.transformer.components_[[i, j]])
|
198 |
+
for (i, j) in itertools.combinations(range(self.n_components), 2)]
|
199 |
+
if not np.allclose(dotps, 0, atol=1e-4):
|
200 |
+
print('SPCA components not orghogonal, max dot', np.abs(dotps).max())
|
201 |
+
|
202 |
+
def get_components(self):
|
203 |
+
var_ratio = self.stdev**2 / self.total_var
|
204 |
+
return self.transformer.components_, self.stdev, var_ratio # SPCA outputs are normalized
|
205 |
+
|
206 |
+
def get_estimator(name, n_components, alpha):
|
207 |
+
if name == 'pca':
|
208 |
+
return PCAEstimator(n_components)
|
209 |
+
if name == 'ipca':
|
210 |
+
return IPCAEstimator(n_components)
|
211 |
+
elif name == 'fbpca':
|
212 |
+
return FacebookPCAEstimator(n_components)
|
213 |
+
elif name == 'ica':
|
214 |
+
return ICAEstimator(n_components)
|
215 |
+
elif name == 'spca':
|
216 |
+
return SPCAEstimator(n_components, alpha)
|
217 |
+
else:
|
218 |
+
raise RuntimeError('Unknown estimator')
|
interactive.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
# An interactive glumpy (OpenGL) + tkinter viewer for interacting with principal components.
|
12 |
+
# Requires OpenGL and CUDA support for rendering.
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import tkinter as tk
|
17 |
+
from tkinter import ttk
|
18 |
+
from types import SimpleNamespace
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
from pathlib import Path
|
21 |
+
from os import makedirs
|
22 |
+
from models import get_instrumented_model
|
23 |
+
from config import Config
|
24 |
+
from decomposition import get_or_compute
|
25 |
+
from torch.nn.functional import interpolate
|
26 |
+
from TkTorchWindow import TorchImageView
|
27 |
+
from functools import partial
|
28 |
+
from platform import system
|
29 |
+
from PIL import Image
|
30 |
+
from utils import pad_frames, prettify_name
|
31 |
+
import pickle
|
32 |
+
|
33 |
+
# For platform specific UI tweaks
|
34 |
+
is_windows = 'Windows' in system()
|
35 |
+
is_linux = 'Linux' in system()
|
36 |
+
is_mac = 'Darwin' in system()
|
37 |
+
|
38 |
+
# Read input parameters
|
39 |
+
args = Config().from_args()
|
40 |
+
|
41 |
+
# Don't bother without GPU
|
42 |
+
assert torch.cuda.is_available(), 'Interactive mode requires CUDA'
|
43 |
+
|
44 |
+
# Use syntax from paper
|
45 |
+
def get_edit_name(idx, s, e, name=None):
|
46 |
+
return 'E({comp}, {edit_range}){edit_name}'.format(
|
47 |
+
comp = idx,
|
48 |
+
edit_range = f'{s}-{e}' if e > s else s,
|
49 |
+
edit_name = f': {name}' if name else ''
|
50 |
+
)
|
51 |
+
|
52 |
+
# Load or compute PCA basis vectors
|
53 |
+
def load_components(class_name, inst):
|
54 |
+
global components, state, use_named_latents
|
55 |
+
|
56 |
+
config = args.from_dict({ 'output_class': class_name })
|
57 |
+
dump_name = get_or_compute(config, inst)
|
58 |
+
data = np.load(dump_name, allow_pickle=False)
|
59 |
+
X_comp = data['act_comp']
|
60 |
+
X_mean = data['act_mean']
|
61 |
+
X_stdev = data['act_stdev']
|
62 |
+
Z_comp = data['lat_comp']
|
63 |
+
Z_mean = data['lat_mean']
|
64 |
+
Z_stdev = data['lat_stdev']
|
65 |
+
random_stdev_act = np.mean(data['random_stdevs'])
|
66 |
+
n_comp = X_comp.shape[0]
|
67 |
+
data.close()
|
68 |
+
|
69 |
+
# Transfer to GPU
|
70 |
+
components = SimpleNamespace(
|
71 |
+
X_comp = torch.from_numpy(X_comp).cuda().float(),
|
72 |
+
X_mean = torch.from_numpy(X_mean).cuda().float(),
|
73 |
+
X_stdev = torch.from_numpy(X_stdev).cuda().float(),
|
74 |
+
Z_comp = torch.from_numpy(Z_comp).cuda().float(),
|
75 |
+
Z_stdev = torch.from_numpy(Z_stdev).cuda().float(),
|
76 |
+
Z_mean = torch.from_numpy(Z_mean).cuda().float(),
|
77 |
+
names = [f'Component {i}' for i in range(n_comp)],
|
78 |
+
latent_types = [model.latent_space_name()]*n_comp,
|
79 |
+
ranges = [(0, model.get_max_latents())]*n_comp,
|
80 |
+
)
|
81 |
+
|
82 |
+
state.component_class = class_name # invalidates cache
|
83 |
+
use_named_latents = False
|
84 |
+
print('Loaded components for', class_name, 'from', dump_name)
|
85 |
+
|
86 |
+
# Load previously exported named components from
|
87 |
+
# directory specified with '--inputs=path/to/comp'
|
88 |
+
def load_named_components(path, class_name):
|
89 |
+
global components, state, use_named_latents
|
90 |
+
|
91 |
+
import glob
|
92 |
+
matches = glob.glob(f'{path}/*.pkl')
|
93 |
+
|
94 |
+
selected = []
|
95 |
+
for dump_path in matches:
|
96 |
+
with open(dump_path, 'rb') as f:
|
97 |
+
data = pickle.load(f)
|
98 |
+
if data['model_name'] != model_name or data['output_class'] != class_name:
|
99 |
+
continue
|
100 |
+
|
101 |
+
if data['latent_space'] != model.latent_space_name():
|
102 |
+
print('Skipping', dump_path, '(wrong latent space)')
|
103 |
+
continue
|
104 |
+
|
105 |
+
selected.append(data)
|
106 |
+
print('Using', dump_path)
|
107 |
+
|
108 |
+
if len(selected) == 0:
|
109 |
+
raise RuntimeError('No valid components in given path.')
|
110 |
+
|
111 |
+
comp_dict = { k : [] for k in ['X_comp', 'Z_comp', 'X_stdev', 'Z_stdev', 'names', 'types', 'layer_names', 'ranges', 'latent_types'] }
|
112 |
+
components = SimpleNamespace(**comp_dict)
|
113 |
+
|
114 |
+
for d in selected:
|
115 |
+
s = d['edit_start']
|
116 |
+
e = d['edit_end']
|
117 |
+
title = get_edit_name(d['component_index'], s, e - 1, d['name']) # show inclusive
|
118 |
+
components.X_comp.append(torch.from_numpy(d['act_comp']).cuda())
|
119 |
+
components.Z_comp.append(torch.from_numpy(d['lat_comp']).cuda())
|
120 |
+
components.X_stdev.append(d['act_stdev'])
|
121 |
+
components.Z_stdev.append(d['lat_stdev'])
|
122 |
+
components.names.append(title)
|
123 |
+
components.types.append(d['edit_type'])
|
124 |
+
components.layer_names.append(d['decomposition']['layer']) # only for act
|
125 |
+
components.ranges.append((s, e))
|
126 |
+
components.latent_types.append(d['latent_space']) # W or Z
|
127 |
+
|
128 |
+
use_named_latents = True
|
129 |
+
print('Loaded named components')
|
130 |
+
|
131 |
+
def setup_model():
|
132 |
+
global model, inst, layer_name, model_name, feat_shape, args, class_name
|
133 |
+
|
134 |
+
model_name = args.model
|
135 |
+
layer_name = args.layer
|
136 |
+
class_name = args.output_class
|
137 |
+
|
138 |
+
# Speed up pytorch
|
139 |
+
torch.autograd.set_grad_enabled(False)
|
140 |
+
torch.backends.cudnn.benchmark = True
|
141 |
+
|
142 |
+
# Load model
|
143 |
+
inst = get_instrumented_model(model_name, class_name, layer_name, torch.device('cuda'), use_w=args.use_w)
|
144 |
+
model = inst.model
|
145 |
+
|
146 |
+
feat_shape = inst.feature_shape[layer_name]
|
147 |
+
sample_dims = np.prod(feat_shape)
|
148 |
+
|
149 |
+
# Initialize
|
150 |
+
if args.inputs:
|
151 |
+
load_named_components(args.inputs, class_name)
|
152 |
+
else:
|
153 |
+
load_components(class_name, inst)
|
154 |
+
|
155 |
+
# Project tensor 'X' onto orthonormal basis 'comp', return coordinates
|
156 |
+
def project_ortho(X, comp):
|
157 |
+
N = comp.shape[0]
|
158 |
+
coords = (comp.reshape(N, -1) * X.reshape(-1)).sum(dim=1)
|
159 |
+
return coords.reshape([N]+[1]*X.ndim)
|
160 |
+
|
161 |
+
def zero_sliders():
|
162 |
+
for v in ui_state.sliders:
|
163 |
+
v.set(0.0)
|
164 |
+
|
165 |
+
def reset_sliders(zero_on_failure=True):
|
166 |
+
global ui_state
|
167 |
+
|
168 |
+
mode = ui_state.mode.get()
|
169 |
+
|
170 |
+
# Not orthogonal: need to solve least-norm problem
|
171 |
+
# Not batch size 1: one set of sliders not enough
|
172 |
+
# Not principal components: unsupported format
|
173 |
+
is_ortho = not (mode == 'latent' and model.latent_space_name() == 'Z')
|
174 |
+
is_single = state.z.shape[0] == 1
|
175 |
+
is_pcs = not use_named_latents
|
176 |
+
|
177 |
+
state.lat_slider_offset = 0
|
178 |
+
state.act_slider_offset = 0
|
179 |
+
|
180 |
+
enabled = False
|
181 |
+
if not (enabled and is_ortho and is_single and is_pcs):
|
182 |
+
if zero_on_failure:
|
183 |
+
zero_sliders()
|
184 |
+
return
|
185 |
+
|
186 |
+
if mode == 'activation':
|
187 |
+
val = state.base_act
|
188 |
+
mean = components.X_mean
|
189 |
+
comp = components.X_comp
|
190 |
+
stdev = components.X_stdev
|
191 |
+
else:
|
192 |
+
val = state.z
|
193 |
+
mean = components.Z_mean
|
194 |
+
comp = components.Z_comp
|
195 |
+
stdev = components.Z_stdev
|
196 |
+
|
197 |
+
n_sliders = len(ui_state.sliders)
|
198 |
+
coords = project_ortho(val - mean, comp)
|
199 |
+
offset = torch.sum(coords[:n_sliders] * comp[:n_sliders], dim=0)
|
200 |
+
scaled_coords = (coords.view(-1) / stdev).detach().cpu().numpy()
|
201 |
+
|
202 |
+
# Part representable by sliders
|
203 |
+
if mode == 'activation':
|
204 |
+
state.act_slider_offset = offset
|
205 |
+
else:
|
206 |
+
state.lat_slider_offset = offset
|
207 |
+
|
208 |
+
for i in range(n_sliders):
|
209 |
+
ui_state.sliders[i].set(round(scaled_coords[i], ndigits=1))
|
210 |
+
|
211 |
+
def setup_ui():
|
212 |
+
global root, toolbar, ui_state, app, canvas
|
213 |
+
|
214 |
+
root = tk.Tk()
|
215 |
+
scale = 1.0
|
216 |
+
app = TorchImageView(root, width=int(scale*1024), height=int(scale*1024), show_fps=False)
|
217 |
+
app.pack(fill=tk.BOTH, expand=tk.YES)
|
218 |
+
root.protocol("WM_DELETE_WINDOW", shutdown)
|
219 |
+
root.title('GANspace')
|
220 |
+
|
221 |
+
toolbar = tk.Toplevel(root)
|
222 |
+
toolbar.protocol("WM_DELETE_WINDOW", shutdown)
|
223 |
+
toolbar.geometry("215x800+0+0")
|
224 |
+
toolbar.title('')
|
225 |
+
|
226 |
+
N_COMPONENTS = min(70, len(components.names))
|
227 |
+
ui_state = SimpleNamespace(
|
228 |
+
sliders = [tk.DoubleVar(value=0.0) for _ in range(N_COMPONENTS)],
|
229 |
+
scales = [],
|
230 |
+
truncation = tk.DoubleVar(value=0.9),
|
231 |
+
outclass = tk.StringVar(value=class_name),
|
232 |
+
random_seed = tk.StringVar(value='0'),
|
233 |
+
mode = tk.StringVar(value='latent'),
|
234 |
+
batch_size = tk.IntVar(value=1), # how many images to show in window
|
235 |
+
edit_layer_start = tk.IntVar(value=0),
|
236 |
+
edit_layer_end = tk.IntVar(value=model.get_max_latents() - 1),
|
237 |
+
slider_max_val = 10.0
|
238 |
+
)
|
239 |
+
|
240 |
+
# Z vs activation mode button
|
241 |
+
#tk.Radiobutton(toolbar, text=f"Latent ({model.latent_space_name()})", variable=ui_state.mode, command=reset_sliders, value='latent').pack(fill="x")
|
242 |
+
#tk.Radiobutton(toolbar, text="Activation", variable=ui_state.mode, command=reset_sliders, value='activation').pack(fill="x")
|
243 |
+
|
244 |
+
# Choose range where latents are modified
|
245 |
+
def set_min(val):
|
246 |
+
ui_state.edit_layer_start.set(min(int(val), ui_state.edit_layer_end.get()))
|
247 |
+
def set_max(val):
|
248 |
+
ui_state.edit_layer_end.set(max(int(val), ui_state.edit_layer_start.get()))
|
249 |
+
max_latent_idx = model.get_max_latents() - 1
|
250 |
+
|
251 |
+
if not use_named_latents:
|
252 |
+
slider_min = tk.Scale(toolbar, command=set_min, variable=ui_state.edit_layer_start,
|
253 |
+
label='Layer start', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x")
|
254 |
+
slider_max = tk.Scale(toolbar, command=set_max, variable=ui_state.edit_layer_end,
|
255 |
+
label='Layer end', from_=0, to=max_latent_idx, orient=tk.HORIZONTAL).pack(fill="x")
|
256 |
+
|
257 |
+
# Scrollable list of components
|
258 |
+
outer_frame = tk.Frame(toolbar, borderwidth=2, relief=tk.SUNKEN)
|
259 |
+
canvas = tk.Canvas(outer_frame, highlightthickness=0, borderwidth=0)
|
260 |
+
frame = tk.Frame(canvas)
|
261 |
+
vsb = tk.Scrollbar(outer_frame, orient="vertical", command=canvas.yview)
|
262 |
+
canvas.configure(yscrollcommand=vsb.set)
|
263 |
+
|
264 |
+
vsb.pack(side="right", fill="y")
|
265 |
+
canvas.pack(side="left", fill="both", expand=True)
|
266 |
+
canvas.create_window((4,4), window=frame, anchor="nw")
|
267 |
+
|
268 |
+
def onCanvasConfigure(event):
|
269 |
+
canvas.itemconfigure("all", width=event.width)
|
270 |
+
canvas.configure(scrollregion=canvas.bbox("all"))
|
271 |
+
canvas.bind("<Configure>", onCanvasConfigure)
|
272 |
+
|
273 |
+
def on_scroll(event):
|
274 |
+
delta = 1 if (event.num == 5 or event.delta < 0) else -1
|
275 |
+
canvas.yview_scroll(delta, "units")
|
276 |
+
|
277 |
+
canvas.bind_all("<Button-4>", on_scroll)
|
278 |
+
canvas.bind_all("<Button-5>", on_scroll)
|
279 |
+
canvas.bind_all("<MouseWheel>", on_scroll)
|
280 |
+
canvas.bind_all("<Key>", lambda event : handle_keypress(event.keysym_num))
|
281 |
+
|
282 |
+
# Sliders and buttons
|
283 |
+
for i in range(N_COMPONENTS):
|
284 |
+
inner = tk.Frame(frame, borderwidth=1, background="#aaaaaa")
|
285 |
+
scale = tk.Scale(inner, variable=ui_state.sliders[i], from_=-ui_state.slider_max_val,
|
286 |
+
to=ui_state.slider_max_val, resolution=0.1, orient=tk.HORIZONTAL, label=components.names[i])
|
287 |
+
scale.pack(fill=tk.X, side=tk.LEFT, expand=True)
|
288 |
+
ui_state.scales.append(scale) # for changing label later
|
289 |
+
if not use_named_latents:
|
290 |
+
tk.Button(inner, text=f"Save", command=partial(export_direction, i, inner)).pack(fill=tk.Y, side=tk.RIGHT)
|
291 |
+
inner.pack(fill=tk.X)
|
292 |
+
|
293 |
+
outer_frame.pack(fill="both", expand=True, pady=0)
|
294 |
+
|
295 |
+
tk.Button(toolbar, text="Reset", command=reset_sliders).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4)
|
296 |
+
|
297 |
+
tk.Scale(toolbar, variable=ui_state.truncation, from_=0.01, to=1.0,
|
298 |
+
resolution=0.01, orient=tk.HORIZONTAL, label='Truncation').pack(fill="x")
|
299 |
+
|
300 |
+
tk.Scale(toolbar, variable=ui_state.batch_size, from_=1, to=9,
|
301 |
+
resolution=1, orient=tk.HORIZONTAL, label='Batch size').pack(fill="x")
|
302 |
+
|
303 |
+
# Output class
|
304 |
+
frame = tk.Frame(toolbar)
|
305 |
+
tk.Label(frame, text="Class name").pack(fill="x", side="left")
|
306 |
+
tk.Entry(frame, textvariable=ui_state.outclass).pack(fill="x", side="right", expand=True, padx=5)
|
307 |
+
frame.pack(fill=tk.X, pady=3)
|
308 |
+
|
309 |
+
# Random seed
|
310 |
+
def update_seed():
|
311 |
+
seed_str = ui_state.random_seed.get()
|
312 |
+
if seed_str.isdigit():
|
313 |
+
resample_latent(int(seed_str))
|
314 |
+
frame = tk.Frame(toolbar)
|
315 |
+
tk.Label(frame, text="Seed").pack(fill="x", side="left")
|
316 |
+
tk.Entry(frame, textvariable=ui_state.random_seed, width=12).pack(fill="x", side="left", expand=True, padx=2)
|
317 |
+
tk.Button(frame, text="Update", command=update_seed).pack(fill="y", side="right", padx=3)
|
318 |
+
frame.pack(fill=tk.X, pady=3)
|
319 |
+
|
320 |
+
# Get new latent or new components
|
321 |
+
tk.Button(toolbar, text="Resample latent", command=partial(resample_latent, None, False)).pack(anchor=tk.CENTER, fill=tk.X, padx=4, pady=4)
|
322 |
+
#tk.Button(toolbar, text="Recompute", command=recompute_components).pack(anchor=tk.CENTER, fill=tk.X)
|
323 |
+
|
324 |
+
# App state
|
325 |
+
state = SimpleNamespace(
|
326 |
+
z=None, # current latent(s)
|
327 |
+
lat_slider_offset = 0, # part of lat that is explained by sliders
|
328 |
+
act_slider_offset = 0, # part of act that is explained by sliders
|
329 |
+
component_class=None, # name of current PCs' image class
|
330 |
+
seed=0, # Latent z_i generated by seed+i
|
331 |
+
base_act = None, # activation of considered layer given z
|
332 |
+
)
|
333 |
+
|
334 |
+
def resample_latent(seed=None, only_style=False):
|
335 |
+
class_name = ui_state.outclass.get()
|
336 |
+
if class_name.isnumeric():
|
337 |
+
class_name = int(class_name)
|
338 |
+
|
339 |
+
if hasattr(model, 'is_valid_class'):
|
340 |
+
if not model.is_valid_class(class_name):
|
341 |
+
return
|
342 |
+
|
343 |
+
model.set_output_class(class_name)
|
344 |
+
|
345 |
+
B = ui_state.batch_size.get()
|
346 |
+
state.seed = np.random.randint(np.iinfo(np.int32).max - B) if seed is None else seed
|
347 |
+
ui_state.random_seed.set(str(state.seed))
|
348 |
+
|
349 |
+
# Use consecutive seeds along batch dimension (for easier reproducibility)
|
350 |
+
trunc = ui_state.truncation.get()
|
351 |
+
latents = [model.sample_latent(1, seed=state.seed + i, truncation=trunc) for i in range(B)]
|
352 |
+
|
353 |
+
state.z = torch.cat(latents).clone().detach() # make leaf node
|
354 |
+
assert state.z.is_leaf, 'Latent is not leaf node!'
|
355 |
+
|
356 |
+
if hasattr(model, 'truncation'):
|
357 |
+
model.truncation = ui_state.truncation.get()
|
358 |
+
print(f'Seeds: {state.seed} -> {state.seed + B - 1}' if B > 1 else f'Seed: {state.seed}')
|
359 |
+
|
360 |
+
torch.manual_seed(state.seed)
|
361 |
+
model.partial_forward(state.z, layer_name)
|
362 |
+
state.base_act = inst.retained_features()[layer_name]
|
363 |
+
|
364 |
+
reset_sliders(zero_on_failure=False)
|
365 |
+
|
366 |
+
# Remove focus from text entry
|
367 |
+
canvas.focus_set()
|
368 |
+
|
369 |
+
# Used to recompute after changing class of conditional model
|
370 |
+
def recompute_components():
|
371 |
+
class_name = ui_state.outclass.get()
|
372 |
+
if class_name.isnumeric():
|
373 |
+
class_name = int(class_name)
|
374 |
+
|
375 |
+
if hasattr(model, 'is_valid_class'):
|
376 |
+
if not model.is_valid_class(class_name):
|
377 |
+
return
|
378 |
+
|
379 |
+
if hasattr(model, 'set_output_class'):
|
380 |
+
model.set_output_class(class_name)
|
381 |
+
|
382 |
+
load_components(class_name, inst)
|
383 |
+
|
384 |
+
# Used to detect parameter changes for lazy recomputation
|
385 |
+
class ParamCache():
|
386 |
+
def update(self, **kwargs):
|
387 |
+
dirty = False
|
388 |
+
for argname, val in kwargs.items():
|
389 |
+
# Check pointer, then value
|
390 |
+
current = getattr(self, argname, 0)
|
391 |
+
if current is not val and pickle.dumps(current) != pickle.dumps(val):
|
392 |
+
setattr(self, argname, val)
|
393 |
+
dirty = True
|
394 |
+
return dirty
|
395 |
+
|
396 |
+
cache = ParamCache()
|
397 |
+
|
398 |
+
def l2norm(t):
|
399 |
+
return torch.norm(t.view(t.shape[0], -1), p=2, dim=1, keepdim=True)
|
400 |
+
|
401 |
+
def apply_edit(z0, delta):
|
402 |
+
return z0 + delta
|
403 |
+
|
404 |
+
def reposition_toolbar():
|
405 |
+
size, X, Y = root.winfo_geometry().split('+')
|
406 |
+
W, H = size.split('x')
|
407 |
+
toolbar_W = toolbar.winfo_geometry().split('x')[0]
|
408 |
+
offset_y = -30 if is_linux else 0 # window title bar
|
409 |
+
toolbar.geometry(f'{toolbar_W}x{H}+{int(X)-int(toolbar_W)}+{int(Y)+offset_y}')
|
410 |
+
toolbar.update()
|
411 |
+
|
412 |
+
def on_draw():
|
413 |
+
global img
|
414 |
+
|
415 |
+
n_comp = len(ui_state.sliders)
|
416 |
+
slider_vals = np.array([s.get() for s in ui_state.sliders], dtype=np.float32)
|
417 |
+
|
418 |
+
# Run model sparingly
|
419 |
+
mode = ui_state.mode.get()
|
420 |
+
latent_start = ui_state.edit_layer_start.get()
|
421 |
+
latent_end = ui_state.edit_layer_end.get() + 1 # save as exclusive, show as inclusive
|
422 |
+
|
423 |
+
if cache.update(coords=slider_vals, comp=state.component_class, mode=mode, z=state.z, s=latent_start, e=latent_end):
|
424 |
+
with torch.no_grad():
|
425 |
+
z_base = state.z - state.lat_slider_offset
|
426 |
+
z_deltas = [0.0]*model.get_max_latents()
|
427 |
+
z_delta_global = 0.0
|
428 |
+
|
429 |
+
n_comp = slider_vals.size
|
430 |
+
act_deltas = {}
|
431 |
+
|
432 |
+
if torch.is_tensor(state.act_slider_offset):
|
433 |
+
act_deltas[layer_name] = -state.act_slider_offset
|
434 |
+
|
435 |
+
for space in components.latent_types:
|
436 |
+
assert space == model.latent_space_name(), \
|
437 |
+
'Cannot mix latent spaces (for now)'
|
438 |
+
|
439 |
+
for c in range(n_comp):
|
440 |
+
coord = slider_vals[c]
|
441 |
+
if coord == 0:
|
442 |
+
continue
|
443 |
+
|
444 |
+
edit_mode = components.types[c] if use_named_latents else mode
|
445 |
+
|
446 |
+
# Activation offset
|
447 |
+
if edit_mode in ['activation', 'both']:
|
448 |
+
delta = components.X_comp[c] * components.X_stdev[c] * coord
|
449 |
+
name = components.layer_names[c] if use_named_latents else layer_name
|
450 |
+
act_deltas[name] = act_deltas.get(name, 0.0) + delta
|
451 |
+
|
452 |
+
# Latent offset
|
453 |
+
if edit_mode in ['latent', 'both']:
|
454 |
+
delta = components.Z_comp[c] * components.Z_stdev[c] * coord
|
455 |
+
edit_range = components.ranges[c] if use_named_latents else (latent_start, latent_end)
|
456 |
+
full_range = (edit_range == (0, model.get_max_latents()))
|
457 |
+
|
458 |
+
# Single or multiple offsets?
|
459 |
+
if full_range:
|
460 |
+
z_delta_global = z_delta_global + delta
|
461 |
+
else:
|
462 |
+
for l in range(*edit_range):
|
463 |
+
z_deltas[l] = z_deltas[l] + delta
|
464 |
+
|
465 |
+
# Apply activation deltas
|
466 |
+
inst.remove_edits()
|
467 |
+
for layer, delta in act_deltas.items():
|
468 |
+
inst.edit_layer(layer, offset=delta)
|
469 |
+
|
470 |
+
# Evaluate
|
471 |
+
has_offsets = any(torch.is_tensor(t) for t in z_deltas)
|
472 |
+
z_final = apply_edit(z_base, z_delta_global)
|
473 |
+
if has_offsets:
|
474 |
+
z_final = [apply_edit(z_final, d) for d in z_deltas]
|
475 |
+
img = model.forward(z_final).clamp(0.0, 1.0)
|
476 |
+
|
477 |
+
app.draw(img)
|
478 |
+
|
479 |
+
# Save necessary data to disk for later loading
|
480 |
+
def export_direction(idx, button_frame):
|
481 |
+
name = tk.StringVar(value='')
|
482 |
+
num_strips = tk.IntVar(value=0)
|
483 |
+
strip_width = tk.IntVar(value=5)
|
484 |
+
|
485 |
+
slider_values = np.array([s.get() for s in ui_state.sliders])
|
486 |
+
slider_value = slider_values[idx]
|
487 |
+
if (slider_values != 0).sum() > 1:
|
488 |
+
print('Please modify only one slider')
|
489 |
+
return
|
490 |
+
elif slider_value == 0:
|
491 |
+
print('Modify selected slider to set usable range (currently 0)')
|
492 |
+
return
|
493 |
+
|
494 |
+
popup = tk.Toplevel(root)
|
495 |
+
popup.geometry("200x200+0+0")
|
496 |
+
tk.Label(popup, text="Edit name").pack()
|
497 |
+
tk.Entry(popup, textvariable=name).pack(pady=5)
|
498 |
+
# tk.Scale(popup, from_=0, to=30, variable=num_strips,
|
499 |
+
# resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strips to export').pack()
|
500 |
+
# tk.Scale(popup, from_=3, to=15, variable=strip_width,
|
501 |
+
# resolution=1, orient=tk.HORIZONTAL, length=200, label='Image strip width').pack()
|
502 |
+
tk.Button(popup, text='OK', command=popup.quit).pack()
|
503 |
+
|
504 |
+
canceled = False
|
505 |
+
def on_close():
|
506 |
+
nonlocal canceled
|
507 |
+
canceled = True
|
508 |
+
popup.quit()
|
509 |
+
|
510 |
+
popup.protocol("WM_DELETE_WINDOW", on_close)
|
511 |
+
x = button_frame.winfo_rootx()
|
512 |
+
y = button_frame.winfo_rooty()
|
513 |
+
w = int(button_frame.winfo_geometry().split('x')[0])
|
514 |
+
popup.geometry('%dx%d+%d+%d' % (180, 90, x + w, y))
|
515 |
+
popup.mainloop()
|
516 |
+
popup.destroy()
|
517 |
+
|
518 |
+
# Update slider name
|
519 |
+
label = get_edit_name(idx, ui_state.edit_layer_start.get(),
|
520 |
+
ui_state.edit_layer_end.get(), name.get())
|
521 |
+
ui_state.scales[idx].config(label=label)
|
522 |
+
|
523 |
+
if canceled:
|
524 |
+
return
|
525 |
+
|
526 |
+
params = {
|
527 |
+
'name': name.get(),
|
528 |
+
'sigma_range': slider_value,
|
529 |
+
'component_index': idx,
|
530 |
+
'act_comp': components.X_comp[idx].detach().cpu().numpy(),
|
531 |
+
'lat_comp': components.Z_comp[idx].detach().cpu().numpy(), # either Z or W
|
532 |
+
'latent_space': model.latent_space_name(),
|
533 |
+
'act_stdev': components.X_stdev[idx].item(),
|
534 |
+
'lat_stdev': components.Z_stdev[idx].item(),
|
535 |
+
'model_name': model_name,
|
536 |
+
'output_class': ui_state.outclass.get(), # applied onto
|
537 |
+
'decomposition': {
|
538 |
+
'name': args.estimator,
|
539 |
+
'components': args.components,
|
540 |
+
'samples': args.n,
|
541 |
+
'layer': args.layer,
|
542 |
+
'class_name': state.component_class # computed from
|
543 |
+
},
|
544 |
+
'edit_type': ui_state.mode.get(),
|
545 |
+
'truncation': ui_state.truncation.get(),
|
546 |
+
'edit_start': ui_state.edit_layer_start.get(),
|
547 |
+
'edit_end': ui_state.edit_layer_end.get() + 1, # show as inclusive, save as exclusive
|
548 |
+
'example_seed': state.seed,
|
549 |
+
}
|
550 |
+
|
551 |
+
edit_mode_str = params['edit_type']
|
552 |
+
if edit_mode_str == 'latent':
|
553 |
+
edit_mode_str = model.latent_space_name().lower()
|
554 |
+
|
555 |
+
comp_class = state.component_class
|
556 |
+
appl_class = params['output_class']
|
557 |
+
if comp_class != appl_class:
|
558 |
+
comp_class = f'{comp_class}_onto_{appl_class}'
|
559 |
+
|
560 |
+
file_ident = "{model}-{name}-{cls}-{est}-{mode}-{layer}-comp{idx}-range{start}-{end}".format(
|
561 |
+
model=model_name,
|
562 |
+
name=prettify_name(params['name']),
|
563 |
+
cls=comp_class,
|
564 |
+
est=args.estimator,
|
565 |
+
mode=edit_mode_str,
|
566 |
+
layer=args.layer,
|
567 |
+
idx=idx,
|
568 |
+
start=params['edit_start'],
|
569 |
+
end=params['edit_end'],
|
570 |
+
)
|
571 |
+
|
572 |
+
out_dir = Path(__file__).parent / 'out' / 'directions'
|
573 |
+
makedirs(out_dir / file_ident, exist_ok=True)
|
574 |
+
|
575 |
+
with open(out_dir / f"{file_ident}.pkl", 'wb') as outfile:
|
576 |
+
pickle.dump(params, outfile)
|
577 |
+
|
578 |
+
print(f'Direction "{name.get()}" saved as "{file_ident}.pkl"')
|
579 |
+
|
580 |
+
batch_size = ui_state.batch_size.get()
|
581 |
+
len_padded = ((num_strips.get() - 1) // batch_size + 1) * batch_size
|
582 |
+
orig_seed = state.seed
|
583 |
+
|
584 |
+
reset_sliders()
|
585 |
+
|
586 |
+
# Limit max resolution
|
587 |
+
max_H = 512
|
588 |
+
ratio = min(1.0, max_H / inst.output_shape[2])
|
589 |
+
|
590 |
+
strips = [[] for _ in range(len_padded)]
|
591 |
+
for b in range(0, len_padded, batch_size):
|
592 |
+
# Resample
|
593 |
+
resample_latent((orig_seed + b) % np.iinfo(np.int32).max)
|
594 |
+
|
595 |
+
sigmas = np.linspace(slider_value, -slider_value, strip_width.get(), dtype=np.float32)
|
596 |
+
for sid, sigma in enumerate(sigmas):
|
597 |
+
ui_state.sliders[idx].set(sigma)
|
598 |
+
|
599 |
+
# Advance and show results on screen
|
600 |
+
on_draw()
|
601 |
+
root.update()
|
602 |
+
app.update()
|
603 |
+
|
604 |
+
batch_res = (255*img).byte().permute(0, 2, 3, 1).detach().cpu().numpy()
|
605 |
+
|
606 |
+
for i, data in enumerate(batch_res):
|
607 |
+
# Save individual
|
608 |
+
name_nodots = file_ident.replace('.', '_')
|
609 |
+
outname = out_dir / file_ident / f"{name_nodots}_ex{b+i}_{sid}.png"
|
610 |
+
im = Image.fromarray(data)
|
611 |
+
im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS)
|
612 |
+
im.save(outname)
|
613 |
+
strips[b+i].append(data)
|
614 |
+
|
615 |
+
for i, strip in enumerate(strips[:num_strips.get()]):
|
616 |
+
print(f'Saving strip {i + 1}/{num_strips.get()}', end='\r', flush=True)
|
617 |
+
data = np.hstack(pad_frames(strip))
|
618 |
+
im = Image.fromarray(data)
|
619 |
+
im = im.resize((int(ratio*im.size[0]), int(ratio*im.size[1])), Image.ANTIALIAS)
|
620 |
+
im.save(out_dir / file_ident / f"{file_ident}_ex{i}.png")
|
621 |
+
|
622 |
+
# Reset to original state
|
623 |
+
resample_latent(orig_seed)
|
624 |
+
ui_state.sliders[idx].set(slider_value)
|
625 |
+
|
626 |
+
|
627 |
+
# Shared by glumpy and tkinter
|
628 |
+
def handle_keypress(code):
|
629 |
+
if code == 65307: # ESC
|
630 |
+
shutdown()
|
631 |
+
elif code == 65360: # HOME
|
632 |
+
reset_sliders()
|
633 |
+
elif code == 114: # R
|
634 |
+
pass #reset_sliders()
|
635 |
+
|
636 |
+
def shutdown():
|
637 |
+
global pending_close
|
638 |
+
pending_close = True
|
639 |
+
|
640 |
+
def on_key_release(symbol, modifiers):
|
641 |
+
handle_keypress(symbol)
|
642 |
+
|
643 |
+
if __name__=='__main__':
|
644 |
+
setup_model()
|
645 |
+
setup_ui()
|
646 |
+
resample_latent()
|
647 |
+
|
648 |
+
pending_close = False
|
649 |
+
while not pending_close:
|
650 |
+
root.update()
|
651 |
+
app.update()
|
652 |
+
on_draw()
|
653 |
+
reposition_toolbar()
|
654 |
+
|
655 |
+
root.destroy()
|
model_clip.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.relu1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.relu2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.relu3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.relu3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x[:1], key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
return x.squeeze(0)
|
92 |
+
|
93 |
+
|
94 |
+
class ModifiedResNet(nn.Module):
|
95 |
+
"""
|
96 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
97 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
98 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
99 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
100 |
+
"""
|
101 |
+
|
102 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
103 |
+
super().__init__()
|
104 |
+
self.output_dim = output_dim
|
105 |
+
self.input_resolution = input_resolution
|
106 |
+
|
107 |
+
# the 3-layer stem
|
108 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
109 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
110 |
+
self.relu1 = nn.ReLU(inplace=True)
|
111 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
112 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.relu2 = nn.ReLU(inplace=True)
|
114 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
115 |
+
self.bn3 = nn.BatchNorm2d(width)
|
116 |
+
self.relu3 = nn.ReLU(inplace=True)
|
117 |
+
self.avgpool = nn.AvgPool2d(2)
|
118 |
+
|
119 |
+
# residual layers
|
120 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
121 |
+
self.layer1 = self._make_layer(width, layers[0])
|
122 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
123 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
124 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
125 |
+
|
126 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
127 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
128 |
+
|
129 |
+
def _make_layer(self, planes, blocks, stride=1):
|
130 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
131 |
+
|
132 |
+
self._inplanes = planes * Bottleneck.expansion
|
133 |
+
for _ in range(1, blocks):
|
134 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
135 |
+
|
136 |
+
return nn.Sequential(*layers)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
def stem(x):
|
140 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
141 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
142 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
143 |
+
x = self.avgpool(x)
|
144 |
+
return x
|
145 |
+
|
146 |
+
x = x.type(self.conv1.weight.dtype)
|
147 |
+
x = stem(x)
|
148 |
+
x = self.layer1(x)
|
149 |
+
x = self.layer2(x)
|
150 |
+
x = self.layer3(x)
|
151 |
+
x = self.layer4(x)
|
152 |
+
x = self.attnpool(x)
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class LayerNorm(nn.LayerNorm):
|
158 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
159 |
+
|
160 |
+
def forward(self, x: torch.Tensor):
|
161 |
+
orig_type = x.dtype
|
162 |
+
ret = super().forward(x.type(torch.float32))
|
163 |
+
return ret.type(orig_type)
|
164 |
+
|
165 |
+
|
166 |
+
class QuickGELU(nn.Module):
|
167 |
+
def forward(self, x: torch.Tensor):
|
168 |
+
return x * torch.sigmoid(1.702 * x)
|
169 |
+
|
170 |
+
|
171 |
+
class ResidualAttentionBlock(nn.Module):
|
172 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
173 |
+
super().__init__()
|
174 |
+
|
175 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
176 |
+
self.ln_1 = LayerNorm(d_model)
|
177 |
+
self.mlp = nn.Sequential(OrderedDict([
|
178 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
179 |
+
("gelu", QuickGELU()),
|
180 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
181 |
+
]))
|
182 |
+
self.ln_2 = LayerNorm(d_model)
|
183 |
+
self.attn_mask = attn_mask
|
184 |
+
|
185 |
+
def attention(self, x: torch.Tensor):
|
186 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
187 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor):
|
190 |
+
x = x + self.attention(self.ln_1(x))
|
191 |
+
x = x + self.mlp(self.ln_2(x))
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class Transformer(nn.Module):
|
196 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
197 |
+
super().__init__()
|
198 |
+
self.width = width
|
199 |
+
self.layers = layers
|
200 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
201 |
+
|
202 |
+
def forward(self, x: torch.Tensor):
|
203 |
+
return self.resblocks(x)
|
204 |
+
|
205 |
+
|
206 |
+
class VisionTransformer(nn.Module):
|
207 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
208 |
+
super().__init__()
|
209 |
+
self.input_resolution = input_resolution
|
210 |
+
self.output_dim = output_dim
|
211 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
212 |
+
|
213 |
+
scale = width ** -0.5
|
214 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
215 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
216 |
+
self.ln_pre = LayerNorm(width)
|
217 |
+
|
218 |
+
self.transformer = Transformer(width, layers, heads)
|
219 |
+
|
220 |
+
self.ln_post = LayerNorm(width)
|
221 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor):
|
224 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
225 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
226 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
227 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
228 |
+
x = x + self.positional_embedding.to(x.dtype)
|
229 |
+
x = self.ln_pre(x)
|
230 |
+
|
231 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
232 |
+
x = self.transformer(x)
|
233 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
234 |
+
|
235 |
+
x = self.ln_post(x[:, 0, :])
|
236 |
+
|
237 |
+
if self.proj is not None:
|
238 |
+
x = x @ self.proj
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class CLIP(nn.Module):
|
244 |
+
def __init__(self,
|
245 |
+
embed_dim: int,
|
246 |
+
# vision
|
247 |
+
image_resolution: int,
|
248 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
249 |
+
vision_width: int,
|
250 |
+
vision_patch_size: int,
|
251 |
+
# text
|
252 |
+
context_length: int,
|
253 |
+
vocab_size: int,
|
254 |
+
transformer_width: int,
|
255 |
+
transformer_heads: int,
|
256 |
+
transformer_layers: int
|
257 |
+
):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.context_length = context_length
|
261 |
+
|
262 |
+
if isinstance(vision_layers, (tuple, list)):
|
263 |
+
vision_heads = vision_width * 32 // 64
|
264 |
+
self.visual = ModifiedResNet(
|
265 |
+
layers=vision_layers,
|
266 |
+
output_dim=embed_dim,
|
267 |
+
heads=vision_heads,
|
268 |
+
input_resolution=image_resolution,
|
269 |
+
width=vision_width
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
vision_heads = vision_width // 64
|
273 |
+
self.visual = VisionTransformer(
|
274 |
+
input_resolution=image_resolution,
|
275 |
+
patch_size=vision_patch_size,
|
276 |
+
width=vision_width,
|
277 |
+
layers=vision_layers,
|
278 |
+
heads=vision_heads,
|
279 |
+
output_dim=embed_dim
|
280 |
+
)
|
281 |
+
|
282 |
+
self.transformer = Transformer(
|
283 |
+
width=transformer_width,
|
284 |
+
layers=transformer_layers,
|
285 |
+
heads=transformer_heads,
|
286 |
+
attn_mask=self.build_attention_mask()
|
287 |
+
)
|
288 |
+
|
289 |
+
self.vocab_size = vocab_size
|
290 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
291 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
292 |
+
self.ln_final = LayerNorm(transformer_width)
|
293 |
+
|
294 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
295 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
296 |
+
|
297 |
+
self.initialize_parameters()
|
298 |
+
|
299 |
+
def initialize_parameters(self):
|
300 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
301 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
302 |
+
|
303 |
+
if isinstance(self.visual, ModifiedResNet):
|
304 |
+
if self.visual.attnpool is not None:
|
305 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
306 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
307 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
308 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
309 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
310 |
+
|
311 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
312 |
+
for name, param in resnet_block.named_parameters():
|
313 |
+
if name.endswith("bn3.weight"):
|
314 |
+
nn.init.zeros_(param)
|
315 |
+
|
316 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
317 |
+
attn_std = self.transformer.width ** -0.5
|
318 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
319 |
+
for block in self.transformer.resblocks:
|
320 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
321 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
322 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
323 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
324 |
+
|
325 |
+
if self.text_projection is not None:
|
326 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
327 |
+
|
328 |
+
def build_attention_mask(self):
|
329 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
330 |
+
# pytorch uses additive attention mask; fill with -inf
|
331 |
+
mask = torch.empty(self.context_length, self.context_length)
|
332 |
+
mask.fill_(float("-inf"))
|
333 |
+
mask.triu_(1) # zero out the lower diagonal
|
334 |
+
return mask
|
335 |
+
|
336 |
+
@property
|
337 |
+
def dtype(self):
|
338 |
+
return self.visual.conv1.weight.dtype
|
339 |
+
|
340 |
+
def encode_image(self, image):
|
341 |
+
return self.visual(image.type(self.dtype))
|
342 |
+
|
343 |
+
def encode_text(self, text):
|
344 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
345 |
+
|
346 |
+
x = x + self.positional_embedding.type(self.dtype)
|
347 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
348 |
+
x = self.transformer(x)
|
349 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
350 |
+
x = self.ln_final(x).type(self.dtype)
|
351 |
+
|
352 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
353 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
354 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
355 |
+
|
356 |
+
return x
|
357 |
+
|
358 |
+
def forward(self, image, text):
|
359 |
+
image_features = self.encode_image(image)
|
360 |
+
text_features = self.encode_text(text)
|
361 |
+
|
362 |
+
# normalized features
|
363 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
364 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
365 |
+
|
366 |
+
# cosine similarity as logits
|
367 |
+
logit_scale = self.logit_scale.exp()
|
368 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
369 |
+
logits_per_text = logits_per_image.t()
|
370 |
+
|
371 |
+
# shape = [global_batch_size, global_batch_size]
|
372 |
+
return logits_per_image, logits_per_text
|
373 |
+
|
374 |
+
|
375 |
+
def convert_weights(model: nn.Module):
|
376 |
+
"""Convert applicable model parameters to fp16"""
|
377 |
+
|
378 |
+
def _convert_weights_to_fp16(l):
|
379 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
380 |
+
l.weight.data = l.weight.data.half()
|
381 |
+
if l.bias is not None:
|
382 |
+
l.bias.data = l.bias.data.half()
|
383 |
+
|
384 |
+
if isinstance(l, nn.MultiheadAttention):
|
385 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
386 |
+
tensor = getattr(l, attr)
|
387 |
+
if tensor is not None:
|
388 |
+
tensor.data = tensor.data.half()
|
389 |
+
|
390 |
+
for name in ["text_projection", "proj"]:
|
391 |
+
if hasattr(l, name):
|
392 |
+
attr = getattr(l, name)
|
393 |
+
if attr is not None:
|
394 |
+
attr.data = attr.data.half()
|
395 |
+
|
396 |
+
model.apply(_convert_weights_to_fp16)
|
397 |
+
|
398 |
+
|
399 |
+
def build_model(state_dict: dict):
|
400 |
+
vit = "visual.proj" in state_dict
|
401 |
+
|
402 |
+
if vit:
|
403 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
404 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
405 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
406 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
407 |
+
image_resolution = vision_patch_size * grid_size
|
408 |
+
else:
|
409 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
410 |
+
vision_layers = tuple(counts)
|
411 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
412 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
413 |
+
vision_patch_size = None
|
414 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
415 |
+
image_resolution = output_width * 32
|
416 |
+
|
417 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
418 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
419 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
420 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
421 |
+
transformer_heads = transformer_width // 64
|
422 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
423 |
+
|
424 |
+
model = CLIP(
|
425 |
+
embed_dim,
|
426 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
427 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
428 |
+
)
|
429 |
+
|
430 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
431 |
+
if key in state_dict:
|
432 |
+
del state_dict[key]
|
433 |
+
|
434 |
+
convert_weights(model)
|
435 |
+
model.load_state_dict(state_dict)
|
436 |
+
return model.eval()
|
models/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
from .wrappers import *
|
models/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (188 Bytes). View file
|
|
models/__pycache__/wrappers.cpython-310.pyc
ADDED
Binary file (24.3 kB). View file
|
|
models/biggan/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import sys
|
3 |
+
|
4 |
+
module_path = Path(__file__).parent / 'pytorch_biggan'
|
5 |
+
sys.path.append(str(module_path.resolve()))
|
6 |
+
from pytorch_pretrained_biggan import *
|
7 |
+
from pytorch_pretrained_biggan.model import GenBlock
|
8 |
+
from pytorch_pretrained_biggan.file_utils import http_get, s3_get
|
models/biggan/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (526 Bytes). View file
|
|
models/biggan/pytorch_biggan/.gitignore
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
# Usually these files are written by a python script from a template
|
30 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
31 |
+
*.manifest
|
32 |
+
*.spec
|
33 |
+
|
34 |
+
# Installer logs
|
35 |
+
pip-log.txt
|
36 |
+
pip-delete-this-directory.txt
|
37 |
+
|
38 |
+
# Unit test / coverage reports
|
39 |
+
htmlcov/
|
40 |
+
.tox/
|
41 |
+
.coverage
|
42 |
+
.coverage.*
|
43 |
+
.cache
|
44 |
+
nosetests.xml
|
45 |
+
coverage.xml
|
46 |
+
*.cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# pyenv
|
76 |
+
.python-version
|
77 |
+
|
78 |
+
# celery beat schedule file
|
79 |
+
celerybeat-schedule
|
80 |
+
|
81 |
+
# SageMath parsed files
|
82 |
+
*.sage.py
|
83 |
+
|
84 |
+
# Environments
|
85 |
+
.env
|
86 |
+
.venv
|
87 |
+
env/
|
88 |
+
venv/
|
89 |
+
ENV/
|
90 |
+
env.bak/
|
91 |
+
venv.bak/
|
92 |
+
|
93 |
+
# Spyder project settings
|
94 |
+
.spyderproject
|
95 |
+
.spyproject
|
96 |
+
|
97 |
+
# Rope project settings
|
98 |
+
.ropeproject
|
99 |
+
|
100 |
+
# mkdocs documentation
|
101 |
+
/site
|
102 |
+
|
103 |
+
# mypy
|
104 |
+
.mypy_cache/
|
105 |
+
|
106 |
+
# vscode
|
107 |
+
.vscode/
|
108 |
+
|
109 |
+
# models
|
110 |
+
models/
|
models/biggan/pytorch_biggan/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2020 Erik Härkönen
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
models/biggan/pytorch_biggan/MANIFEST.in
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
include LICENSE
|
models/biggan/pytorch_biggan/README.md
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BigStyleGAN
|
2 |
+
This is a copy of HuggingFace's BigGAN implementation, with the addition of layerwise latent inputs.
|
3 |
+
|
4 |
+
# PyTorch pretrained BigGAN
|
5 |
+
An op-for-op PyTorch reimplementation of DeepMind's BigGAN model with the pre-trained weights from DeepMind.
|
6 |
+
|
7 |
+
## Introduction
|
8 |
+
|
9 |
+
This repository contains an op-for-op PyTorch reimplementation of DeepMind's BigGAN that was released with the paper [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://openreview.net/forum?id=B1xsqj09Fm) by Andrew Brock, Jeff Donahue and Karen Simonyan.
|
10 |
+
|
11 |
+
This PyTorch implementation of BigGAN is provided with the [pretrained 128x128, 256x256 and 512x512 models by DeepMind](https://tfhub.dev/deepmind/biggan-deep-128/1). We also provide the scripts used to download and convert these models from the TensorFlow Hub models.
|
12 |
+
|
13 |
+
This reimplementation was done from the raw computation graph of the Tensorflow version and behave similarly to the TensorFlow version (variance of the output difference of the order of 1e-5).
|
14 |
+
|
15 |
+
This implementation currently only contains the generator as the weights of the discriminator were not released (although the structure of the discriminator is very similar to the generator so it could be added pretty easily. Tell me if you want to do a PR on that, I would be happy to help.)
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
This repo was tested on Python 3.6 and PyTorch 1.0.1
|
20 |
+
|
21 |
+
PyTorch pretrained BigGAN can be installed from pip as follows:
|
22 |
+
```bash
|
23 |
+
pip install pytorch-pretrained-biggan
|
24 |
+
```
|
25 |
+
|
26 |
+
If you simply want to play with the GAN this should be enough.
|
27 |
+
|
28 |
+
If you want to use the conversion scripts and the imagenet utilities, additional requirements are needed, in particular TensorFlow and NLTK. To install all the requirements please use the `full_requirements.txt` file:
|
29 |
+
```bash
|
30 |
+
git clone https://github.com/huggingface/pytorch-pretrained-BigGAN.git
|
31 |
+
cd pytorch-pretrained-BigGAN
|
32 |
+
pip install -r full_requirements.txt
|
33 |
+
```
|
34 |
+
|
35 |
+
## Models
|
36 |
+
|
37 |
+
This repository provide direct and simple access to the pretrained "deep" versions of BigGAN for 128, 256 and 512 pixels resolutions as described in the [associated publication](https://openreview.net/forum?id=B1xsqj09Fm).
|
38 |
+
Here are some details on the models:
|
39 |
+
|
40 |
+
- `BigGAN-deep-128`: a 50.4M parameters model generating 128x128 pixels images, the model dump weights 201 MB,
|
41 |
+
- `BigGAN-deep-256`: a 55.9M parameters model generating 256x256 pixels images, the model dump weights 224 MB,
|
42 |
+
- `BigGAN-deep-512`: a 56.2M parameters model generating 512x512 pixels images, the model dump weights 225 MB.
|
43 |
+
|
44 |
+
Please refer to Appendix B of the paper for details on the architectures.
|
45 |
+
|
46 |
+
All models comprise pre-computed batch norm statistics for 51 truncation values between 0 and 1 (see Appendix C.1 in the paper for details).
|
47 |
+
|
48 |
+
## Usage
|
49 |
+
|
50 |
+
Here is a quick-start example using `BigGAN` with a pre-trained model.
|
51 |
+
|
52 |
+
See the [doc section](#doc) below for details on these classes and methods.
|
53 |
+
|
54 |
+
```python
|
55 |
+
import torch
|
56 |
+
from pytorch_pretrained_biggan import (BigGAN, one_hot_from_names, truncated_noise_sample,
|
57 |
+
save_as_images, display_in_terminal)
|
58 |
+
|
59 |
+
# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
|
60 |
+
import logging
|
61 |
+
logging.basicConfig(level=logging.INFO)
|
62 |
+
|
63 |
+
# Load pre-trained model tokenizer (vocabulary)
|
64 |
+
model = BigGAN.from_pretrained('biggan-deep-256')
|
65 |
+
|
66 |
+
# Prepare a input
|
67 |
+
truncation = 0.4
|
68 |
+
class_vector = one_hot_from_names(['soap bubble', 'coffee', 'mushroom'], batch_size=3)
|
69 |
+
noise_vector = truncated_noise_sample(truncation=truncation, batch_size=3)
|
70 |
+
|
71 |
+
# All in tensors
|
72 |
+
noise_vector = torch.from_numpy(noise_vector)
|
73 |
+
class_vector = torch.from_numpy(class_vector)
|
74 |
+
|
75 |
+
# If you have a GPU, put everything on cuda
|
76 |
+
noise_vector = noise_vector.to('cuda')
|
77 |
+
class_vector = class_vector.to('cuda')
|
78 |
+
model.to('cuda')
|
79 |
+
|
80 |
+
# Generate an image
|
81 |
+
with torch.no_grad():
|
82 |
+
output = model(noise_vector, class_vector, truncation)
|
83 |
+
|
84 |
+
# If you have a GPU put back on CPU
|
85 |
+
output = output.to('cpu')
|
86 |
+
|
87 |
+
# If you have a sixtel compatible terminal you can display the images in the terminal
|
88 |
+
# (see https://github.com/saitoha/libsixel for details)
|
89 |
+
display_in_terminal(output)
|
90 |
+
|
91 |
+
# Save results as png images
|
92 |
+
save_as_images(output)
|
93 |
+
```
|
94 |
+
|
95 |
+
![output_0](assets/output_0.png)
|
96 |
+
![output_1](assets/output_1.png)
|
97 |
+
![output_2](assets/output_2.png)
|
98 |
+
|
99 |
+
## Doc
|
100 |
+
|
101 |
+
### Loading DeepMind's pre-trained weights
|
102 |
+
|
103 |
+
To load one of DeepMind's pre-trained models, instantiate a `BigGAN` model with `from_pretrained()` as:
|
104 |
+
|
105 |
+
```python
|
106 |
+
model = BigGAN.from_pretrained(PRE_TRAINED_MODEL_NAME_OR_PATH, cache_dir=None)
|
107 |
+
```
|
108 |
+
|
109 |
+
where
|
110 |
+
|
111 |
+
- `PRE_TRAINED_MODEL_NAME_OR_PATH` is either:
|
112 |
+
|
113 |
+
- the shortcut name of a Google AI's or OpenAI's pre-trained model selected in the list:
|
114 |
+
|
115 |
+
- `biggan-deep-128`: 12-layer, 768-hidden, 12-heads, 110M parameters
|
116 |
+
- `biggan-deep-256`: 24-layer, 1024-hidden, 16-heads, 340M parameters
|
117 |
+
- `biggan-deep-512`: 12-layer, 768-hidden, 12-heads , 110M parameters
|
118 |
+
|
119 |
+
- a path or url to a pretrained model archive containing:
|
120 |
+
|
121 |
+
- `config.json`: a configuration file for the model, and
|
122 |
+
- `pytorch_model.bin` a PyTorch dump of a pre-trained instance of `BigGAN` (saved with the usual `torch.save()`).
|
123 |
+
|
124 |
+
If `PRE_TRAINED_MODEL_NAME_OR_PATH` is a shortcut name, the pre-trained weights will be downloaded from AWS S3 (see the links [here](pytorch_pretrained_biggan/model.py)) and stored in a cache folder to avoid future download (the cache folder can be found at `~/.pytorch_pretrained_biggan/`).
|
125 |
+
- `cache_dir` can be an optional path to a specific directory to download and cache the pre-trained model weights.
|
126 |
+
|
127 |
+
### Configuration
|
128 |
+
|
129 |
+
`BigGANConfig` is a class to store and load BigGAN configurations. It's defined in [`config.py`](./pytorch_pretrained_biggan/config.py).
|
130 |
+
|
131 |
+
Here are some details on the attributes:
|
132 |
+
|
133 |
+
- `output_dim`: output resolution of the GAN (128, 256 or 512) for the pre-trained models,
|
134 |
+
- `z_dim`: size of the noise vector (128 for the pre-trained models).
|
135 |
+
- `class_embed_dim`: size of the class embedding vectors (128 for the pre-trained models).
|
136 |
+
- `channel_width`: size of each channel (128 for the pre-trained models).
|
137 |
+
- `num_classes`: number of classes in the training dataset, like imagenet (1000 for the pre-trained models).
|
138 |
+
- `layers`: A list of layers definition. Each definition for a layer is a triple of [up-sample in the layer ? (bool), number of input channels (int), number of output channels (int)]
|
139 |
+
- `attention_layer_position`: Position of the self-attention layer in the layer hierarchy (8 for the pre-trained models).
|
140 |
+
- `eps`: epsilon value to use for spectral and batch normalization layers (1e-4 for the pre-trained models).
|
141 |
+
- `n_stats`: number of pre-computed statistics for the batch normalization layers associated to various truncation values between 0 and 1 (51 for the pre-trained models).
|
142 |
+
|
143 |
+
### Model
|
144 |
+
|
145 |
+
`BigGAN` is a PyTorch model (`torch.nn.Module`) of BigGAN defined in [`model.py`](./pytorch_pretrained_biggan/model.py). This model comprises the class embeddings (a linear layer) and the generator with a series of convolutions and conditional batch norms. The discriminator is currently not implemented since pre-trained weights have not been released for it.
|
146 |
+
|
147 |
+
The inputs and output are **identical to the TensorFlow model inputs and outputs**.
|
148 |
+
|
149 |
+
We detail them here.
|
150 |
+
|
151 |
+
`BigGAN` takes as *inputs*:
|
152 |
+
|
153 |
+
- `z`: a torch.FloatTensor of shape [batch_size, config.z_dim] with noise sampled from a truncated normal distribution, and
|
154 |
+
- `class_label`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
|
155 |
+
- `truncation`: a float between 0 (not comprised) and 1. The truncation of the truncated normal used for creating the noise vector. This truncation value is used to selecte between a set of pre-computed statistics (means and variances) for the batch norm layers.
|
156 |
+
|
157 |
+
`BigGAN` *outputs* an array of shape [batch_size, 3, resolution, resolution] where resolution is 128, 256 or 512 depending of the model:
|
158 |
+
|
159 |
+
### Utilities: Images, Noise, Imagenet classes
|
160 |
+
|
161 |
+
We provide a few utility method to use the model. They are defined in [`utils.py`](./pytorch_pretrained_biggan/utils.py).
|
162 |
+
|
163 |
+
Here are some details on these methods:
|
164 |
+
|
165 |
+
- `truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None)`:
|
166 |
+
|
167 |
+
Create a truncated noise vector.
|
168 |
+
- Params:
|
169 |
+
- batch_size: batch size.
|
170 |
+
- dim_z: dimension of z
|
171 |
+
- truncation: truncation value to use
|
172 |
+
- seed: seed for the random generator
|
173 |
+
- Output:
|
174 |
+
array of shape (batch_size, dim_z)
|
175 |
+
|
176 |
+
- `convert_to_images(obj)`:
|
177 |
+
|
178 |
+
Convert an output tensor from BigGAN in a list of images.
|
179 |
+
- Params:
|
180 |
+
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
181 |
+
- Output:
|
182 |
+
- list of Pillow Images of size (height, width)
|
183 |
+
|
184 |
+
- `save_as_images(obj, file_name='output')`:
|
185 |
+
|
186 |
+
Convert and save an output tensor from BigGAN in a list of saved images.
|
187 |
+
- Params:
|
188 |
+
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
189 |
+
- file_name: path and beggingin of filename to save.
|
190 |
+
Images will be saved as `file_name_{image_number}.png`
|
191 |
+
|
192 |
+
- `display_in_terminal(obj)`:
|
193 |
+
|
194 |
+
Convert and display an output tensor from BigGAN in the terminal. This function use `libsixel` and will only work in a libsixel-compatible terminal. Please refer to https://github.com/saitoha/libsixel for more details.
|
195 |
+
- Params:
|
196 |
+
- obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
197 |
+
- file_name: path and beggingin of filename to save.
|
198 |
+
Images will be saved as `file_name_{image_number}.png`
|
199 |
+
|
200 |
+
- `one_hot_from_int(int_or_list, batch_size=1)`:
|
201 |
+
|
202 |
+
Create a one-hot vector from a class index or a list of class indices.
|
203 |
+
- Params:
|
204 |
+
- int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
|
205 |
+
- batch_size: batch size.
|
206 |
+
- If int_or_list is an int create a batch of identical classes.
|
207 |
+
- If int_or_list is a list, we should have `len(int_or_list) == batch_size`
|
208 |
+
- Output:
|
209 |
+
- array of shape (batch_size, 1000)
|
210 |
+
|
211 |
+
- `one_hot_from_names(class_name, batch_size=1)`:
|
212 |
+
|
213 |
+
Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...). We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one. If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
|
214 |
+
- Params:
|
215 |
+
- class_name: string containing the name of an imagenet object.
|
216 |
+
- Output:
|
217 |
+
- array of shape (batch_size, 1000)
|
218 |
+
|
219 |
+
## Download and conversion scripts
|
220 |
+
|
221 |
+
Scripts to download and convert the TensorFlow models from TensorFlow Hub are provided in [./scripts](./scripts/).
|
222 |
+
|
223 |
+
The scripts can be used directly as:
|
224 |
+
```bash
|
225 |
+
./scripts/download_tf_hub_models.sh
|
226 |
+
./scripts/convert_tf_hub_models.sh
|
227 |
+
```
|
models/biggan/pytorch_biggan/assets/output_0.png
ADDED
models/biggan/pytorch_biggan/assets/output_1.png
ADDED
models/biggan/pytorch_biggan/assets/output_2.png
ADDED
models/biggan/pytorch_biggan/full_requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow
|
2 |
+
tensorflow-hub
|
3 |
+
Pillow
|
4 |
+
nltk
|
5 |
+
libsixel-python
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import BigGANConfig
|
2 |
+
from .model import BigGAN
|
3 |
+
from .file_utils import PYTORCH_PRETRAINED_BIGGAN_CACHE, cached_path
|
4 |
+
from .utils import (truncated_noise_sample, save_as_images,
|
5 |
+
convert_to_images, display_in_terminal,
|
6 |
+
one_hot_from_int, one_hot_from_names)
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/config.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
BigGAN config.
|
4 |
+
"""
|
5 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import json
|
9 |
+
|
10 |
+
class BigGANConfig(object):
|
11 |
+
""" Configuration class to store the configuration of a `BigGAN`.
|
12 |
+
Defaults are for the 128x128 model.
|
13 |
+
layers tuple are (up-sample in the layer ?, input channels, output channels)
|
14 |
+
"""
|
15 |
+
def __init__(self,
|
16 |
+
output_dim=128,
|
17 |
+
z_dim=128,
|
18 |
+
class_embed_dim=128,
|
19 |
+
channel_width=128,
|
20 |
+
num_classes=1000,
|
21 |
+
layers=[(False, 16, 16),
|
22 |
+
(True, 16, 16),
|
23 |
+
(False, 16, 16),
|
24 |
+
(True, 16, 8),
|
25 |
+
(False, 8, 8),
|
26 |
+
(True, 8, 4),
|
27 |
+
(False, 4, 4),
|
28 |
+
(True, 4, 2),
|
29 |
+
(False, 2, 2),
|
30 |
+
(True, 2, 1)],
|
31 |
+
attention_layer_position=8,
|
32 |
+
eps=1e-4,
|
33 |
+
n_stats=51):
|
34 |
+
"""Constructs BigGANConfig. """
|
35 |
+
self.output_dim = output_dim
|
36 |
+
self.z_dim = z_dim
|
37 |
+
self.class_embed_dim = class_embed_dim
|
38 |
+
self.channel_width = channel_width
|
39 |
+
self.num_classes = num_classes
|
40 |
+
self.layers = layers
|
41 |
+
self.attention_layer_position = attention_layer_position
|
42 |
+
self.eps = eps
|
43 |
+
self.n_stats = n_stats
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def from_dict(cls, json_object):
|
47 |
+
"""Constructs a `BigGANConfig` from a Python dictionary of parameters."""
|
48 |
+
config = BigGANConfig()
|
49 |
+
for key, value in json_object.items():
|
50 |
+
config.__dict__[key] = value
|
51 |
+
return config
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_json_file(cls, json_file):
|
55 |
+
"""Constructs a `BigGANConfig` from a json file of parameters."""
|
56 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
57 |
+
text = reader.read()
|
58 |
+
return cls.from_dict(json.loads(text))
|
59 |
+
|
60 |
+
def __repr__(self):
|
61 |
+
return str(self.to_json_string())
|
62 |
+
|
63 |
+
def to_dict(self):
|
64 |
+
"""Serializes this instance to a Python dictionary."""
|
65 |
+
output = copy.deepcopy(self.__dict__)
|
66 |
+
return output
|
67 |
+
|
68 |
+
def to_json_string(self):
|
69 |
+
"""Serializes this instance to a JSON string."""
|
70 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/convert_tf_to_pytorch.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
Convert a TF Hub model for BigGAN in a PT one.
|
4 |
+
"""
|
5 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
6 |
+
|
7 |
+
from itertools import chain
|
8 |
+
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.nn.functional import normalize
|
17 |
+
|
18 |
+
from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME
|
19 |
+
from .config import BigGANConfig
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None):
|
25 |
+
try:
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
import tensorflow_hub as hub
|
29 |
+
except ImportError:
|
30 |
+
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. "
|
31 |
+
"Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. "
|
32 |
+
"And see https://github.com/tensorflow/hub for installing Hub. "
|
33 |
+
"Probably pip install tensorflow tensorflow-hub")
|
34 |
+
tf.reset_default_graph()
|
35 |
+
logger.info('Loading BigGAN module from: {}'.format(tf_model_path))
|
36 |
+
module = hub.Module(tf_model_path)
|
37 |
+
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
|
38 |
+
for k, v in module.get_input_info_dict().items()}
|
39 |
+
output = module(inputs)
|
40 |
+
|
41 |
+
initializer = tf.global_variables_initializer()
|
42 |
+
sess = tf.Session()
|
43 |
+
stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ())
|
44 |
+
numpy_stacks = []
|
45 |
+
for i in stacks:
|
46 |
+
logger.info("Retrieving module_apply_default/stack_{}".format(i))
|
47 |
+
try:
|
48 |
+
stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i)
|
49 |
+
except KeyError:
|
50 |
+
break # We have all the stats
|
51 |
+
numpy_stacks.append(sess.run(stack_var))
|
52 |
+
|
53 |
+
if batch_norm_stats_path is not None:
|
54 |
+
torch.save(numpy_stacks, batch_norm_stats_path)
|
55 |
+
else:
|
56 |
+
return numpy_stacks
|
57 |
+
|
58 |
+
|
59 |
+
def build_tf_to_pytorch_map(model, config):
|
60 |
+
""" Build a map from TF variables to PyTorch modules. """
|
61 |
+
tf_to_pt_map = {}
|
62 |
+
|
63 |
+
# Embeddings and GenZ
|
64 |
+
tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight,
|
65 |
+
'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias,
|
66 |
+
'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig,
|
67 |
+
'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u})
|
68 |
+
|
69 |
+
# GBlock blocks
|
70 |
+
model_layer_idx = 0
|
71 |
+
for i, (up, in_channels, out_channels) in enumerate(config.layers):
|
72 |
+
if i == config.attention_layer_position:
|
73 |
+
model_layer_idx += 1
|
74 |
+
layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/"
|
75 |
+
layer_pnt = model.generator.layers[model_layer_idx]
|
76 |
+
for i in range(4): # Batchnorms
|
77 |
+
batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/")
|
78 |
+
batch_pnt = getattr(layer_pnt, 'bn_%d' % i)
|
79 |
+
for name in ('offset', 'scale'):
|
80 |
+
sub_module_str = batch_str + name + "/"
|
81 |
+
sub_module_pnt = getattr(batch_pnt, name)
|
82 |
+
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
|
83 |
+
sub_module_str + "u0": sub_module_pnt.weight_u})
|
84 |
+
for i in range(4): # Convolutions
|
85 |
+
conv_str = layer_str + "conv%d/" % i
|
86 |
+
conv_pnt = getattr(layer_pnt, 'conv_%d' % i)
|
87 |
+
tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias,
|
88 |
+
conv_str + "w/ema_0.9999": conv_pnt.weight_orig,
|
89 |
+
conv_str + "u0": conv_pnt.weight_u})
|
90 |
+
model_layer_idx += 1
|
91 |
+
|
92 |
+
# Attention block
|
93 |
+
layer_str = "Generator/attention/"
|
94 |
+
layer_pnt = model.generator.layers[config.attention_layer_position]
|
95 |
+
tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma})
|
96 |
+
for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'],
|
97 |
+
['g/', 'o_conv/', 'phi/', 'theta/']):
|
98 |
+
sub_module_str = layer_str + tf_name
|
99 |
+
sub_module_pnt = getattr(layer_pnt, pt_name)
|
100 |
+
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
|
101 |
+
sub_module_str + "u0": sub_module_pnt.weight_u})
|
102 |
+
|
103 |
+
# final batch norm and conv to rgb
|
104 |
+
layer_str = "Generator/BatchNorm/"
|
105 |
+
layer_pnt = model.generator.bn
|
106 |
+
tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias,
|
107 |
+
layer_str + "scale/ema_0.9999": layer_pnt.weight})
|
108 |
+
layer_str = "Generator/conv_to_rgb/"
|
109 |
+
layer_pnt = model.generator.conv_to_rgb
|
110 |
+
tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias,
|
111 |
+
layer_str + "w/ema_0.9999": layer_pnt.weight_orig,
|
112 |
+
layer_str + "u0": layer_pnt.weight_u})
|
113 |
+
return tf_to_pt_map
|
114 |
+
|
115 |
+
|
116 |
+
def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None):
|
117 |
+
""" Load tf checkpoints and standing statistics in a pytorch model
|
118 |
+
"""
|
119 |
+
try:
|
120 |
+
import numpy as np
|
121 |
+
import tensorflow as tf
|
122 |
+
except ImportError:
|
123 |
+
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
124 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
125 |
+
# Load weights from TF model
|
126 |
+
checkpoint_path = tf_model_path + "/variables/variables"
|
127 |
+
init_vars = tf.train.list_variables(checkpoint_path)
|
128 |
+
from pprint import pprint
|
129 |
+
pprint(init_vars)
|
130 |
+
|
131 |
+
# Extract batch norm statistics from model if needed
|
132 |
+
if batch_norm_stats_path:
|
133 |
+
stats = torch.load(batch_norm_stats_path)
|
134 |
+
else:
|
135 |
+
logger.info("Extracting batch norm stats")
|
136 |
+
stats = extract_batch_norm_stats(tf_model_path)
|
137 |
+
|
138 |
+
# Build TF to PyTorch weights loading map
|
139 |
+
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
140 |
+
|
141 |
+
tf_weights = {}
|
142 |
+
for name in tf_to_pt_map.keys():
|
143 |
+
array = tf.train.load_variable(checkpoint_path, name)
|
144 |
+
tf_weights[name] = array
|
145 |
+
# logger.info("Loading TF weight {} with shape {}".format(name, array.shape))
|
146 |
+
|
147 |
+
# Load parameters
|
148 |
+
with torch.no_grad():
|
149 |
+
pt_params_pnt = set()
|
150 |
+
for name, pointer in tf_to_pt_map.items():
|
151 |
+
array = tf_weights[name]
|
152 |
+
if pointer.dim() == 1:
|
153 |
+
if pointer.dim() < array.ndim:
|
154 |
+
array = np.squeeze(array)
|
155 |
+
elif pointer.dim() == 2: # Weights
|
156 |
+
array = np.transpose(array)
|
157 |
+
elif pointer.dim() == 4: # Convolutions
|
158 |
+
array = np.transpose(array, (3, 2, 0, 1))
|
159 |
+
else:
|
160 |
+
raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape))
|
161 |
+
if pointer.shape != array.shape:
|
162 |
+
raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape)))
|
163 |
+
logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape))
|
164 |
+
pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array)
|
165 |
+
tf_weights.pop(name, None)
|
166 |
+
pt_params_pnt.add(pointer.data_ptr())
|
167 |
+
|
168 |
+
# Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model):
|
169 |
+
for module in model.modules():
|
170 |
+
for n, buffer in module.named_buffers():
|
171 |
+
if n == 'weight_v':
|
172 |
+
weight_mat = module.weight_orig
|
173 |
+
weight_mat = weight_mat.reshape(weight_mat.size(0), -1)
|
174 |
+
u = module.weight_u
|
175 |
+
|
176 |
+
v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps)
|
177 |
+
buffer.data = v
|
178 |
+
pt_params_pnt.add(buffer.data_ptr())
|
179 |
+
|
180 |
+
u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps)
|
181 |
+
module.weight_u.data = u
|
182 |
+
pt_params_pnt.add(module.weight_u.data_ptr())
|
183 |
+
|
184 |
+
# Load batch norm statistics
|
185 |
+
index = 0
|
186 |
+
for layer in model.generator.layers:
|
187 |
+
if not hasattr(layer, 'bn_0'):
|
188 |
+
continue
|
189 |
+
for i in range(4): # Batchnorms
|
190 |
+
bn_pointer = getattr(layer, 'bn_%d' % i)
|
191 |
+
pointer = bn_pointer.running_means
|
192 |
+
if pointer.shape != stats[index].shape:
|
193 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
194 |
+
pointer.data = torch.from_numpy(stats[index])
|
195 |
+
pt_params_pnt.add(pointer.data_ptr())
|
196 |
+
|
197 |
+
pointer = bn_pointer.running_vars
|
198 |
+
if pointer.shape != stats[index+1].shape:
|
199 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
200 |
+
pointer.data = torch.from_numpy(stats[index+1])
|
201 |
+
pt_params_pnt.add(pointer.data_ptr())
|
202 |
+
|
203 |
+
index += 2
|
204 |
+
|
205 |
+
bn_pointer = model.generator.bn
|
206 |
+
pointer = bn_pointer.running_means
|
207 |
+
if pointer.shape != stats[index].shape:
|
208 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
209 |
+
pointer.data = torch.from_numpy(stats[index])
|
210 |
+
pt_params_pnt.add(pointer.data_ptr())
|
211 |
+
|
212 |
+
pointer = bn_pointer.running_vars
|
213 |
+
if pointer.shape != stats[index+1].shape:
|
214 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
215 |
+
pointer.data = torch.from_numpy(stats[index+1])
|
216 |
+
pt_params_pnt.add(pointer.data_ptr())
|
217 |
+
|
218 |
+
remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \
|
219 |
+
if t.data_ptr() not in pt_params_pnt)
|
220 |
+
|
221 |
+
logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys())))
|
222 |
+
logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params)))
|
223 |
+
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
228 |
+
layers=[(False, 16, 16),
|
229 |
+
(True, 16, 16),
|
230 |
+
(False, 16, 16),
|
231 |
+
(True, 16, 8),
|
232 |
+
(False, 8, 8),
|
233 |
+
(True, 8, 4),
|
234 |
+
(False, 4, 4),
|
235 |
+
(True, 4, 2),
|
236 |
+
(False, 2, 2),
|
237 |
+
(True, 2, 1)],
|
238 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
239 |
+
|
240 |
+
BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
241 |
+
layers=[(False, 16, 16),
|
242 |
+
(True, 16, 16),
|
243 |
+
(False, 16, 16),
|
244 |
+
(True, 16, 8),
|
245 |
+
(False, 8, 8),
|
246 |
+
(True, 8, 8),
|
247 |
+
(False, 8, 8),
|
248 |
+
(True, 8, 4),
|
249 |
+
(False, 4, 4),
|
250 |
+
(True, 4, 2),
|
251 |
+
(False, 2, 2),
|
252 |
+
(True, 2, 1)],
|
253 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
254 |
+
|
255 |
+
BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
256 |
+
layers=[(False, 16, 16),
|
257 |
+
(True, 16, 16),
|
258 |
+
(False, 16, 16),
|
259 |
+
(True, 16, 8),
|
260 |
+
(False, 8, 8),
|
261 |
+
(True, 8, 8),
|
262 |
+
(False, 8, 8),
|
263 |
+
(True, 8, 4),
|
264 |
+
(False, 4, 4),
|
265 |
+
(True, 4, 2),
|
266 |
+
(False, 2, 2),
|
267 |
+
(True, 2, 1),
|
268 |
+
(False, 1, 1),
|
269 |
+
(True, 1, 1)],
|
270 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
271 |
+
|
272 |
+
|
273 |
+
def main():
|
274 |
+
parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model")
|
275 |
+
parser.add_argument("--model_type", type=str, default="", required=True,
|
276 |
+
help="BigGAN model type (128, 256, 512)")
|
277 |
+
parser.add_argument("--tf_model_path", type=str, default="", required=True,
|
278 |
+
help="Path of the downloaded TF Hub model")
|
279 |
+
parser.add_argument("--pt_save_path", type=str, default="",
|
280 |
+
help="Folder to save the PyTorch model (default: Folder of the TF Hub model)")
|
281 |
+
parser.add_argument("--batch_norm_stats_path", type=str, default="",
|
282 |
+
help="Path of previously extracted batch norm statistics")
|
283 |
+
args = parser.parse_args()
|
284 |
+
|
285 |
+
logging.basicConfig(level=logging.INFO)
|
286 |
+
|
287 |
+
if not args.pt_save_path:
|
288 |
+
args.pt_save_path = args.tf_model_path
|
289 |
+
|
290 |
+
if args.model_type == "128":
|
291 |
+
config = BigGAN128
|
292 |
+
elif args.model_type == "256":
|
293 |
+
config = BigGAN256
|
294 |
+
elif args.model_type == "512":
|
295 |
+
config = BigGAN512
|
296 |
+
else:
|
297 |
+
raise ValueError("model_type should be one of 128, 256 or 512")
|
298 |
+
|
299 |
+
model = BigGAN(config)
|
300 |
+
model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path)
|
301 |
+
|
302 |
+
model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME)
|
303 |
+
config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME)
|
304 |
+
|
305 |
+
logger.info("Save model dump to {}".format(model_save_path))
|
306 |
+
torch.save(model.state_dict(), model_save_path)
|
307 |
+
logger.info("Save configuration file to {}".format(config_save_path))
|
308 |
+
with open(config_save_path, "w", encoding="utf-8") as f:
|
309 |
+
f.write(config.to_json_string())
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
main()
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/file_utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
4 |
+
Copyright by the AllenNLP authors.
|
5 |
+
"""
|
6 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import tempfile
|
13 |
+
from functools import wraps
|
14 |
+
from hashlib import sha256
|
15 |
+
import sys
|
16 |
+
from io import open
|
17 |
+
|
18 |
+
import boto3
|
19 |
+
import requests
|
20 |
+
from botocore.exceptions import ClientError
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
try:
|
24 |
+
from urllib.parse import urlparse
|
25 |
+
except ImportError:
|
26 |
+
from urlparse import urlparse
|
27 |
+
|
28 |
+
try:
|
29 |
+
from pathlib import Path
|
30 |
+
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
|
31 |
+
Path.home() / '.pytorch_pretrained_biggan'))
|
32 |
+
except (AttributeError, ImportError):
|
33 |
+
PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
|
34 |
+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
37 |
+
|
38 |
+
|
39 |
+
def url_to_filename(url, etag=None):
|
40 |
+
"""
|
41 |
+
Convert `url` into a hashed filename in a repeatable way.
|
42 |
+
If `etag` is specified, append its hash to the url's, delimited
|
43 |
+
by a period.
|
44 |
+
"""
|
45 |
+
url_bytes = url.encode('utf-8')
|
46 |
+
url_hash = sha256(url_bytes)
|
47 |
+
filename = url_hash.hexdigest()
|
48 |
+
|
49 |
+
if etag:
|
50 |
+
etag_bytes = etag.encode('utf-8')
|
51 |
+
etag_hash = sha256(etag_bytes)
|
52 |
+
filename += '.' + etag_hash.hexdigest()
|
53 |
+
|
54 |
+
return filename
|
55 |
+
|
56 |
+
|
57 |
+
def filename_to_url(filename, cache_dir=None):
|
58 |
+
"""
|
59 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
60 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
61 |
+
"""
|
62 |
+
if cache_dir is None:
|
63 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
64 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
65 |
+
cache_dir = str(cache_dir)
|
66 |
+
|
67 |
+
cache_path = os.path.join(cache_dir, filename)
|
68 |
+
if not os.path.exists(cache_path):
|
69 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
70 |
+
|
71 |
+
meta_path = cache_path + '.json'
|
72 |
+
if not os.path.exists(meta_path):
|
73 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
74 |
+
|
75 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
76 |
+
metadata = json.load(meta_file)
|
77 |
+
url = metadata['url']
|
78 |
+
etag = metadata['etag']
|
79 |
+
|
80 |
+
return url, etag
|
81 |
+
|
82 |
+
|
83 |
+
def cached_path(url_or_filename, cache_dir=None):
|
84 |
+
"""
|
85 |
+
Given something that might be a URL (or might be a local path),
|
86 |
+
determine which. If it's a URL, download the file and cache it, and
|
87 |
+
return the path to the cached file. If it's already a local path,
|
88 |
+
make sure the file exists and then return the path.
|
89 |
+
"""
|
90 |
+
if cache_dir is None:
|
91 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
92 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
93 |
+
url_or_filename = str(url_or_filename)
|
94 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
95 |
+
cache_dir = str(cache_dir)
|
96 |
+
|
97 |
+
parsed = urlparse(url_or_filename)
|
98 |
+
|
99 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
100 |
+
# URL, so get it from the cache (downloading if necessary)
|
101 |
+
return get_from_cache(url_or_filename, cache_dir)
|
102 |
+
elif os.path.exists(url_or_filename):
|
103 |
+
# File, and it exists.
|
104 |
+
return url_or_filename
|
105 |
+
elif parsed.scheme == '':
|
106 |
+
# File, but it doesn't exist.
|
107 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
108 |
+
else:
|
109 |
+
# Something unknown
|
110 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
111 |
+
|
112 |
+
|
113 |
+
def split_s3_path(url):
|
114 |
+
"""Split a full s3 path into the bucket name and path."""
|
115 |
+
parsed = urlparse(url)
|
116 |
+
if not parsed.netloc or not parsed.path:
|
117 |
+
raise ValueError("bad s3 path {}".format(url))
|
118 |
+
bucket_name = parsed.netloc
|
119 |
+
s3_path = parsed.path
|
120 |
+
# Remove '/' at beginning of path.
|
121 |
+
if s3_path.startswith("/"):
|
122 |
+
s3_path = s3_path[1:]
|
123 |
+
return bucket_name, s3_path
|
124 |
+
|
125 |
+
|
126 |
+
def s3_request(func):
|
127 |
+
"""
|
128 |
+
Wrapper function for s3 requests in order to create more helpful error
|
129 |
+
messages.
|
130 |
+
"""
|
131 |
+
|
132 |
+
@wraps(func)
|
133 |
+
def wrapper(url, *args, **kwargs):
|
134 |
+
try:
|
135 |
+
return func(url, *args, **kwargs)
|
136 |
+
except ClientError as exc:
|
137 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
138 |
+
raise EnvironmentError("file {} not found".format(url))
|
139 |
+
else:
|
140 |
+
raise
|
141 |
+
|
142 |
+
return wrapper
|
143 |
+
|
144 |
+
|
145 |
+
@s3_request
|
146 |
+
def s3_etag(url):
|
147 |
+
"""Check ETag on S3 object."""
|
148 |
+
s3_resource = boto3.resource("s3")
|
149 |
+
bucket_name, s3_path = split_s3_path(url)
|
150 |
+
s3_object = s3_resource.Object(bucket_name, s3_path)
|
151 |
+
return s3_object.e_tag
|
152 |
+
|
153 |
+
|
154 |
+
@s3_request
|
155 |
+
def s3_get(url, temp_file):
|
156 |
+
"""Pull a file directly from S3."""
|
157 |
+
s3_resource = boto3.resource("s3")
|
158 |
+
bucket_name, s3_path = split_s3_path(url)
|
159 |
+
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
160 |
+
|
161 |
+
|
162 |
+
def http_get(url, temp_file):
|
163 |
+
req = requests.get(url, stream=True)
|
164 |
+
content_length = req.headers.get('Content-Length')
|
165 |
+
total = int(content_length) if content_length is not None else None
|
166 |
+
progress = tqdm(unit="B", total=total)
|
167 |
+
for chunk in req.iter_content(chunk_size=1024):
|
168 |
+
if chunk: # filter out keep-alive new chunks
|
169 |
+
progress.update(len(chunk))
|
170 |
+
temp_file.write(chunk)
|
171 |
+
progress.close()
|
172 |
+
|
173 |
+
|
174 |
+
def get_from_cache(url, cache_dir=None):
|
175 |
+
"""
|
176 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
177 |
+
If it's not there, download it. Then return the path to the cached file.
|
178 |
+
"""
|
179 |
+
if cache_dir is None:
|
180 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
181 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
182 |
+
cache_dir = str(cache_dir)
|
183 |
+
|
184 |
+
if not os.path.exists(cache_dir):
|
185 |
+
os.makedirs(cache_dir)
|
186 |
+
|
187 |
+
# Get eTag to add to filename, if it exists.
|
188 |
+
if url.startswith("s3://"):
|
189 |
+
etag = s3_etag(url)
|
190 |
+
else:
|
191 |
+
response = requests.head(url, allow_redirects=True)
|
192 |
+
if response.status_code != 200:
|
193 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
194 |
+
.format(url, response.status_code))
|
195 |
+
etag = response.headers.get("ETag")
|
196 |
+
|
197 |
+
filename = url_to_filename(url, etag)
|
198 |
+
|
199 |
+
# get cache path to put the file
|
200 |
+
cache_path = os.path.join(cache_dir, filename)
|
201 |
+
|
202 |
+
if not os.path.exists(cache_path):
|
203 |
+
# Download to temporary file, then copy to cache dir once finished.
|
204 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
205 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
206 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
207 |
+
|
208 |
+
# GET file object
|
209 |
+
if url.startswith("s3://"):
|
210 |
+
s3_get(url, temp_file)
|
211 |
+
else:
|
212 |
+
http_get(url, temp_file)
|
213 |
+
|
214 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
215 |
+
temp_file.flush()
|
216 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
217 |
+
temp_file.seek(0)
|
218 |
+
|
219 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
220 |
+
with open(cache_path, 'wb') as cache_file:
|
221 |
+
shutil.copyfileobj(temp_file, cache_file)
|
222 |
+
|
223 |
+
logger.info("creating metadata file for %s", cache_path)
|
224 |
+
meta = {'url': url, 'etag': etag}
|
225 |
+
meta_path = cache_path + '.json'
|
226 |
+
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
227 |
+
json.dump(meta, meta_file)
|
228 |
+
|
229 |
+
logger.info("removing temp file %s", temp_file.name)
|
230 |
+
|
231 |
+
return cache_path
|
232 |
+
|
233 |
+
|
234 |
+
def read_set_from_file(filename):
|
235 |
+
'''
|
236 |
+
Extract a de-duped collection (set) of text from a file.
|
237 |
+
Expected file format is one item per line.
|
238 |
+
'''
|
239 |
+
collection = set()
|
240 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
241 |
+
for line in file_:
|
242 |
+
collection.add(line.rstrip())
|
243 |
+
return collection
|
244 |
+
|
245 |
+
|
246 |
+
def get_file_extension(path, dot=True, lower=True):
|
247 |
+
ext = os.path.splitext(path)[1]
|
248 |
+
ext = ext if dot else ext[1:]
|
249 |
+
return ext.lower() if lower else ext
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/model.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
""" BigGAN PyTorch model.
|
3 |
+
From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
|
4 |
+
By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
|
5 |
+
https://openreview.net/forum?id=B1xsqj09Fm
|
6 |
+
|
7 |
+
PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
|
8 |
+
Some part of the code are adapted from https://github.com/brain-research/self-attention-gan
|
9 |
+
|
10 |
+
This version only comprises the generator (since the discriminator's weights are not released).
|
11 |
+
This version only comprises the "deep" version of BigGAN (see publication).
|
12 |
+
|
13 |
+
Modified by Erik Härkönen:
|
14 |
+
* Added support for per-layer latent vectors
|
15 |
+
"""
|
16 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
17 |
+
|
18 |
+
import os
|
19 |
+
import logging
|
20 |
+
import math
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
import torch.nn as nn
|
25 |
+
import torch.nn.functional as F
|
26 |
+
|
27 |
+
from .config import BigGANConfig
|
28 |
+
from .file_utils import cached_path
|
29 |
+
|
30 |
+
logger = logging.getLogger(__name__)
|
31 |
+
|
32 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
33 |
+
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
|
34 |
+
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
|
35 |
+
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
|
36 |
+
}
|
37 |
+
|
38 |
+
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
39 |
+
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
|
40 |
+
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
|
41 |
+
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
|
42 |
+
}
|
43 |
+
|
44 |
+
WEIGHTS_NAME = 'pytorch_model.bin'
|
45 |
+
CONFIG_NAME = 'config.json'
|
46 |
+
|
47 |
+
|
48 |
+
def snconv2d(eps=1e-12, **kwargs):
|
49 |
+
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
|
50 |
+
|
51 |
+
def snlinear(eps=1e-12, **kwargs):
|
52 |
+
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
|
53 |
+
|
54 |
+
def sn_embedding(eps=1e-12, **kwargs):
|
55 |
+
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
|
56 |
+
|
57 |
+
class SelfAttn(nn.Module):
|
58 |
+
""" Self attention Layer"""
|
59 |
+
def __init__(self, in_channels, eps=1e-12):
|
60 |
+
super(SelfAttn, self).__init__()
|
61 |
+
self.in_channels = in_channels
|
62 |
+
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
|
63 |
+
kernel_size=1, bias=False, eps=eps)
|
64 |
+
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
|
65 |
+
kernel_size=1, bias=False, eps=eps)
|
66 |
+
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
|
67 |
+
kernel_size=1, bias=False, eps=eps)
|
68 |
+
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
|
69 |
+
kernel_size=1, bias=False, eps=eps)
|
70 |
+
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
71 |
+
self.softmax = nn.Softmax(dim=-1)
|
72 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
_, ch, h, w = x.size()
|
76 |
+
# Theta path
|
77 |
+
theta = self.snconv1x1_theta(x)
|
78 |
+
theta = theta.view(-1, ch//8, h*w)
|
79 |
+
# Phi path
|
80 |
+
phi = self.snconv1x1_phi(x)
|
81 |
+
phi = self.maxpool(phi)
|
82 |
+
phi = phi.view(-1, ch//8, h*w//4)
|
83 |
+
# Attn map
|
84 |
+
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
85 |
+
attn = self.softmax(attn)
|
86 |
+
# g path
|
87 |
+
g = self.snconv1x1_g(x)
|
88 |
+
g = self.maxpool(g)
|
89 |
+
g = g.view(-1, ch//2, h*w//4)
|
90 |
+
# Attn_g - o_conv
|
91 |
+
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
92 |
+
attn_g = attn_g.view(-1, ch//2, h, w)
|
93 |
+
attn_g = self.snconv1x1_o_conv(attn_g)
|
94 |
+
# Out
|
95 |
+
out = x + self.gamma*attn_g
|
96 |
+
return out
|
97 |
+
|
98 |
+
|
99 |
+
class BigGANBatchNorm(nn.Module):
|
100 |
+
""" This is a batch norm module that can handle conditional input and can be provided with pre-computed
|
101 |
+
activation means and variances for various truncation parameters.
|
102 |
+
|
103 |
+
We cannot just rely on torch.batch_norm since it cannot handle
|
104 |
+
batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
|
105 |
+
If you want to train this model you should add running means and variance computation logic.
|
106 |
+
"""
|
107 |
+
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
|
108 |
+
super(BigGANBatchNorm, self).__init__()
|
109 |
+
self.num_features = num_features
|
110 |
+
self.eps = eps
|
111 |
+
self.conditional = conditional
|
112 |
+
|
113 |
+
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
|
114 |
+
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
|
115 |
+
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
|
116 |
+
self.step_size = 1.0 / (n_stats - 1)
|
117 |
+
|
118 |
+
if conditional:
|
119 |
+
assert condition_vector_dim is not None
|
120 |
+
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
|
121 |
+
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
|
122 |
+
else:
|
123 |
+
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
|
124 |
+
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
|
125 |
+
|
126 |
+
def forward(self, x, truncation, condition_vector=None):
|
127 |
+
# Retreive pre-computed statistics associated to this truncation
|
128 |
+
coef, start_idx = math.modf(truncation / self.step_size)
|
129 |
+
start_idx = int(start_idx)
|
130 |
+
if coef != 0.0: # Interpolate
|
131 |
+
running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
|
132 |
+
running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
|
133 |
+
else:
|
134 |
+
running_mean = self.running_means[start_idx]
|
135 |
+
running_var = self.running_vars[start_idx]
|
136 |
+
|
137 |
+
if self.conditional:
|
138 |
+
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
139 |
+
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
140 |
+
|
141 |
+
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
|
142 |
+
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
|
143 |
+
|
144 |
+
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
|
145 |
+
else:
|
146 |
+
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
|
147 |
+
training=False, momentum=0.0, eps=self.eps)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
|
152 |
+
class GenBlock(nn.Module):
|
153 |
+
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
|
154 |
+
n_stats=51, eps=1e-12):
|
155 |
+
super(GenBlock, self).__init__()
|
156 |
+
self.up_sample = up_sample
|
157 |
+
self.drop_channels = (in_size != out_size)
|
158 |
+
middle_size = in_size // reduction_factor
|
159 |
+
|
160 |
+
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
161 |
+
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
|
162 |
+
|
163 |
+
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
164 |
+
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
|
165 |
+
|
166 |
+
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
167 |
+
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
|
168 |
+
|
169 |
+
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
170 |
+
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
|
171 |
+
|
172 |
+
self.relu = nn.ReLU()
|
173 |
+
|
174 |
+
def forward(self, x, cond_vector, truncation):
|
175 |
+
x0 = x
|
176 |
+
|
177 |
+
x = self.bn_0(x, truncation, cond_vector)
|
178 |
+
x = self.relu(x)
|
179 |
+
x = self.conv_0(x)
|
180 |
+
|
181 |
+
x = self.bn_1(x, truncation, cond_vector)
|
182 |
+
x = self.relu(x)
|
183 |
+
if self.up_sample:
|
184 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
185 |
+
x = self.conv_1(x)
|
186 |
+
|
187 |
+
x = self.bn_2(x, truncation, cond_vector)
|
188 |
+
x = self.relu(x)
|
189 |
+
x = self.conv_2(x)
|
190 |
+
|
191 |
+
x = self.bn_3(x, truncation, cond_vector)
|
192 |
+
x = self.relu(x)
|
193 |
+
x = self.conv_3(x)
|
194 |
+
|
195 |
+
if self.drop_channels:
|
196 |
+
new_channels = x0.shape[1] // 2
|
197 |
+
x0 = x0[:, :new_channels, ...]
|
198 |
+
if self.up_sample:
|
199 |
+
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
|
200 |
+
|
201 |
+
out = x + x0
|
202 |
+
return out
|
203 |
+
|
204 |
+
class Generator(nn.Module):
|
205 |
+
def __init__(self, config):
|
206 |
+
super(Generator, self).__init__()
|
207 |
+
self.config = config
|
208 |
+
ch = config.channel_width
|
209 |
+
condition_vector_dim = config.z_dim * 2
|
210 |
+
|
211 |
+
self.gen_z = snlinear(in_features=condition_vector_dim,
|
212 |
+
out_features=4 * 4 * 16 * ch, eps=config.eps)
|
213 |
+
|
214 |
+
layers = []
|
215 |
+
for i, layer in enumerate(config.layers):
|
216 |
+
if i == config.attention_layer_position:
|
217 |
+
layers.append(SelfAttn(ch*layer[1], eps=config.eps))
|
218 |
+
layers.append(GenBlock(ch*layer[1],
|
219 |
+
ch*layer[2],
|
220 |
+
condition_vector_dim,
|
221 |
+
up_sample=layer[0],
|
222 |
+
n_stats=config.n_stats,
|
223 |
+
eps=config.eps))
|
224 |
+
self.layers = nn.ModuleList(layers)
|
225 |
+
|
226 |
+
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
|
227 |
+
self.relu = nn.ReLU()
|
228 |
+
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
|
229 |
+
self.tanh = nn.Tanh()
|
230 |
+
|
231 |
+
def forward(self, cond_vector, truncation):
|
232 |
+
z = self.gen_z(cond_vector[0])
|
233 |
+
|
234 |
+
# We use this conversion step to be able to use TF weights:
|
235 |
+
# TF convention on shape is [batch, height, width, channels]
|
236 |
+
# PT convention on shape is [batch, channels, height, width]
|
237 |
+
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
|
238 |
+
z = z.permute(0, 3, 1, 2).contiguous()
|
239 |
+
|
240 |
+
cond_idx = 1
|
241 |
+
for i, layer in enumerate(self.layers):
|
242 |
+
if isinstance(layer, GenBlock):
|
243 |
+
z = layer(z, cond_vector[cond_idx], truncation)
|
244 |
+
cond_idx += 1
|
245 |
+
else:
|
246 |
+
z = layer(z)
|
247 |
+
|
248 |
+
z = self.bn(z, truncation)
|
249 |
+
z = self.relu(z)
|
250 |
+
z = self.conv_to_rgb(z)
|
251 |
+
z = z[:, :3, ...]
|
252 |
+
z = self.tanh(z)
|
253 |
+
return z
|
254 |
+
|
255 |
+
class BigGAN(nn.Module):
|
256 |
+
"""BigGAN Generator."""
|
257 |
+
|
258 |
+
@classmethod
|
259 |
+
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
260 |
+
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
261 |
+
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
262 |
+
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
263 |
+
else:
|
264 |
+
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
265 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
266 |
+
|
267 |
+
try:
|
268 |
+
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
|
269 |
+
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
270 |
+
except EnvironmentError:
|
271 |
+
logger.error("Wrong model name, should be a valid path to a folder containing "
|
272 |
+
"a {} file and a {} file or a model name in {}".format(
|
273 |
+
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
|
274 |
+
raise
|
275 |
+
|
276 |
+
logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
|
277 |
+
|
278 |
+
# Load config
|
279 |
+
config = BigGANConfig.from_json_file(resolved_config_file)
|
280 |
+
logger.info("Model config {}".format(config))
|
281 |
+
|
282 |
+
# Instantiate model.
|
283 |
+
model = cls(config, *inputs, **kwargs)
|
284 |
+
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
285 |
+
model.load_state_dict(state_dict, strict=False)
|
286 |
+
return model
|
287 |
+
|
288 |
+
def __init__(self, config):
|
289 |
+
super(BigGAN, self).__init__()
|
290 |
+
self.config = config
|
291 |
+
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
|
292 |
+
self.generator = Generator(config)
|
293 |
+
self.n_latents = len(config.layers) + 1 # one for gen_z + one per layer
|
294 |
+
|
295 |
+
def forward(self, z, class_label, truncation):
|
296 |
+
assert 0 < truncation <= 1
|
297 |
+
|
298 |
+
if not isinstance(z, list):
|
299 |
+
z = self.n_latents*[z]
|
300 |
+
|
301 |
+
if isinstance(class_label, list):
|
302 |
+
embed = [self.embeddings(l) for l in class_label]
|
303 |
+
else:
|
304 |
+
embed = self.n_latents*[self.embeddings(class_label)]
|
305 |
+
|
306 |
+
assert len(z) == self.n_latents, f'Expected {self.n_latents} latents, got {len(z)}'
|
307 |
+
assert len(embed) == self.n_latents, f'Expected {self.n_latents} class vectors, got {len(class_label)}'
|
308 |
+
|
309 |
+
cond_vectors = [torch.cat((z, e), dim=1) for (z, e) in zip(z, embed)]
|
310 |
+
z = self.generator(cond_vectors, truncation)
|
311 |
+
return z
|
312 |
+
|
313 |
+
|
314 |
+
if __name__ == "__main__":
|
315 |
+
import PIL
|
316 |
+
from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
|
317 |
+
from .convert_tf_to_pytorch import load_tf_weights_in_biggan
|
318 |
+
|
319 |
+
load_cache = False
|
320 |
+
cache_path = './saved_model.pt'
|
321 |
+
config = BigGANConfig()
|
322 |
+
model = BigGAN(config)
|
323 |
+
if not load_cache:
|
324 |
+
model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin')
|
325 |
+
torch.save(model.state_dict(), cache_path)
|
326 |
+
else:
|
327 |
+
model.load_state_dict(torch.load(cache_path))
|
328 |
+
|
329 |
+
model.eval()
|
330 |
+
|
331 |
+
truncation = 0.4
|
332 |
+
noise = truncated_noise_sample(batch_size=2, truncation=truncation)
|
333 |
+
label = one_hot_from_names('diver', batch_size=2)
|
334 |
+
|
335 |
+
# Tests
|
336 |
+
# noise = np.zeros((1, 128))
|
337 |
+
# label = [983]
|
338 |
+
|
339 |
+
noise = torch.tensor(noise, dtype=torch.float)
|
340 |
+
label = torch.tensor(label, dtype=torch.float)
|
341 |
+
with torch.no_grad():
|
342 |
+
outputs = model(noise, label, truncation)
|
343 |
+
print(outputs.shape)
|
344 |
+
|
345 |
+
save_as_images(outputs)
|
models/biggan/pytorch_biggan/pytorch_pretrained_biggan/utils.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
""" BigGAN utilities to prepare truncated noise samples and convert/save/display output images.
|
3 |
+
Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes.
|
4 |
+
We use Wordnet so you can just input a name in a string and automatically get a corresponding
|
5 |
+
imagenet class if it exists (or a hypo/hypernym exists in imagenet).
|
6 |
+
"""
|
7 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
8 |
+
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
from io import BytesIO
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from scipy.stats import truncnorm
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
NUM_CLASSES = 1000
|
19 |
+
|
20 |
+
|
21 |
+
def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):
|
22 |
+
""" Create a truncated noise vector.
|
23 |
+
Params:
|
24 |
+
batch_size: batch size.
|
25 |
+
dim_z: dimension of z
|
26 |
+
truncation: truncation value to use
|
27 |
+
seed: seed for the random generator
|
28 |
+
Output:
|
29 |
+
array of shape (batch_size, dim_z)
|
30 |
+
"""
|
31 |
+
state = None if seed is None else np.random.RandomState(seed)
|
32 |
+
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
|
33 |
+
return truncation * values
|
34 |
+
|
35 |
+
|
36 |
+
def convert_to_images(obj):
|
37 |
+
""" Convert an output tensor from BigGAN in a list of images.
|
38 |
+
Params:
|
39 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
40 |
+
Output:
|
41 |
+
list of Pillow Images of size (height, width)
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
import PIL
|
45 |
+
except ImportError:
|
46 |
+
raise ImportError("Please install Pillow to use images: pip install Pillow")
|
47 |
+
|
48 |
+
if not isinstance(obj, np.ndarray):
|
49 |
+
obj = obj.detach().numpy()
|
50 |
+
|
51 |
+
obj = obj.transpose((0, 2, 3, 1))
|
52 |
+
obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
|
53 |
+
|
54 |
+
img = []
|
55 |
+
for i, out in enumerate(obj):
|
56 |
+
out_array = np.asarray(np.uint8(out), dtype=np.uint8)
|
57 |
+
img.append(PIL.Image.fromarray(out_array))
|
58 |
+
return img
|
59 |
+
|
60 |
+
|
61 |
+
def save_as_images(obj, file_name='output'):
|
62 |
+
""" Convert and save an output tensor from BigGAN in a list of saved images.
|
63 |
+
Params:
|
64 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
65 |
+
file_name: path and beggingin of filename to save.
|
66 |
+
Images will be saved as `file_name_{image_number}.png`
|
67 |
+
"""
|
68 |
+
img = convert_to_images(obj)
|
69 |
+
|
70 |
+
for i, out in enumerate(img):
|
71 |
+
current_file_name = file_name + '_%d.png' % i
|
72 |
+
logger.info("Saving image to {}".format(current_file_name))
|
73 |
+
out.save(current_file_name, 'png')
|
74 |
+
|
75 |
+
|
76 |
+
def display_in_terminal(obj):
|
77 |
+
""" Convert and display an output tensor from BigGAN in the terminal.
|
78 |
+
This function use `libsixel` and will only work in a libsixel-compatible terminal.
|
79 |
+
Please refer to https://github.com/saitoha/libsixel for more details.
|
80 |
+
|
81 |
+
Params:
|
82 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
83 |
+
file_name: path and beggingin of filename to save.
|
84 |
+
Images will be saved as `file_name_{image_number}.png`
|
85 |
+
"""
|
86 |
+
try:
|
87 |
+
import PIL
|
88 |
+
from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize,
|
89 |
+
sixel_dither_set_palette, sixel_dither_set_pixelformat,
|
90 |
+
sixel_dither_get, sixel_encode, sixel_dither_unref,
|
91 |
+
sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888,
|
92 |
+
SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8,
|
93 |
+
SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1)
|
94 |
+
except ImportError:
|
95 |
+
raise ImportError("Display in Terminal requires Pillow, libsixel "
|
96 |
+
"and a libsixel compatible terminal. "
|
97 |
+
"Please read info at https://github.com/saitoha/libsixel "
|
98 |
+
"and install with pip install Pillow libsixel-python")
|
99 |
+
|
100 |
+
s = BytesIO()
|
101 |
+
|
102 |
+
images = convert_to_images(obj)
|
103 |
+
widths, heights = zip(*(i.size for i in images))
|
104 |
+
|
105 |
+
output_width = sum(widths)
|
106 |
+
output_height = max(heights)
|
107 |
+
|
108 |
+
output_image = PIL.Image.new('RGB', (output_width, output_height))
|
109 |
+
|
110 |
+
x_offset = 0
|
111 |
+
for im in images:
|
112 |
+
output_image.paste(im, (x_offset,0))
|
113 |
+
x_offset += im.size[0]
|
114 |
+
|
115 |
+
try:
|
116 |
+
data = output_image.tobytes()
|
117 |
+
except NotImplementedError:
|
118 |
+
data = output_image.tostring()
|
119 |
+
output = sixel_output_new(lambda data, s: s.write(data), s)
|
120 |
+
|
121 |
+
try:
|
122 |
+
if output_image.mode == 'RGBA':
|
123 |
+
dither = sixel_dither_new(256)
|
124 |
+
sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888)
|
125 |
+
elif output_image.mode == 'RGB':
|
126 |
+
dither = sixel_dither_new(256)
|
127 |
+
sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888)
|
128 |
+
elif output_image.mode == 'P':
|
129 |
+
palette = output_image.getpalette()
|
130 |
+
dither = sixel_dither_new(256)
|
131 |
+
sixel_dither_set_palette(dither, palette)
|
132 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8)
|
133 |
+
elif output_image.mode == 'L':
|
134 |
+
dither = sixel_dither_get(SIXEL_BUILTIN_G8)
|
135 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8)
|
136 |
+
elif output_image.mode == '1':
|
137 |
+
dither = sixel_dither_get(SIXEL_BUILTIN_G1)
|
138 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1)
|
139 |
+
else:
|
140 |
+
raise RuntimeError('unexpected output_image mode')
|
141 |
+
try:
|
142 |
+
sixel_encode(data, output_width, output_height, 1, dither, output)
|
143 |
+
print(s.getvalue().decode('ascii'))
|
144 |
+
finally:
|
145 |
+
sixel_dither_unref(dither)
|
146 |
+
finally:
|
147 |
+
sixel_output_unref(output)
|
148 |
+
|
149 |
+
|
150 |
+
def one_hot_from_int(int_or_list, batch_size=1):
|
151 |
+
""" Create a one-hot vector from a class index or a list of class indices.
|
152 |
+
Params:
|
153 |
+
int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
|
154 |
+
batch_size: batch size.
|
155 |
+
If int_or_list is an int create a batch of identical classes.
|
156 |
+
If int_or_list is a list, we should have `len(int_or_list) == batch_size`
|
157 |
+
Output:
|
158 |
+
array of shape (batch_size, 1000)
|
159 |
+
"""
|
160 |
+
if isinstance(int_or_list, int):
|
161 |
+
int_or_list = [int_or_list]
|
162 |
+
|
163 |
+
if len(int_or_list) == 1 and batch_size > 1:
|
164 |
+
int_or_list = [int_or_list[0]] * batch_size
|
165 |
+
|
166 |
+
assert batch_size == len(int_or_list)
|
167 |
+
|
168 |
+
array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)
|
169 |
+
for i, j in enumerate(int_or_list):
|
170 |
+
array[i, j] = 1.0
|
171 |
+
return array
|
172 |
+
|
173 |
+
|
174 |
+
def one_hot_from_names(class_name_or_list, batch_size=1):
|
175 |
+
""" Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...).
|
176 |
+
We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one.
|
177 |
+
If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
|
178 |
+
|
179 |
+
Params:
|
180 |
+
class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch).
|
181 |
+
Output:
|
182 |
+
array of shape (batch_size, 1000)
|
183 |
+
"""
|
184 |
+
try:
|
185 |
+
from nltk.corpus import wordnet as wn
|
186 |
+
except ImportError:
|
187 |
+
raise ImportError("You need to install nltk to use this function")
|
188 |
+
|
189 |
+
if not isinstance(class_name_or_list, (list, tuple)):
|
190 |
+
class_name_or_list = [class_name_or_list]
|
191 |
+
else:
|
192 |
+
batch_size = max(batch_size, len(class_name_or_list))
|
193 |
+
|
194 |
+
classes = []
|
195 |
+
for class_name in class_name_or_list:
|
196 |
+
class_name = class_name.replace(" ", "_")
|
197 |
+
|
198 |
+
original_synsets = wn.synsets(class_name)
|
199 |
+
original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names
|
200 |
+
if not original_synsets:
|
201 |
+
return None
|
202 |
+
|
203 |
+
possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets))
|
204 |
+
if possible_synsets:
|
205 |
+
classes.append(IMAGENET[possible_synsets[0].offset()])
|
206 |
+
else:
|
207 |
+
# try hypernyms and hyponyms
|
208 |
+
possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], [])
|
209 |
+
possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets))
|
210 |
+
if possible_synsets:
|
211 |
+
classes.append(IMAGENET[possible_synsets[0].offset()])
|
212 |
+
|
213 |
+
return one_hot_from_int(classes, batch_size=batch_size)
|
214 |
+
|
215 |
+
|
216 |
+
IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999}
|
models/biggan/pytorch_biggan/requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch
|
2 |
+
torch>=0.4.1
|
3 |
+
# progress bars in model download and training scripts
|
4 |
+
tqdm
|
5 |
+
# Accessing files from S3 directly.
|
6 |
+
boto3
|
7 |
+
# Used for downloading models over HTTP
|
8 |
+
requests
|
models/biggan/pytorch_biggan/scripts/convert_tf_hub_models.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
set -e
|
9 |
+
set -x
|
10 |
+
|
11 |
+
models="128 256 512"
|
12 |
+
|
13 |
+
mkdir -p models/model_128
|
14 |
+
mkdir -p models/model_256
|
15 |
+
mkdir -p models/model_512
|
16 |
+
|
17 |
+
# Convert TF Hub models.
|
18 |
+
for model in $models
|
19 |
+
do
|
20 |
+
pytorch_pretrained_biggan --model_type $model --tf_model_path models/model_$model --pt_save_path models/model_$model
|
21 |
+
done
|
models/biggan/pytorch_biggan/scripts/download_tf_hub_models.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
set -e
|
9 |
+
set -x
|
10 |
+
|
11 |
+
models="128 256 512"
|
12 |
+
|
13 |
+
mkdir -p models/model_128
|
14 |
+
mkdir -p models/model_256
|
15 |
+
mkdir -p models/model_512
|
16 |
+
|
17 |
+
# Download TF Hub models.
|
18 |
+
for model in $models
|
19 |
+
do
|
20 |
+
curl -L "https://tfhub.dev/deepmind/biggan-deep-$model/1?tf-hub-format=compressed" | tar -zxvC models/model_$model
|
21 |
+
done
|
models/biggan/pytorch_biggan/setup.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
|
3 |
+
|
4 |
+
To create the package for pypi.
|
5 |
+
|
6 |
+
1. Change the version in __init__.py and setup.py.
|
7 |
+
|
8 |
+
2. Commit these changes with the message: "Release: VERSION"
|
9 |
+
|
10 |
+
3. Add a tag in git to mark the release: "git tag VERSION -m'Adds tag VERSION for pypi' "
|
11 |
+
Push the tag to git: git push --tags origin master
|
12 |
+
|
13 |
+
4. Build both the sources and the wheel. Do not change anything in setup.py between
|
14 |
+
creating the wheel and the source distribution (obviously).
|
15 |
+
|
16 |
+
For the wheel, run: "python setup.py bdist_wheel" in the top level allennlp directory.
|
17 |
+
(this will build a wheel for the python version you use to build it - make sure you use python 3.x).
|
18 |
+
|
19 |
+
For the sources, run: "python setup.py sdist"
|
20 |
+
You should now have a /dist directory with both .whl and .tar.gz source versions of allennlp.
|
21 |
+
|
22 |
+
5. Check that everything looks correct by uploading the package to the pypi test server:
|
23 |
+
|
24 |
+
twine upload dist/* -r pypitest
|
25 |
+
(pypi suggest using twine as other methods upload files via plaintext.)
|
26 |
+
|
27 |
+
Check that you can install it in a virtualenv by running:
|
28 |
+
pip install -i https://testpypi.python.org/pypi allennlp
|
29 |
+
|
30 |
+
6. Upload the final version to actual pypi:
|
31 |
+
twine upload dist/* -r pypi
|
32 |
+
|
33 |
+
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
34 |
+
|
35 |
+
"""
|
36 |
+
from io import open
|
37 |
+
from setuptools import find_packages, setup
|
38 |
+
|
39 |
+
setup(
|
40 |
+
name="pytorch_pretrained_biggan",
|
41 |
+
version="0.1.0",
|
42 |
+
author="Thomas Wolf",
|
43 |
+
author_email="[email protected]",
|
44 |
+
description="PyTorch version of DeepMind's BigGAN model with pre-trained models",
|
45 |
+
long_description=open("README.md", "r", encoding='utf-8').read(),
|
46 |
+
long_description_content_type="text/markdown",
|
47 |
+
keywords='BIGGAN GAN deep learning google deepmind',
|
48 |
+
license='Apache',
|
49 |
+
url="https://github.com/huggingface/pytorch-pretrained-BigGAN",
|
50 |
+
packages=find_packages(exclude=["*.tests", "*.tests.*",
|
51 |
+
"tests.*", "tests"]),
|
52 |
+
install_requires=['torch>=0.4.1',
|
53 |
+
'numpy',
|
54 |
+
'boto3',
|
55 |
+
'requests',
|
56 |
+
'tqdm'],
|
57 |
+
tests_require=['pytest'],
|
58 |
+
entry_points={
|
59 |
+
'console_scripts': [
|
60 |
+
"pytorch_pretrained_biggan=pytorch_pretrained_biggan.convert_tf_to_pytorch:main",
|
61 |
+
]
|
62 |
+
},
|
63 |
+
classifiers=[
|
64 |
+
'Intended Audience :: Science/Research',
|
65 |
+
'License :: OSI Approved :: Apache Software License',
|
66 |
+
'Programming Language :: Python :: 3',
|
67 |
+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
68 |
+
],
|
69 |
+
)
|
models/stylegan/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
from pathlib import Path
|
12 |
+
import sys
|
13 |
+
|
14 |
+
#module_path = Path(__file__).parent / 'pytorch_biggan'
|
15 |
+
#sys.path.append(str(module_path.resolve()))
|
16 |
+
|
17 |
+
from .model import StyleGAN_G, NoiseLayer
|
models/stylegan/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (289 Bytes). View file
|
|
models/stylegan/__pycache__/model.cpython-310.pyc
ADDED
Binary file (16.4 kB). View file
|
|
models/stylegan/model.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 Erik Härkönen. All rights reserved.
|
2 |
+
# This file is licensed to you under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License. You may obtain a copy
|
4 |
+
# of the License at http://www.apache.org/licenses/LICENSE-2.0
|
5 |
+
|
6 |
+
# Unless required by applicable law or agreed to in writing, software distributed under
|
7 |
+
# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR REPRESENTATIONS
|
8 |
+
# OF ANY KIND, either express or implied. See the License for the specific language
|
9 |
+
# governing permissions and limitations under the License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
from collections import OrderedDict
|
16 |
+
from pathlib import Path
|
17 |
+
import requests
|
18 |
+
import pickle
|
19 |
+
import sys
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
# Reimplementation of StyleGAN in PyTorch
|
24 |
+
# Source: https://github.com/lernapparat/lernapparat/blob/master/style_gan/pytorch_style_gan.ipynb
|
25 |
+
|
26 |
+
class MyLinear(nn.Module):
|
27 |
+
"""Linear layer with equalized learning rate and custom learning rate multiplier."""
|
28 |
+
def __init__(self, input_size, output_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True):
|
29 |
+
super().__init__()
|
30 |
+
he_std = gain * input_size**(-0.5) # He init
|
31 |
+
# Equalized learning rate and custom learning rate multiplier.
|
32 |
+
if use_wscale:
|
33 |
+
init_std = 1.0 / lrmul
|
34 |
+
self.w_mul = he_std * lrmul
|
35 |
+
else:
|
36 |
+
init_std = he_std / lrmul
|
37 |
+
self.w_mul = lrmul
|
38 |
+
self.weight = torch.nn.Parameter(torch.randn(output_size, input_size) * init_std)
|
39 |
+
if bias:
|
40 |
+
self.bias = torch.nn.Parameter(torch.zeros(output_size))
|
41 |
+
self.b_mul = lrmul
|
42 |
+
else:
|
43 |
+
self.bias = None
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
bias = self.bias
|
47 |
+
if bias is not None:
|
48 |
+
bias = bias * self.b_mul
|
49 |
+
return F.linear(x, self.weight * self.w_mul, bias)
|
50 |
+
|
51 |
+
class MyConv2d(nn.Module):
|
52 |
+
"""Conv layer with equalized learning rate and custom learning rate multiplier."""
|
53 |
+
def __init__(self, input_channels, output_channels, kernel_size, gain=2**(0.5), use_wscale=False, lrmul=1, bias=True,
|
54 |
+
intermediate=None, upscale=False):
|
55 |
+
super().__init__()
|
56 |
+
if upscale:
|
57 |
+
self.upscale = Upscale2d()
|
58 |
+
else:
|
59 |
+
self.upscale = None
|
60 |
+
he_std = gain * (input_channels * kernel_size ** 2) ** (-0.5) # He init
|
61 |
+
self.kernel_size = kernel_size
|
62 |
+
if use_wscale:
|
63 |
+
init_std = 1.0 / lrmul
|
64 |
+
self.w_mul = he_std * lrmul
|
65 |
+
else:
|
66 |
+
init_std = he_std / lrmul
|
67 |
+
self.w_mul = lrmul
|
68 |
+
self.weight = torch.nn.Parameter(torch.randn(output_channels, input_channels, kernel_size, kernel_size) * init_std)
|
69 |
+
if bias:
|
70 |
+
self.bias = torch.nn.Parameter(torch.zeros(output_channels))
|
71 |
+
self.b_mul = lrmul
|
72 |
+
else:
|
73 |
+
self.bias = None
|
74 |
+
self.intermediate = intermediate
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
bias = self.bias
|
78 |
+
if bias is not None:
|
79 |
+
bias = bias * self.b_mul
|
80 |
+
|
81 |
+
have_convolution = False
|
82 |
+
if self.upscale is not None and min(x.shape[2:]) * 2 >= 128:
|
83 |
+
# this is the fused upscale + conv from StyleGAN, sadly this seems incompatible with the non-fused way
|
84 |
+
# this really needs to be cleaned up and go into the conv...
|
85 |
+
w = self.weight * self.w_mul
|
86 |
+
w = w.permute(1, 0, 2, 3)
|
87 |
+
# probably applying a conv on w would be more efficient. also this quadruples the weight (average)?!
|
88 |
+
w = F.pad(w, (1,1,1,1))
|
89 |
+
w = w[:, :, 1:, 1:]+ w[:, :, :-1, 1:] + w[:, :, 1:, :-1] + w[:, :, :-1, :-1]
|
90 |
+
x = F.conv_transpose2d(x, w, stride=2, padding=(w.size(-1)-1)//2)
|
91 |
+
have_convolution = True
|
92 |
+
elif self.upscale is not None:
|
93 |
+
x = self.upscale(x)
|
94 |
+
|
95 |
+
if not have_convolution and self.intermediate is None:
|
96 |
+
return F.conv2d(x, self.weight * self.w_mul, bias, padding=self.kernel_size//2)
|
97 |
+
elif not have_convolution:
|
98 |
+
x = F.conv2d(x, self.weight * self.w_mul, None, padding=self.kernel_size//2)
|
99 |
+
|
100 |
+
if self.intermediate is not None:
|
101 |
+
x = self.intermediate(x)
|
102 |
+
if bias is not None:
|
103 |
+
x = x + bias.view(1, -1, 1, 1)
|
104 |
+
return x
|
105 |
+
|
106 |
+
class NoiseLayer(nn.Module):
|
107 |
+
"""adds noise. noise is per pixel (constant over channels) with per-channel weight"""
|
108 |
+
def __init__(self, channels):
|
109 |
+
super().__init__()
|
110 |
+
self.weight = nn.Parameter(torch.zeros(channels))
|
111 |
+
self.noise = None
|
112 |
+
|
113 |
+
def forward(self, x, noise=None):
|
114 |
+
if noise is None and self.noise is None:
|
115 |
+
noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
|
116 |
+
elif noise is None:
|
117 |
+
# here is a little trick: if you get all the noiselayers and set each
|
118 |
+
# modules .noise attribute, you can have pre-defined noise.
|
119 |
+
# Very useful for analysis
|
120 |
+
noise = self.noise
|
121 |
+
x = x + self.weight.view(1, -1, 1, 1) * noise
|
122 |
+
return x
|
123 |
+
|
124 |
+
class StyleMod(nn.Module):
|
125 |
+
def __init__(self, latent_size, channels, use_wscale):
|
126 |
+
super(StyleMod, self).__init__()
|
127 |
+
self.lin = MyLinear(latent_size,
|
128 |
+
channels * 2,
|
129 |
+
gain=1.0, use_wscale=use_wscale)
|
130 |
+
|
131 |
+
def forward(self, x, latent):
|
132 |
+
style = self.lin(latent) # style => [batch_size, n_channels*2]
|
133 |
+
shape = [-1, 2, x.size(1)] + (x.dim() - 2) * [1]
|
134 |
+
style = style.view(shape) # [batch_size, 2, n_channels, ...]
|
135 |
+
x = x * (style[:, 0] + 1.) + style[:, 1]
|
136 |
+
return x
|
137 |
+
|
138 |
+
class PixelNormLayer(nn.Module):
|
139 |
+
def __init__(self, epsilon=1e-8):
|
140 |
+
super().__init__()
|
141 |
+
self.epsilon = epsilon
|
142 |
+
def forward(self, x):
|
143 |
+
return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
|
144 |
+
|
145 |
+
class BlurLayer(nn.Module):
|
146 |
+
def __init__(self, kernel=[1, 2, 1], normalize=True, flip=False, stride=1):
|
147 |
+
super(BlurLayer, self).__init__()
|
148 |
+
kernel=[1, 2, 1]
|
149 |
+
kernel = torch.tensor(kernel, dtype=torch.float32)
|
150 |
+
kernel = kernel[:, None] * kernel[None, :]
|
151 |
+
kernel = kernel[None, None]
|
152 |
+
if normalize:
|
153 |
+
kernel = kernel / kernel.sum()
|
154 |
+
if flip:
|
155 |
+
kernel = kernel[:, :, ::-1, ::-1]
|
156 |
+
self.register_buffer('kernel', kernel)
|
157 |
+
self.stride = stride
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
# expand kernel channels
|
161 |
+
kernel = self.kernel.expand(x.size(1), -1, -1, -1)
|
162 |
+
x = F.conv2d(
|
163 |
+
x,
|
164 |
+
kernel,
|
165 |
+
stride=self.stride,
|
166 |
+
padding=int((self.kernel.size(2)-1)/2),
|
167 |
+
groups=x.size(1)
|
168 |
+
)
|
169 |
+
return x
|
170 |
+
|
171 |
+
def upscale2d(x, factor=2, gain=1):
|
172 |
+
assert x.dim() == 4
|
173 |
+
if gain != 1:
|
174 |
+
x = x * gain
|
175 |
+
if factor != 1:
|
176 |
+
shape = x.shape
|
177 |
+
x = x.view(shape[0], shape[1], shape[2], 1, shape[3], 1).expand(-1, -1, -1, factor, -1, factor)
|
178 |
+
x = x.contiguous().view(shape[0], shape[1], factor * shape[2], factor * shape[3])
|
179 |
+
return x
|
180 |
+
|
181 |
+
class Upscale2d(nn.Module):
|
182 |
+
def __init__(self, factor=2, gain=1):
|
183 |
+
super().__init__()
|
184 |
+
assert isinstance(factor, int) and factor >= 1
|
185 |
+
self.gain = gain
|
186 |
+
self.factor = factor
|
187 |
+
def forward(self, x):
|
188 |
+
return upscale2d(x, factor=self.factor, gain=self.gain)
|
189 |
+
|
190 |
+
class G_mapping(nn.Sequential):
|
191 |
+
def __init__(self, nonlinearity='lrelu', use_wscale=True):
|
192 |
+
act, gain = {'relu': (torch.relu, np.sqrt(2)),
|
193 |
+
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
|
194 |
+
layers = [
|
195 |
+
('pixel_norm', PixelNormLayer()),
|
196 |
+
('dense0', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
197 |
+
('dense0_act', act),
|
198 |
+
('dense1', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
199 |
+
('dense1_act', act),
|
200 |
+
('dense2', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
201 |
+
('dense2_act', act),
|
202 |
+
('dense3', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
203 |
+
('dense3_act', act),
|
204 |
+
('dense4', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
205 |
+
('dense4_act', act),
|
206 |
+
('dense5', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
207 |
+
('dense5_act', act),
|
208 |
+
('dense6', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
209 |
+
('dense6_act', act),
|
210 |
+
('dense7', MyLinear(512, 512, gain=gain, lrmul=0.01, use_wscale=use_wscale)),
|
211 |
+
('dense7_act', act)
|
212 |
+
]
|
213 |
+
super().__init__(OrderedDict(layers))
|
214 |
+
|
215 |
+
def forward(self, x):
|
216 |
+
return super().forward(x)
|
217 |
+
|
218 |
+
class Truncation(nn.Module):
|
219 |
+
def __init__(self, avg_latent, max_layer=8, threshold=0.7):
|
220 |
+
super().__init__()
|
221 |
+
self.max_layer = max_layer
|
222 |
+
self.threshold = threshold
|
223 |
+
self.register_buffer('avg_latent', avg_latent)
|
224 |
+
def forward(self, x):
|
225 |
+
assert x.dim() == 3
|
226 |
+
interp = torch.lerp(self.avg_latent, x, self.threshold)
|
227 |
+
do_trunc = (torch.arange(x.size(1)) < self.max_layer).view(1, -1, 1)
|
228 |
+
return torch.where(do_trunc, interp, x)
|
229 |
+
|
230 |
+
class LayerEpilogue(nn.Module):
|
231 |
+
"""Things to do at the end of each layer."""
|
232 |
+
def __init__(self, channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
|
233 |
+
super().__init__()
|
234 |
+
layers = []
|
235 |
+
if use_noise:
|
236 |
+
layers.append(('noise', NoiseLayer(channels)))
|
237 |
+
layers.append(('activation', activation_layer))
|
238 |
+
if use_pixel_norm:
|
239 |
+
layers.append(('pixel_norm', PixelNorm()))
|
240 |
+
if use_instance_norm:
|
241 |
+
layers.append(('instance_norm', nn.InstanceNorm2d(channels)))
|
242 |
+
self.top_epi = nn.Sequential(OrderedDict(layers))
|
243 |
+
if use_styles:
|
244 |
+
self.style_mod = StyleMod(dlatent_size, channels, use_wscale=use_wscale)
|
245 |
+
else:
|
246 |
+
self.style_mod = None
|
247 |
+
def forward(self, x, dlatents_in_slice=None):
|
248 |
+
x = self.top_epi(x)
|
249 |
+
if self.style_mod is not None:
|
250 |
+
x = self.style_mod(x, dlatents_in_slice)
|
251 |
+
else:
|
252 |
+
assert dlatents_in_slice is None
|
253 |
+
return x
|
254 |
+
|
255 |
+
|
256 |
+
class InputBlock(nn.Module):
|
257 |
+
def __init__(self, nf, dlatent_size, const_input_layer, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
|
258 |
+
super().__init__()
|
259 |
+
self.const_input_layer = const_input_layer
|
260 |
+
self.nf = nf
|
261 |
+
if self.const_input_layer:
|
262 |
+
# called 'const' in tf
|
263 |
+
self.const = nn.Parameter(torch.ones(1, nf, 4, 4))
|
264 |
+
self.bias = nn.Parameter(torch.ones(nf))
|
265 |
+
else:
|
266 |
+
self.dense = MyLinear(dlatent_size, nf*16, gain=gain/4, use_wscale=use_wscale) # tweak gain to match the official implementation of Progressing GAN
|
267 |
+
self.epi1 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
|
268 |
+
self.conv = MyConv2d(nf, nf, 3, gain=gain, use_wscale=use_wscale)
|
269 |
+
self.epi2 = LayerEpilogue(nf, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
|
270 |
+
|
271 |
+
def forward(self, dlatents_in_range):
|
272 |
+
batch_size = dlatents_in_range.size(0)
|
273 |
+
if self.const_input_layer:
|
274 |
+
x = self.const.expand(batch_size, -1, -1, -1)
|
275 |
+
x = x + self.bias.view(1, -1, 1, 1)
|
276 |
+
else:
|
277 |
+
x = self.dense(dlatents_in_range[:, 0]).view(batch_size, self.nf, 4, 4)
|
278 |
+
x = self.epi1(x, dlatents_in_range[:, 0])
|
279 |
+
x = self.conv(x)
|
280 |
+
x = self.epi2(x, dlatents_in_range[:, 1])
|
281 |
+
return x
|
282 |
+
|
283 |
+
|
284 |
+
class GSynthesisBlock(nn.Module):
|
285 |
+
def __init__(self, in_channels, out_channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer):
|
286 |
+
# 2**res x 2**res # res = 3..resolution_log2
|
287 |
+
super().__init__()
|
288 |
+
if blur_filter:
|
289 |
+
blur = BlurLayer(blur_filter)
|
290 |
+
else:
|
291 |
+
blur = None
|
292 |
+
self.conv0_up = MyConv2d(in_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale,
|
293 |
+
intermediate=blur, upscale=True)
|
294 |
+
self.epi1 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
|
295 |
+
self.conv1 = MyConv2d(out_channels, out_channels, kernel_size=3, gain=gain, use_wscale=use_wscale)
|
296 |
+
self.epi2 = LayerEpilogue(out_channels, dlatent_size, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, activation_layer)
|
297 |
+
|
298 |
+
def forward(self, x, dlatents_in_range):
|
299 |
+
x = self.conv0_up(x)
|
300 |
+
x = self.epi1(x, dlatents_in_range[:, 0])
|
301 |
+
x = self.conv1(x)
|
302 |
+
x = self.epi2(x, dlatents_in_range[:, 1])
|
303 |
+
return x
|
304 |
+
|
305 |
+
class G_synthesis(nn.Module):
|
306 |
+
def __init__(self,
|
307 |
+
dlatent_size = 512, # Disentangled latent (W) dimensionality.
|
308 |
+
num_channels = 3, # Number of output color channels.
|
309 |
+
resolution = 1024, # Output resolution.
|
310 |
+
fmap_base = 8192, # Overall multiplier for the number of feature maps.
|
311 |
+
fmap_decay = 1.0, # log2 feature map reduction when doubling the resolution.
|
312 |
+
fmap_max = 512, # Maximum number of feature maps in any layer.
|
313 |
+
use_styles = True, # Enable style inputs?
|
314 |
+
const_input_layer = True, # First layer is a learned constant?
|
315 |
+
use_noise = True, # Enable noise inputs?
|
316 |
+
randomize_noise = True, # True = randomize noise inputs every time (non-deterministic), False = read noise inputs from variables.
|
317 |
+
nonlinearity = 'lrelu', # Activation function: 'relu', 'lrelu'
|
318 |
+
use_wscale = True, # Enable equalized learning rate?
|
319 |
+
use_pixel_norm = False, # Enable pixelwise feature vector normalization?
|
320 |
+
use_instance_norm = True, # Enable instance normalization?
|
321 |
+
dtype = torch.float32, # Data type to use for activations and outputs.
|
322 |
+
blur_filter = [1,2,1], # Low-pass filter to apply when resampling activations. None = no filtering.
|
323 |
+
):
|
324 |
+
|
325 |
+
super().__init__()
|
326 |
+
def nf(stage):
|
327 |
+
return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)
|
328 |
+
self.dlatent_size = dlatent_size
|
329 |
+
resolution_log2 = int(np.log2(resolution))
|
330 |
+
assert resolution == 2**resolution_log2 and resolution >= 4
|
331 |
+
|
332 |
+
act, gain = {'relu': (torch.relu, np.sqrt(2)),
|
333 |
+
'lrelu': (nn.LeakyReLU(negative_slope=0.2), np.sqrt(2))}[nonlinearity]
|
334 |
+
num_layers = resolution_log2 * 2 - 2
|
335 |
+
num_styles = num_layers if use_styles else 1
|
336 |
+
torgbs = []
|
337 |
+
blocks = []
|
338 |
+
for res in range(2, resolution_log2 + 1):
|
339 |
+
channels = nf(res-1)
|
340 |
+
name = '{s}x{s}'.format(s=2**res)
|
341 |
+
if res == 2:
|
342 |
+
blocks.append((name,
|
343 |
+
InputBlock(channels, dlatent_size, const_input_layer, gain, use_wscale,
|
344 |
+
use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
|
345 |
+
|
346 |
+
else:
|
347 |
+
blocks.append((name,
|
348 |
+
GSynthesisBlock(last_channels, channels, blur_filter, dlatent_size, gain, use_wscale, use_noise, use_pixel_norm, use_instance_norm, use_styles, act)))
|
349 |
+
last_channels = channels
|
350 |
+
self.torgb = MyConv2d(channels, num_channels, 1, gain=1, use_wscale=use_wscale)
|
351 |
+
self.blocks = nn.ModuleDict(OrderedDict(blocks))
|
352 |
+
|
353 |
+
def forward(self, dlatents_in):
|
354 |
+
# Input: Disentangled latents (W) [minibatch, num_layers, dlatent_size].
|
355 |
+
# lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0), trainable=False), dtype)
|
356 |
+
batch_size = dlatents_in.size(0)
|
357 |
+
for i, m in enumerate(self.blocks.values()):
|
358 |
+
if i == 0:
|
359 |
+
x = m(dlatents_in[:, 2*i:2*i+2])
|
360 |
+
else:
|
361 |
+
x = m(x, dlatents_in[:, 2*i:2*i+2])
|
362 |
+
rgb = self.torgb(x)
|
363 |
+
return rgb
|
364 |
+
|
365 |
+
|
366 |
+
class StyleGAN_G(nn.Sequential):
|
367 |
+
def __init__(self, resolution, truncation=1.0):
|
368 |
+
self.resolution = resolution
|
369 |
+
self.layers = OrderedDict([
|
370 |
+
('g_mapping', G_mapping()),
|
371 |
+
#('truncation', Truncation(avg_latent)),
|
372 |
+
('g_synthesis', G_synthesis(resolution=resolution)),
|
373 |
+
])
|
374 |
+
super().__init__(self.layers)
|
375 |
+
|
376 |
+
def forward(self, x, latent_is_w=False):
|
377 |
+
if isinstance(x, list):
|
378 |
+
assert len(x) == 18, 'Must provide 1 or 18 latents'
|
379 |
+
if not latent_is_w:
|
380 |
+
x = [self.layers['g_mapping'].forward(l) for l in x]
|
381 |
+
x = torch.stack(x, dim=1)
|
382 |
+
else:
|
383 |
+
if not latent_is_w:
|
384 |
+
x = self.layers['g_mapping'].forward(x)
|
385 |
+
x = x.unsqueeze(1).expand(-1, 18, -1)
|
386 |
+
|
387 |
+
x = self.layers['g_synthesis'].forward(x)
|
388 |
+
|
389 |
+
return x
|
390 |
+
|
391 |
+
# From: https://github.com/lernapparat/lernapparat/releases/download/v2019-02-01/
|
392 |
+
def load_weights(self, checkpoint):
|
393 |
+
self.load_state_dict(torch.load(checkpoint))
|
394 |
+
|
395 |
+
def export_from_tf(self, pickle_path):
|
396 |
+
module_path = Path(__file__).parent / 'stylegan_tf'
|
397 |
+
sys.path.append(str(module_path.resolve()))
|
398 |
+
|
399 |
+
import dnnlib, dnnlib.tflib, pickle, torch, collections
|
400 |
+
dnnlib.tflib.init_tf()
|
401 |
+
|
402 |
+
weights = pickle.load(open(pickle_path,'rb'))
|
403 |
+
weights_pt = [collections.OrderedDict([(k, torch.from_numpy(v.value().eval())) for k,v in w.trainables.items()]) for w in weights]
|
404 |
+
#torch.save(weights_pt, pytorch_name)
|
405 |
+
|
406 |
+
# then on the PyTorch side run
|
407 |
+
state_G, state_D, state_Gs = weights_pt #torch.load('./karras2019stylegan-ffhq-1024x1024.pt')
|
408 |
+
def key_translate(k):
|
409 |
+
k = k.lower().split('/')
|
410 |
+
if k[0] == 'g_synthesis':
|
411 |
+
if not k[1].startswith('torgb'):
|
412 |
+
k.insert(1, 'blocks')
|
413 |
+
k = '.'.join(k)
|
414 |
+
k = (k.replace('const.const','const').replace('const.bias','bias').replace('const.stylemod','epi1.style_mod.lin')
|
415 |
+
.replace('const.noise.weight','epi1.top_epi.noise.weight')
|
416 |
+
.replace('conv.noise.weight','epi2.top_epi.noise.weight')
|
417 |
+
.replace('conv.stylemod','epi2.style_mod.lin')
|
418 |
+
.replace('conv0_up.noise.weight', 'epi1.top_epi.noise.weight')
|
419 |
+
.replace('conv0_up.stylemod','epi1.style_mod.lin')
|
420 |
+
.replace('conv1.noise.weight', 'epi2.top_epi.noise.weight')
|
421 |
+
.replace('conv1.stylemod','epi2.style_mod.lin')
|
422 |
+
.replace('torgb_lod0','torgb'))
|
423 |
+
else:
|
424 |
+
k = '.'.join(k)
|
425 |
+
return k
|
426 |
+
|
427 |
+
def weight_translate(k, w):
|
428 |
+
k = key_translate(k)
|
429 |
+
if k.endswith('.weight'):
|
430 |
+
if w.dim() == 2:
|
431 |
+
w = w.t()
|
432 |
+
elif w.dim() == 1:
|
433 |
+
pass
|
434 |
+
else:
|
435 |
+
assert w.dim() == 4
|
436 |
+
w = w.permute(3, 2, 0, 1)
|
437 |
+
return w
|
438 |
+
|
439 |
+
# we delete the useless torgb filters
|
440 |
+
param_dict = {key_translate(k) : weight_translate(k, v) for k,v in state_Gs.items() if 'torgb_lod' not in key_translate(k)}
|
441 |
+
if 1:
|
442 |
+
sd_shapes = {k : v.shape for k,v in self.state_dict().items()}
|
443 |
+
param_shapes = {k : v.shape for k,v in param_dict.items() }
|
444 |
+
|
445 |
+
for k in list(sd_shapes)+list(param_shapes):
|
446 |
+
pds = param_shapes.get(k)
|
447 |
+
sds = sd_shapes.get(k)
|
448 |
+
if pds is None:
|
449 |
+
print ("sd only", k, sds)
|
450 |
+
elif sds is None:
|
451 |
+
print ("pd only", k, pds)
|
452 |
+
elif sds != pds:
|
453 |
+
print ("mismatch!", k, pds, sds)
|
454 |
+
|
455 |
+
self.load_state_dict(param_dict, strict=False) # needed for the blur kernels
|
456 |
+
torch.save(self.state_dict(), Path(pickle_path).with_suffix('.pt'))
|
models/stylegan/stylegan_tf/LICENSE.txt
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
|
3 |
+
|
4 |
+
Attribution-NonCommercial 4.0 International
|
5 |
+
|
6 |
+
=======================================================================
|
7 |
+
|
8 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
9 |
+
does not provide legal services or legal advice. Distribution of
|
10 |
+
Creative Commons public licenses does not create a lawyer-client or
|
11 |
+
other relationship. Creative Commons makes its licenses and related
|
12 |
+
information available on an "as-is" basis. Creative Commons gives no
|
13 |
+
warranties regarding its licenses, any material licensed under their
|
14 |
+
terms and conditions, or any related information. Creative Commons
|
15 |
+
disclaims all liability for damages resulting from their use to the
|
16 |
+
fullest extent possible.
|
17 |
+
|
18 |
+
Using Creative Commons Public Licenses
|
19 |
+
|
20 |
+
Creative Commons public licenses provide a standard set of terms and
|
21 |
+
conditions that creators and other rights holders may use to share
|
22 |
+
original works of authorship and other material subject to copyright
|
23 |
+
and certain other rights specified in the public license below. The
|
24 |
+
following considerations are for informational purposes only, are not
|
25 |
+
exhaustive, and do not form part of our licenses.
|
26 |
+
|
27 |
+
Considerations for licensors: Our public licenses are
|
28 |
+
intended for use by those authorized to give the public
|
29 |
+
permission to use material in ways otherwise restricted by
|
30 |
+
copyright and certain other rights. Our licenses are
|
31 |
+
irrevocable. Licensors should read and understand the terms
|
32 |
+
and conditions of the license they choose before applying it.
|
33 |
+
Licensors should also secure all rights necessary before
|
34 |
+
applying our licenses so that the public can reuse the
|
35 |
+
material as expected. Licensors should clearly mark any
|
36 |
+
material not subject to the license. This includes other CC-
|
37 |
+
licensed material, or material used under an exception or
|
38 |
+
limitation to copyright. More considerations for licensors:
|
39 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
40 |
+
|
41 |
+
Considerations for the public: By using one of our public
|
42 |
+
licenses, a licensor grants the public permission to use the
|
43 |
+
licensed material under specified terms and conditions. If
|
44 |
+
the licensor's permission is not necessary for any reason--for
|
45 |
+
example, because of any applicable exception or limitation to
|
46 |
+
copyright--then that use is not regulated by the license. Our
|
47 |
+
licenses grant only permissions under copyright and certain
|
48 |
+
other rights that a licensor has authority to grant. Use of
|
49 |
+
the licensed material may still be restricted for other
|
50 |
+
reasons, including because others have copyright or other
|
51 |
+
rights in the material. A licensor may make special requests,
|
52 |
+
such as asking that all changes be marked or described.
|
53 |
+
Although not required by our licenses, you are encouraged to
|
54 |
+
respect those requests where reasonable. More_considerations
|
55 |
+
for the public:
|
56 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
57 |
+
|
58 |
+
=======================================================================
|
59 |
+
|
60 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
61 |
+
License
|
62 |
+
|
63 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
64 |
+
to be bound by the terms and conditions of this Creative Commons
|
65 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
66 |
+
License"). To the extent this Public License may be interpreted as a
|
67 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
68 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
69 |
+
such rights in consideration of benefits the Licensor receives from
|
70 |
+
making the Licensed Material available under these terms and
|
71 |
+
conditions.
|
72 |
+
|
73 |
+
|
74 |
+
Section 1 -- Definitions.
|
75 |
+
|
76 |
+
a. Adapted Material means material subject to Copyright and Similar
|
77 |
+
Rights that is derived from or based upon the Licensed Material
|
78 |
+
and in which the Licensed Material is translated, altered,
|
79 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
80 |
+
permission under the Copyright and Similar Rights held by the
|
81 |
+
Licensor. For purposes of this Public License, where the Licensed
|
82 |
+
Material is a musical work, performance, or sound recording,
|
83 |
+
Adapted Material is always produced where the Licensed Material is
|
84 |
+
synched in timed relation with a moving image.
|
85 |
+
|
86 |
+
b. Adapter's License means the license You apply to Your Copyright
|
87 |
+
and Similar Rights in Your contributions to Adapted Material in
|
88 |
+
accordance with the terms and conditions of this Public License.
|
89 |
+
|
90 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
91 |
+
closely related to copyright including, without limitation,
|
92 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
93 |
+
Rights, without regard to how the rights are labeled or
|
94 |
+
categorized. For purposes of this Public License, the rights
|
95 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
96 |
+
Rights.
|
97 |
+
d. Effective Technological Measures means those measures that, in the
|
98 |
+
absence of proper authority, may not be circumvented under laws
|
99 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
100 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
101 |
+
agreements.
|
102 |
+
|
103 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
104 |
+
any other exception or limitation to Copyright and Similar Rights
|
105 |
+
that applies to Your use of the Licensed Material.
|
106 |
+
|
107 |
+
f. Licensed Material means the artistic or literary work, database,
|
108 |
+
or other material to which the Licensor applied this Public
|
109 |
+
License.
|
110 |
+
|
111 |
+
g. Licensed Rights means the rights granted to You subject to the
|
112 |
+
terms and conditions of this Public License, which are limited to
|
113 |
+
all Copyright and Similar Rights that apply to Your use of the
|
114 |
+
Licensed Material and that the Licensor has authority to license.
|
115 |
+
|
116 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
117 |
+
under this Public License.
|
118 |
+
|
119 |
+
i. NonCommercial means not primarily intended for or directed towards
|
120 |
+
commercial advantage or monetary compensation. For purposes of
|
121 |
+
this Public License, the exchange of the Licensed Material for
|
122 |
+
other material subject to Copyright and Similar Rights by digital
|
123 |
+
file-sharing or similar means is NonCommercial provided there is
|
124 |
+
no payment of monetary compensation in connection with the
|
125 |
+
exchange.
|
126 |
+
|
127 |
+
j. Share means to provide material to the public by any means or
|
128 |
+
process that requires permission under the Licensed Rights, such
|
129 |
+
as reproduction, public display, public performance, distribution,
|
130 |
+
dissemination, communication, or importation, and to make material
|
131 |
+
available to the public including in ways that members of the
|
132 |
+
public may access the material from a place and at a time
|
133 |
+
individually chosen by them.
|
134 |
+
|
135 |
+
k. Sui Generis Database Rights means rights other than copyright
|
136 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
137 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
138 |
+
as amended and/or succeeded, as well as other essentially
|
139 |
+
equivalent rights anywhere in the world.
|
140 |
+
|
141 |
+
l. You means the individual or entity exercising the Licensed Rights
|
142 |
+
under this Public License. Your has a corresponding meaning.
|
143 |
+
|
144 |
+
|
145 |
+
Section 2 -- Scope.
|
146 |
+
|
147 |
+
a. License grant.
|
148 |
+
|
149 |
+
1. Subject to the terms and conditions of this Public License,
|
150 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
151 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
152 |
+
exercise the Licensed Rights in the Licensed Material to:
|
153 |
+
|
154 |
+
a. reproduce and Share the Licensed Material, in whole or
|
155 |
+
in part, for NonCommercial purposes only; and
|
156 |
+
|
157 |
+
b. produce, reproduce, and Share Adapted Material for
|
158 |
+
NonCommercial purposes only.
|
159 |
+
|
160 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
161 |
+
Exceptions and Limitations apply to Your use, this Public
|
162 |
+
License does not apply, and You do not need to comply with
|
163 |
+
its terms and conditions.
|
164 |
+
|
165 |
+
3. Term. The term of this Public License is specified in Section
|
166 |
+
6(a).
|
167 |
+
|
168 |
+
4. Media and formats; technical modifications allowed. The
|
169 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
170 |
+
all media and formats whether now known or hereafter created,
|
171 |
+
and to make technical modifications necessary to do so. The
|
172 |
+
Licensor waives and/or agrees not to assert any right or
|
173 |
+
authority to forbid You from making technical modifications
|
174 |
+
necessary to exercise the Licensed Rights, including
|
175 |
+
technical modifications necessary to circumvent Effective
|
176 |
+
Technological Measures. For purposes of this Public License,
|
177 |
+
simply making modifications authorized by this Section 2(a)
|
178 |
+
(4) never produces Adapted Material.
|
179 |
+
|
180 |
+
5. Downstream recipients.
|
181 |
+
|
182 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
183 |
+
recipient of the Licensed Material automatically
|
184 |
+
receives an offer from the Licensor to exercise the
|
185 |
+
Licensed Rights under the terms and conditions of this
|
186 |
+
Public License.
|
187 |
+
|
188 |
+
b. No downstream restrictions. You may not offer or impose
|
189 |
+
any additional or different terms or conditions on, or
|
190 |
+
apply any Effective Technological Measures to, the
|
191 |
+
Licensed Material if doing so restricts exercise of the
|
192 |
+
Licensed Rights by any recipient of the Licensed
|
193 |
+
Material.
|
194 |
+
|
195 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
196 |
+
may be construed as permission to assert or imply that You
|
197 |
+
are, or that Your use of the Licensed Material is, connected
|
198 |
+
with, or sponsored, endorsed, or granted official status by,
|
199 |
+
the Licensor or others designated to receive attribution as
|
200 |
+
provided in Section 3(a)(1)(A)(i).
|
201 |
+
|
202 |
+
b. Other rights.
|
203 |
+
|
204 |
+
1. Moral rights, such as the right of integrity, are not
|
205 |
+
licensed under this Public License, nor are publicity,
|
206 |
+
privacy, and/or other similar personality rights; however, to
|
207 |
+
the extent possible, the Licensor waives and/or agrees not to
|
208 |
+
assert any such rights held by the Licensor to the limited
|
209 |
+
extent necessary to allow You to exercise the Licensed
|
210 |
+
Rights, but not otherwise.
|
211 |
+
|
212 |
+
2. Patent and trademark rights are not licensed under this
|
213 |
+
Public License.
|
214 |
+
|
215 |
+
3. To the extent possible, the Licensor waives any right to
|
216 |
+
collect royalties from You for the exercise of the Licensed
|
217 |
+
Rights, whether directly or through a collecting society
|
218 |
+
under any voluntary or waivable statutory or compulsory
|
219 |
+
licensing scheme. In all other cases the Licensor expressly
|
220 |
+
reserves any right to collect such royalties, including when
|
221 |
+
the Licensed Material is used other than for NonCommercial
|
222 |
+
purposes.
|
223 |
+
|
224 |
+
|
225 |
+
Section 3 -- License Conditions.
|
226 |
+
|
227 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
228 |
+
following conditions.
|
229 |
+
|
230 |
+
a. Attribution.
|
231 |
+
|
232 |
+
1. If You Share the Licensed Material (including in modified
|
233 |
+
form), You must:
|
234 |
+
|
235 |
+
a. retain the following if it is supplied by the Licensor
|
236 |
+
with the Licensed Material:
|
237 |
+
|
238 |
+
i. identification of the creator(s) of the Licensed
|
239 |
+
Material and any others designated to receive
|
240 |
+
attribution, in any reasonable manner requested by
|
241 |
+
the Licensor (including by pseudonym if
|
242 |
+
designated);
|
243 |
+
|
244 |
+
ii. a copyright notice;
|
245 |
+
|
246 |
+
iii. a notice that refers to this Public License;
|
247 |
+
|
248 |
+
iv. a notice that refers to the disclaimer of
|
249 |
+
warranties;
|
250 |
+
|
251 |
+
v. a URI or hyperlink to the Licensed Material to the
|
252 |
+
extent reasonably practicable;
|
253 |
+
|
254 |
+
b. indicate if You modified the Licensed Material and
|
255 |
+
retain an indication of any previous modifications; and
|
256 |
+
|
257 |
+
c. indicate the Licensed Material is licensed under this
|
258 |
+
Public License, and include the text of, or the URI or
|
259 |
+
hyperlink to, this Public License.
|
260 |
+
|
261 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
262 |
+
reasonable manner based on the medium, means, and context in
|
263 |
+
which You Share the Licensed Material. For example, it may be
|
264 |
+
reasonable to satisfy the conditions by providing a URI or
|
265 |
+
hyperlink to a resource that includes the required
|
266 |
+
information.
|
267 |
+
|
268 |
+
3. If requested by the Licensor, You must remove any of the
|
269 |
+
information required by Section 3(a)(1)(A) to the extent
|
270 |
+
reasonably practicable.
|
271 |
+
|
272 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
273 |
+
License You apply must not prevent recipients of the Adapted
|
274 |
+
Material from complying with this Public License.
|
275 |
+
|
276 |
+
|
277 |
+
Section 4 -- Sui Generis Database Rights.
|
278 |
+
|
279 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
280 |
+
apply to Your use of the Licensed Material:
|
281 |
+
|
282 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
283 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
284 |
+
portion of the contents of the database for NonCommercial purposes
|
285 |
+
only;
|
286 |
+
|
287 |
+
b. if You include all or a substantial portion of the database
|
288 |
+
contents in a database in which You have Sui Generis Database
|
289 |
+
Rights, then the database in which You have Sui Generis Database
|
290 |
+
Rights (but not its individual contents) is Adapted Material; and
|
291 |
+
|
292 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
293 |
+
all or a substantial portion of the contents of the database.
|
294 |
+
|
295 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
296 |
+
replace Your obligations under this Public License where the Licensed
|
297 |
+
Rights include other Copyright and Similar Rights.
|
298 |
+
|
299 |
+
|
300 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
301 |
+
|
302 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
303 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
304 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
305 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
306 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
307 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
308 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
309 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
310 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
311 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
312 |
+
|
313 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
314 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
315 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
316 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
317 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
318 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
319 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
320 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
321 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
322 |
+
|
323 |
+
c. The disclaimer of warranties and limitation of liability provided
|
324 |
+
above shall be interpreted in a manner that, to the extent
|
325 |
+
possible, most closely approximates an absolute disclaimer and
|
326 |
+
waiver of all liability.
|
327 |
+
|
328 |
+
|
329 |
+
Section 6 -- Term and Termination.
|
330 |
+
|
331 |
+
a. This Public License applies for the term of the Copyright and
|
332 |
+
Similar Rights licensed here. However, if You fail to comply with
|
333 |
+
this Public License, then Your rights under this Public License
|
334 |
+
terminate automatically.
|
335 |
+
|
336 |
+
b. Where Your right to use the Licensed Material has terminated under
|
337 |
+
Section 6(a), it reinstates:
|
338 |
+
|
339 |
+
1. automatically as of the date the violation is cured, provided
|
340 |
+
it is cured within 30 days of Your discovery of the
|
341 |
+
violation; or
|
342 |
+
|
343 |
+
2. upon express reinstatement by the Licensor.
|
344 |
+
|
345 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
346 |
+
right the Licensor may have to seek remedies for Your violations
|
347 |
+
of this Public License.
|
348 |
+
|
349 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
350 |
+
Licensed Material under separate terms or conditions or stop
|
351 |
+
distributing the Licensed Material at any time; however, doing so
|
352 |
+
will not terminate this Public License.
|
353 |
+
|
354 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
355 |
+
License.
|
356 |
+
|
357 |
+
|
358 |
+
Section 7 -- Other Terms and Conditions.
|
359 |
+
|
360 |
+
a. The Licensor shall not be bound by any additional or different
|
361 |
+
terms or conditions communicated by You unless expressly agreed.
|
362 |
+
|
363 |
+
b. Any arrangements, understandings, or agreements regarding the
|
364 |
+
Licensed Material not stated herein are separate from and
|
365 |
+
independent of the terms and conditions of this Public License.
|
366 |
+
|
367 |
+
|
368 |
+
Section 8 -- Interpretation.
|
369 |
+
|
370 |
+
a. For the avoidance of doubt, this Public License does not, and
|
371 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
372 |
+
conditions on any use of the Licensed Material that could lawfully
|
373 |
+
be made without permission under this Public License.
|
374 |
+
|
375 |
+
b. To the extent possible, if any provision of this Public License is
|
376 |
+
deemed unenforceable, it shall be automatically reformed to the
|
377 |
+
minimum extent necessary to make it enforceable. If the provision
|
378 |
+
cannot be reformed, it shall be severed from this Public License
|
379 |
+
without affecting the enforceability of the remaining terms and
|
380 |
+
conditions.
|
381 |
+
|
382 |
+
c. No term or condition of this Public License will be waived and no
|
383 |
+
failure to comply consented to unless expressly agreed to by the
|
384 |
+
Licensor.
|
385 |
+
|
386 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
387 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
388 |
+
that apply to the Licensor or You, including from the legal
|
389 |
+
processes of any jurisdiction or authority.
|
390 |
+
|
391 |
+
=======================================================================
|
392 |
+
|
393 |
+
Creative Commons is not a party to its public
|
394 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
395 |
+
its public licenses to material it publishes and in those instances
|
396 |
+
will be considered the "Licensor." The text of the Creative Commons
|
397 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
398 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
399 |
+
material is shared under a Creative Commons public license or as
|
400 |
+
otherwise permitted by the Creative Commons policies published at
|
401 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
402 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
403 |
+
of Creative Commons without its prior written consent including,
|
404 |
+
without limitation, in connection with any unauthorized modifications
|
405 |
+
to any of its public licenses or any other arrangements,
|
406 |
+
understandings, or agreements concerning use of licensed material. For
|
407 |
+
the avoidance of doubt, this paragraph does not form part of the
|
408 |
+
public licenses.
|
409 |
+
|
410 |
+
Creative Commons may be contacted at creativecommons.org.
|
models/stylegan/stylegan_tf/README.md
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## StyleGAN — Official TensorFlow Implementation
|
2 |
+
![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg?style=plastic)
|
3 |
+
![TensorFlow 1.10](https://img.shields.io/badge/tensorflow-1.10-green.svg?style=plastic)
|
4 |
+
![cuDNN 7.3.1](https://img.shields.io/badge/cudnn-7.3.1-green.svg?style=plastic)
|
5 |
+
![License CC BY-NC](https://img.shields.io/badge/license-CC_BY--NC-green.svg?style=plastic)
|
6 |
+
|
7 |
+
![Teaser image](./stylegan-teaser.png)
|
8 |
+
**Picture:** *These people are not real – they were produced by our generator that allows control over different aspects of the image.*
|
9 |
+
|
10 |
+
This repository contains the official TensorFlow implementation of the following paper:
|
11 |
+
|
12 |
+
> **A Style-Based Generator Architecture for Generative Adversarial Networks**<br>
|
13 |
+
> Tero Karras (NVIDIA), Samuli Laine (NVIDIA), Timo Aila (NVIDIA)<br>
|
14 |
+
> https://arxiv.org/abs/1812.04948
|
15 |
+
>
|
16 |
+
> **Abstract:** *We propose an alternative generator architecture for generative adversarial networks, borrowing from style transfer literature. The new architecture leads to an automatically learned, unsupervised separation of high-level attributes (e.g., pose and identity when trained on human faces) and stochastic variation in the generated images (e.g., freckles, hair), and it enables intuitive, scale-specific control of the synthesis. The new generator improves the state-of-the-art in terms of traditional distribution quality metrics, leads to demonstrably better interpolation properties, and also better disentangles the latent factors of variation. To quantify interpolation quality and disentanglement, we propose two new, automated methods that are applicable to any generator architecture. Finally, we introduce a new, highly varied and high-quality dataset of human faces.*
|
17 |
+
|
18 |
+
For business inquiries, please contact [[email protected]](mailto:[email protected])<br>
|
19 |
+
For press and other inquiries, please contact Hector Marinez at [[email protected]](mailto:[email protected])<br>
|
20 |
+
|
21 |
+
**★★★ NEW: StyleGAN2 is available at [https://github.com/NVlabs/stylegan2](https://github.com/NVlabs/stylegan2) ★★★**
|
22 |
+
|
23 |
+
## Resources
|
24 |
+
|
25 |
+
Material related to our paper is available via the following links:
|
26 |
+
|
27 |
+
- Paper: https://arxiv.org/abs/1812.04948
|
28 |
+
- Video: https://youtu.be/kSLJriaOumA
|
29 |
+
- Code: https://github.com/NVlabs/stylegan
|
30 |
+
- FFHQ: https://github.com/NVlabs/ffhq-dataset
|
31 |
+
|
32 |
+
Additional material can be found on Google Drive:
|
33 |
+
|
34 |
+
| Path | Description
|
35 |
+
| :--- | :----------
|
36 |
+
| [StyleGAN](https://drive.google.com/open?id=1uka3a1noXHAydRPRbknqwKVGODvnmUBX) | Main folder.
|
37 |
+
| ├ [stylegan-paper.pdf](https://drive.google.com/open?id=1v-HkF3Ehrpon7wVIx4r5DLcko_U_V6Lt) | High-quality version of the paper PDF.
|
38 |
+
| ├ [stylegan-video.mp4](https://drive.google.com/open?id=1uzwkZHQX_9pYg1i0d1Nbe3D9xPO8-qBf) | High-quality version of the result video.
|
39 |
+
| ├ [images](https://drive.google.com/open?id=1-l46akONUWF6LCpDoeq63H53rD7MeiTd) | Example images produced using our generator.
|
40 |
+
| │ ├ [representative-images](https://drive.google.com/open?id=1ToY5P4Vvf5_c3TyUizQ8fckFFoFtBvD8) | High-quality images to be used in articles, blog posts, etc.
|
41 |
+
| │ └ [100k-generated-images](https://drive.google.com/open?id=100DJ0QXyG89HZzB4w2Cbyf4xjNK54cQ1) | 100,000 generated images for different amounts of truncation.
|
42 |
+
| │    ├ [ffhq-1024x1024](https://drive.google.com/open?id=14lm8VRN1pr4g_KVe6_LvyDX1PObst6d4) | Generated using Flickr-Faces-HQ dataset at 1024×1024.
|
43 |
+
| │    ├ [bedrooms-256x256](https://drive.google.com/open?id=1Vxz9fksw4kgjiHrvHkX4Hze4dyThFW6t) | Generated using LSUN Bedroom dataset at 256×256.
|
44 |
+
| │    ├ [cars-512x384](https://drive.google.com/open?id=1MFCvOMdLE2_mpeLPTiDw5dxc2CRuKkzS) | Generated using LSUN Car dataset at 512×384.
|
45 |
+
| │    └ [cats-256x256](https://drive.google.com/open?id=1gq-Gj3GRFiyghTPKhp8uDMA9HV_0ZFWQ) | Generated using LSUN Cat dataset at 256×256.
|
46 |
+
| ├ [videos](https://drive.google.com/open?id=1N8pOd_Bf8v89NGUaROdbD8-ayLPgyRRo) | Example videos produced using our generator.
|
47 |
+
| │ └ [high-quality-video-clips](https://drive.google.com/open?id=1NFO7_vH0t98J13ckJYFd7kuaTkyeRJ86) | Individual segments of the result video as high-quality MP4.
|
48 |
+
| ├ [ffhq-dataset](https://drive.google.com/open?id=1u2xu7bSrWxrbUxk-dT-UvEJq8IjdmNTP) | Raw data for the [Flickr-Faces-HQ dataset](https://github.com/NVlabs/ffhq-dataset).
|
49 |
+
| └ [networks](https://drive.google.com/open?id=1MASQyN5m0voPcx7-9K0r5gObhvvPups7) | Pre-trained networks as pickled instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py).
|
50 |
+
|    ├ [stylegan-ffhq-1024x1024.pkl](https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ) | StyleGAN trained with Flickr-Faces-HQ dataset at 1024×1024.
|
51 |
+
|    ├ [stylegan-celebahq-1024x1024.pkl](https://drive.google.com/uc?id=1MGqJl28pN4t7SAtSrPdSRJSQJqahkzUf) | StyleGAN trained with CelebA-HQ dataset at 1024×1024.
|
52 |
+
|    ├ [stylegan-bedrooms-256x256.pkl](https://drive.google.com/uc?id=1MOSKeGF0FJcivpBI7s63V9YHloUTORiF) | StyleGAN trained with LSUN Bedroom dataset at 256×256.
|
53 |
+
|    ├ [stylegan-cars-512x384.pkl](https://drive.google.com/uc?id=1MJ6iCfNtMIRicihwRorsM3b7mmtmK9c3) | StyleGAN trained with LSUN Car dataset at 512×384.
|
54 |
+
|    ├ [stylegan-cats-256x256.pkl](https://drive.google.com/uc?id=1MQywl0FNt6lHu8E_EUqnRbviagS7fbiJ) | StyleGAN trained with LSUN Cat dataset at 256×256.
|
55 |
+
|    └ [metrics](https://drive.google.com/open?id=1MvYdWCBuMfnoYGptRH-AgKLbPTsIQLhl) | Auxiliary networks for the quality and disentanglement metrics.
|
56 |
+
|       ├ [inception_v3_features.pkl](https://drive.google.com/uc?id=1MzTY44rLToO5APn8TZmfR7_ENSe5aZUn) | Standard [Inception-v3](https://arxiv.org/abs/1512.00567) classifier that outputs a raw feature vector.
|
57 |
+
|       ├ [vgg16_zhang_perceptual.pkl](https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2) | Standard [LPIPS](https://arxiv.org/abs/1801.03924) metric to estimate perceptual similarity.
|
58 |
+
|       ├ [celebahq-classifier-00-male.pkl](https://drive.google.com/uc?id=1Q5-AI6TwWhCVM7Muu4tBM7rp5nG_gmCX) | Binary classifier trained to detect a single attribute of CelebA-HQ.
|
59 |
+
|       └ ⋯ | Please see the file listing for remaining networks.
|
60 |
+
|
61 |
+
## Licenses
|
62 |
+
|
63 |
+
All material, excluding the Flickr-Faces-HQ dataset, is made available under [Creative Commons BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license by NVIDIA Corporation. You can **use, redistribute, and adapt** the material for **non-commercial purposes**, as long as you give appropriate credit by **citing our paper** and **indicating any changes** that you've made.
|
64 |
+
|
65 |
+
For license information regarding the FFHQ dataset, please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
|
66 |
+
|
67 |
+
`inception_v3_features.pkl` and `inception_v3_softmax.pkl` are derived from the pre-trained [Inception-v3](https://arxiv.org/abs/1512.00567) network by Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens, and Zbigniew Wojna. The network was originally shared under [Apache 2.0](https://github.com/tensorflow/models/blob/master/LICENSE) license on the [TensorFlow Models](https://github.com/tensorflow/models) repository.
|
68 |
+
|
69 |
+
`vgg16.pkl` and `vgg16_zhang_perceptual.pkl` are derived from the pre-trained [VGG-16](https://arxiv.org/abs/1409.1556) network by Karen Simonyan and Andrew Zisserman. The network was originally shared under [Creative Commons BY 4.0](https://creativecommons.org/licenses/by/4.0/) license on the [Very Deep Convolutional Networks for Large-Scale Visual Recognition](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) project page.
|
70 |
+
|
71 |
+
`vgg16_zhang_perceptual.pkl` is further derived from the pre-trained [LPIPS](https://arxiv.org/abs/1801.03924) weights by Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, and Oliver Wang. The weights were originally shared under [BSD 2-Clause "Simplified" License](https://github.com/richzhang/PerceptualSimilarity/blob/master/LICENSE) on the [PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity) repository.
|
72 |
+
|
73 |
+
## System requirements
|
74 |
+
|
75 |
+
* Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons.
|
76 |
+
* 64-bit Python 3.6 installation. We recommend Anaconda3 with numpy 1.14.3 or newer.
|
77 |
+
* TensorFlow 1.10.0 or newer with GPU support.
|
78 |
+
* One or more high-end NVIDIA GPUs with at least 11GB of DRAM. We recommend NVIDIA DGX-1 with 8 Tesla V100 GPUs.
|
79 |
+
* NVIDIA driver 391.35 or newer, CUDA toolkit 9.0 or newer, cuDNN 7.3.1 or newer.
|
80 |
+
|
81 |
+
## Using pre-trained networks
|
82 |
+
|
83 |
+
A minimal example of using a pre-trained StyleGAN generator is given in [pretrained_example.py](./pretrained_example.py). When executed, the script downloads a pre-trained StyleGAN generator from Google Drive and uses it to generate an image:
|
84 |
+
|
85 |
+
```
|
86 |
+
> python pretrained_example.py
|
87 |
+
Downloading https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ .... done
|
88 |
+
|
89 |
+
Gs Params OutputShape WeightShape
|
90 |
+
--- --- --- ---
|
91 |
+
latents_in - (?, 512) -
|
92 |
+
...
|
93 |
+
images_out - (?, 3, 1024, 1024) -
|
94 |
+
--- --- --- ---
|
95 |
+
Total 26219627
|
96 |
+
|
97 |
+
> ls results
|
98 |
+
example.png # https://drive.google.com/uc?id=1UDLT_zb-rof9kKH0GwiJW_bS9MoZi8oP
|
99 |
+
```
|
100 |
+
|
101 |
+
A more advanced example is given in [generate_figures.py](./generate_figures.py). The script reproduces the figures from our paper in order to illustrate style mixing, noise inputs, and truncation:
|
102 |
+
```
|
103 |
+
> python generate_figures.py
|
104 |
+
results/figure02-uncurated-ffhq.png # https://drive.google.com/uc?id=1U3r1xgcD7o-Fd0SBRpq8PXYajm7_30cu
|
105 |
+
results/figure03-style-mixing.png # https://drive.google.com/uc?id=1U-nlMDtpnf1RcYkaFQtbh5oxnhA97hy6
|
106 |
+
results/figure04-noise-detail.png # https://drive.google.com/uc?id=1UX3m39u_DTU6eLnEW6MqGzbwPFt2R9cG
|
107 |
+
results/figure05-noise-components.png # https://drive.google.com/uc?id=1UQKPcvYVeWMRccGMbs2pPD9PVv1QDyp_
|
108 |
+
results/figure08-truncation-trick.png # https://drive.google.com/uc?id=1ULea0C12zGlxdDQFNLXOWZCHi3QNfk_v
|
109 |
+
results/figure10-uncurated-bedrooms.png # https://drive.google.com/uc?id=1UEBnms1XMfj78OHj3_cx80mUf_m9DUJr
|
110 |
+
results/figure11-uncurated-cars.png # https://drive.google.com/uc?id=1UO-4JtAs64Kun5vIj10UXqAJ1d5Ir1Ke
|
111 |
+
results/figure12-uncurated-cats.png # https://drive.google.com/uc?id=1USnJc14prlu3QAYxstrtlfXC9sDWPA-W
|
112 |
+
```
|
113 |
+
|
114 |
+
The pre-trained networks are stored as standard pickle files on Google Drive:
|
115 |
+
|
116 |
+
```
|
117 |
+
# Load pre-trained network.
|
118 |
+
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
|
119 |
+
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
|
120 |
+
_G, _D, Gs = pickle.load(f)
|
121 |
+
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
|
122 |
+
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
|
123 |
+
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
|
124 |
+
```
|
125 |
+
|
126 |
+
The above code downloads the file and unpickles it to yield 3 instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py). To generate images, you will typically want to use `Gs` – the other two networks are provided for completeness. In order for `pickle.load()` to work, you will need to have the `dnnlib` source directory in your PYTHONPATH and a `tf.Session` set as default. The session can initialized by calling `dnnlib.tflib.init_tf()`.
|
127 |
+
|
128 |
+
There are three ways to use the pre-trained generator:
|
129 |
+
|
130 |
+
1. Use `Gs.run()` for immediate-mode operation where the inputs and outputs are numpy arrays:
|
131 |
+
```
|
132 |
+
# Pick latent vector.
|
133 |
+
rnd = np.random.RandomState(5)
|
134 |
+
latents = rnd.randn(1, Gs.input_shape[1])
|
135 |
+
|
136 |
+
# Generate image.
|
137 |
+
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
|
138 |
+
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
|
139 |
+
```
|
140 |
+
The first argument is a batch of latent vectors of shape `[num, 512]`. The second argument is reserved for class labels (not used by StyleGAN). The remaining keyword arguments are optional and can be used to further modify the operation (see below). The output is a batch of images, whose format is dictated by the `output_transform` argument.
|
141 |
+
|
142 |
+
2. Use `Gs.get_output_for()` to incorporate the generator as a part of a larger TensorFlow expression:
|
143 |
+
```
|
144 |
+
latents = tf.random_normal([self.minibatch_per_gpu] + Gs_clone.input_shape[1:])
|
145 |
+
images = Gs_clone.get_output_for(latents, None, is_validation=True, randomize_noise=True)
|
146 |
+
images = tflib.convert_images_to_uint8(images)
|
147 |
+
result_expr.append(inception_clone.get_output_for(images))
|
148 |
+
```
|
149 |
+
The above code is from [metrics/frechet_inception_distance.py](./metrics/frechet_inception_distance.py). It generates a batch of random images and feeds them directly to the [Inception-v3](https://arxiv.org/abs/1512.00567) network without having to convert the data to numpy arrays in between.
|
150 |
+
|
151 |
+
3. Look up `Gs.components.mapping` and `Gs.components.synthesis` to access individual sub-networks of the generator. Similar to `Gs`, the sub-networks are represented as independent instances of [dnnlib.tflib.Network](./dnnlib/tflib/network.py):
|
152 |
+
```
|
153 |
+
src_latents = np.stack(np.random.RandomState(seed).randn(Gs.input_shape[1]) for seed in src_seeds)
|
154 |
+
src_dlatents = Gs.components.mapping.run(src_latents, None) # [seed, layer, component]
|
155 |
+
src_images = Gs.components.synthesis.run(src_dlatents, randomize_noise=False, **synthesis_kwargs)
|
156 |
+
```
|
157 |
+
The above code is from [generate_figures.py](./generate_figures.py). It first transforms a batch of latent vectors into the intermediate *W* space using the mapping network and then turns these vectors into a batch of images using the synthesis network. The `dlatents` array stores a separate copy of the same *w* vector for each layer of the synthesis network to facilitate style mixing.
|
158 |
+
|
159 |
+
The exact details of the generator are defined in [training/networks_stylegan.py](./training/networks_stylegan.py) (see `G_style`, `G_mapping`, and `G_synthesis`). The following keyword arguments can be specified to modify the behavior when calling `run()` and `get_output_for()`:
|
160 |
+
|
161 |
+
* `truncation_psi` and `truncation_cutoff` control the truncation trick that that is performed by default when using `Gs` (ψ=0.7, cutoff=8). It can be disabled by setting `truncation_psi=1` or `is_validation=True`, and the image quality can be further improved at the cost of variation by setting e.g. `truncation_psi=0.5`. Note that truncation is always disabled when using the sub-networks directly. The average *w* needed to manually perform the truncation trick can be looked up using `Gs.get_var('dlatent_avg')`.
|
162 |
+
|
163 |
+
* `randomize_noise` determines whether to use re-randomize the noise inputs for each generated image (`True`, default) or whether to use specific noise values for the entire minibatch (`False`). The specific values can be accessed via the `tf.Variable` instances that are found using `[var for name, var in Gs.components.synthesis.vars.items() if name.startswith('noise')]`.
|
164 |
+
|
165 |
+
* When using the mapping network directly, you can specify `dlatent_broadcast=None` to disable the automatic duplication of `dlatents` over the layers of the synthesis network.
|
166 |
+
|
167 |
+
* Runtime performance can be fine-tuned via `structure='fixed'` and `dtype='float16'`. The former disables support for progressive growing, which is not needed for a fully-trained generator, and the latter performs all computation using half-precision floating point arithmetic.
|
168 |
+
|
169 |
+
## Preparing datasets for training
|
170 |
+
|
171 |
+
The training and evaluation scripts operate on datasets stored as multi-resolution TFRecords. Each dataset is represented by a directory containing the same image data in several resolutions to enable efficient streaming. There is a separate *.tfrecords file for each resolution, and if the dataset contains labels, they are stored in a separate file as well. By default, the scripts expect to find the datasets at `datasets/<NAME>/<NAME>-<RESOLUTION>.tfrecords`. The directory can be changed by editing [config.py](./config.py):
|
172 |
+
|
173 |
+
```
|
174 |
+
result_dir = 'results'
|
175 |
+
data_dir = 'datasets'
|
176 |
+
cache_dir = 'cache'
|
177 |
+
```
|
178 |
+
|
179 |
+
To obtain the FFHQ dataset (`datasets/ffhq`), please refer to the [Flickr-Faces-HQ repository](https://github.com/NVlabs/ffhq-dataset).
|
180 |
+
|
181 |
+
To obtain the CelebA-HQ dataset (`datasets/celebahq`), please refer to the [Progressive GAN repository](https://github.com/tkarras/progressive_growing_of_gans).
|
182 |
+
|
183 |
+
To obtain other datasets, including LSUN, please consult their corresponding project pages. The datasets can be converted to multi-resolution TFRecords using the provided [dataset_tool.py](./dataset_tool.py):
|
184 |
+
|
185 |
+
```
|
186 |
+
> python dataset_tool.py create_lsun datasets/lsun-bedroom-full ~/lsun/bedroom_lmdb --resolution 256
|
187 |
+
> python dataset_tool.py create_lsun_wide datasets/lsun-car-512x384 ~/lsun/car_lmdb --width 512 --height 384
|
188 |
+
> python dataset_tool.py create_lsun datasets/lsun-cat-full ~/lsun/cat_lmdb --resolution 256
|
189 |
+
> python dataset_tool.py create_cifar10 datasets/cifar10 ~/cifar10
|
190 |
+
> python dataset_tool.py create_from_images datasets/custom-dataset ~/custom-images
|
191 |
+
```
|
192 |
+
|
193 |
+
## Training networks
|
194 |
+
|
195 |
+
Once the datasets are set up, you can train your own StyleGAN networks as follows:
|
196 |
+
|
197 |
+
1. Edit [train.py](./train.py) to specify the dataset and training configuration by uncommenting or editing specific lines.
|
198 |
+
2. Run the training script with `python train.py`.
|
199 |
+
3. The results are written to a newly created directory `results/<ID>-<DESCRIPTION>`.
|
200 |
+
4. The training may take several days (or weeks) to complete, depending on the configuration.
|
201 |
+
|
202 |
+
By default, `train.py` is configured to train the highest-quality StyleGAN (configuration F in Table 1) for the FFHQ dataset at 1024×1024 resolution using 8 GPUs. Please note that we have used 8 GPUs in all of our experiments. Training with fewer GPUs may not produce identical results – if you wish to compare against our technique, we strongly recommend using the same number of GPUs.
|
203 |
+
|
204 |
+
Expected training times for the default configuration using Tesla V100 GPUs:
|
205 |
+
|
206 |
+
| GPUs | 1024×1024 | 512×512 | 256×256 |
|
207 |
+
| :--- | :-------------- | :------------ | :------------ |
|
208 |
+
| 1 | 41 days 4 hours | 24 days 21 hours | 14 days 22 hours |
|
209 |
+
| 2 | 21 days 22 hours | 13 days 7 hours | 9 days 5 hours |
|
210 |
+
| 4 | 11 days 8 hours | 7 days 0 hours | 4 days 21 hours |
|
211 |
+
| 8 | 6 days 14 hours | 4 days 10 hours | 3 days 8 hours |
|
212 |
+
|
213 |
+
## Evaluating quality and disentanglement
|
214 |
+
|
215 |
+
The quality and disentanglement metrics used in our paper can be evaluated using [run_metrics.py](./run_metrics.py). By default, the script will evaluate the Fréchet Inception Distance (`fid50k`) for the pre-trained FFHQ generator and write the results into a newly created directory under `results`. The exact behavior can be changed by uncommenting or editing specific lines in [run_metrics.py](./run_metrics.py).
|
216 |
+
|
217 |
+
Expected evaluation time and results for the pre-trained FFHQ generator using one Tesla V100 GPU:
|
218 |
+
|
219 |
+
| Metric | Time | Result | Description
|
220 |
+
| :----- | :--- | :----- | :----------
|
221 |
+
| fid50k | 16 min | 4.4159 | Fréchet Inception Distance using 50,000 images.
|
222 |
+
| ppl_zfull | 55 min | 664.8854 | Perceptual Path Length for full paths in *Z*.
|
223 |
+
| ppl_wfull | 55 min | 233.3059 | Perceptual Path Length for full paths in *W*.
|
224 |
+
| ppl_zend | 55 min | 666.1057 | Perceptual Path Length for path endpoints in *Z*.
|
225 |
+
| ppl_wend | 55 min | 197.2266 | Perceptual Path Length for path endpoints in *W*.
|
226 |
+
| ls | 10 hours | z: 165.0106<br>w: 3.7447 | Linear Separability in *Z* and *W*.
|
227 |
+
|
228 |
+
Please note that the exact results may vary from run to run due to the non-deterministic nature of TensorFlow.
|
229 |
+
|
230 |
+
## Acknowledgements
|
231 |
+
|
232 |
+
We thank Jaakko Lehtinen, David Luebke, and Tuomas Kynkäänniemi for in-depth discussions and helpful comments; Janne Hellsten, Tero Kuosmanen, and Pekka Jänis for compute infrastructure and help with the code release.
|
models/stylegan/stylegan_tf/config.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Global configuration."""
|
9 |
+
|
10 |
+
#----------------------------------------------------------------------------
|
11 |
+
# Paths.
|
12 |
+
|
13 |
+
result_dir = 'results'
|
14 |
+
data_dir = 'datasets'
|
15 |
+
cache_dir = 'cache'
|
16 |
+
run_dir_ignore = ['results', 'datasets', 'cache']
|
17 |
+
|
18 |
+
#----------------------------------------------------------------------------
|
models/stylegan/stylegan_tf/dataset_tool.py
ADDED
@@ -0,0 +1,645 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
"""Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN."""
|
9 |
+
|
10 |
+
# pylint: disable=too-many-lines
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import glob
|
14 |
+
import argparse
|
15 |
+
import threading
|
16 |
+
import six.moves.queue as Queue # pylint: disable=import-error
|
17 |
+
import traceback
|
18 |
+
import numpy as np
|
19 |
+
import tensorflow as tf
|
20 |
+
import PIL.Image
|
21 |
+
import dnnlib.tflib as tflib
|
22 |
+
|
23 |
+
from training import dataset
|
24 |
+
|
25 |
+
#----------------------------------------------------------------------------
|
26 |
+
|
27 |
+
def error(msg):
|
28 |
+
print('Error: ' + msg)
|
29 |
+
exit(1)
|
30 |
+
|
31 |
+
#----------------------------------------------------------------------------
|
32 |
+
|
33 |
+
class TFRecordExporter:
|
34 |
+
def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):
|
35 |
+
self.tfrecord_dir = tfrecord_dir
|
36 |
+
self.tfr_prefix = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
|
37 |
+
self.expected_images = expected_images
|
38 |
+
self.cur_images = 0
|
39 |
+
self.shape = None
|
40 |
+
self.resolution_log2 = None
|
41 |
+
self.tfr_writers = []
|
42 |
+
self.print_progress = print_progress
|
43 |
+
self.progress_interval = progress_interval
|
44 |
+
|
45 |
+
if self.print_progress:
|
46 |
+
print('Creating dataset "%s"' % tfrecord_dir)
|
47 |
+
if not os.path.isdir(self.tfrecord_dir):
|
48 |
+
os.makedirs(self.tfrecord_dir)
|
49 |
+
assert os.path.isdir(self.tfrecord_dir)
|
50 |
+
|
51 |
+
def close(self):
|
52 |
+
if self.print_progress:
|
53 |
+
print('%-40s\r' % 'Flushing data...', end='', flush=True)
|
54 |
+
for tfr_writer in self.tfr_writers:
|
55 |
+
tfr_writer.close()
|
56 |
+
self.tfr_writers = []
|
57 |
+
if self.print_progress:
|
58 |
+
print('%-40s\r' % '', end='', flush=True)
|
59 |
+
print('Added %d images.' % self.cur_images)
|
60 |
+
|
61 |
+
def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
|
62 |
+
order = np.arange(self.expected_images)
|
63 |
+
np.random.RandomState(123).shuffle(order)
|
64 |
+
return order
|
65 |
+
|
66 |
+
def add_image(self, img):
|
67 |
+
if self.print_progress and self.cur_images % self.progress_interval == 0:
|
68 |
+
print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
|
69 |
+
if self.shape is None:
|
70 |
+
self.shape = img.shape
|
71 |
+
self.resolution_log2 = int(np.log2(self.shape[1]))
|
72 |
+
assert self.shape[0] in [1, 3]
|
73 |
+
assert self.shape[1] == self.shape[2]
|
74 |
+
assert self.shape[1] == 2**self.resolution_log2
|
75 |
+
tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
|
76 |
+
for lod in range(self.resolution_log2 - 1):
|
77 |
+
tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)
|
78 |
+
self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
|
79 |
+
assert img.shape == self.shape
|
80 |
+
for lod, tfr_writer in enumerate(self.tfr_writers):
|
81 |
+
if lod:
|
82 |
+
img = img.astype(np.float32)
|
83 |
+
img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
|
84 |
+
quant = np.rint(img).clip(0, 255).astype(np.uint8)
|
85 |
+
ex = tf.train.Example(features=tf.train.Features(feature={
|
86 |
+
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
|
87 |
+
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
|
88 |
+
tfr_writer.write(ex.SerializeToString())
|
89 |
+
self.cur_images += 1
|
90 |
+
|
91 |
+
def add_labels(self, labels):
|
92 |
+
if self.print_progress:
|
93 |
+
print('%-40s\r' % 'Saving labels...', end='', flush=True)
|
94 |
+
assert labels.shape[0] == self.cur_images
|
95 |
+
with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
|
96 |
+
np.save(f, labels.astype(np.float32))
|
97 |
+
|
98 |
+
def __enter__(self):
|
99 |
+
return self
|
100 |
+
|
101 |
+
def __exit__(self, *args):
|
102 |
+
self.close()
|
103 |
+
|
104 |
+
#----------------------------------------------------------------------------
|
105 |
+
|
106 |
+
class ExceptionInfo(object):
|
107 |
+
def __init__(self):
|
108 |
+
self.value = sys.exc_info()[1]
|
109 |
+
self.traceback = traceback.format_exc()
|
110 |
+
|
111 |
+
#----------------------------------------------------------------------------
|
112 |
+
|
113 |
+
class WorkerThread(threading.Thread):
|
114 |
+
def __init__(self, task_queue):
|
115 |
+
threading.Thread.__init__(self)
|
116 |
+
self.task_queue = task_queue
|
117 |
+
|
118 |
+
def run(self):
|
119 |
+
while True:
|
120 |
+
func, args, result_queue = self.task_queue.get()
|
121 |
+
if func is None:
|
122 |
+
break
|
123 |
+
try:
|
124 |
+
result = func(*args)
|
125 |
+
except:
|
126 |
+
result = ExceptionInfo()
|
127 |
+
result_queue.put((result, args))
|
128 |
+
|
129 |
+
#----------------------------------------------------------------------------
|
130 |
+
|
131 |
+
class ThreadPool(object):
|
132 |
+
def __init__(self, num_threads):
|
133 |
+
assert num_threads >= 1
|
134 |
+
self.task_queue = Queue.Queue()
|
135 |
+
self.result_queues = dict()
|
136 |
+
self.num_threads = num_threads
|
137 |
+
for _idx in range(self.num_threads):
|
138 |
+
thread = WorkerThread(self.task_queue)
|
139 |
+
thread.daemon = True
|
140 |
+
thread.start()
|
141 |
+
|
142 |
+
def add_task(self, func, args=()):
|
143 |
+
assert hasattr(func, '__call__') # must be a function
|
144 |
+
if func not in self.result_queues:
|
145 |
+
self.result_queues[func] = Queue.Queue()
|
146 |
+
self.task_queue.put((func, args, self.result_queues[func]))
|
147 |
+
|
148 |
+
def get_result(self, func): # returns (result, args)
|
149 |
+
result, args = self.result_queues[func].get()
|
150 |
+
if isinstance(result, ExceptionInfo):
|
151 |
+
print('\n\nWorker thread caught an exception:\n' + result.traceback)
|
152 |
+
raise result.value
|
153 |
+
return result, args
|
154 |
+
|
155 |
+
def finish(self):
|
156 |
+
for _idx in range(self.num_threads):
|
157 |
+
self.task_queue.put((None, (), None))
|
158 |
+
|
159 |
+
def __enter__(self): # for 'with' statement
|
160 |
+
return self
|
161 |
+
|
162 |
+
def __exit__(self, *excinfo):
|
163 |
+
self.finish()
|
164 |
+
|
165 |
+
def process_items_concurrently(self, item_iterator, process_func=lambda x: x, pre_func=lambda x: x, post_func=lambda x: x, max_items_in_flight=None):
|
166 |
+
if max_items_in_flight is None: max_items_in_flight = self.num_threads * 4
|
167 |
+
assert max_items_in_flight >= 1
|
168 |
+
results = []
|
169 |
+
retire_idx = [0]
|
170 |
+
|
171 |
+
def task_func(prepared, _idx):
|
172 |
+
return process_func(prepared)
|
173 |
+
|
174 |
+
def retire_result():
|
175 |
+
processed, (_prepared, idx) = self.get_result(task_func)
|
176 |
+
results[idx] = processed
|
177 |
+
while retire_idx[0] < len(results) and results[retire_idx[0]] is not None:
|
178 |
+
yield post_func(results[retire_idx[0]])
|
179 |
+
results[retire_idx[0]] = None
|
180 |
+
retire_idx[0] += 1
|
181 |
+
|
182 |
+
for idx, item in enumerate(item_iterator):
|
183 |
+
prepared = pre_func(item)
|
184 |
+
results.append(None)
|
185 |
+
self.add_task(func=task_func, args=(prepared, idx))
|
186 |
+
while retire_idx[0] < idx - max_items_in_flight + 2:
|
187 |
+
for res in retire_result(): yield res
|
188 |
+
while retire_idx[0] < len(results):
|
189 |
+
for res in retire_result(): yield res
|
190 |
+
|
191 |
+
#----------------------------------------------------------------------------
|
192 |
+
|
193 |
+
def display(tfrecord_dir):
|
194 |
+
print('Loading dataset "%s"' % tfrecord_dir)
|
195 |
+
tflib.init_tf({'gpu_options.allow_growth': True})
|
196 |
+
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size='full', repeat=False, shuffle_mb=0)
|
197 |
+
tflib.init_uninitialized_vars()
|
198 |
+
import cv2 # pip install opencv-python
|
199 |
+
|
200 |
+
idx = 0
|
201 |
+
while True:
|
202 |
+
try:
|
203 |
+
images, labels = dset.get_minibatch_np(1)
|
204 |
+
except tf.errors.OutOfRangeError:
|
205 |
+
break
|
206 |
+
if idx == 0:
|
207 |
+
print('Displaying images')
|
208 |
+
cv2.namedWindow('dataset_tool')
|
209 |
+
print('Press SPACE or ENTER to advance, ESC to exit')
|
210 |
+
print('\nidx = %-8d\nlabel = %s' % (idx, labels[0].tolist()))
|
211 |
+
cv2.imshow('dataset_tool', images[0].transpose(1, 2, 0)[:, :, ::-1]) # CHW => HWC, RGB => BGR
|
212 |
+
idx += 1
|
213 |
+
if cv2.waitKey() == 27:
|
214 |
+
break
|
215 |
+
print('\nDisplayed %d images.' % idx)
|
216 |
+
|
217 |
+
#----------------------------------------------------------------------------
|
218 |
+
|
219 |
+
def extract(tfrecord_dir, output_dir):
|
220 |
+
print('Loading dataset "%s"' % tfrecord_dir)
|
221 |
+
tflib.init_tf({'gpu_options.allow_growth': True})
|
222 |
+
dset = dataset.TFRecordDataset(tfrecord_dir, max_label_size=0, repeat=False, shuffle_mb=0)
|
223 |
+
tflib.init_uninitialized_vars()
|
224 |
+
|
225 |
+
print('Extracting images to "%s"' % output_dir)
|
226 |
+
if not os.path.isdir(output_dir):
|
227 |
+
os.makedirs(output_dir)
|
228 |
+
idx = 0
|
229 |
+
while True:
|
230 |
+
if idx % 10 == 0:
|
231 |
+
print('%d\r' % idx, end='', flush=True)
|
232 |
+
try:
|
233 |
+
images, _labels = dset.get_minibatch_np(1)
|
234 |
+
except tf.errors.OutOfRangeError:
|
235 |
+
break
|
236 |
+
if images.shape[1] == 1:
|
237 |
+
img = PIL.Image.fromarray(images[0][0], 'L')
|
238 |
+
else:
|
239 |
+
img = PIL.Image.fromarray(images[0].transpose(1, 2, 0), 'RGB')
|
240 |
+
img.save(os.path.join(output_dir, 'img%08d.png' % idx))
|
241 |
+
idx += 1
|
242 |
+
print('Extracted %d images.' % idx)
|
243 |
+
|
244 |
+
#----------------------------------------------------------------------------
|
245 |
+
|
246 |
+
def compare(tfrecord_dir_a, tfrecord_dir_b, ignore_labels):
|
247 |
+
max_label_size = 0 if ignore_labels else 'full'
|
248 |
+
print('Loading dataset "%s"' % tfrecord_dir_a)
|
249 |
+
tflib.init_tf({'gpu_options.allow_growth': True})
|
250 |
+
dset_a = dataset.TFRecordDataset(tfrecord_dir_a, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
|
251 |
+
print('Loading dataset "%s"' % tfrecord_dir_b)
|
252 |
+
dset_b = dataset.TFRecordDataset(tfrecord_dir_b, max_label_size=max_label_size, repeat=False, shuffle_mb=0)
|
253 |
+
tflib.init_uninitialized_vars()
|
254 |
+
|
255 |
+
print('Comparing datasets')
|
256 |
+
idx = 0
|
257 |
+
identical_images = 0
|
258 |
+
identical_labels = 0
|
259 |
+
while True:
|
260 |
+
if idx % 100 == 0:
|
261 |
+
print('%d\r' % idx, end='', flush=True)
|
262 |
+
try:
|
263 |
+
images_a, labels_a = dset_a.get_minibatch_np(1)
|
264 |
+
except tf.errors.OutOfRangeError:
|
265 |
+
images_a, labels_a = None, None
|
266 |
+
try:
|
267 |
+
images_b, labels_b = dset_b.get_minibatch_np(1)
|
268 |
+
except tf.errors.OutOfRangeError:
|
269 |
+
images_b, labels_b = None, None
|
270 |
+
if images_a is None or images_b is None:
|
271 |
+
if images_a is not None or images_b is not None:
|
272 |
+
print('Datasets contain different number of images')
|
273 |
+
break
|
274 |
+
if images_a.shape == images_b.shape and np.all(images_a == images_b):
|
275 |
+
identical_images += 1
|
276 |
+
else:
|
277 |
+
print('Image %d is different' % idx)
|
278 |
+
if labels_a.shape == labels_b.shape and np.all(labels_a == labels_b):
|
279 |
+
identical_labels += 1
|
280 |
+
else:
|
281 |
+
print('Label %d is different' % idx)
|
282 |
+
idx += 1
|
283 |
+
print('Identical images: %d / %d' % (identical_images, idx))
|
284 |
+
if not ignore_labels:
|
285 |
+
print('Identical labels: %d / %d' % (identical_labels, idx))
|
286 |
+
|
287 |
+
#----------------------------------------------------------------------------
|
288 |
+
|
289 |
+
def create_mnist(tfrecord_dir, mnist_dir):
|
290 |
+
print('Loading MNIST from "%s"' % mnist_dir)
|
291 |
+
import gzip
|
292 |
+
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
|
293 |
+
images = np.frombuffer(file.read(), np.uint8, offset=16)
|
294 |
+
with gzip.open(os.path.join(mnist_dir, 'train-labels-idx1-ubyte.gz'), 'rb') as file:
|
295 |
+
labels = np.frombuffer(file.read(), np.uint8, offset=8)
|
296 |
+
images = images.reshape(-1, 1, 28, 28)
|
297 |
+
images = np.pad(images, [(0,0), (0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
298 |
+
assert images.shape == (60000, 1, 32, 32) and images.dtype == np.uint8
|
299 |
+
assert labels.shape == (60000,) and labels.dtype == np.uint8
|
300 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
301 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
302 |
+
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
|
303 |
+
onehot[np.arange(labels.size), labels] = 1.0
|
304 |
+
|
305 |
+
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
|
306 |
+
order = tfr.choose_shuffled_order()
|
307 |
+
for idx in range(order.size):
|
308 |
+
tfr.add_image(images[order[idx]])
|
309 |
+
tfr.add_labels(onehot[order])
|
310 |
+
|
311 |
+
#----------------------------------------------------------------------------
|
312 |
+
|
313 |
+
def create_mnistrgb(tfrecord_dir, mnist_dir, num_images=1000000, random_seed=123):
|
314 |
+
print('Loading MNIST from "%s"' % mnist_dir)
|
315 |
+
import gzip
|
316 |
+
with gzip.open(os.path.join(mnist_dir, 'train-images-idx3-ubyte.gz'), 'rb') as file:
|
317 |
+
images = np.frombuffer(file.read(), np.uint8, offset=16)
|
318 |
+
images = images.reshape(-1, 28, 28)
|
319 |
+
images = np.pad(images, [(0,0), (2,2), (2,2)], 'constant', constant_values=0)
|
320 |
+
assert images.shape == (60000, 32, 32) and images.dtype == np.uint8
|
321 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
322 |
+
|
323 |
+
with TFRecordExporter(tfrecord_dir, num_images) as tfr:
|
324 |
+
rnd = np.random.RandomState(random_seed)
|
325 |
+
for _idx in range(num_images):
|
326 |
+
tfr.add_image(images[rnd.randint(images.shape[0], size=3)])
|
327 |
+
|
328 |
+
#----------------------------------------------------------------------------
|
329 |
+
|
330 |
+
def create_cifar10(tfrecord_dir, cifar10_dir):
|
331 |
+
print('Loading CIFAR-10 from "%s"' % cifar10_dir)
|
332 |
+
import pickle
|
333 |
+
images = []
|
334 |
+
labels = []
|
335 |
+
for batch in range(1, 6):
|
336 |
+
with open(os.path.join(cifar10_dir, 'data_batch_%d' % batch), 'rb') as file:
|
337 |
+
data = pickle.load(file, encoding='latin1')
|
338 |
+
images.append(data['data'].reshape(-1, 3, 32, 32))
|
339 |
+
labels.append(data['labels'])
|
340 |
+
images = np.concatenate(images)
|
341 |
+
labels = np.concatenate(labels)
|
342 |
+
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
|
343 |
+
assert labels.shape == (50000,) and labels.dtype == np.int32
|
344 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
345 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
346 |
+
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
|
347 |
+
onehot[np.arange(labels.size), labels] = 1.0
|
348 |
+
|
349 |
+
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
|
350 |
+
order = tfr.choose_shuffled_order()
|
351 |
+
for idx in range(order.size):
|
352 |
+
tfr.add_image(images[order[idx]])
|
353 |
+
tfr.add_labels(onehot[order])
|
354 |
+
|
355 |
+
#----------------------------------------------------------------------------
|
356 |
+
|
357 |
+
def create_cifar100(tfrecord_dir, cifar100_dir):
|
358 |
+
print('Loading CIFAR-100 from "%s"' % cifar100_dir)
|
359 |
+
import pickle
|
360 |
+
with open(os.path.join(cifar100_dir, 'train'), 'rb') as file:
|
361 |
+
data = pickle.load(file, encoding='latin1')
|
362 |
+
images = data['data'].reshape(-1, 3, 32, 32)
|
363 |
+
labels = np.array(data['fine_labels'])
|
364 |
+
assert images.shape == (50000, 3, 32, 32) and images.dtype == np.uint8
|
365 |
+
assert labels.shape == (50000,) and labels.dtype == np.int32
|
366 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
367 |
+
assert np.min(labels) == 0 and np.max(labels) == 99
|
368 |
+
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
|
369 |
+
onehot[np.arange(labels.size), labels] = 1.0
|
370 |
+
|
371 |
+
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
|
372 |
+
order = tfr.choose_shuffled_order()
|
373 |
+
for idx in range(order.size):
|
374 |
+
tfr.add_image(images[order[idx]])
|
375 |
+
tfr.add_labels(onehot[order])
|
376 |
+
|
377 |
+
#----------------------------------------------------------------------------
|
378 |
+
|
379 |
+
def create_svhn(tfrecord_dir, svhn_dir):
|
380 |
+
print('Loading SVHN from "%s"' % svhn_dir)
|
381 |
+
import pickle
|
382 |
+
images = []
|
383 |
+
labels = []
|
384 |
+
for batch in range(1, 4):
|
385 |
+
with open(os.path.join(svhn_dir, 'train_%d.pkl' % batch), 'rb') as file:
|
386 |
+
data = pickle.load(file, encoding='latin1')
|
387 |
+
images.append(data[0])
|
388 |
+
labels.append(data[1])
|
389 |
+
images = np.concatenate(images)
|
390 |
+
labels = np.concatenate(labels)
|
391 |
+
assert images.shape == (73257, 3, 32, 32) and images.dtype == np.uint8
|
392 |
+
assert labels.shape == (73257,) and labels.dtype == np.uint8
|
393 |
+
assert np.min(images) == 0 and np.max(images) == 255
|
394 |
+
assert np.min(labels) == 0 and np.max(labels) == 9
|
395 |
+
onehot = np.zeros((labels.size, np.max(labels) + 1), dtype=np.float32)
|
396 |
+
onehot[np.arange(labels.size), labels] = 1.0
|
397 |
+
|
398 |
+
with TFRecordExporter(tfrecord_dir, images.shape[0]) as tfr:
|
399 |
+
order = tfr.choose_shuffled_order()
|
400 |
+
for idx in range(order.size):
|
401 |
+
tfr.add_image(images[order[idx]])
|
402 |
+
tfr.add_labels(onehot[order])
|
403 |
+
|
404 |
+
#----------------------------------------------------------------------------
|
405 |
+
|
406 |
+
def create_lsun(tfrecord_dir, lmdb_dir, resolution=256, max_images=None):
|
407 |
+
print('Loading LSUN dataset from "%s"' % lmdb_dir)
|
408 |
+
import lmdb # pip install lmdb # pylint: disable=import-error
|
409 |
+
import cv2 # pip install opencv-python
|
410 |
+
import io
|
411 |
+
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
|
412 |
+
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
|
413 |
+
if max_images is None:
|
414 |
+
max_images = total_images
|
415 |
+
with TFRecordExporter(tfrecord_dir, max_images) as tfr:
|
416 |
+
for _idx, (_key, value) in enumerate(txn.cursor()):
|
417 |
+
try:
|
418 |
+
try:
|
419 |
+
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
|
420 |
+
if img is None:
|
421 |
+
raise IOError('cv2.imdecode failed')
|
422 |
+
img = img[:, :, ::-1] # BGR => RGB
|
423 |
+
except IOError:
|
424 |
+
img = np.asarray(PIL.Image.open(io.BytesIO(value)))
|
425 |
+
crop = np.min(img.shape[:2])
|
426 |
+
img = img[(img.shape[0] - crop) // 2 : (img.shape[0] + crop) // 2, (img.shape[1] - crop) // 2 : (img.shape[1] + crop) // 2]
|
427 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
428 |
+
img = img.resize((resolution, resolution), PIL.Image.ANTIALIAS)
|
429 |
+
img = np.asarray(img)
|
430 |
+
img = img.transpose([2, 0, 1]) # HWC => CHW
|
431 |
+
tfr.add_image(img)
|
432 |
+
except:
|
433 |
+
print(sys.exc_info()[1])
|
434 |
+
if tfr.cur_images == max_images:
|
435 |
+
break
|
436 |
+
|
437 |
+
#----------------------------------------------------------------------------
|
438 |
+
|
439 |
+
def create_lsun_wide(tfrecord_dir, lmdb_dir, width=512, height=384, max_images=None):
|
440 |
+
assert width == 2 ** int(np.round(np.log2(width)))
|
441 |
+
assert height <= width
|
442 |
+
print('Loading LSUN dataset from "%s"' % lmdb_dir)
|
443 |
+
import lmdb # pip install lmdb # pylint: disable=import-error
|
444 |
+
import cv2 # pip install opencv-python
|
445 |
+
import io
|
446 |
+
with lmdb.open(lmdb_dir, readonly=True).begin(write=False) as txn:
|
447 |
+
total_images = txn.stat()['entries'] # pylint: disable=no-value-for-parameter
|
448 |
+
if max_images is None:
|
449 |
+
max_images = total_images
|
450 |
+
with TFRecordExporter(tfrecord_dir, max_images, print_progress=False) as tfr:
|
451 |
+
for idx, (_key, value) in enumerate(txn.cursor()):
|
452 |
+
try:
|
453 |
+
try:
|
454 |
+
img = cv2.imdecode(np.fromstring(value, dtype=np.uint8), 1)
|
455 |
+
if img is None:
|
456 |
+
raise IOError('cv2.imdecode failed')
|
457 |
+
img = img[:, :, ::-1] # BGR => RGB
|
458 |
+
except IOError:
|
459 |
+
img = np.asarray(PIL.Image.open(io.BytesIO(value)))
|
460 |
+
|
461 |
+
ch = int(np.round(width * img.shape[0] / img.shape[1]))
|
462 |
+
if img.shape[1] < width or ch < height:
|
463 |
+
continue
|
464 |
+
|
465 |
+
img = img[(img.shape[0] - ch) // 2 : (img.shape[0] + ch) // 2]
|
466 |
+
img = PIL.Image.fromarray(img, 'RGB')
|
467 |
+
img = img.resize((width, height), PIL.Image.ANTIALIAS)
|
468 |
+
img = np.asarray(img)
|
469 |
+
img = img.transpose([2, 0, 1]) # HWC => CHW
|
470 |
+
|
471 |
+
canvas = np.zeros([3, width, width], dtype=np.uint8)
|
472 |
+
canvas[:, (width - height) // 2 : (width + height) // 2] = img
|
473 |
+
tfr.add_image(canvas)
|
474 |
+
print('\r%d / %d => %d ' % (idx + 1, total_images, tfr.cur_images), end='')
|
475 |
+
|
476 |
+
except:
|
477 |
+
print(sys.exc_info()[1])
|
478 |
+
if tfr.cur_images == max_images:
|
479 |
+
break
|
480 |
+
print()
|
481 |
+
|
482 |
+
#----------------------------------------------------------------------------
|
483 |
+
|
484 |
+
def create_celeba(tfrecord_dir, celeba_dir, cx=89, cy=121):
|
485 |
+
print('Loading CelebA from "%s"' % celeba_dir)
|
486 |
+
glob_pattern = os.path.join(celeba_dir, 'img_align_celeba_png', '*.png')
|
487 |
+
image_filenames = sorted(glob.glob(glob_pattern))
|
488 |
+
expected_images = 202599
|
489 |
+
if len(image_filenames) != expected_images:
|
490 |
+
error('Expected to find %d images' % expected_images)
|
491 |
+
|
492 |
+
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
|
493 |
+
order = tfr.choose_shuffled_order()
|
494 |
+
for idx in range(order.size):
|
495 |
+
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
|
496 |
+
assert img.shape == (218, 178, 3)
|
497 |
+
img = img[cy - 64 : cy + 64, cx - 64 : cx + 64]
|
498 |
+
img = img.transpose(2, 0, 1) # HWC => CHW
|
499 |
+
tfr.add_image(img)
|
500 |
+
|
501 |
+
#----------------------------------------------------------------------------
|
502 |
+
|
503 |
+
def create_from_images(tfrecord_dir, image_dir, shuffle):
|
504 |
+
print('Loading images from "%s"' % image_dir)
|
505 |
+
image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
|
506 |
+
if len(image_filenames) == 0:
|
507 |
+
error('No input images found')
|
508 |
+
|
509 |
+
img = np.asarray(PIL.Image.open(image_filenames[0]))
|
510 |
+
resolution = img.shape[0]
|
511 |
+
channels = img.shape[2] if img.ndim == 3 else 1
|
512 |
+
if img.shape[1] != resolution:
|
513 |
+
error('Input images must have the same width and height')
|
514 |
+
if resolution != 2 ** int(np.floor(np.log2(resolution))):
|
515 |
+
error('Input image resolution must be a power-of-two')
|
516 |
+
if channels not in [1, 3]:
|
517 |
+
error('Input images must be stored as RGB or grayscale')
|
518 |
+
|
519 |
+
with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
|
520 |
+
order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
|
521 |
+
for idx in range(order.size):
|
522 |
+
img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
|
523 |
+
if channels == 1:
|
524 |
+
img = img[np.newaxis, :, :] # HW => CHW
|
525 |
+
else:
|
526 |
+
img = img.transpose([2, 0, 1]) # HWC => CHW
|
527 |
+
tfr.add_image(img)
|
528 |
+
|
529 |
+
#----------------------------------------------------------------------------
|
530 |
+
|
531 |
+
def create_from_hdf5(tfrecord_dir, hdf5_filename, shuffle):
|
532 |
+
print('Loading HDF5 archive from "%s"' % hdf5_filename)
|
533 |
+
import h5py # conda install h5py
|
534 |
+
with h5py.File(hdf5_filename, 'r') as hdf5_file:
|
535 |
+
hdf5_data = max([value for key, value in hdf5_file.items() if key.startswith('data')], key=lambda lod: lod.shape[3])
|
536 |
+
with TFRecordExporter(tfrecord_dir, hdf5_data.shape[0]) as tfr:
|
537 |
+
order = tfr.choose_shuffled_order() if shuffle else np.arange(hdf5_data.shape[0])
|
538 |
+
for idx in range(order.size):
|
539 |
+
tfr.add_image(hdf5_data[order[idx]])
|
540 |
+
npy_filename = os.path.splitext(hdf5_filename)[0] + '-labels.npy'
|
541 |
+
if os.path.isfile(npy_filename):
|
542 |
+
tfr.add_labels(np.load(npy_filename)[order])
|
543 |
+
|
544 |
+
#----------------------------------------------------------------------------
|
545 |
+
|
546 |
+
def execute_cmdline(argv):
|
547 |
+
prog = argv[0]
|
548 |
+
parser = argparse.ArgumentParser(
|
549 |
+
prog = prog,
|
550 |
+
description = 'Tool for creating multi-resolution TFRecords datasets for StyleGAN and ProGAN.',
|
551 |
+
epilog = 'Type "%s <command> -h" for more information.' % prog)
|
552 |
+
|
553 |
+
subparsers = parser.add_subparsers(dest='command')
|
554 |
+
subparsers.required = True
|
555 |
+
def add_command(cmd, desc, example=None):
|
556 |
+
epilog = 'Example: %s %s' % (prog, example) if example is not None else None
|
557 |
+
return subparsers.add_parser(cmd, description=desc, help=desc, epilog=epilog)
|
558 |
+
|
559 |
+
p = add_command( 'display', 'Display images in dataset.',
|
560 |
+
'display datasets/mnist')
|
561 |
+
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
|
562 |
+
|
563 |
+
p = add_command( 'extract', 'Extract images from dataset.',
|
564 |
+
'extract datasets/mnist mnist-images')
|
565 |
+
p.add_argument( 'tfrecord_dir', help='Directory containing dataset')
|
566 |
+
p.add_argument( 'output_dir', help='Directory to extract the images into')
|
567 |
+
|
568 |
+
p = add_command( 'compare', 'Compare two datasets.',
|
569 |
+
'compare datasets/mydataset datasets/mnist')
|
570 |
+
p.add_argument( 'tfrecord_dir_a', help='Directory containing first dataset')
|
571 |
+
p.add_argument( 'tfrecord_dir_b', help='Directory containing second dataset')
|
572 |
+
p.add_argument( '--ignore_labels', help='Ignore labels (default: 0)', type=int, default=0)
|
573 |
+
|
574 |
+
p = add_command( 'create_mnist', 'Create dataset for MNIST.',
|
575 |
+
'create_mnist datasets/mnist ~/downloads/mnist')
|
576 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
577 |
+
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
|
578 |
+
|
579 |
+
p = add_command( 'create_mnistrgb', 'Create dataset for MNIST-RGB.',
|
580 |
+
'create_mnistrgb datasets/mnistrgb ~/downloads/mnist')
|
581 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
582 |
+
p.add_argument( 'mnist_dir', help='Directory containing MNIST')
|
583 |
+
p.add_argument( '--num_images', help='Number of composite images to create (default: 1000000)', type=int, default=1000000)
|
584 |
+
p.add_argument( '--random_seed', help='Random seed (default: 123)', type=int, default=123)
|
585 |
+
|
586 |
+
p = add_command( 'create_cifar10', 'Create dataset for CIFAR-10.',
|
587 |
+
'create_cifar10 datasets/cifar10 ~/downloads/cifar10')
|
588 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
589 |
+
p.add_argument( 'cifar10_dir', help='Directory containing CIFAR-10')
|
590 |
+
|
591 |
+
p = add_command( 'create_cifar100', 'Create dataset for CIFAR-100.',
|
592 |
+
'create_cifar100 datasets/cifar100 ~/downloads/cifar100')
|
593 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
594 |
+
p.add_argument( 'cifar100_dir', help='Directory containing CIFAR-100')
|
595 |
+
|
596 |
+
p = add_command( 'create_svhn', 'Create dataset for SVHN.',
|
597 |
+
'create_svhn datasets/svhn ~/downloads/svhn')
|
598 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
599 |
+
p.add_argument( 'svhn_dir', help='Directory containing SVHN')
|
600 |
+
|
601 |
+
p = add_command( 'create_lsun', 'Create dataset for single LSUN category.',
|
602 |
+
'create_lsun datasets/lsun-car-100k ~/downloads/lsun/car_lmdb --resolution 256 --max_images 100000')
|
603 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
604 |
+
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
|
605 |
+
p.add_argument( '--resolution', help='Output resolution (default: 256)', type=int, default=256)
|
606 |
+
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
|
607 |
+
|
608 |
+
p = add_command( 'create_lsun_wide', 'Create LSUN dataset with non-square aspect ratio.',
|
609 |
+
'create_lsun_wide datasets/lsun-car-512x384 ~/downloads/lsun/car_lmdb --width 512 --height 384')
|
610 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
611 |
+
p.add_argument( 'lmdb_dir', help='Directory containing LMDB database')
|
612 |
+
p.add_argument( '--width', help='Output width (default: 512)', type=int, default=512)
|
613 |
+
p.add_argument( '--height', help='Output height (default: 384)', type=int, default=384)
|
614 |
+
p.add_argument( '--max_images', help='Maximum number of images (default: none)', type=int, default=None)
|
615 |
+
|
616 |
+
p = add_command( 'create_celeba', 'Create dataset for CelebA.',
|
617 |
+
'create_celeba datasets/celeba ~/downloads/celeba')
|
618 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
619 |
+
p.add_argument( 'celeba_dir', help='Directory containing CelebA')
|
620 |
+
p.add_argument( '--cx', help='Center X coordinate (default: 89)', type=int, default=89)
|
621 |
+
p.add_argument( '--cy', help='Center Y coordinate (default: 121)', type=int, default=121)
|
622 |
+
|
623 |
+
p = add_command( 'create_from_images', 'Create dataset from a directory full of images.',
|
624 |
+
'create_from_images datasets/mydataset myimagedir')
|
625 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
626 |
+
p.add_argument( 'image_dir', help='Directory containing the images')
|
627 |
+
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
|
628 |
+
|
629 |
+
p = add_command( 'create_from_hdf5', 'Create dataset from legacy HDF5 archive.',
|
630 |
+
'create_from_hdf5 datasets/celebahq ~/downloads/celeba-hq-1024x1024.h5')
|
631 |
+
p.add_argument( 'tfrecord_dir', help='New dataset directory to be created')
|
632 |
+
p.add_argument( 'hdf5_filename', help='HDF5 archive containing the images')
|
633 |
+
p.add_argument( '--shuffle', help='Randomize image order (default: 1)', type=int, default=1)
|
634 |
+
|
635 |
+
args = parser.parse_args(argv[1:] if len(argv) > 1 else ['-h'])
|
636 |
+
func = globals()[args.command]
|
637 |
+
del args.command
|
638 |
+
func(**vars(args))
|
639 |
+
|
640 |
+
#----------------------------------------------------------------------------
|
641 |
+
|
642 |
+
if __name__ == "__main__":
|
643 |
+
execute_cmdline(sys.argv)
|
644 |
+
|
645 |
+
#----------------------------------------------------------------------------
|
models/stylegan/stylegan_tf/dnnlib/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
|
2 |
+
#
|
3 |
+
# This work is licensed under the Creative Commons Attribution-NonCommercial
|
4 |
+
# 4.0 International License. To view a copy of this license, visit
|
5 |
+
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
|
6 |
+
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
|
7 |
+
|
8 |
+
from . import submission
|
9 |
+
|
10 |
+
from .submission.run_context import RunContext
|
11 |
+
|
12 |
+
from .submission.submit import SubmitTarget
|
13 |
+
from .submission.submit import PathType
|
14 |
+
from .submission.submit import SubmitConfig
|
15 |
+
from .submission.submit import get_path_from_template
|
16 |
+
from .submission.submit import submit_run
|
17 |
+
|
18 |
+
from .util import EasyDict
|
19 |
+
|
20 |
+
submit_config: SubmitConfig = None # Package level variable for SubmitConfig which is only valid when inside the run function.
|