Tu Bui commited on
Commit
6142a25
·
0 Parent(s):

first commit

Browse files
Dockerfile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM tuvbui/torchcpu:torch111
2
+ ADD cldm ./cldm
3
+ ADD flae ./flae
4
+ ADD ldm ./ldm
5
+ ADD tools ./tools
6
+ ADD pages ./pages
7
+ ADD Embed_Secret.py .
8
+
9
+ EXPOSE 7860
10
+ CMD streamlit run Embed_Secret.py --server.enableXsrfProtection=false --server.port 7860 -- --weight https://kahlan.cvssp.org/data/Flickr25K/tubui/stega/unet100b_croprs/epoch=000070-step=000219999.ckpt --config https://kahlan.cvssp.org/data/Flickr25K/tubui/stega/unet100b_croprs/-project.yaml
Embed_Secret.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ streamlit app demo
5
+ how to run:
6
+ streamlit run app.py --server.port 8501
7
+
8
+ @author: Tu Bui @surrey.ac.uk
9
+ """
10
+ import os, sys, torch
11
+ import argparse
12
+ from pathlib import Path
13
+ import numpy as np
14
+ import pickle
15
+ import pytorch_lightning as pl
16
+ from torchvision import transforms
17
+ import argparse
18
+ from ldm.util import instantiate_from_config
19
+ from omegaconf import OmegaConf
20
+ from PIL import Image
21
+ from tools.augment_imagenetc import RandomImagenetC
22
+ from io import BytesIO
23
+ from tools.helpers import welcome_message
24
+ from tools.ecc import BCH, RSC
25
+
26
+ import streamlit as st
27
+ from streamlit.source_util import (
28
+ page_icon_and_name,
29
+ calc_md5,
30
+ get_pages,
31
+ _on_pages_changed
32
+ )
33
+
34
+ model_names = ['UNet']
35
+ SECRET_LEN = 100
36
+
37
+
38
+ def delete_page(main_script_path_str, page_name):
39
+
40
+ current_pages = get_pages(main_script_path_str)
41
+
42
+ for key, value in current_pages.items():
43
+ print(value['page_name'])
44
+ if value['page_name'] == page_name:
45
+ del current_pages[key]
46
+ break
47
+ else:
48
+ pass
49
+ _on_pages_changed.send()
50
+
51
+
52
+ def add_page(main_script_path_str, page_name):
53
+
54
+ pages = get_pages(main_script_path_str)
55
+ main_script_path = Path(main_script_path_str)
56
+ pages_dir = main_script_path.parent / "pages"
57
+ # st.write(list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")))
58
+ script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0]
59
+ script_path_str = str(script_path.resolve())
60
+ pi, pn = page_icon_and_name(script_path)
61
+ psh = calc_md5(script_path_str)
62
+ pages[psh] = {
63
+ "page_script_hash": psh,
64
+ "page_name": pn,
65
+ "icon": pi,
66
+ "script_path": script_path_str,
67
+ }
68
+ _on_pages_changed.send()
69
+
70
+ def unormalize(x):
71
+ # convert x in range [-1, 1], (B,C,H,W), tensor to [0, 255], uint8, numpy, (B,H,W,C)
72
+ x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
73
+ return x
74
+
75
+ def to_bytes(x, mime):
76
+ x = Image.fromarray(x)
77
+ buf = BytesIO()
78
+ f = "JPEG" if mime == 'image/jpeg' else "PNG"
79
+ x.save(buf, format=f)
80
+ byte_im = buf.getvalue()
81
+ return byte_im
82
+
83
+
84
+ def load_UNet(args):
85
+ print('args: ', args)
86
+ # # crop safe model
87
+ # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/configs/-project.yaml'
88
+ # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/checkpoints/epoch=000060-step=000189999.ckpt'
89
+
90
+ # # resized crop safe model
91
+ # config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
92
+ # weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
93
+
94
+ config_file = args.config_file
95
+ weight_file = args.weight_file
96
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
97
+ if weight_file.startswith('http'): # download from url
98
+ weight_dir = Path('./weights')
99
+ weight_dir.mkdir(exist_ok=True)
100
+ weight_path = weight_dir / weight_file.split('/')[-1]
101
+ config_path = weight_dir / config_file.split('/')[-1]
102
+ if not weight_path.exists():
103
+ import wget
104
+ print(f'Downloading {weight_file}...')
105
+ with st.spinner("Downloading model... this may take awhile!"):
106
+ wget.download(weight_file, str(weight_path))
107
+ wget.download(config_file, str(config_path))
108
+ weight_file = str(weight_path)
109
+ config_file = str(config_path)
110
+
111
+ config = OmegaConf.load(config_file).model
112
+ secret_len = config.params.secret_len
113
+ assert SECRET_LEN == secret_len
114
+ model = instantiate_from_config(config)
115
+ state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
116
+ if 'global_step' in state_dict:
117
+ print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}')
118
+
119
+ if 'state_dict' in state_dict:
120
+ state_dict = state_dict['state_dict']
121
+ misses, ignores = model.load_state_dict(state_dict, strict=False)
122
+ print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
123
+ model = model.to(device)
124
+ model.eval()
125
+ return model
126
+
127
+ def embed_secret(model_name, model, cover, tform, secret):
128
+ if model_name == 'UNet':
129
+ w, h = cover.size
130
+ with torch.no_grad():
131
+ im = tform(cover).unsqueeze(0).cuda() # 1, 3, 256, 256
132
+ stego, _ = model(im, secret) # 1, 3, 256, 256
133
+ res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
134
+ res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
135
+ res = res.permute(0,2,3,1).cpu().numpy() # (1,256,256,3)
136
+ stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 # (256, 256, 3), ndarray, uint8
137
+ stego_uint8 = stego_uint8.astype(np.uint8)
138
+ else:
139
+ raise NotImplementedError
140
+ return stego_uint8
141
+
142
+ def identity(x):
143
+ return x
144
+
145
+ def decode_secret(model_name, model, im, tform):
146
+ if model_name in ['RoSteALS', 'UNet']:
147
+ with torch.no_grad():
148
+ im = tform(im).unsqueeze(0).cuda() # 1, 3, 256, 256
149
+ secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
150
+ else:
151
+ raise NotImplementedError
152
+ return secret_pred
153
+
154
+
155
+ @st.cache_resource
156
+ def load_model(model_name, _args):
157
+ if model_name == 'UNet':
158
+ tform_emb = transforms.Compose([
159
+ transforms.Resize((256,256)),
160
+ transforms.ToTensor(),
161
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
162
+ ])
163
+ tform_det = transforms.Compose([
164
+ transforms.Resize((224,224)),
165
+ transforms.ToTensor(),
166
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
167
+ ])
168
+ model = load_UNet(_args)
169
+ else:
170
+ raise NotImplementedError
171
+ return model, tform_emb, tform_det
172
+
173
+
174
+ @st.cache_resource
175
+ def load_ecc(ecc_name):
176
+ if ecc_name == 'BCH':
177
+ # ecc = BCH(285, 10, SECRET_LEN, verbose=True)
178
+ ecc = BCH(payload_len= SECRET_LEN, verbose=True)
179
+ elif ecc_name == 'RSC':
180
+ ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
181
+ return ecc
182
+
183
+
184
+ class Resize(object):
185
+ def __init__(self, size=None) -> None:
186
+ self.size = size
187
+ def __call__(self, x, size=None):
188
+ if isinstance(x, np.ndarray):
189
+ x = Image.fromarray(x)
190
+ new_size = size if size is not None else self.size
191
+ if min(x.size) > min(new_size): # downsample
192
+ x = x.resize(new_size, Image.LANCZOS)
193
+ else: # upsample
194
+ x = x.resize(new_size, Image.BILINEAR)
195
+ x = np.array(x)
196
+ return x
197
+
198
+
199
+ def parse_st_args():
200
+ # usage: streamlit run app.py -- --arg1 val1 --arg2 val2
201
+ parser = argparse.ArgumentParser()
202
+ parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt')
203
+ parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml')
204
+ # parser.add_argument('--cpu', action='store_true')
205
+ args = parser.parse_args()
206
+ return args
207
+
208
+
209
+ def app(args):
210
+ # delete_page('Embed_Secret', 'Extract_Secret')
211
+ st.title('Watermarking Demo')
212
+ # setup model
213
+ model_name = st.selectbox("Choose the model", model_names)
214
+ model, tform_emb, tform_det = load_model(model_name, args)
215
+ display_width = 300
216
+
217
+ # ecc
218
+ ecc = load_ecc('BCH')
219
+ assert ecc.get_total_len() == SECRET_LEN
220
+
221
+ # setup st
222
+ st.subheader("Input")
223
+ image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
224
+ if image_file is not None:
225
+ print('Image: ', image_file.name)
226
+ ext = image_file.name.split('.')[-1]
227
+ im = Image.open(image_file).convert('RGB')
228
+ size0 = im.size
229
+ st.image(im, width=display_width)
230
+ secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret')
231
+ assert len(secret_text) <= ecc.data_len
232
+
233
+ # embed
234
+ st.subheader("Embed results")
235
+ status = st.empty()
236
+ prep = transforms.Compose([
237
+ transforms.Resize((256,256)),
238
+ transforms.CenterCrop((224,224))
239
+ ])
240
+ if image_file is not None and secret_text is not None:
241
+ secret = ecc.encode_text([secret_text]) # (1, len)
242
+ secret = torch.from_numpy(secret).float().cuda()
243
+ # im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
244
+ stego = embed_secret(model_name, model, im, tform_emb, secret)
245
+ st.image(stego, width=display_width)
246
+
247
+ # download button
248
+ mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
249
+ stego_bytes = to_bytes(stego, mime)
250
+ st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime)
251
+
252
+ # verify secret
253
+ stego_processed = prep(Image.fromarray(stego))
254
+ secret_pred = decode_secret(model_name, model, stego_processed, tform_det)
255
+ bit_acc = (secret_pred == secret.cpu().numpy()).mean()
256
+ secret_pred = ecc.decode_text(secret_pred)[0]
257
+ status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True)
258
+ status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True)
259
+
260
+ if __name__ == '__main__':
261
+ args = parse_st_args()
262
+ app(args)
263
+
264
+
265
+
cldm/ae.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import einops
3
+ import torch
4
+ import torch as th
5
+ import torch.nn as nn
6
+ from torch.nn import functional as thf
7
+ import pytorch_lightning as pl
8
+ import torchvision
9
+ from copy import deepcopy
10
+ from ldm.modules.diffusionmodules.util import (
11
+ conv_nd,
12
+ linear,
13
+ zero_module,
14
+ timestep_embedding,
15
+ )
16
+ from contextlib import contextmanager, nullcontext
17
+ from einops import rearrange, repeat
18
+ from torchvision.utils import make_grid
19
+ from ldm.modules.attention import SpatialTransformer
20
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
21
+ from ldm.models.diffusion.ddpm import LatentDiffusion
22
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
23
+ from ldm.models.diffusion.ddim import DDIMSampler
24
+ from ldm.modules.ema import LitEma
25
+ from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
26
+ from ldm.modules.diffusionmodules.model import Encoder
27
+ import lpips
28
+ import kornia
29
+ from kornia import color
30
+
31
+ def disabled_train(self, mode=True):
32
+ """Overwrite model.train with this function to make sure train/eval mode
33
+ does not change anymore."""
34
+ return self
35
+
36
+ class View(nn.Module):
37
+ def __init__(self, *shape):
38
+ super().__init__()
39
+ self.shape = shape
40
+
41
+ def forward(self, x):
42
+ return x.view(*self.shape)
43
+
44
+
45
+ class SecretEncoder3(nn.Module):
46
+ def __init__(self, secret_len, base_res=16, resolution=64) -> None:
47
+ super().__init__()
48
+ log_resolution = int(np.log2(resolution))
49
+ log_base = int(np.log2(base_res))
50
+ self.secret_len = secret_len
51
+ self.secret_scaler = nn.Sequential(
52
+ nn.Linear(secret_len, base_res*base_res*3),
53
+ nn.SiLU(),
54
+ View(-1, 3, base_res, base_res),
55
+ nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
56
+ zero_module(conv_nd(2, 3, 3, 3, padding=1))
57
+ ) # secret len -> ch x res x res
58
+
59
+ def copy_encoder_weight(self, ae_model):
60
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
61
+ return None
62
+
63
+ def encode(self, x):
64
+ x = self.secret_scaler(x)
65
+ return x
66
+
67
+ def forward(self, x, c):
68
+ # x: [B, C, H, W], c: [B, secret_len]
69
+ c = self.encode(c)
70
+ return c, None
71
+
72
+
73
+ class SecretEncoder4(nn.Module):
74
+ """same as SecretEncoder3 but with ch as input"""
75
+ def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None:
76
+ super().__init__()
77
+ log_resolution = int(np.log2(resolution))
78
+ log_base = int(np.log2(base_res))
79
+ self.secret_len = secret_len
80
+ self.secret_scaler = nn.Sequential(
81
+ nn.Linear(secret_len, base_res*base_res*ch),
82
+ nn.SiLU(),
83
+ View(-1, ch, base_res, base_res),
84
+ nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
85
+ zero_module(conv_nd(2, ch, ch, 3, padding=1))
86
+ ) # secret len -> ch x res x res
87
+
88
+ def copy_encoder_weight(self, ae_model):
89
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
90
+ return None
91
+
92
+ def encode(self, x):
93
+ x = self.secret_scaler(x)
94
+ return x
95
+
96
+ def forward(self, x, c):
97
+ # x: [B, C, H, W], c: [B, secret_len]
98
+ c = self.encode(c)
99
+ return c, None
100
+
101
+ class SecretEncoder6(nn.Module):
102
+ """join img emb with secret emb"""
103
+ def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None:
104
+ super().__init__()
105
+ assert emode in ['c3', 'c2', 'm3']
106
+
107
+ if emode == 'c3': # c3: concat c and x each has ch channels
108
+ secret_ch = ch
109
+ join_ch = 2*ch
110
+ elif emode == 'c2': # c2: concat c (2) and x ave (1)
111
+ secret_ch = 2
112
+ join_ch = ch
113
+ elif emode == 'm3': # m3: multiply c (ch) and x (ch)
114
+ secret_ch = ch
115
+ join_ch = ch
116
+
117
+ # m3: multiply c (ch) and x ave (1)
118
+ log_resolution = int(np.log2(resolution))
119
+ log_base = int(np.log2(base_res))
120
+ self.secret_len = secret_len
121
+ self.emode = emode
122
+ self.resolution = resolution
123
+ self.secret_scaler = nn.Sequential(
124
+ nn.Linear(secret_len, base_res*base_res*secret_ch),
125
+ nn.SiLU(),
126
+ View(-1, secret_ch, base_res, base_res),
127
+ nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
128
+ ) # secret len -> ch x res x res
129
+ self.join_encoder = nn.Sequential(
130
+ conv_nd(2, join_ch, join_ch, 3, padding=1),
131
+ nn.SiLU(),
132
+ conv_nd(2, join_ch, ch, 3, padding=1),
133
+ nn.SiLU(),
134
+ conv_nd(2, ch, ch, 3, padding=1),
135
+ nn.SiLU()
136
+ )
137
+ self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
138
+
139
+ def copy_encoder_weight(self, ae_model):
140
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
141
+ return None
142
+
143
+ def encode(self, x):
144
+ x = self.secret_scaler(x)
145
+ return x
146
+
147
+ def forward(self, x, c):
148
+ # x: [B, C, H, W], c: [B, secret_len]
149
+ c = self.encode(c)
150
+ if self.emode == 'c3':
151
+ x = torch.cat([x, c], dim=1)
152
+ elif self.emode == 'c2':
153
+ x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1)
154
+ elif self.emode == 'm3':
155
+ x = x * c
156
+ dx = self.join_encoder(x)
157
+ dx = self.out_layer(dx)
158
+ return dx, None
159
+
160
+ class SecretEncoder5(nn.Module):
161
+ """same as SecretEncoder3 but with ch as input"""
162
+ def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None:
163
+ super().__init__()
164
+ log_resolution = int(np.log2(resolution))
165
+ log_base = int(np.log2(base_res))
166
+ self.secret_len = secret_len
167
+ self.joint = joint
168
+ self.resolution = resolution
169
+ self.secret_scaler = nn.Sequential(
170
+ nn.Linear(secret_len, base_res*base_res*ch),
171
+ nn.SiLU(),
172
+ View(-1, ch, base_res, base_res),
173
+ nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
174
+ ) # secret len -> ch x res x res
175
+ if joint:
176
+ self.join_encoder = nn.Sequential(
177
+ conv_nd(2, 2*ch, 2*ch, 3, padding=1),
178
+ nn.SiLU(),
179
+ conv_nd(2, 2*ch, ch, 3, padding=1),
180
+ nn.SiLU()
181
+ )
182
+ self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
183
+
184
+ def copy_encoder_weight(self, ae_model):
185
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
186
+ return None
187
+
188
+ def encode(self, x):
189
+ x = self.secret_scaler(x)
190
+ return x
191
+
192
+ def forward(self, x, c):
193
+ # x: [B, C, H, W], c: [B, secret_len]
194
+ c = self.encode(c)
195
+ if self.joint:
196
+ x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True)
197
+ c = self.join_encoder(torch.cat([x, c], dim=1))
198
+ c = self.out_layer(c)
199
+ return c, None
200
+
201
+
202
+ class SecretEncoder2(nn.Module):
203
+ def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
204
+ ignore_keys=[],
205
+ image_key="image",
206
+ colorize_nlabels=None,
207
+ monitor=None,
208
+ ema_decay=None,
209
+ learn_logvar=False) -> None:
210
+ super().__init__()
211
+ log_resolution = int(np.log2(ddconfig.resolution))
212
+ self.secret_len = secret_len
213
+ self.learn_logvar = learn_logvar
214
+ self.image_key = image_key
215
+ self.encoder = Encoder(**ddconfig)
216
+ self.encoder.conv_out = zero_module(self.encoder.conv_out)
217
+ self.embed_dim = embed_dim
218
+
219
+ if colorize_nlabels is not None:
220
+ assert type(colorize_nlabels)==int
221
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
222
+
223
+ if monitor is not None:
224
+ self.monitor = monitor
225
+
226
+ self.secret_scaler = nn.Sequential(
227
+ nn.Linear(secret_len, 32*32*ddconfig.out_ch),
228
+ nn.SiLU(),
229
+ View(-1, ddconfig.out_ch, 32, 32),
230
+ nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
231
+ # zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
232
+ ) # secret len -> ch x res x res
233
+ # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
234
+ # self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
235
+
236
+ self.use_ema = ema_decay is not None
237
+ if self.use_ema:
238
+ self.ema_decay = ema_decay
239
+ assert 0. < ema_decay < 1.
240
+ self.model_ema = LitEma(self, decay=ema_decay)
241
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
242
+
243
+ if ckpt_path is not None:
244
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
245
+
246
+
247
+ def init_from_ckpt(self, path, ignore_keys=list()):
248
+ sd = torch.load(path, map_location="cpu")["state_dict"]
249
+ keys = list(sd.keys())
250
+ for k in keys:
251
+ for ik in ignore_keys:
252
+ if k.startswith(ik):
253
+ print("Deleting key {} from state_dict.".format(k))
254
+ del sd[k]
255
+ misses, ignores = self.load_state_dict(sd, strict=False)
256
+ print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
257
+
258
+ def copy_encoder_weight(self, ae_model):
259
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
260
+ return None
261
+ self.encoder.load_state_dict(ae_model.encoder.state_dict())
262
+ self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
263
+
264
+ @contextmanager
265
+ def ema_scope(self, context=None):
266
+ if self.use_ema:
267
+ self.model_ema.store(self.parameters())
268
+ self.model_ema.copy_to(self)
269
+ if context is not None:
270
+ print(f"{context}: Switched to EMA weights")
271
+ try:
272
+ yield None
273
+ finally:
274
+ if self.use_ema:
275
+ self.model_ema.restore(self.parameters())
276
+ if context is not None:
277
+ print(f"{context}: Restored training weights")
278
+
279
+ def on_train_batch_end(self, *args, **kwargs):
280
+ if self.use_ema:
281
+ self.model_ema(self)
282
+
283
+ def encode(self, x):
284
+ h = self.encoder(x)
285
+ posterior = h
286
+ return posterior
287
+
288
+ def forward(self, x, c):
289
+ # x: [B, C, H, W], c: [B, secret_len]
290
+ c = self.secret_scaler(c)
291
+ x = torch.cat([x, c], dim=1)
292
+ z = self.encode(x)
293
+ # z = self.out_layer(z)
294
+ return z, None
295
+
296
+
297
+ class SecretEncoder7(nn.Module):
298
+ def __init__(self, secret_len, ddconfig, ckpt_path=None,
299
+ ignore_keys=[],embed_dim=3,
300
+ ema_decay=None) -> None:
301
+ super().__init__()
302
+ log_resolution = int(np.log2(ddconfig.resolution))
303
+ self.secret_len = secret_len
304
+ self.encoder = Encoder(**ddconfig)
305
+ # self.encoder.conv_out = zero_module(self.encoder.conv_out)
306
+ self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
307
+
308
+ self.secret_scaler = nn.Sequential(
309
+ nn.Linear(secret_len, 32*32*2),
310
+ nn.SiLU(),
311
+ View(-1, 2, 32, 32),
312
+ # nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
313
+ # zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
314
+ ) # secret len -> ch x res x res
315
+ # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
316
+ # self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
317
+
318
+ self.use_ema = ema_decay is not None
319
+ if self.use_ema:
320
+ self.ema_decay = ema_decay
321
+ assert 0. < ema_decay < 1.
322
+ self.model_ema = LitEma(self, decay=ema_decay)
323
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
324
+
325
+ if ckpt_path is not None:
326
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
327
+
328
+
329
+ def init_from_ckpt(self, path, ignore_keys=list()):
330
+ sd = torch.load(path, map_location="cpu")["state_dict"]
331
+ keys = list(sd.keys())
332
+ for k in keys:
333
+ for ik in ignore_keys:
334
+ if k.startswith(ik):
335
+ print("Deleting key {} from state_dict.".format(k))
336
+ del sd[k]
337
+ misses, ignores = self.load_state_dict(sd, strict=False)
338
+ print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.")
339
+
340
+ def copy_encoder_weight(self, ae_model):
341
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
342
+ # return None
343
+ self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict()))
344
+ self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict()))
345
+
346
+ @contextmanager
347
+ def ema_scope(self, context=None):
348
+ if self.use_ema:
349
+ self.model_ema.store(self.parameters())
350
+ self.model_ema.copy_to(self)
351
+ if context is not None:
352
+ print(f"{context}: Switched to EMA weights")
353
+ try:
354
+ yield None
355
+ finally:
356
+ if self.use_ema:
357
+ self.model_ema.restore(self.parameters())
358
+ if context is not None:
359
+ print(f"{context}: Restored training weights")
360
+
361
+ def on_train_batch_end(self, *args, **kwargs):
362
+ if self.use_ema:
363
+ self.model_ema(self)
364
+
365
+ def encode(self, x):
366
+ h = self.encoder(x)
367
+ h = self.quant_conv(h)
368
+ return h
369
+
370
+ def forward(self, x, c):
371
+ # x: [B, C, H, W], c: [B, secret_len]
372
+ c = self.secret_scaler(c) # [B, 2, 32, 32]
373
+ # c = thf.interpolate(c, size=x.shape[-2:], mode="bilinear", align_corners=False)
374
+ c = thf.interpolate(c, size=x.shape[-2:], mode="nearest")
375
+ x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...]
376
+ x = torch.cat([x.unsqueeze(1), c], dim=1)
377
+ z = self.encode(x)
378
+ # z = self.out_layer(z)
379
+ return z, None
380
+
381
+ class SecretEncoder(nn.Module):
382
+ def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
383
+ ignore_keys=[],
384
+ image_key="image",
385
+ colorize_nlabels=None,
386
+ monitor=None,
387
+ ema_decay=None,
388
+ learn_logvar=False) -> None:
389
+ super().__init__()
390
+ log_resolution = int(np.log2(ddconfig.resolution))
391
+ self.secret_len = secret_len
392
+ self.learn_logvar = learn_logvar
393
+ self.image_key = image_key
394
+ self.encoder = Encoder(**ddconfig)
395
+ assert ddconfig["double_z"]
396
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
397
+ self.embed_dim = embed_dim
398
+
399
+ if colorize_nlabels is not None:
400
+ assert type(colorize_nlabels)==int
401
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
402
+
403
+ if monitor is not None:
404
+ self.monitor = monitor
405
+
406
+ self.use_ema = ema_decay is not None
407
+ if self.use_ema:
408
+ self.ema_decay = ema_decay
409
+ assert 0. < ema_decay < 1.
410
+ self.model_ema = LitEma(self, decay=ema_decay)
411
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
412
+
413
+ if ckpt_path is not None:
414
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
415
+
416
+ self.secret_scaler = nn.Sequential(
417
+ nn.Linear(secret_len, 32*32*ddconfig.out_ch),
418
+ nn.SiLU(),
419
+ View(-1, ddconfig.out_ch, 32, 32),
420
+ nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
421
+ zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
422
+ ) # secret len -> ch x res x res
423
+ # out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
424
+ self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
425
+
426
+ def init_from_ckpt(self, path, ignore_keys=list()):
427
+ sd = torch.load(path, map_location="cpu")["state_dict"]
428
+ keys = list(sd.keys())
429
+ for k in keys:
430
+ for ik in ignore_keys:
431
+ if k.startswith(ik):
432
+ print("Deleting key {} from state_dict.".format(k))
433
+ del sd[k]
434
+ misses, ignores = self.load_state_dict(sd, strict=False)
435
+ print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
436
+
437
+ def copy_encoder_weight(self, ae_model):
438
+ # misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
439
+ self.encoder.load_state_dict(ae_model.encoder.state_dict())
440
+ self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
441
+
442
+ @contextmanager
443
+ def ema_scope(self, context=None):
444
+ if self.use_ema:
445
+ self.model_ema.store(self.parameters())
446
+ self.model_ema.copy_to(self)
447
+ if context is not None:
448
+ print(f"{context}: Switched to EMA weights")
449
+ try:
450
+ yield None
451
+ finally:
452
+ if self.use_ema:
453
+ self.model_ema.restore(self.parameters())
454
+ if context is not None:
455
+ print(f"{context}: Restored training weights")
456
+
457
+ def on_train_batch_end(self, *args, **kwargs):
458
+ if self.use_ema:
459
+ self.model_ema(self)
460
+
461
+ def encode(self, x):
462
+ h = self.encoder(x)
463
+ moments = self.quant_conv(h)
464
+ posterior = DiagonalGaussianDistribution(moments)
465
+ return posterior
466
+
467
+ def forward(self, x, c):
468
+ # x: [B, C, H, W], c: [B, secret_len]
469
+ c = self.secret_scaler(c)
470
+ x = x + c
471
+ posterior = self.encode(x)
472
+ z = posterior.sample()
473
+ z = self.out_layer(z)
474
+ return z, posterior
475
+
476
+
477
+ class ControlAE(pl.LightningModule):
478
+ def __init__(self,
479
+ first_stage_key,
480
+ first_stage_config,
481
+ control_key,
482
+ control_config,
483
+ decoder_config,
484
+ loss_config,
485
+ noise_config='__none__',
486
+ use_ema=False,
487
+ secret_warmup=False,
488
+ scale_factor=1.,
489
+ ckpt_path="__none__",
490
+ ):
491
+ super().__init__()
492
+ self.scale_factor = scale_factor
493
+ self.control_key = control_key
494
+ self.first_stage_key = first_stage_key
495
+ self.ae = instantiate_from_config(first_stage_config)
496
+ self.control = instantiate_from_config(control_config)
497
+ self.decoder = instantiate_from_config(decoder_config)
498
+ self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
499
+ if noise_config != '__none__':
500
+ print('Using noise')
501
+ self.noise = instantiate_from_config(noise_config)
502
+ # copy weights from first stage
503
+ self.control.copy_encoder_weight(self.ae)
504
+ # freeze first stage
505
+ self.ae.eval()
506
+ self.ae.train = disabled_train
507
+ for p in self.ae.parameters():
508
+ p.requires_grad = False
509
+
510
+ self.loss_layer = instantiate_from_config(loss_config)
511
+
512
+ # early training phase
513
+ # self.fixed_input = True
514
+ self.fixed_x = None
515
+ self.fixed_img = None
516
+ self.fixed_input_recon = None
517
+ self.fixed_control = None
518
+ self.register_buffer("fixed_input", torch.tensor(True))
519
+
520
+ # secret warmup
521
+ self.secret_warmup = secret_warmup
522
+ self.secret_baselen = 2
523
+ self.secret_len = control_config.params.secret_len
524
+ if self.secret_warmup:
525
+ assert self.secret_len == 2**(int(np.log2(self.secret_len)))
526
+
527
+ self.use_ema = use_ema
528
+ if self.use_ema:
529
+ print('Using EMA')
530
+ self.control_ema = LitEma(self.control)
531
+ self.decoder_ema = LitEma(self.decoder)
532
+ print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
533
+
534
+ if ckpt_path != '__none__':
535
+ self.init_from_ckpt(ckpt_path, ignore_keys=[])
536
+
537
+ def get_warmup_secret(self, old_secret):
538
+ # old_secret: [B, secret_len]
539
+ # new_secret: [B, secret_len]
540
+ if self.secret_warmup:
541
+ bsz = old_secret.shape[0]
542
+ nrepeats = self.secret_len // self.secret_baselen
543
+ new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1)
544
+ return new_secret.to(old_secret.device)
545
+ else:
546
+ return old_secret
547
+
548
+ def init_from_ckpt(self, path, ignore_keys=list()):
549
+ sd = torch.load(path, map_location="cpu")["state_dict"]
550
+ keys = list(sd.keys())
551
+ for k in keys:
552
+ for ik in ignore_keys:
553
+ if k.startswith(ik):
554
+ print("Deleting key {} from state_dict.".format(k))
555
+ del sd[k]
556
+ self.load_state_dict(sd, strict=False)
557
+ print(f"Restored from {path}")
558
+
559
+ @contextmanager
560
+ def ema_scope(self, context=None):
561
+ if self.use_ema:
562
+ self.control_ema.store(self.control.parameters())
563
+ self.decoder_ema.store(self.decoder.parameters())
564
+ self.control_ema.copy_to(self.control)
565
+ self.decoder_ema.copy_to(self.decoder)
566
+ if context is not None:
567
+ print(f"{context}: Switched to EMA weights")
568
+ try:
569
+ yield None
570
+ finally:
571
+ if self.use_ema:
572
+ self.control_ema.restore(self.control.parameters())
573
+ self.decoder_ema.restore(self.decoder.parameters())
574
+ if context is not None:
575
+ print(f"{context}: Restored training weights")
576
+
577
+ def on_train_batch_end(self, *args, **kwargs):
578
+ if self.use_ema:
579
+ self.control_ema(self.control)
580
+ self.decoder_ema(self.decoder)
581
+
582
+ def compute_loss(self, pred, target):
583
+ # return thf.mse_loss(pred, target, reduction="none").mean(dim=(1, 2, 3))
584
+ lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3])
585
+ pred_yuv = color.rgb_to_yuv((pred + 1) / 2)
586
+ target_yuv = color.rgb_to_yuv((target + 1) / 2)
587
+ yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3])
588
+ yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
589
+ return lpips_loss + yuv_loss
590
+
591
+ def forward(self, x, image, c):
592
+ if self.control.__class__.__name__ == 'SecretEncoder6':
593
+ eps, posterior = self.control(x, c)
594
+ else:
595
+ eps, posterior = self.control(image, c)
596
+ return x + eps, posterior
597
+
598
+ @torch.no_grad()
599
+ def get_input(self, batch, return_first_stage=False, bs=None):
600
+ image = batch[self.first_stage_key]
601
+ control = batch[self.control_key]
602
+ control = self.get_warmup_secret(control)
603
+ if bs is not None:
604
+ image = image[:bs]
605
+ control = control[:bs]
606
+ else:
607
+ bs = image.shape[0]
608
+ # encode image 1st stage
609
+ image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
610
+ x = self.encode_first_stage(image).detach()
611
+ image_rec = self.decode_first_stage(x).detach()
612
+
613
+ # check if using fixed input (early training phase)
614
+ # if self.training and self.fixed_input:
615
+ if self.fixed_input:
616
+ if self.fixed_x is None: # first iteration
617
+ print('[TRAINING] Warmup - using fixed input image for now!')
618
+ self.fixed_x = x.detach().clone()[:bs]
619
+ self.fixed_img = image.detach().clone()[:bs]
620
+ self.fixed_input_recon = image_rec.detach().clone()[:bs]
621
+ self.fixed_control = control.detach().clone()[:bs] # use for log_images with fixed_input option only
622
+ x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon
623
+
624
+ out = [x, control]
625
+ if return_first_stage:
626
+ out.extend([image, image_rec])
627
+ return out
628
+
629
+ def decode_first_stage(self, z):
630
+ z = 1./self.scale_factor * z
631
+ image_rec = self.ae.decode(z)
632
+ return image_rec
633
+
634
+ def encode_first_stage(self, image):
635
+ encoder_posterior = self.ae.encode(image)
636
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
637
+ z = encoder_posterior.sample()
638
+ elif isinstance(encoder_posterior, torch.Tensor):
639
+ z = encoder_posterior
640
+ else:
641
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
642
+ return self.scale_factor * z
643
+
644
+ def shared_step(self, batch):
645
+ x, c, img, _ = self.get_input(batch, return_first_stage=True)
646
+ # import pdb; pdb.set_trace()
647
+ x, posterior = self(x, img, c)
648
+ image_rec = self.decode_first_stage(x)
649
+ # resize
650
+ if img.shape[-1] > 256:
651
+ img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach()
652
+ image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False)
653
+ if hasattr(self, 'noise') and self.noise.is_activated():
654
+ image_rec_noised = self.noise(image_rec, self.global_step, p=0.9)
655
+ else:
656
+ image_rec_noised = self.crop(image_rec) # center crop
657
+ image_rec_noised = torch.clamp(image_rec_noised, -1, 1)
658
+ pred = self.decoder(image_rec_noised)
659
+
660
+ loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step)
661
+ bit_acc = loss_dict["bit_acc"]
662
+
663
+ bit_acc_ = bit_acc.item()
664
+
665
+ if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
666
+ self.loss_layer.activate_ramp(self.global_step)
667
+
668
+ if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
669
+ if hasattr(self, 'noise') and (not self.noise.is_activated()):
670
+ self.noise.activate(self.global_step)
671
+
672
+ if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
673
+ print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
674
+ self.fixed_input = ~self.fixed_input
675
+ return loss, loss_dict
676
+
677
+ def training_step(self, batch, batch_idx):
678
+ loss, loss_dict = self.shared_step(batch)
679
+ loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
680
+ self.log_dict(loss_dict, prog_bar=True,
681
+ logger=True, on_step=True, on_epoch=True)
682
+
683
+ self.log("global_step", self.global_step,
684
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
685
+ # if self.use_scheduler:
686
+ # lr = self.optimizers().param_groups[0]['lr']
687
+ # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
688
+
689
+ return loss
690
+
691
+ @torch.no_grad()
692
+ def validation_step(self, batch, batch_idx):
693
+ _, loss_dict_no_ema = self.shared_step(batch)
694
+ loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
695
+ with self.ema_scope():
696
+ _, loss_dict_ema = self.shared_step(batch)
697
+ loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
698
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
699
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
700
+
701
+ @torch.no_grad()
702
+ def log_images(self, batch, fixed_input=False, **kwargs):
703
+ log = dict()
704
+ if fixed_input and self.fixed_img is not None:
705
+ x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon
706
+ else:
707
+ x, c, img, img_recon = self.get_input(batch, return_first_stage=True)
708
+ x, _ = self(x, img, c)
709
+ image_out = self.decode_first_stage(x)
710
+ if hasattr(self, 'noise') and self.noise.is_activated():
711
+ img_noise = self.noise(image_out, self.global_step, p=1.0)
712
+ log['noised'] = img_noise
713
+ log['input'] = img
714
+ log['output'] = image_out
715
+ log['recon'] = img_recon
716
+ return log
717
+
718
+ def configure_optimizers(self):
719
+ lr = self.learning_rate
720
+ params = list(self.control.parameters()) + list(self.decoder.parameters())
721
+ optimizer = torch.optim.AdamW(params, lr=lr)
722
+ return optimizer
723
+
724
+
725
+
726
+
727
+
cldm/cldm.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import einops
3
+ import torch
4
+ import torch as th
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from ldm.modules.diffusionmodules.util import (
8
+ conv_nd,
9
+ linear,
10
+ zero_module,
11
+ timestep_embedding,
12
+ )
13
+
14
+ from einops import rearrange, repeat
15
+ from torchvision.utils import make_grid
16
+ from ldm.modules.attention import SpatialTransformer
17
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
18
+ from ldm.models.diffusion.ddpm import LatentDiffusion
19
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
20
+ from ldm.models.diffusion.ddim import DDIMSampler
21
+
22
+
23
+ class ControlledUnetModel(UNetModel):
24
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
25
+ hs = []
26
+ with torch.no_grad():
27
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
28
+ emb = self.time_embed(t_emb)
29
+ h = x.type(self.dtype)
30
+ for module in self.input_blocks:
31
+ h = module(h, emb, context)
32
+ hs.append(h)
33
+ h = self.middle_block(h, emb, context)
34
+
35
+ h += control.pop()
36
+
37
+ for i, module in enumerate(self.output_blocks):
38
+ if only_mid_control:
39
+ h = torch.cat([h, hs.pop()], dim=1)
40
+ else:
41
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
42
+ h = module(h, emb, context)
43
+
44
+ h = h.type(x.dtype)
45
+ return self.out(h)
46
+
47
+ class View(nn.Module):
48
+ def __init__(self, *shape):
49
+ super().__init__()
50
+ self.shape = shape
51
+
52
+ def forward(self, x):
53
+ return x.view(*self.shape)
54
+
55
+ class ControlNet(nn.Module):
56
+ def __init__(
57
+ self,
58
+ image_size,
59
+ in_channels,
60
+ model_channels,
61
+ hint_channels,
62
+ num_res_blocks,
63
+ attention_resolutions,
64
+ dropout=0,
65
+ channel_mult=(1, 2, 4, 8),
66
+ conv_resample=True,
67
+ dims=2,
68
+ use_checkpoint=False,
69
+ use_fp16=False,
70
+ num_heads=-1,
71
+ num_head_channels=-1,
72
+ num_heads_upsample=-1,
73
+ use_scale_shift_norm=False,
74
+ resblock_updown=False,
75
+ use_new_attention_order=False,
76
+ use_spatial_transformer=False, # custom transformer support
77
+ transformer_depth=1, # custom transformer support
78
+ context_dim=None, # custom transformer support
79
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
80
+ legacy=True,
81
+ disable_self_attentions=None,
82
+ num_attention_blocks=None,
83
+ disable_middle_self_attn=False,
84
+ use_linear_in_transformer=False,
85
+ secret_len = 0,
86
+ ):
87
+ super().__init__()
88
+ if use_spatial_transformer:
89
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
90
+
91
+ if context_dim is not None:
92
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
93
+ from omegaconf.listconfig import ListConfig
94
+ if type(context_dim) == ListConfig:
95
+ context_dim = list(context_dim)
96
+
97
+ if num_heads_upsample == -1:
98
+ num_heads_upsample = num_heads
99
+
100
+ if num_heads == -1:
101
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
102
+
103
+ if num_head_channels == -1:
104
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
105
+
106
+ self.dims = dims
107
+ self.image_size = image_size
108
+ self.in_channels = in_channels
109
+ self.model_channels = model_channels
110
+ if isinstance(num_res_blocks, int):
111
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
112
+ else:
113
+ if len(num_res_blocks) != len(channel_mult):
114
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
115
+ "as a list/tuple (per-level) with the same length as channel_mult")
116
+ self.num_res_blocks = num_res_blocks
117
+ if disable_self_attentions is not None:
118
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
119
+ assert len(disable_self_attentions) == len(channel_mult)
120
+ if num_attention_blocks is not None:
121
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
122
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
123
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
124
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
125
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
126
+ f"attention will still not be set.")
127
+
128
+ self.attention_resolutions = attention_resolutions
129
+ self.dropout = dropout
130
+ self.channel_mult = channel_mult
131
+ self.conv_resample = conv_resample
132
+ self.use_checkpoint = use_checkpoint
133
+ self.dtype = th.float16 if use_fp16 else th.float32
134
+ self.num_heads = num_heads
135
+ self.num_head_channels = num_head_channels
136
+ self.num_heads_upsample = num_heads_upsample
137
+ self.predict_codebook_ids = n_embed is not None
138
+
139
+ time_embed_dim = model_channels * 4
140
+ self.time_embed = nn.Sequential(
141
+ linear(model_channels, time_embed_dim),
142
+ nn.SiLU(),
143
+ linear(time_embed_dim, time_embed_dim),
144
+ )
145
+
146
+ self.input_blocks = nn.ModuleList(
147
+ [
148
+ TimestepEmbedSequential(
149
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
150
+ )
151
+ ]
152
+ )
153
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
154
+ self.secret_len = secret_len
155
+ if secret_len > 0:
156
+ log_resolution = int(np.log2(64))
157
+ self.input_hint_block = TimestepEmbedSequential(
158
+ nn.Linear(secret_len, 16*16*4),
159
+ nn.SiLU(),
160
+ View(-1, 4, 16, 16),
161
+ nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
162
+ conv_nd(dims, 4, 64, 3, padding=1),
163
+ nn.SiLU(),
164
+ conv_nd(dims, 64, 256, 3, padding=1),
165
+ nn.SiLU(),
166
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
167
+ )
168
+ else:
169
+ self.input_hint_block = TimestepEmbedSequential(
170
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
171
+ nn.SiLU(),
172
+ conv_nd(dims, 16, 16, 3, padding=1),
173
+ nn.SiLU(),
174
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
175
+ nn.SiLU(),
176
+ conv_nd(dims, 32, 32, 3, padding=1),
177
+ nn.SiLU(),
178
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
179
+ nn.SiLU(),
180
+ conv_nd(dims, 96, 96, 3, padding=1),
181
+ nn.SiLU(),
182
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
183
+ nn.SiLU(),
184
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
185
+ )
186
+
187
+ self._feature_size = model_channels
188
+ input_block_chans = [model_channels]
189
+ ch = model_channels
190
+ ds = 1
191
+ for level, mult in enumerate(channel_mult):
192
+ for nr in range(self.num_res_blocks[level]):
193
+ layers = [
194
+ ResBlock(
195
+ ch,
196
+ time_embed_dim,
197
+ dropout,
198
+ out_channels=mult * model_channels,
199
+ dims=dims,
200
+ use_checkpoint=use_checkpoint,
201
+ use_scale_shift_norm=use_scale_shift_norm,
202
+ )
203
+ ]
204
+ ch = mult * model_channels
205
+ if ds in attention_resolutions:
206
+ if num_head_channels == -1:
207
+ dim_head = ch // num_heads
208
+ else:
209
+ num_heads = ch // num_head_channels
210
+ dim_head = num_head_channels
211
+ if legacy:
212
+ #num_heads = 1
213
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
214
+ if exists(disable_self_attentions):
215
+ disabled_sa = disable_self_attentions[level]
216
+ else:
217
+ disabled_sa = False
218
+
219
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
220
+ layers.append(
221
+ AttentionBlock(
222
+ ch,
223
+ use_checkpoint=use_checkpoint,
224
+ num_heads=num_heads,
225
+ num_head_channels=dim_head,
226
+ use_new_attention_order=use_new_attention_order,
227
+ ) if not use_spatial_transformer else SpatialTransformer(
228
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
229
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
230
+ use_checkpoint=use_checkpoint
231
+ )
232
+ )
233
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
234
+ self.zero_convs.append(self.make_zero_conv(ch))
235
+ self._feature_size += ch
236
+ input_block_chans.append(ch)
237
+ if level != len(channel_mult) - 1:
238
+ out_ch = ch
239
+ self.input_blocks.append(
240
+ TimestepEmbedSequential(
241
+ ResBlock(
242
+ ch,
243
+ time_embed_dim,
244
+ dropout,
245
+ out_channels=out_ch,
246
+ dims=dims,
247
+ use_checkpoint=use_checkpoint,
248
+ use_scale_shift_norm=use_scale_shift_norm,
249
+ down=True,
250
+ )
251
+ if resblock_updown
252
+ else Downsample(
253
+ ch, conv_resample, dims=dims, out_channels=out_ch
254
+ )
255
+ )
256
+ )
257
+ ch = out_ch
258
+ input_block_chans.append(ch)
259
+ self.zero_convs.append(self.make_zero_conv(ch))
260
+ ds *= 2
261
+ self._feature_size += ch
262
+
263
+ if num_head_channels == -1:
264
+ dim_head = ch // num_heads
265
+ else:
266
+ num_heads = ch // num_head_channels
267
+ dim_head = num_head_channels
268
+ if legacy:
269
+ #num_heads = 1
270
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
271
+ self.middle_block = TimestepEmbedSequential(
272
+ ResBlock(
273
+ ch,
274
+ time_embed_dim,
275
+ dropout,
276
+ dims=dims,
277
+ use_checkpoint=use_checkpoint,
278
+ use_scale_shift_norm=use_scale_shift_norm,
279
+ ),
280
+ AttentionBlock(
281
+ ch,
282
+ use_checkpoint=use_checkpoint,
283
+ num_heads=num_heads,
284
+ num_head_channels=dim_head,
285
+ use_new_attention_order=use_new_attention_order,
286
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
287
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
288
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
289
+ use_checkpoint=use_checkpoint
290
+ ),
291
+ ResBlock(
292
+ ch,
293
+ time_embed_dim,
294
+ dropout,
295
+ dims=dims,
296
+ use_checkpoint=use_checkpoint,
297
+ use_scale_shift_norm=use_scale_shift_norm,
298
+ ),
299
+ )
300
+ self.middle_block_out = self.make_zero_conv(ch)
301
+ self._feature_size += ch
302
+
303
+ def make_zero_conv(self, channels):
304
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
305
+
306
+ def forward(self, x, hint, timesteps, context, **kwargs):
307
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
308
+ emb = self.time_embed(t_emb)
309
+ # import pdb; pdb.set_trace()
310
+ guided_hint = self.input_hint_block(hint, emb, context)
311
+
312
+ outs = []
313
+
314
+ h = x.type(self.dtype)
315
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
316
+ if guided_hint is not None:
317
+ h = module(h, emb, context)
318
+ h += guided_hint
319
+ guided_hint = None
320
+ else:
321
+ h = module(h, emb, context)
322
+ outs.append(zero_conv(h, emb, context))
323
+
324
+ h = self.middle_block(h, emb, context)
325
+ outs.append(self.middle_block_out(h, emb, context))
326
+
327
+ return outs
328
+
329
+
330
+ class SecretDecoder(nn.Module):
331
+ def __init__(self, arch='CNN', act='ReLU', norm='none', resolution=256, in_channels=3, secret_len=100):
332
+ super().__init__()
333
+ self.resolution = resolution
334
+ self.arch = arch
335
+ print(f'SecretDecoder arch: {arch}')
336
+ def activation(name = 'ReLU'):
337
+ if name == 'ReLU':
338
+ return nn.ReLU()
339
+ elif name == 'LeakyReLU':
340
+ return nn.LeakyReLU()
341
+ elif name == 'SiLU':
342
+ return nn.SiLU()
343
+
344
+ def normalisation(name, n):
345
+ if name == 'none':
346
+ return nn.Identity()
347
+ elif name == 'BatchNorm2D':
348
+ return nn.BatchNorm2d(n)
349
+ elif name == 'BatchNorm1d':
350
+ return nn.BatchNorm1d(n)
351
+ elif name == 'LayerNorm':
352
+ return nn.LayerNorm(n)
353
+
354
+ if arch=='CNN':
355
+ self.decoder = nn.Sequential(
356
+ nn.Conv2d(in_channels, 32, (3, 3), 2, 1), # 128
357
+ activation(act),
358
+ nn.Conv2d(32, 32, 3, 1, 1),
359
+ activation(act),
360
+ nn.Conv2d(32, 64, 3, 2, 1), # 64
361
+ activation(act),
362
+ nn.Conv2d(64, 64, 3, 1, 1),
363
+ activation(act),
364
+ nn.Conv2d(64, 64, 3, 2, 1), # 32
365
+ activation(act),
366
+ nn.Conv2d(64, 128, 3, 2, 1), # 16
367
+ activation(act),
368
+ nn.Conv2d(128, 128, (3, 3), 2, 1), # 8
369
+ activation(act),
370
+ )
371
+ self.dense = nn.Sequential(
372
+ nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
373
+ activation(act),
374
+ nn.Linear(512, secret_len)
375
+ )
376
+ elif arch == 'resnet50':
377
+ self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
378
+ self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
379
+ else:
380
+ raise NotImplementedError
381
+
382
+ def forward(self, image):
383
+ x = self.decoder(image)
384
+ if self.arch == 'CNN':
385
+ x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
386
+ x = self.dense(x)
387
+ return x
388
+
389
+
390
+ class ControlLDM(LatentDiffusion):
391
+
392
+ def __init__(self, control_stage_config, control_key, only_mid_control, secret_decoder_config, *args, **kwargs):
393
+ super().__init__(*args, **kwargs)
394
+ self.control_model = instantiate_from_config(control_stage_config)
395
+ self.control_key = control_key
396
+ self.only_mid_control = only_mid_control
397
+ if secret_decoder_config != 'none':
398
+ self.secret_decoder = instantiate_from_config(secret_decoder_config)
399
+
400
+ @torch.no_grad()
401
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
402
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
403
+ control = batch[self.control_key]
404
+ if bs is not None:
405
+ control = control[:bs]
406
+ control = control.to(self.device)
407
+ if self.control_key == 'hint':
408
+ control = einops.rearrange(control, 'b h w c -> b c h w')
409
+ control = control.to(memory_format=torch.contiguous_format).float()
410
+ return x, dict(c_crossattn=[c], c_concat=[control])
411
+
412
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
413
+ assert isinstance(cond, dict)
414
+ diffusion_model = self.model.diffusion_model
415
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
416
+ cond_hint = torch.cat(cond['c_concat'], 1)
417
+
418
+ control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
419
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
420
+
421
+ return eps
422
+
423
+ @torch.no_grad()
424
+ def get_unconditional_conditioning(self, N):
425
+ return self.get_learned_conditioning([""] * N)
426
+
427
+ @torch.no_grad()
428
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
429
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
430
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
431
+ use_ema_scope=True,
432
+ **kwargs):
433
+ use_ddim = ddim_steps is not None
434
+
435
+ log = dict()
436
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
437
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
438
+ N = min(z.shape[0], N)
439
+ n_row = min(z.shape[0], n_row)
440
+ log["reconstruction"] = self.decode_first_stage(z)
441
+ log["control"] = c_cat * 2.0 - 1.0
442
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
443
+
444
+ if plot_diffusion_rows:
445
+ # get diffusion row
446
+ diffusion_row = list()
447
+ z_start = z[:n_row]
448
+ for t in range(self.num_timesteps):
449
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
450
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
451
+ t = t.to(self.device).long()
452
+ noise = torch.randn_like(z_start)
453
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
454
+ diffusion_row.append(self.decode_first_stage(z_noisy))
455
+
456
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
457
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
458
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
459
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
460
+ log["diffusion_row"] = diffusion_grid
461
+
462
+ if sample:
463
+ # get denoise row
464
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
465
+ batch_size=N, ddim=use_ddim,
466
+ ddim_steps=ddim_steps, eta=ddim_eta)
467
+ x_samples = self.decode_first_stage(samples)
468
+ log["samples"] = x_samples
469
+ if plot_denoise_rows:
470
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
471
+ log["denoise_row"] = denoise_grid
472
+ # import pudb; pudb.set_trace()
473
+ if unconditional_guidance_scale > 1.0:
474
+ uc_cross = self.get_unconditional_conditioning(N)
475
+ uc_cat = c_cat # torch.zeros_like(c_cat)
476
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
477
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
478
+ batch_size=N, ddim=use_ddim,
479
+ ddim_steps=ddim_steps, eta=ddim_eta,
480
+ unconditional_guidance_scale=unconditional_guidance_scale,
481
+ unconditional_conditioning=uc_full,
482
+ )
483
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
484
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
485
+
486
+ return log
487
+
488
+ @torch.no_grad()
489
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
490
+ ddim_sampler = DDIMSampler(self)
491
+ # import pdb; pdb.set_trace()
492
+ # b, c, h, w = cond["c_concat"][0].shape
493
+ b, c, h, w = cond["c_concat"][0].shape[0], self.channels, self.image_size*8, self.image_size*8
494
+ shape = (self.channels, h // 8, w // 8)
495
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
496
+ return samples, intermediates
497
+
498
+ def configure_optimizers(self):
499
+ lr = self.learning_rate
500
+ params = list(self.control_model.parameters())
501
+ if not self.sd_locked:
502
+ params += list(self.model.diffusion_model.output_blocks.parameters())
503
+ params += list(self.model.diffusion_model.out.parameters())
504
+ opt = torch.optim.AdamW(params, lr=lr)
505
+ return opt
506
+
507
+ def low_vram_shift(self, is_diffusing):
508
+ if is_diffusing:
509
+ self.model = self.model.cuda()
510
+ self.control_model = self.control_model.cuda()
511
+ self.first_stage_model = self.first_stage_model.cpu()
512
+ self.cond_stage_model = self.cond_stage_model.cpu()
513
+ else:
514
+ self.model = self.model.cpu()
515
+ self.control_model = self.control_model.cpu()
516
+ self.first_stage_model = self.first_stage_model.cuda()
517
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/diffsteg.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import einops
3
+ import torch
4
+ import torch as th
5
+ import torch.nn as nn
6
+ from torch.nn import functional as thf
7
+ import torchvision
8
+ from ldm.modules.diffusionmodules.util import (
9
+ conv_nd,
10
+ linear,
11
+ zero_module,
12
+ timestep_embedding,
13
+ )
14
+
15
+ from einops import rearrange, repeat
16
+ from torchvision.utils import make_grid
17
+ from ldm.modules.attention import SpatialTransformer
18
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
19
+ from ldm.models.diffusion.ddpm import LatentDiffusion
20
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
21
+ from ldm.models.diffusion.ddim import DDIMSampler
22
+
23
+
24
+ # class CUNetModel(nn.Module):
25
+ # def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
26
+ # hs = []
27
+ # with torch.no_grad():
28
+ # t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
29
+ # emb = self.time_embed(t_emb)
30
+
31
+ # h = x.type(self.dtype)
32
+ # for module in self.input_blocks:
33
+ # h = module(h, emb, context)
34
+ # hs.append(h)
35
+
36
+ # h = self.middle_block(h, emb, context)
37
+ # h += control.pop(0)
38
+ # for module in self.output_blocks:
39
+ # if only_mid_control:
40
+ # h = th.cat([h, hs.pop()], dim=1)
41
+ # else:
42
+ # h = torch.cat([h, hs.pop() + control.pop(0)], dim=1)
43
+ # h = module(h, emb, context)
44
+ # h = h.type(x.dtype)
45
+ # return self.out(h)
46
+
47
+ class SecretNet(nn.Module):
48
+ def __init__(
49
+ self,
50
+ image_size,
51
+ in_channels,
52
+ model_channels,
53
+ hint_channels,
54
+ num_res_blocks,
55
+ attention_resolutions,
56
+ dropout=0,
57
+ channel_mult=(1, 2, 4, 8),
58
+ conv_resample=True,
59
+ dims=2,
60
+ use_checkpoint=False,
61
+ use_fp16=False,
62
+ num_heads=-1,
63
+ num_head_channels=-1,
64
+ num_heads_upsample=-1,
65
+ use_scale_shift_norm=False,
66
+ resblock_updown=False,
67
+ use_new_attention_order=False,
68
+ use_spatial_transformer=False, # custom transformer support
69
+ transformer_depth=1, # custom transformer support
70
+ context_dim=None, # custom transformer support
71
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
72
+ legacy=True,
73
+ disable_self_attentions=None,
74
+ num_attention_blocks=None,
75
+ disable_middle_self_attn=False,
76
+ use_linear_in_transformer=False,
77
+ secret_len = 0,
78
+ ):
79
+ super().__init__()
80
+ if use_spatial_transformer:
81
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
82
+
83
+ if context_dim is not None:
84
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
85
+ from omegaconf.listconfig import ListConfig
86
+ if type(context_dim) == ListConfig:
87
+ context_dim = list(context_dim)
88
+
89
+ if num_heads_upsample == -1:
90
+ num_heads_upsample = num_heads
91
+
92
+ if num_heads == -1:
93
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
94
+
95
+ if num_head_channels == -1:
96
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
97
+
98
+ self.dims = dims
99
+ self.image_size = image_size
100
+ self.in_channels = in_channels
101
+ self.model_channels = model_channels
102
+ if isinstance(num_res_blocks, int):
103
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
104
+ else:
105
+ if len(num_res_blocks) != len(channel_mult):
106
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
107
+ "as a list/tuple (per-level) with the same length as channel_mult")
108
+ self.num_res_blocks = num_res_blocks
109
+ if disable_self_attentions is not None:
110
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
111
+ assert len(disable_self_attentions) == len(channel_mult)
112
+ if num_attention_blocks is not None:
113
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
114
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
115
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
116
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
117
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
118
+ f"attention will still not be set.")
119
+
120
+ self.attention_resolutions = attention_resolutions
121
+ self.dropout = dropout
122
+ self.channel_mult = channel_mult
123
+ self.conv_resample = conv_resample
124
+ self.use_checkpoint = use_checkpoint
125
+ self.dtype = th.float16 if use_fp16 else th.float32
126
+ self.num_heads = num_heads
127
+ self.num_head_channels = num_head_channels
128
+ self.num_heads_upsample = num_heads_upsample
129
+ self.predict_codebook_ids = n_embed is not None
130
+
131
+ time_embed_dim = model_channels * 4
132
+ self.time_embed = nn.Sequential(
133
+ linear(model_channels, time_embed_dim),
134
+ nn.SiLU(),
135
+ linear(time_embed_dim, time_embed_dim),
136
+ )
137
+
138
+ # self.input_blocks = nn.ModuleList(
139
+ # [
140
+ # TimestepEmbedSequential(
141
+ # conv_nd(dims, in_channels, model_channels, 3, padding=1)
142
+ # )
143
+ # ]
144
+ # )
145
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
146
+ self.secret_len = secret_len
147
+ if secret_len > 0: # TODO: update for dec
148
+ log_resolution = int(np.log2(64))
149
+ self.input_hint_block = TimestepEmbedSequential(
150
+ nn.Linear(secret_len, 16*16*4),
151
+ nn.SiLU(),
152
+ View(-1, 4, 16, 16),
153
+ nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
154
+ conv_nd(dims, 4, 64, 3, padding=1),
155
+ nn.SiLU(),
156
+ conv_nd(dims, 64, 256, 3, padding=1),
157
+ nn.SiLU(),
158
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
159
+ )
160
+
161
+ self._feature_size = model_channels
162
+ input_block_chans = [model_channels]
163
+ ch = model_channels
164
+ ds = 1
165
+ for level, mult in enumerate(channel_mult):
166
+ for nr in range(self.num_res_blocks[level]):
167
+ layers = []
168
+ ch = mult * model_channels
169
+ if ds in attention_resolutions:
170
+ if num_head_channels == -1:
171
+ dim_head = ch // num_heads
172
+ else:
173
+ num_heads = ch // num_head_channels
174
+ dim_head = num_head_channels
175
+ if legacy:
176
+ #num_heads = 1
177
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
178
+ if exists(disable_self_attentions):
179
+ disabled_sa = disable_self_attentions[level]
180
+ else:
181
+ disabled_sa = False
182
+
183
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
184
+ layers.append(0)
185
+ # self.input_blocks.append(TimestepEmbedSequential(*layers))
186
+ # self.zero_convs.append(self.make_zero_conv(ch))
187
+ self._feature_size += ch
188
+ input_block_chans.append(ch)
189
+ if level != len(channel_mult) - 1:
190
+ out_ch = ch
191
+ self.input_blocks.append(
192
+ 0
193
+ )
194
+ ch = out_ch
195
+ input_block_chans.append(ch)
196
+ # self.zero_convs.append(self.make_zero_conv(ch))
197
+ ds *= 2
198
+ self._feature_size += ch
199
+
200
+ if num_head_channels == -1:
201
+ dim_head = ch // num_heads
202
+ else:
203
+ num_heads = ch // num_head_channels
204
+ dim_head = num_head_channels
205
+ if legacy:
206
+ #num_heads = 1
207
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
208
+ self.middle_block = TimestepEmbedSequential(
209
+ ResBlock(
210
+ ch,
211
+ time_embed_dim,
212
+ dropout,
213
+ dims=dims,
214
+ use_checkpoint=use_checkpoint,
215
+ use_scale_shift_norm=use_scale_shift_norm,
216
+ ),
217
+ AttentionBlock(
218
+ ch,
219
+ use_checkpoint=use_checkpoint,
220
+ num_heads=num_heads,
221
+ num_head_channels=dim_head,
222
+ use_new_attention_order=use_new_attention_order,
223
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
224
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
225
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
226
+ use_checkpoint=use_checkpoint
227
+ ),
228
+ ResBlock(
229
+ ch,
230
+ time_embed_dim,
231
+ dropout,
232
+ dims=dims,
233
+ use_checkpoint=use_checkpoint,
234
+ use_scale_shift_norm=use_scale_shift_norm,
235
+ ),
236
+ )
237
+ self.middle_block_out = self.make_zero_conv(ch)
238
+ self._feature_size += ch
239
+
240
+ def make_zero_conv(self, channels):
241
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
242
+
243
+ def forward(self, x, hint, timesteps, context, **kwargs):
244
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
245
+ emb = self.time_embed(t_emb)
246
+ guided_hint = self.input_hint_block(hint, emb, context)
247
+ # import pdb; pdb.set_trace()
248
+ outs = []
249
+
250
+ h = x.type(self.dtype)
251
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
252
+ if guided_hint is not None:
253
+ h = module(h, emb, context)
254
+ h += guided_hint
255
+ guided_hint = None
256
+ else:
257
+ h = module(h, emb, context)
258
+ outs.append(zero_conv(h, emb, context))
259
+
260
+ h = self.middle_block(h, emb, context)
261
+ outs.append(self.middle_block_out(h, emb, context))
262
+
263
+ return outs
264
+
265
+ class ControlledUnetModel(UNetModel):
266
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
267
+ hs = []
268
+ with torch.no_grad():
269
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
270
+ emb = self.time_embed(t_emb)
271
+ h = x.type(self.dtype)
272
+ for module in self.input_blocks:
273
+ h = module(h, emb, context)
274
+ hs.append(h)
275
+ h = self.middle_block(h, emb, context)
276
+
277
+ h += control.pop()
278
+
279
+ for i, module in enumerate(self.output_blocks):
280
+ if only_mid_control:
281
+ h = torch.cat([h, hs.pop()], dim=1)
282
+ else:
283
+ h = torch.cat([h, hs.pop() + control.pop()], dim=1)
284
+ h = module(h, emb, context)
285
+
286
+ h = h.type(x.dtype)
287
+ return self.out(h)
288
+
289
+ class View(nn.Module):
290
+ def __init__(self, *shape):
291
+ super().__init__()
292
+ self.shape = shape
293
+
294
+ def forward(self, x):
295
+ return x.view(*self.shape)
296
+
297
+ class ControlNet(nn.Module):
298
+ def __init__(
299
+ self,
300
+ image_size,
301
+ in_channels,
302
+ model_channels,
303
+ hint_channels,
304
+ num_res_blocks,
305
+ attention_resolutions,
306
+ dropout=0,
307
+ channel_mult=(1, 2, 4, 8),
308
+ conv_resample=True,
309
+ dims=2,
310
+ use_checkpoint=False,
311
+ use_fp16=False,
312
+ num_heads=-1,
313
+ num_head_channels=-1,
314
+ num_heads_upsample=-1,
315
+ use_scale_shift_norm=False,
316
+ resblock_updown=False,
317
+ use_new_attention_order=False,
318
+ use_spatial_transformer=False, # custom transformer support
319
+ transformer_depth=1, # custom transformer support
320
+ context_dim=None, # custom transformer support
321
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
322
+ legacy=True,
323
+ disable_self_attentions=None,
324
+ num_attention_blocks=None,
325
+ disable_middle_self_attn=False,
326
+ use_linear_in_transformer=False,
327
+ secret_len = 0,
328
+ ):
329
+ super().__init__()
330
+ if use_spatial_transformer:
331
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
332
+
333
+ if context_dim is not None:
334
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
335
+ from omegaconf.listconfig import ListConfig
336
+ if type(context_dim) == ListConfig:
337
+ context_dim = list(context_dim)
338
+
339
+ if num_heads_upsample == -1:
340
+ num_heads_upsample = num_heads
341
+
342
+ if num_heads == -1:
343
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
344
+
345
+ if num_head_channels == -1:
346
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
347
+
348
+ self.dims = dims
349
+ self.image_size = image_size
350
+ self.in_channels = in_channels
351
+ self.model_channels = model_channels
352
+ if isinstance(num_res_blocks, int):
353
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
354
+ else:
355
+ if len(num_res_blocks) != len(channel_mult):
356
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
357
+ "as a list/tuple (per-level) with the same length as channel_mult")
358
+ self.num_res_blocks = num_res_blocks
359
+ if disable_self_attentions is not None:
360
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
361
+ assert len(disable_self_attentions) == len(channel_mult)
362
+ if num_attention_blocks is not None:
363
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
364
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
365
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
366
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
367
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
368
+ f"attention will still not be set.")
369
+
370
+ self.attention_resolutions = attention_resolutions
371
+ self.dropout = dropout
372
+ self.channel_mult = channel_mult
373
+ self.conv_resample = conv_resample
374
+ self.use_checkpoint = use_checkpoint
375
+ self.dtype = th.float16 if use_fp16 else th.float32
376
+ self.num_heads = num_heads
377
+ self.num_head_channels = num_head_channels
378
+ self.num_heads_upsample = num_heads_upsample
379
+ self.predict_codebook_ids = n_embed is not None
380
+
381
+ time_embed_dim = model_channels * 4
382
+ self.time_embed = nn.Sequential(
383
+ linear(model_channels, time_embed_dim),
384
+ nn.SiLU(),
385
+ linear(time_embed_dim, time_embed_dim),
386
+ )
387
+
388
+ self.input_blocks = nn.ModuleList(
389
+ [
390
+ TimestepEmbedSequential(
391
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
392
+ )
393
+ ]
394
+ )
395
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
396
+ self.secret_len = secret_len
397
+ if secret_len > 0:
398
+ log_resolution = int(np.log2(64))
399
+ self.input_hint_block = TimestepEmbedSequential(
400
+ nn.Linear(secret_len, 16*16*4),
401
+ nn.SiLU(),
402
+ View(-1, 4, 16, 16),
403
+ nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
404
+ conv_nd(dims, 4, 64, 3, padding=1),
405
+ nn.SiLU(),
406
+ conv_nd(dims, 64, 256, 3, padding=1),
407
+ nn.SiLU(),
408
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
409
+ )
410
+ else:
411
+ self.input_hint_block = TimestepEmbedSequential(
412
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
413
+ nn.SiLU(),
414
+ conv_nd(dims, 16, 16, 3, padding=1),
415
+ nn.SiLU(),
416
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
417
+ nn.SiLU(),
418
+ conv_nd(dims, 32, 32, 3, padding=1),
419
+ nn.SiLU(),
420
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
421
+ nn.SiLU(),
422
+ conv_nd(dims, 96, 96, 3, padding=1),
423
+ nn.SiLU(),
424
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
425
+ nn.SiLU(),
426
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
427
+ )
428
+
429
+ self._feature_size = model_channels
430
+ input_block_chans = [model_channels]
431
+ ch = model_channels
432
+ ds = 1
433
+ for level, mult in enumerate(channel_mult):
434
+ for nr in range(self.num_res_blocks[level]):
435
+ layers = [
436
+ ResBlock(
437
+ ch,
438
+ time_embed_dim,
439
+ dropout,
440
+ out_channels=mult * model_channels,
441
+ dims=dims,
442
+ use_checkpoint=use_checkpoint,
443
+ use_scale_shift_norm=use_scale_shift_norm,
444
+ )
445
+ ]
446
+ ch = mult * model_channels
447
+ if ds in attention_resolutions:
448
+ if num_head_channels == -1:
449
+ dim_head = ch // num_heads
450
+ else:
451
+ num_heads = ch // num_head_channels
452
+ dim_head = num_head_channels
453
+ if legacy:
454
+ #num_heads = 1
455
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
456
+ if exists(disable_self_attentions):
457
+ disabled_sa = disable_self_attentions[level]
458
+ else:
459
+ disabled_sa = False
460
+
461
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
462
+ layers.append(
463
+ AttentionBlock(
464
+ ch,
465
+ use_checkpoint=use_checkpoint,
466
+ num_heads=num_heads,
467
+ num_head_channels=dim_head,
468
+ use_new_attention_order=use_new_attention_order,
469
+ ) if not use_spatial_transformer else SpatialTransformer(
470
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
471
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
472
+ use_checkpoint=use_checkpoint
473
+ )
474
+ )
475
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
476
+ self.zero_convs.append(self.make_zero_conv(ch))
477
+ self._feature_size += ch
478
+ input_block_chans.append(ch)
479
+ if level != len(channel_mult) - 1:
480
+ out_ch = ch
481
+ self.input_blocks.append(
482
+ TimestepEmbedSequential(
483
+ ResBlock(
484
+ ch,
485
+ time_embed_dim,
486
+ dropout,
487
+ out_channels=out_ch,
488
+ dims=dims,
489
+ use_checkpoint=use_checkpoint,
490
+ use_scale_shift_norm=use_scale_shift_norm,
491
+ down=True,
492
+ )
493
+ if resblock_updown
494
+ else Downsample(
495
+ ch, conv_resample, dims=dims, out_channels=out_ch
496
+ )
497
+ )
498
+ )
499
+ ch = out_ch
500
+ input_block_chans.append(ch)
501
+ self.zero_convs.append(self.make_zero_conv(ch))
502
+ ds *= 2
503
+ self._feature_size += ch
504
+
505
+ if num_head_channels == -1:
506
+ dim_head = ch // num_heads
507
+ else:
508
+ num_heads = ch // num_head_channels
509
+ dim_head = num_head_channels
510
+ if legacy:
511
+ #num_heads = 1
512
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
513
+ self.middle_block = TimestepEmbedSequential(
514
+ ResBlock(
515
+ ch,
516
+ time_embed_dim,
517
+ dropout,
518
+ dims=dims,
519
+ use_checkpoint=use_checkpoint,
520
+ use_scale_shift_norm=use_scale_shift_norm,
521
+ ),
522
+ AttentionBlock(
523
+ ch,
524
+ use_checkpoint=use_checkpoint,
525
+ num_heads=num_heads,
526
+ num_head_channels=dim_head,
527
+ use_new_attention_order=use_new_attention_order,
528
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
529
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
530
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
531
+ use_checkpoint=use_checkpoint
532
+ ),
533
+ ResBlock(
534
+ ch,
535
+ time_embed_dim,
536
+ dropout,
537
+ dims=dims,
538
+ use_checkpoint=use_checkpoint,
539
+ use_scale_shift_norm=use_scale_shift_norm,
540
+ ),
541
+ )
542
+ self.middle_block_out = self.make_zero_conv(ch)
543
+ self._feature_size += ch
544
+
545
+ def make_zero_conv(self, channels):
546
+ return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
547
+
548
+ def forward(self, x, hint, timesteps, context, **kwargs):
549
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
550
+ emb = self.time_embed(t_emb)
551
+ guided_hint = self.input_hint_block(hint, emb, context)
552
+ # import pdb; pdb.set_trace()
553
+ outs = []
554
+
555
+ h = x.type(self.dtype)
556
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
557
+ if guided_hint is not None:
558
+ h = module(h, emb, context)
559
+ h += guided_hint
560
+ guided_hint = None
561
+ else:
562
+ h = module(h, emb, context)
563
+ outs.append(zero_conv(h, emb, context))
564
+
565
+ h = self.middle_block(h, emb, context)
566
+ outs.append(self.middle_block_out(h, emb, context))
567
+
568
+ return outs
569
+
570
+
571
+ class SecretDecoder(nn.Module):
572
+ def __init__(self, arch='resnet50', secret_len=100):
573
+ super().__init__()
574
+ self.arch = arch
575
+ print(f'SecretDecoder arch: {arch}')
576
+ self.resolution = 224
577
+ if arch == 'resnet50':
578
+ self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
579
+ self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
580
+ elif arch == 'resnet18':
581
+ self.decoder = torchvision.models.resnet18(pretrained=True, progress=False)
582
+ self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
583
+ else:
584
+ raise NotImplementedError
585
+
586
+ def forward(self, image):
587
+ if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution:
588
+ image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False)
589
+ x = self.decoder(image)
590
+ return x
591
+
592
+
593
+ class ControlLDM(LatentDiffusion):
594
+
595
+ def __init__(self, control_stage_config, control_key, only_mid_control, secret_decoder_config, *args, **kwargs):
596
+ super().__init__(*args, **kwargs)
597
+ self.control_model = instantiate_from_config(control_stage_config)
598
+ self.control_key = control_key
599
+ self.only_mid_control = only_mid_control
600
+
601
+ self.secret_decoder = None if secret_decoder_config == 'none' else instantiate_from_config(secret_decoder_config)
602
+ self.secret_loss_layer = nn.BCEWithLogitsLoss()
603
+
604
+ @torch.no_grad()
605
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
606
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
607
+ control = batch[self.control_key]
608
+ if bs is not None:
609
+ control = control[:bs]
610
+ control = control.to(self.device)
611
+ if self.control_key == 'hint':
612
+ control = einops.rearrange(control, 'b h w c -> b c h w')
613
+ control = control.to(memory_format=torch.contiguous_format).float()
614
+ return x, dict(c_crossattn=[c], c_concat=[control])
615
+
616
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
617
+ assert isinstance(cond, dict)
618
+ diffusion_model = self.model.diffusion_model
619
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
620
+ cond_hint = torch.cat(cond['c_concat'], 1)
621
+
622
+ control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
623
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
624
+
625
+ return eps
626
+
627
+ def p_losses(self, x_start, cond, t, noise=None):
628
+ noise = default(noise, lambda: torch.randn_like(x_start))
629
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
630
+ model_output = self.apply_model(x_noisy, t, cond)
631
+ loss_dict = {}
632
+ prefix = 'train' if self.training else 'val'
633
+
634
+ if self.parameterization == "x0":
635
+ target = x_start
636
+ x_recon = model_output
637
+ elif self.parameterization == "eps":
638
+ target = noise
639
+ x_recon = self.predict_start_from_noise(x_noisy, t, noise=model_output)
640
+ elif self.parameterization == "v":
641
+ target = self.get_v(x_start, noise, t)
642
+ else:
643
+ raise NotImplementedError()
644
+
645
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
646
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
647
+
648
+ logvar_t = self.logvar[t].to(self.device)
649
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
650
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
651
+ if self.learn_logvar:
652
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
653
+ loss_dict.update({'logvar': self.logvar.data.mean()})
654
+
655
+ loss = self.l_simple_weight * loss.mean()
656
+
657
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
658
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
659
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
660
+ loss += (self.original_elbo_weight * loss_vlb)
661
+ # secret decode
662
+ if self.secret_decoder is not None:
663
+ simple_loss_weight = 0.1
664
+ x_recon = self.differentiable_decode_first_stage(x_recon)
665
+ secret_pred = self.secret_decoder(x_recon)
666
+ secret = cond['c_concat'][0]
667
+ loss_secret = self.secret_loss_layer(secret_pred, secret)
668
+ bit_acc = ((secret_pred.detach() > 0).float() == secret).float().mean()
669
+ loss_dict.update({f'{prefix}/bit_acc': bit_acc})
670
+ loss_dict.update({f'{prefix}/loss_secret': loss_secret})
671
+ loss = (loss*simple_loss_weight + loss_secret) / (simple_loss_weight + 1)
672
+
673
+ loss_dict.update({f'{prefix}/loss': loss})
674
+ return loss, loss_dict
675
+
676
+ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
677
+ if predict_cids:
678
+ if z.dim() == 4:
679
+ z = torch.argmax(z.exp(), dim=1).long()
680
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
681
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
682
+
683
+ z = 1. / self.scale_factor * z
684
+ return self.first_stage_model.decode(z)
685
+
686
+ @torch.no_grad()
687
+ def get_unconditional_conditioning(self, N):
688
+ return self.get_learned_conditioning([""] * N)
689
+
690
+ @torch.no_grad()
691
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
692
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
693
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
694
+ use_ema_scope=True,
695
+ **kwargs):
696
+ use_ddim = ddim_steps is not None
697
+
698
+ log = dict()
699
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
700
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
701
+ N = min(z.shape[0], N)
702
+ n_row = min(z.shape[0], n_row)
703
+ log["reconstruction"] = self.decode_first_stage(z)
704
+ # log["control"] = c_cat * 2.0 - 1.0
705
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
706
+
707
+ if plot_diffusion_rows:
708
+ # get diffusion row
709
+ diffusion_row = list()
710
+ z_start = z[:n_row]
711
+ for t in range(self.num_timesteps):
712
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
713
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
714
+ t = t.to(self.device).long()
715
+ noise = torch.randn_like(z_start)
716
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
717
+ diffusion_row.append(self.decode_first_stage(z_noisy))
718
+
719
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
720
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
721
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
722
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
723
+ log["diffusion_row"] = diffusion_grid
724
+
725
+ if sample:
726
+ # get denoise row
727
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
728
+ batch_size=N, ddim=use_ddim,
729
+ ddim_steps=ddim_steps, eta=ddim_eta)
730
+ x_samples = self.decode_first_stage(samples)
731
+ log["samples"] = x_samples
732
+ if plot_denoise_rows:
733
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
734
+ log["denoise_row"] = denoise_grid
735
+ # import pudb; pudb.set_trace()
736
+ if unconditional_guidance_scale > 1.0:
737
+ uc_cross = self.get_unconditional_conditioning(N)
738
+ uc_cat = c_cat # torch.zeros_like(c_cat)
739
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
740
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
741
+ batch_size=N, ddim=use_ddim,
742
+ ddim_steps=ddim_steps, eta=ddim_eta,
743
+ unconditional_guidance_scale=unconditional_guidance_scale,
744
+ unconditional_conditioning=uc_full,
745
+ )
746
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
747
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
748
+
749
+ return log
750
+
751
+ @torch.no_grad()
752
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
753
+ ddim_sampler = DDIMSampler(self)
754
+ # import pdb; pdb.set_trace()
755
+ # b, c, h, w = cond["c_concat"][0].shape
756
+ b, c, h, w = cond["c_concat"][0].shape[0], self.channels, self.image_size*8, self.image_size*8
757
+ shape = (self.channels, h // 8, w // 8)
758
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
759
+ return samples, intermediates
760
+
761
+ def configure_optimizers(self):
762
+ lr = self.learning_rate
763
+ params = list(self.control_model.parameters())
764
+ if self.secret_decoder is not None:
765
+ params += list(self.secret_decoder.parameters())
766
+ if not self.sd_locked:
767
+ params += list(self.model.diffusion_model.output_blocks.parameters())
768
+ params += list(self.model.diffusion_model.out.parameters())
769
+ opt = torch.optim.AdamW(params, lr=lr)
770
+ return opt
771
+
772
+ def low_vram_shift(self, is_diffusing):
773
+ if is_diffusing:
774
+ self.model = self.model.cuda()
775
+ self.control_model = self.control_model.cuda()
776
+ self.first_stage_model = self.first_stage_model.cpu()
777
+ self.cond_stage_model = self.cond_stage_model.cpu()
778
+ else:
779
+ self.model = self.model.cpu()
780
+ self.control_model = self.control_model.cpu()
781
+ self.first_stage_model = self.first_stage_model.cuda()
782
+ self.cond_stage_model = self.cond_stage_model.cuda()
cldm/hack.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+ import warnings
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ warnings.filterwarnings(action='ignore', category=DeprecationWarning)
14
+ warnings.filterwarnings(action='ignore', category=UserWarning)
15
+ print('logging improved.')
16
+ return
17
+
18
+
19
+ def enable_sliced_attention():
20
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
21
+ print('Enabled sliced_attention.')
22
+ return
23
+
24
+
25
+ def hack_everything(clip_skip=0):
26
+ disable_verbosity()
27
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
28
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
29
+ print('Enabled clip hacks.')
30
+ return
31
+
32
+
33
+ # Written by Lvmin
34
+ def _hacked_clip_forward(self, text):
35
+ PAD = self.tokenizer.pad_token_id
36
+ EOS = self.tokenizer.eos_token_id
37
+ BOS = self.tokenizer.bos_token_id
38
+
39
+ def tokenize(t):
40
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
41
+
42
+ def transformer_encode(t):
43
+ if self.clip_skip > 1:
44
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
45
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
46
+ else:
47
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
48
+
49
+ def split(x):
50
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
51
+
52
+ def pad(x, p, i):
53
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
54
+
55
+ raw_tokens_list = tokenize(text)
56
+ tokens_list = []
57
+
58
+ for raw_tokens in raw_tokens_list:
59
+ raw_tokens_123 = split(raw_tokens)
60
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
61
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
62
+ tokens_list.append(raw_tokens_123)
63
+
64
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
65
+
66
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
67
+ y = transformer_encode(feed)
68
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
69
+
70
+ return z
71
+
72
+
73
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
74
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
75
+ h = self.heads
76
+
77
+ q = self.to_q(x)
78
+ context = default(context, x)
79
+ k = self.to_k(context)
80
+ v = self.to_v(context)
81
+ del context, x
82
+
83
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
84
+
85
+ limit = k.shape[0]
86
+ att_step = 1
87
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
88
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
89
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
90
+
91
+ q_chunks.reverse()
92
+ k_chunks.reverse()
93
+ v_chunks.reverse()
94
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
95
+ del k, q, v
96
+ for i in range(0, limit, att_step):
97
+ q_buffer = q_chunks.pop()
98
+ k_buffer = k_chunks.pop()
99
+ v_buffer = v_chunks.pop()
100
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
101
+
102
+ del k_buffer, q_buffer
103
+ # attention, what we cannot get enough of, by chunks
104
+
105
+ sim_buffer = sim_buffer.softmax(dim=-1)
106
+
107
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
108
+ del v_buffer
109
+ sim[i:i + att_step, :, :] = sim_buffer
110
+
111
+ del sim_buffer
112
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
113
+ return self.to_out(sim)
cldm/logger.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf
3
+ import numpy as np
4
+ import torch
5
+ import torchvision
6
+ from PIL import Image
7
+ from pytorch_lightning.callbacks import Callback
8
+ from pytorch_lightning.utilities.distributed import rank_zero_only
9
+ from pytorch_lightning.utilities import rank_zero_info
10
+ import time
11
+
12
+
13
+ class CUDACallback(Callback):
14
+ # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
15
+ def on_train_epoch_start(self, trainer, pl_module):
16
+ # Reset the memory use counter
17
+ torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
18
+ torch.cuda.synchronize(trainer.root_gpu)
19
+ self.start_time = time.time()
20
+
21
+ def on_train_epoch_end(self, trainer, pl_module, outputs):
22
+ torch.cuda.synchronize(trainer.root_gpu)
23
+ max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
24
+ epoch_time = (time.time() - self.start_time)/3600
25
+
26
+ try:
27
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
28
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
29
+
30
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} hours")
31
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
32
+ except AttributeError:
33
+ pass
34
+
35
+
36
+ class SetupCallback(Callback):
37
+ def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
38
+ super().__init__()
39
+ self.resume = resume
40
+ self.now = now
41
+ self.logdir = logdir
42
+ self.ckptdir = ckptdir
43
+ self.cfgdir = cfgdir
44
+ self.config = config
45
+ self.lightning_config = lightning_config
46
+
47
+ def on_keyboard_interrupt(self, trainer, pl_module):
48
+ if trainer.global_rank == 0:
49
+ print("Summoning checkpoint.")
50
+ ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
51
+ trainer.save_checkpoint(ckpt_path)
52
+
53
+ def on_pretrain_routine_start(self, trainer, pl_module):
54
+ if trainer.global_rank == 0:
55
+ # Create logdirs and save configs
56
+ os.makedirs(self.logdir, exist_ok=True)
57
+ os.makedirs(self.ckptdir, exist_ok=True)
58
+ os.makedirs(self.cfgdir, exist_ok=True)
59
+
60
+ if "callbacks" in self.lightning_config:
61
+ if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
62
+ os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
63
+ print("Project config")
64
+ print(OmegaConf.to_yaml(self.config))
65
+ OmegaConf.save(self.config,
66
+ os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
67
+
68
+ print("Lightning config")
69
+ print(OmegaConf.to_yaml(self.lightning_config))
70
+ OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
71
+ os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
72
+
73
+ else:
74
+ # ModelCheckpoint callback created log directory --- remove it
75
+ if not self.resume and os.path.exists(self.logdir):
76
+ dst, name = os.path.split(self.logdir)
77
+ dst = os.path.join(dst, "child_runs", name)
78
+ os.makedirs(os.path.split(dst)[0], exist_ok=True)
79
+ try:
80
+ os.rename(self.logdir, dst)
81
+ except FileNotFoundError:
82
+ pass
83
+
84
+ class ImageLogger(Callback):
85
+ def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
86
+ rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
87
+ log_images_kwargs=None, fixed_input=False):
88
+ super().__init__()
89
+ self.rescale = rescale
90
+ self.batch_freq = batch_frequency
91
+ self.max_images = max_images
92
+ if not increase_log_steps:
93
+ self.log_steps = [self.batch_freq]
94
+ self.clamp = clamp
95
+ self.disabled = disabled
96
+ self.log_on_batch_idx = log_on_batch_idx
97
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
98
+ self.log_first_step = log_first_step
99
+ self.fixed_input = fixed_input
100
+
101
+ @rank_zero_only
102
+ def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
103
+ root = os.path.join(save_dir, "image_log", split)
104
+ for k in images:
105
+ grid = torchvision.utils.make_grid(images[k], nrow=4)
106
+ if self.rescale:
107
+ grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
108
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
109
+ grid = grid.numpy()
110
+ grid = (grid * 255).astype(np.uint8)
111
+ filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
112
+ path = os.path.join(root, filename)
113
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
114
+ Image.fromarray(grid).save(path)
115
+
116
+ def log_img(self, pl_module, batch, batch_idx, split="train"):
117
+ check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
118
+ if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
119
+ hasattr(pl_module, "log_images") and
120
+ callable(pl_module.log_images) and
121
+ self.max_images > 0):
122
+ logger = type(pl_module.logger)
123
+
124
+ is_train = pl_module.training
125
+ if is_train:
126
+ pl_module.eval()
127
+
128
+ with torch.no_grad():
129
+ images = pl_module.log_images(batch, fixed_input=self.fixed_input, split=split, **self.log_images_kwargs)
130
+
131
+ for k in images:
132
+ N = min(images[k].shape[0], self.max_images)
133
+ images[k] = images[k][:N]
134
+ if isinstance(images[k], torch.Tensor):
135
+ images[k] = images[k].detach().cpu()
136
+ if self.clamp:
137
+ images[k] = torch.clamp(images[k], -1., 1.)
138
+ self.log_local(pl_module.logger.save_dir, split, images,
139
+ pl_module.global_step, pl_module.current_epoch, batch_idx)
140
+
141
+ if is_train:
142
+ pl_module.train()
143
+
144
+ def check_frequency(self, check_idx):
145
+ return check_idx % self.batch_freq == 0
146
+
147
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
148
+ if not self.disabled:
149
+ self.log_img(pl_module, batch, batch_idx, split="train")
cldm/loss.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from lpips import LPIPS
4
+ from kornia import color
5
+ # from taming.modules.losses.vqperceptual import *
6
+
7
+ class ImageSecretLoss(nn.Module):
8
+ def __init__(self, recon_type='rgb', recon_weight=1., perceptual_weight=1.0, secret_weight=10., kl_weight=0.000001, logvar_init=0.0, ramp=100000, max_image_weight_ratio=2.) -> None:
9
+ super().__init__()
10
+ self.recon_type = recon_type
11
+ assert recon_type in ['rgb', 'yuv']
12
+ if recon_type == 'yuv':
13
+ self.register_buffer('yuv_scales', torch.tensor([1,100,100]).unsqueeze(1).float()) # [3,1]
14
+ self.recon_weight = recon_weight
15
+ self.perceptual_weight = perceptual_weight
16
+ self.secret_weight = secret_weight
17
+ self.kl_weight = kl_weight
18
+
19
+ self.ramp = ramp
20
+ self.max_image_weight = max_image_weight_ratio * secret_weight - 1
21
+ self.register_buffer('ramp_on', torch.tensor(False))
22
+ self.register_buffer('step0', torch.tensor(1e9)) # large number
23
+
24
+ self.perceptual_loss = LPIPS().eval()
25
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
26
+ self.bce = nn.BCEWithLogitsLoss(reduction="none")
27
+
28
+ def activate_ramp(self, global_step):
29
+ if not self.ramp_on: # do not activate ramp twice
30
+ self.step0 = torch.tensor(global_step)
31
+ self.ramp_on = ~self.ramp_on
32
+ print('[TRAINING] Activate ramp for image loss at step ', global_step)
33
+
34
+ def compute_recon_loss(self, inputs, reconstructions):
35
+ if self.recon_type == 'rgb':
36
+ rec_loss = torch.abs(inputs - reconstructions).mean(dim=[1,2,3])
37
+ elif self.recon_type == 'yuv':
38
+ reconstructions_yuv = color.rgb_to_yuv((reconstructions + 1) / 2)
39
+ inputs_yuv = color.rgb_to_yuv((inputs + 1) / 2)
40
+ yuv_loss = torch.mean((reconstructions_yuv - inputs_yuv)**2, dim=[2,3])
41
+ rec_loss = torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
42
+ else:
43
+ raise ValueError(f"Unknown recon type {self.recon_type}")
44
+ return rec_loss
45
+
46
+ def forward(self, inputs, reconstructions, posteriors, secret_gt, secret_pred, global_step):
47
+ loss_dict = {}
48
+ rec_loss = self.compute_recon_loss(inputs.contiguous(), reconstructions.contiguous())
49
+
50
+ loss = rec_loss*self.recon_weight
51
+
52
+ if self.perceptual_weight > 0:
53
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()).mean(dim=[1,2,3])
54
+ loss += self.perceptual_weight * p_loss
55
+ loss_dict['p_loss'] = p_loss.mean()
56
+
57
+ loss = loss / torch.exp(self.logvar) + self.logvar
58
+ if self.kl_weight > 0:
59
+ kl_loss = posteriors.kl()
60
+ loss += kl_loss*self.kl_weight
61
+ loss_dict['kl_loss'] = kl_loss.mean()
62
+
63
+ image_weight = 1 + min(self.max_image_weight, max(0., self.max_image_weight*(global_step - self.step0.item())/self.ramp))
64
+
65
+ secret_loss = self.bce(secret_pred, secret_gt).mean(dim=1)
66
+ loss = (loss*image_weight + secret_loss*self.secret_weight) / (image_weight+self.secret_weight)
67
+
68
+ # loss dict update
69
+ bit_acc = ((secret_pred.detach() > 0).float() == secret_gt).float().mean()
70
+ loss_dict['bit_acc'] = bit_acc
71
+ loss_dict['loss'] = loss.mean()
72
+ loss_dict['img_lw'] = image_weight/self.secret_weight
73
+ loss_dict['rec_loss'] = rec_loss.mean()
74
+ loss_dict['secret_loss'] = secret_loss.mean()
75
+
76
+ return loss.mean(), loss_dict
77
+
78
+
cldm/loss_weight_scheduler.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+
5
+ @author: Tu Bui @University of Surrey
6
+ """
7
+
8
+ class SimpleLossWeightScheduler(object):
9
+ def __init__(self, simple_loss_weight_max=10., wait_steps=50000, ramp=100000) -> None:
10
+ self.simple_loss_weight_max = simple_loss_weight_max
11
+ self.wait_steps = wait_steps
12
+ self.ramp = ramp
13
+
14
+ def __call__(self, step):
15
+ max_weight = self.simple_loss_weight_max - 1
16
+ w = 1 + min(max_weight, max(0., max_weight*(step - self.wait_steps)/self.ramp))
17
+ return w
cldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
cldm/plms.py ADDED
@@ -0,0 +1,1481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SAMPLING ONLY."""
2
+ import os
3
+ import torch
4
+ from torch import nn
5
+ import torchvision
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from functools import partial
9
+ from PIL import Image
10
+ import shutil
11
+
12
+ from ldm.modules.diffusionmodules.util import (
13
+ make_ddim_sampling_parameters,
14
+ make_ddim_timesteps,
15
+ noise_like,
16
+ )
17
+ import clip
18
+ from einops import rearrange
19
+ import random
20
+
21
+
22
+ class VGGPerceptualLoss(torch.nn.Module):
23
+ def __init__(self, resize=True):
24
+ super(VGGPerceptualLoss, self).__init__()
25
+ blocks = []
26
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
27
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
28
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
29
+ blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
30
+ for bl in blocks:
31
+ for p in bl.parameters():
32
+ p.requires_grad = False
33
+ self.blocks = torch.nn.ModuleList(blocks)
34
+ self.transform = torch.nn.functional.interpolate
35
+ self.resize = resize
36
+ self.register_buffer(
37
+ "mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
38
+ )
39
+ self.register_buffer(
40
+ "std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
41
+ )
42
+
43
+ def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
44
+ input = (input - self.mean) / self.std
45
+ target = (target - self.mean) / self.std
46
+ if self.resize:
47
+ input = self.transform(
48
+ input, mode="bilinear", size=(224, 224), align_corners=False
49
+ )
50
+ target = self.transform(
51
+ target, mode="bilinear", size=(224, 224), align_corners=False
52
+ )
53
+ loss = 0.0
54
+ x = input
55
+ y = target
56
+ for i, block in enumerate(self.blocks):
57
+ x = block(x)
58
+ y = block(y)
59
+ if i in feature_layers:
60
+ loss += torch.nn.functional.l1_loss(x, y)
61
+ if i in style_layers:
62
+ act_x = x.reshape(x.shape[0], x.shape[1], -1)
63
+ act_y = y.reshape(y.shape[0], y.shape[1], -1)
64
+ gram_x = act_x @ act_x.permute(0, 2, 1)
65
+ gram_y = act_y @ act_y.permute(0, 2, 1)
66
+ loss += torch.nn.functional.l1_loss(gram_x, gram_y)
67
+ return loss
68
+
69
+
70
+ class DCLIPLoss(torch.nn.Module):
71
+ def __init__(self):
72
+ super(DCLIPLoss, self).__init__()
73
+ self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
74
+ self.upsample = torch.nn.Upsample(scale_factor=7)
75
+ self.avg_pool = torch.nn.AvgPool2d(kernel_size=16)
76
+
77
+ def forward(self, image1, image2, text1, text2):
78
+ text1 = clip.tokenize([text1]).to("cuda")
79
+ text2 = clip.tokenize([text2]).to("cuda")
80
+ image1 = image1.unsqueeze(0).cuda()
81
+ image2 = image2.unsqueeze(0)
82
+ image1 = self.avg_pool(self.upsample(image1))
83
+ image2 = self.avg_pool(self.upsample(image2))
84
+ image1_feat = self.model.encode_image(image1)
85
+ image2_feat = self.model.encode_image(image2)
86
+ text1_feat = self.model.encode_text(text1)
87
+ text2_feat = self.model.encode_text(text2)
88
+ d_image_feat = image1_feat - image2_feat
89
+ d_text_feat = text1_feat - text2_feat
90
+ similarity = torch.nn.CosineSimilarity()(d_image_feat, d_text_feat)
91
+ return 1 - similarity
92
+
93
+
94
+ class PLMSSampler(object):
95
+ def __init__(self, model, schedule="linear", **kwargs):
96
+ super().__init__()
97
+ self.model = model
98
+ self.ddpm_num_timesteps = model.num_timesteps
99
+ self.schedule = schedule
100
+
101
+ def register_buffer(self, name, attr):
102
+ if type(attr) == torch.Tensor:
103
+ if attr.device != torch.device("cuda"):
104
+ attr = attr.to(torch.device("cuda"))
105
+ setattr(self, name, attr)
106
+
107
+ def make_schedule(
108
+ self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
109
+ ):
110
+ if ddim_eta != 0:
111
+ raise ValueError("ddim_eta must be 0 for PLMS")
112
+ self.ddim_timesteps = make_ddim_timesteps(
113
+ ddim_discr_method=ddim_discretize,
114
+ num_ddim_timesteps=ddim_num_steps,
115
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
116
+ verbose=verbose,
117
+ )
118
+ alphas_cumprod = self.model.alphas_cumprod
119
+ assert (
120
+ alphas_cumprod.shape[0] == self.ddpm_num_timesteps
121
+ ), "alphas have to be defined for each timestep"
122
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
123
+
124
+ self.register_buffer("betas", to_torch(self.model.betas))
125
+ self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
126
+ self.register_buffer(
127
+ "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
128
+ )
129
+
130
+ # calculations for diffusion q(x_t | x_{t-1}) and others
131
+ self.register_buffer(
132
+ "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
133
+ )
134
+ self.register_buffer(
135
+ "sqrt_one_minus_alphas_cumprod",
136
+ to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
137
+ )
138
+ self.register_buffer(
139
+ "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
140
+ )
141
+ self.register_buffer(
142
+ "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
143
+ )
144
+ self.register_buffer(
145
+ "sqrt_recipm1_alphas_cumprod",
146
+ to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
147
+ )
148
+
149
+ # ddim sampling parameters
150
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
151
+ alphacums=alphas_cumprod.cpu(),
152
+ ddim_timesteps=self.ddim_timesteps,
153
+ eta=0.0,
154
+ verbose=verbose,
155
+ )
156
+ self.register_buffer("ddim_sigmas", ddim_sigmas)
157
+ self.register_buffer("ddim_alphas", ddim_alphas)
158
+ self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
159
+ self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
160
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
161
+ (1 - self.alphas_cumprod_prev)
162
+ / (1 - self.alphas_cumprod)
163
+ * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
164
+ )
165
+ self.register_buffer(
166
+ "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
167
+ )
168
+
169
+ @torch.no_grad()
170
+ def sample(self,
171
+ S,
172
+ batch_size,
173
+ shape,
174
+ conditioning=None,
175
+ callback=None,
176
+ normals_sequence=None,
177
+ img_callback=None,
178
+ quantize_x0=False,
179
+ eta=0.,
180
+ mask=None,
181
+ x0=None,
182
+ temperature=1.,
183
+ noise_dropout=0.,
184
+ score_corrector=None,
185
+ corrector_kwargs=None,
186
+ verbose=True,
187
+ x_T=None,
188
+ log_every_t=100,
189
+ unconditional_guidance_scale=1.,
190
+ unconditional_conditioning=None,
191
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
192
+ dynamic_threshold=None,
193
+ **kwargs
194
+ ):
195
+ if conditioning is not None:
196
+ if isinstance(conditioning, dict):
197
+ cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
198
+ if cbs != batch_size:
199
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
200
+ else:
201
+ if conditioning.shape[0] != batch_size:
202
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
203
+
204
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
205
+ # sampling
206
+ C, H, W = shape
207
+ size = (batch_size, C, H, W)
208
+ print(f'Data shape for PLMS sampling is {size}')
209
+
210
+ samples, intermediates = self.plms_sampling(conditioning, size,
211
+ callback=callback,
212
+ img_callback=img_callback,
213
+ quantize_denoised=quantize_x0,
214
+ mask=mask, x0=x0,
215
+ ddim_use_original_steps=False,
216
+ noise_dropout=noise_dropout,
217
+ temperature=temperature,
218
+ score_corrector=score_corrector,
219
+ corrector_kwargs=corrector_kwargs,
220
+ x_T=x_T,
221
+ log_every_t=log_every_t,
222
+ unconditional_guidance_scale=unconditional_guidance_scale,
223
+ unconditional_conditioning=unconditional_conditioning,
224
+ )
225
+ return samples, intermediates
226
+
227
+ @torch.no_grad()
228
+ def plms_sampling(
229
+ self,
230
+ cond,
231
+ shape,
232
+ x_T=None,
233
+ ddim_use_original_steps=False,
234
+ callback=None,
235
+ timesteps=None,
236
+ quantize_denoised=False,
237
+ mask=None,
238
+ x0=None,
239
+ img_callback=None,
240
+ log_every_t=100,
241
+ temperature=1.0,
242
+ noise_dropout=0.0,
243
+ score_corrector=None,
244
+ corrector_kwargs=None,
245
+ unconditional_guidance_scale=1.0,
246
+ unconditional_conditioning=None,
247
+ ):
248
+ device = self.model.betas.device
249
+ b = shape[0]
250
+ if x_T is None:
251
+ img = torch.randn(shape, device=device)
252
+ else:
253
+ img = x_T
254
+
255
+ if timesteps is None:
256
+ timesteps = (
257
+ self.ddpm_num_timesteps
258
+ if ddim_use_original_steps
259
+ else self.ddim_timesteps
260
+ )
261
+ elif timesteps is not None and not ddim_use_original_steps:
262
+ subset_end = (
263
+ int(
264
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
265
+ * self.ddim_timesteps.shape[0]
266
+ )
267
+ - 1
268
+ )
269
+ timesteps = self.ddim_timesteps[:subset_end]
270
+
271
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
272
+ time_range = (
273
+ list(reversed(range(0, timesteps)))
274
+ if ddim_use_original_steps
275
+ else np.flip(timesteps)
276
+ )
277
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
278
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
279
+
280
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
281
+ old_eps = []
282
+
283
+ for i, step in enumerate(iterator):
284
+ index = total_steps - i - 1
285
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
286
+ ts_next = torch.full(
287
+ (b,),
288
+ time_range[min(i + 1, len(time_range) - 1)],
289
+ device=device,
290
+ dtype=torch.long,
291
+ )
292
+
293
+ if mask is not None:
294
+ assert x0 is not None
295
+ # import ipdb; ipdb.set_trace()
296
+ img_orig = self.model.q_sample(
297
+ x0, ts
298
+ ) # TODO: deterministic forward pass?
299
+ img = img_orig * mask + (1.0 - mask) * img
300
+
301
+ outs = self.p_sample_plms(
302
+ img,
303
+ cond,
304
+ ts,
305
+ index=index,
306
+ use_original_steps=ddim_use_original_steps,
307
+ quantize_denoised=quantize_denoised,
308
+ temperature=temperature,
309
+ noise_dropout=noise_dropout,
310
+ score_corrector=score_corrector,
311
+ corrector_kwargs=corrector_kwargs,
312
+ unconditional_guidance_scale=unconditional_guidance_scale,
313
+ unconditional_conditioning=unconditional_conditioning,
314
+ old_eps=old_eps,
315
+ t_next=ts_next,
316
+ )
317
+ img, pred_x0, e_t = outs
318
+ old_eps.append(e_t)
319
+ if len(old_eps) >= 4:
320
+ old_eps.pop(0)
321
+ if callback:
322
+ callback(i)
323
+ if img_callback:
324
+ img_callback(pred_x0, i)
325
+
326
+ if index % 1 == 0 or index == total_steps - 1:
327
+ intermediates["x_inter"].append(img)
328
+ intermediates["pred_x0"].append(pred_x0)
329
+
330
+ return img, intermediates
331
+
332
+ @torch.no_grad()
333
+ def p_sample_plms(
334
+ self,
335
+ x,
336
+ c,
337
+ t,
338
+ index,
339
+ repeat_noise=False,
340
+ use_original_steps=False,
341
+ quantize_denoised=False,
342
+ temperature=1.0,
343
+ noise_dropout=0.0,
344
+ score_corrector=None,
345
+ corrector_kwargs=None,
346
+ unconditional_guidance_scale=1.0,
347
+ unconditional_conditioning=None,
348
+ old_eps=None,
349
+ t_next=None,
350
+ ):
351
+ b, *_, device = *x.shape, x.device
352
+
353
+ def get_model_output(x, t):
354
+ if (
355
+ unconditional_conditioning is None
356
+ or unconditional_guidance_scale == 1.0
357
+ ):
358
+ e_t = self.model.apply_model(x, t, c)
359
+ else:
360
+ x_in = torch.cat([x] * 2)
361
+ t_in = torch.cat([t] * 2)
362
+ if isinstance(c, dict):
363
+ c_in = {key: [torch.cat([unconditional_conditioning[key][0], c[key][0]])] for key in c}
364
+ else:
365
+ c_in = torch.cat([unconditional_conditioning, c])
366
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
367
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
368
+
369
+ if score_corrector is not None:
370
+ assert self.model.parameterization == "eps"
371
+ e_t = score_corrector.modify_score(
372
+ self.model, e_t, x, t, c, **corrector_kwargs
373
+ )
374
+
375
+ return e_t
376
+
377
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
378
+ alphas_prev = (
379
+ self.model.alphas_cumprod_prev
380
+ if use_original_steps
381
+ else self.ddim_alphas_prev
382
+ )
383
+ sqrt_one_minus_alphas = (
384
+ self.model.sqrt_one_minus_alphas_cumprod
385
+ if use_original_steps
386
+ else self.ddim_sqrt_one_minus_alphas
387
+ )
388
+ sigmas = (
389
+ self.model.ddim_sigmas_for_original_num_steps
390
+ if use_original_steps
391
+ else self.ddim_sigmas
392
+ )
393
+
394
+ def get_x_prev_and_pred_x0(e_t, index):
395
+ # select parameters corresponding to the currently considered timestep
396
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
397
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
398
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
399
+ sqrt_one_minus_at = torch.full(
400
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
401
+ )
402
+
403
+ # current prediction for x_0
404
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
405
+ if quantize_denoised:
406
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
407
+ # direction pointing to x_t
408
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
409
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
410
+ if noise_dropout > 0.0:
411
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
412
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
413
+ return x_prev, pred_x0
414
+
415
+ e_t = get_model_output(x, t)
416
+ if len(old_eps) == 0:
417
+ # Pseudo Improved Euler (2nd order)
418
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
419
+ e_t_next = get_model_output(x_prev, t_next)
420
+ e_t_prime = (e_t + e_t_next) / 2
421
+ elif len(old_eps) == 1:
422
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
423
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
424
+ elif len(old_eps) == 2:
425
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
426
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
427
+ elif len(old_eps) >= 3:
428
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
429
+ e_t_prime = (
430
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
431
+ ) / 24
432
+
433
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
434
+
435
+ return x_prev, pred_x0, e_t
436
+
437
+ ###### Above are original stable-diffusion code ############
438
+
439
+ ###### Encode Image ########################################
440
+
441
+ @torch.no_grad()
442
+ def sample_encode_save_noise(
443
+ self,
444
+ S,
445
+ batch_size,
446
+ shape,
447
+ conditioning=None,
448
+ callback=None,
449
+ normals_sequence=None,
450
+ img_callback=None,
451
+ quantize_x0=False,
452
+ eta=0.0,
453
+ mask=None,
454
+ x0=None,
455
+ temperature=1.0,
456
+ noise_dropout=0.0,
457
+ score_corrector=None,
458
+ corrector_kwargs=None,
459
+ verbose=True,
460
+ x_T=None,
461
+ log_every_t=100,
462
+ unconditional_guidance_scale=1.0,
463
+ unconditional_conditioning=None,
464
+ input_image=None,
465
+ noise_save_path=None,
466
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
467
+ **kwargs,
468
+ ):
469
+ assert conditioning is not None
470
+ # assert not isinstance(conditioning, dict)
471
+
472
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
473
+ # sampling
474
+ C, H, W = shape
475
+ size = (batch_size, C, H, W)
476
+ if verbose:
477
+ print(f"Data shape for PLMS sampling is {size}")
478
+
479
+ samples, intermediates, x0_loop = self.plms_sampling_enc_save_noise(
480
+ conditioning,
481
+ size,
482
+ callback=callback,
483
+ img_callback=img_callback,
484
+ quantize_denoised=quantize_x0,
485
+ mask=mask,
486
+ x0=x0,
487
+ ddim_use_original_steps=False,
488
+ noise_dropout=noise_dropout,
489
+ temperature=temperature,
490
+ score_corrector=score_corrector,
491
+ corrector_kwargs=corrector_kwargs,
492
+ x_T=x_T,
493
+ log_every_t=log_every_t,
494
+ unconditional_guidance_scale=unconditional_guidance_scale,
495
+ unconditional_conditioning=unconditional_conditioning,
496
+ input_image=input_image,
497
+ noise_save_path=noise_save_path,
498
+ verbose=verbose
499
+ )
500
+ return samples, intermediates, x0_loop
501
+
502
+ @torch.no_grad()
503
+ def plms_sampling_enc_save_noise(
504
+ self,
505
+ cond,
506
+ shape,
507
+ x_T=None,
508
+ ddim_use_original_steps=False,
509
+ callback=None,
510
+ timesteps=None,
511
+ quantize_denoised=False,
512
+ mask=None,
513
+ x0=None,
514
+ img_callback=None,
515
+ log_every_t=100,
516
+ temperature=1.0,
517
+ noise_dropout=0.0,
518
+ score_corrector=None,
519
+ corrector_kwargs=None,
520
+ unconditional_guidance_scale=1.0,
521
+ unconditional_conditioning=None,
522
+ input_image=None,
523
+ noise_save_path=None,
524
+ verbose=True,
525
+ ):
526
+ device = self.model.betas.device
527
+
528
+ b = shape[0]
529
+ if x_T is None:
530
+ img = torch.randn(shape, device=device)
531
+ else:
532
+ img = x_T
533
+
534
+ if timesteps is None:
535
+ timesteps = (
536
+ self.ddpm_num_timesteps
537
+ if ddim_use_original_steps
538
+ else self.ddim_timesteps
539
+ )
540
+ elif timesteps is not None and not ddim_use_original_steps:
541
+ subset_end = (
542
+ int(
543
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
544
+ * self.ddim_timesteps.shape[0]
545
+ )
546
+ - 1
547
+ )
548
+ timesteps = self.ddim_timesteps[:subset_end]
549
+
550
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
551
+ time_range = (
552
+ list(reversed(range(0, timesteps)))
553
+ if ddim_use_original_steps
554
+ else np.flip(timesteps)
555
+ )
556
+ time_range = list(range(0, timesteps)) if ddim_use_original_steps else timesteps
557
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
558
+ if verbose:
559
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
560
+ iterator = tqdm(time_range[:-1], desc='PLMS Sampler', total=total_steps)
561
+ else:
562
+ iterator = time_range[:-1]
563
+ old_eps = []
564
+ noise_images = []
565
+ for each_time in time_range:
566
+ noised_image = self.model.q_sample(
567
+ input_image, torch.tensor([each_time]).to(device)
568
+ )
569
+ noise_images.append(noised_image)
570
+ # torch.save(noised_image, noise_save_path + "_image_time%d.pt" % (each_time))
571
+ # import pudb; pudb.set_trace()
572
+ x0_loop = input_image.clone()
573
+ alphas = (
574
+ self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas
575
+ )
576
+ alphas_prev = (
577
+ self.model.alphas_cumprod_prev
578
+ if ddim_use_original_steps
579
+ else self.ddim_alphas_prev
580
+ )
581
+ sqrt_one_minus_alphas = (
582
+ self.model.sqrt_one_minus_alphas_cumprod
583
+ if ddim_use_original_steps
584
+ else self.ddim_sqrt_one_minus_alphas
585
+ )
586
+ sigmas = (
587
+ self.model.ddim_sigmas_for_original_num_steps
588
+ if ddim_use_original_steps
589
+ else self.ddim_sigmas
590
+ )
591
+
592
+ def get_model_output(x, t):
593
+ x_in = torch.cat([x] * 2)
594
+ t_in = torch.cat([t] * 2)
595
+ if isinstance(cond, dict):
596
+ c_in = {key: [torch.cat([unconditional_conditioning[key][0], cond[key][0]])] for key in cond}
597
+ else:
598
+ c_in = torch.cat([unconditional_conditioning, cond])
599
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
600
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
601
+ return e_t
602
+
603
+ def get_x_prev_and_pred_x0(e_t, index, curr_x0):
604
+ # select parameters corresponding to the currently considered timestep
605
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
606
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
607
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
608
+ sqrt_one_minus_at = torch.full(
609
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
610
+ )
611
+
612
+ # current prediction for x_0
613
+ pred_x0 = (curr_x0 - sqrt_one_minus_at * e_t) / a_t.sqrt()
614
+
615
+ a_t = torch.full((b, 1, 1, 1), alphas[index + 1], device=device)
616
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index + 1], device=device)
617
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index + 1], device=device)
618
+ sqrt_one_minus_at = torch.full(
619
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index + 1], device=device
620
+ )
621
+
622
+ dir_xt = (1.0 - a_t - sigma_t ** 2).sqrt() * e_t
623
+
624
+ x_prev = a_t.sqrt() * pred_x0 + dir_xt
625
+
626
+ return x_prev, pred_x0
627
+
628
+ for i, step in enumerate(iterator):
629
+ index = i
630
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
631
+ ts_next = torch.full(
632
+ (b,),
633
+ time_range[min(i + 1, len(time_range) - 1)],
634
+ device=device,
635
+ dtype=torch.long,
636
+ )
637
+ e_t = get_model_output(x0_loop, ts)
638
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index, x0_loop)
639
+ x0_loop = x_prev
640
+ # torch.save(x0_loop, noise_save_path + "_final_latent.pt")
641
+
642
+ # Reconstruction
643
+ img = x0_loop.clone()
644
+ time_range = (
645
+ list(reversed(range(0, timesteps)))
646
+ if ddim_use_original_steps
647
+ else np.flip(timesteps)
648
+ )
649
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
650
+ if verbose:
651
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
652
+ iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps, miniters=total_steps+1, mininterval=600)
653
+ else:
654
+ iterator = time_range
655
+ old_eps = []
656
+ for i, step in enumerate(iterator):
657
+ index = total_steps - i - 1
658
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
659
+ ts_next = torch.full(
660
+ (b,),
661
+ time_range[min(i + 1, len(time_range) - 1)],
662
+ device=device,
663
+ dtype=torch.long,
664
+ )
665
+
666
+ if mask is not None:
667
+ assert x0 is not None
668
+ img_orig = self.model.q_sample(
669
+ x0, ts
670
+ ) # TODO: deterministic forward pass?
671
+ img = img_orig * mask + (1.0 - mask) * img
672
+
673
+ outs = self.p_sample_plms_dec_save_noise(
674
+ img,
675
+ cond,
676
+ ts,
677
+ index=index,
678
+ use_original_steps=ddim_use_original_steps,
679
+ quantize_denoised=quantize_denoised,
680
+ temperature=temperature,
681
+ noise_dropout=noise_dropout,
682
+ score_corrector=score_corrector,
683
+ corrector_kwargs=corrector_kwargs,
684
+ unconditional_guidance_scale=unconditional_guidance_scale,
685
+ unconditional_conditioning=unconditional_conditioning,
686
+ old_eps=old_eps,
687
+ t_next=ts_next,
688
+ input_image=input_image,
689
+ noise_save_path=noise_save_path,
690
+ noise_image=noise_images.pop(),
691
+ )
692
+ img, pred_x0, e_t = outs
693
+
694
+ old_eps.append(e_t)
695
+ if len(old_eps) >= 4:
696
+ old_eps.pop(0)
697
+ if callback:
698
+ callback(i)
699
+ if img_callback:
700
+ img_callback(pred_x0, i)
701
+
702
+ if index % log_every_t == 0 or index == total_steps - 1:
703
+ intermediates["x_inter"].append(img)
704
+ intermediates["pred_x0"].append(pred_x0)
705
+
706
+ return img, intermediates, x0_loop
707
+
708
+ @torch.no_grad()
709
+ def p_sample_plms_dec_save_noise(
710
+ self,
711
+ x,
712
+ c1,
713
+ t,
714
+ index,
715
+ repeat_noise=False,
716
+ use_original_steps=False,
717
+ quantize_denoised=False,
718
+ temperature=1.0,
719
+ noise_dropout=0.0,
720
+ score_corrector=None,
721
+ corrector_kwargs=None,
722
+ unconditional_guidance_scale=1.0,
723
+ unconditional_conditioning=None,
724
+ old_eps=None,
725
+ t_next=None,
726
+ input_image=None,
727
+ noise_save_path=None,
728
+ noise_image=None,
729
+ ):
730
+ b, *_, device = *x.shape, x.device
731
+
732
+ def get_model_output(x, t):
733
+ if (
734
+ unconditional_conditioning is None
735
+ or unconditional_guidance_scale == 1.0
736
+ ):
737
+ e_t = self.model.apply_model(x, t, c1)
738
+ else:
739
+ x_in = torch.cat([x] * 2)
740
+ t_in = torch.cat([t] * 2)
741
+ if isinstance(c1, dict):
742
+ c_in = {key: [torch.cat([unconditional_conditioning[key][0], c1[key][0]])] for key in c1}
743
+ else:
744
+ c_in = torch.cat([unconditional_conditioning, c1])
745
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
746
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
747
+ return e_t
748
+
749
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
750
+ alphas_prev = (
751
+ self.model.alphas_cumprod_prev
752
+ if use_original_steps
753
+ else self.ddim_alphas_prev
754
+ )
755
+ sqrt_one_minus_alphas = (
756
+ self.model.sqrt_one_minus_alphas_cumprod
757
+ if use_original_steps
758
+ else self.ddim_sqrt_one_minus_alphas
759
+ )
760
+ sigmas = (
761
+ self.model.ddim_sigmas_for_original_num_steps
762
+ if use_original_steps
763
+ else self.ddim_sigmas
764
+ )
765
+
766
+ def get_x_prev_and_pred_x0(e_t, index):
767
+ # select parameters corresponding to the currently considered timestep
768
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
769
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
770
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
771
+ sqrt_one_minus_at = torch.full(
772
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
773
+ )
774
+
775
+ # current prediction for x_0
776
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
777
+ if quantize_denoised:
778
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
779
+ # direction pointing to x_t
780
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
781
+ time_curr = index * 20 + 1
782
+ # img_prev = torch.load(noise_save_path + "_image_time%d.pt" % (time_curr))
783
+ img_prev = noise_image
784
+ noise = img_prev - a_prev.sqrt() * pred_x0 - dir_xt
785
+ # torch.save(noise, noise_save_path + "_time%d.pt" % (time_curr))
786
+
787
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
788
+ return x_prev, pred_x0
789
+
790
+ e_t = get_model_output(x, t)
791
+ if len(old_eps) == 0:
792
+ # Pseudo Improved Euler (2nd order)
793
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
794
+ e_t_next = get_model_output(x_prev, t_next)
795
+ e_t_prime = (e_t + e_t_next) / 2
796
+ elif len(old_eps) == 1:
797
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
798
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
799
+ elif len(old_eps) == 2:
800
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
801
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
802
+ elif len(old_eps) >= 3:
803
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
804
+ e_t_prime = (
805
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
806
+ ) / 24
807
+
808
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
809
+
810
+ return x_prev, pred_x0, e_t
811
+
812
+ ################## Encode Image End ###############################
813
+
814
+ def p_sample_plms_sampling(
815
+ self,
816
+ x,
817
+ c1,
818
+ c2,
819
+ t,
820
+ index,
821
+ repeat_noise=False,
822
+ use_original_steps=False,
823
+ quantize_denoised=False,
824
+ temperature=1.0,
825
+ noise_dropout=0.0,
826
+ score_corrector=None,
827
+ corrector_kwargs=None,
828
+ unconditional_guidance_scale=1.0,
829
+ unconditional_conditioning=None,
830
+ old_eps=None,
831
+ t_next=None,
832
+ input_image=None,
833
+ optimizing_weight=None,
834
+ noise_save_path=None,
835
+ ):
836
+ b, *_, device = *x.shape, x.device
837
+
838
+ def optimize_model_output(x, t):
839
+ # weight_for_pencil = torch.nn.Sigmoid()(optimizing_weight)
840
+ # condition = weight_for_pencil * c1 + (1 - weight_for_pencil) * c2
841
+ condition = optimizing_weight * c1 + (1 - optimizing_weight) * c2
842
+ if (
843
+ unconditional_conditioning is None
844
+ or unconditional_guidance_scale == 1.0
845
+ ):
846
+ e_t = self.model.apply_model(x, t, condition)
847
+ else:
848
+ x_in = torch.cat([x] * 2)
849
+ t_in = torch.cat([t] * 2)
850
+ if isinstance(condition, dict):
851
+ c_in = {key: [torch.cat([unconditional_conditioning[key][0], condition[key][0]])] for key in condition}
852
+ else:
853
+ c_in = torch.cat([unconditional_conditioning, condition])
854
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
855
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
856
+ return e_t
857
+
858
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
859
+ alphas_prev = (
860
+ self.model.alphas_cumprod_prev
861
+ if use_original_steps
862
+ else self.ddim_alphas_prev
863
+ )
864
+ sqrt_one_minus_alphas = (
865
+ self.model.sqrt_one_minus_alphas_cumprod
866
+ if use_original_steps
867
+ else self.ddim_sqrt_one_minus_alphas
868
+ )
869
+ sigmas = (
870
+ self.model.ddim_sigmas_for_original_num_steps
871
+ if use_original_steps
872
+ else self.ddim_sigmas
873
+ )
874
+
875
+ def get_x_prev_and_pred_x0(e_t, index):
876
+ # select parameters corresponding to the currently considered timestep
877
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
878
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
879
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
880
+ sqrt_one_minus_at = torch.full(
881
+ (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
882
+ )
883
+
884
+ # current prediction for x_0
885
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
886
+ if quantize_denoised:
887
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
888
+ # direction pointing to x_t
889
+ dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
890
+ time_curr = index * 20 + 1
891
+ if noise_save_path and index > 16:
892
+ noise = torch.load(noise_save_path + "_time%d.pt" % (time_curr))[:1]
893
+ else:
894
+ noise = torch.zeros_like(dir_xt)
895
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
896
+ return x_prev, pred_x0
897
+
898
+ e_t = optimize_model_output(x, t)
899
+ if len(old_eps) == 0:
900
+ # Pseudo Improved Euler (2nd order)
901
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
902
+ # e_t_next = get_model_output(x_prev, t_next)
903
+ e_t_next = optimize_model_output(x_prev, t_next)
904
+ e_t_prime = (e_t + e_t_next) / 2
905
+ elif len(old_eps) == 1:
906
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
907
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
908
+ elif len(old_eps) == 2:
909
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
910
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
911
+ elif len(old_eps) >= 3:
912
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
913
+ e_t_prime = (
914
+ 55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
915
+ ) / 24
916
+
917
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
918
+
919
+ return x_prev, pred_x0, e_t
920
+
921
+ ################## Edit Input Image ###############################
922
+
923
+ def sample_optimize_intrinsic_edit(
924
+ self,
925
+ S,
926
+ batch_size,
927
+ shape,
928
+ conditioning1=None,
929
+ conditioning2=None,
930
+ callback=None,
931
+ normals_sequence=None,
932
+ img_callback=None,
933
+ quantize_x0=False,
934
+ eta=0.0,
935
+ mask=None,
936
+ x0=None,
937
+ temperature=1.0,
938
+ noise_dropout=0.0,
939
+ score_corrector=None,
940
+ corrector_kwargs=None,
941
+ verbose=True,
942
+ x_T=None,
943
+ log_every_t=100,
944
+ unconditional_guidance_scale=1.0,
945
+ unconditional_conditioning=None,
946
+ input_image=None,
947
+ noise_save_path=None,
948
+ lambda_t=None,
949
+ lambda_save_path=None,
950
+ image_save_path=None,
951
+ original_text=None,
952
+ new_text=None,
953
+ otext=None,
954
+ noise_saved_path=None,
955
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
956
+ **kwargs,
957
+ ):
958
+ assert conditioning1 is not None
959
+ assert conditioning2 is not None
960
+
961
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
962
+ # sampling
963
+ C, H, W = shape
964
+ size = (batch_size, C, H, W)
965
+ print(f"Data shape for PLMS sampling is {size}")
966
+
967
+ self.plms_sampling_optimize_intrinsic_edit(
968
+ conditioning1,
969
+ conditioning2,
970
+ size,
971
+ callback=callback,
972
+ img_callback=img_callback,
973
+ quantize_denoised=quantize_x0,
974
+ mask=mask,
975
+ x0=x0,
976
+ ddim_use_original_steps=False,
977
+ noise_dropout=noise_dropout,
978
+ temperature=temperature,
979
+ score_corrector=score_corrector,
980
+ corrector_kwargs=corrector_kwargs,
981
+ x_T=x_T,
982
+ log_every_t=log_every_t,
983
+ unconditional_guidance_scale=unconditional_guidance_scale,
984
+ unconditional_conditioning=unconditional_conditioning,
985
+ input_image=input_image,
986
+ noise_save_path=noise_save_path,
987
+ lambda_t=lambda_t,
988
+ lambda_save_path=lambda_save_path,
989
+ image_save_path=image_save_path,
990
+ original_text=original_text,
991
+ new_text=new_text,
992
+ otext=otext,
993
+ noise_saved_path=noise_saved_path,
994
+ )
995
+ return None
996
+
997
+ def plms_sampling_optimize_intrinsic_edit(
998
+ self,
999
+ cond1,
1000
+ cond2,
1001
+ shape,
1002
+ x_T=None,
1003
+ ddim_use_original_steps=False,
1004
+ callback=None,
1005
+ timesteps=None,
1006
+ quantize_denoised=False,
1007
+ mask=None,
1008
+ x0=None,
1009
+ img_callback=None,
1010
+ log_every_t=100,
1011
+ temperature=1.0,
1012
+ noise_dropout=0.0,
1013
+ score_corrector=None,
1014
+ corrector_kwargs=None,
1015
+ unconditional_guidance_scale=1.0,
1016
+ unconditional_conditioning=None,
1017
+ input_image=None,
1018
+ noise_save_path=None,
1019
+ lambda_t=None,
1020
+ lambda_save_path=None,
1021
+ image_save_path=None,
1022
+ original_text=None,
1023
+ new_text=None,
1024
+ otext=None,
1025
+ noise_saved_path=None,
1026
+ ):
1027
+ # Different from above, the intrinsic edit version needs
1028
+ device = self.model.betas.device
1029
+
1030
+ b = shape[0]
1031
+ if x_T is None:
1032
+ img = torch.randn(shape, device=device)
1033
+ else:
1034
+ img = x_T
1035
+ img_clone = img.clone()
1036
+
1037
+ if timesteps is None:
1038
+ timesteps = (
1039
+ self.ddpm_num_timesteps
1040
+ if ddim_use_original_steps
1041
+ else self.ddim_timesteps
1042
+ )
1043
+ elif timesteps is not None and not ddim_use_original_steps:
1044
+ subset_end = (
1045
+ int(
1046
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
1047
+ * self.ddim_timesteps.shape[0]
1048
+ )
1049
+ - 1
1050
+ )
1051
+ timesteps = self.ddim_timesteps[:subset_end]
1052
+
1053
+ intermediates = {"x_inter": [img], "pred_x0": [img]}
1054
+ time_range = (
1055
+ list(reversed(range(0, timesteps)))
1056
+ if ddim_use_original_steps
1057
+ else np.flip(timesteps)
1058
+ )
1059
+
1060
+ weighting_parameter = lambda_t
1061
+ weighting_parameter.requires_grad = True
1062
+ from torch import optim
1063
+
1064
+ optimizer = optim.Adam([weighting_parameter], lr=0.05)
1065
+
1066
+ print("Original image")
1067
+ with torch.no_grad():
1068
+ img = img_clone.clone()
1069
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
1070
+ iterator = time_range
1071
+ old_eps = []
1072
+
1073
+ for i, step in enumerate(iterator):
1074
+ index = total_steps - i - 1
1075
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
1076
+ ts_next = torch.full(
1077
+ (b,),
1078
+ time_range[min(i + 1, len(time_range) - 1)],
1079
+ device=device,
1080
+ dtype=torch.long,
1081
+ )
1082
+
1083
+ outs = self.p_sample_plms_sampling(
1084
+ img,
1085
+ cond1,
1086
+ cond2,
1087
+ ts,
1088
+ index=index,
1089
+ use_original_steps=ddim_use_original_steps,
1090
+ quantize_denoised=quantize_denoised,
1091
+ temperature=temperature,
1092
+ noise_dropout=noise_dropout,
1093
+ score_corrector=score_corrector,
1094
+ corrector_kwargs=corrector_kwargs,
1095
+ unconditional_guidance_scale=unconditional_guidance_scale,
1096
+ unconditional_conditioning=unconditional_conditioning,
1097
+ old_eps=old_eps,
1098
+ t_next=ts_next,
1099
+ input_image=input_image,
1100
+ optimizing_weight=torch.ones(50)[i],
1101
+ noise_save_path=noise_saved_path,
1102
+ )
1103
+ img, pred_x0, e_t = outs
1104
+ old_eps.append(e_t)
1105
+ if len(old_eps) >= 4:
1106
+ old_eps.pop(0)
1107
+ img_temp = self.model.decode_first_stage(img)
1108
+ img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
1109
+ img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2)
1110
+ # save image
1111
+ with torch.no_grad():
1112
+ x_sample = 255.0 * rearrange(
1113
+ img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
1114
+ )
1115
+ imgsave = Image.fromarray(x_sample.astype(np.uint8))
1116
+ imgsave.save(image_save_path + "original.png")
1117
+ readed_image = (
1118
+ torchvision.io.read_image(image_save_path + "original.png").float()
1119
+ / 255
1120
+ )
1121
+ print("Optimizing start")
1122
+ for epoch in tqdm(range(10)):
1123
+ img = img_clone.clone()
1124
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
1125
+ iterator = time_range
1126
+ old_eps = []
1127
+
1128
+ for i, step in enumerate(iterator):
1129
+ index = total_steps - i - 1
1130
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
1131
+ ts_next = torch.full(
1132
+ (b,),
1133
+ time_range[min(i + 1, len(time_range) - 1)],
1134
+ device=device,
1135
+ dtype=torch.long,
1136
+ )
1137
+
1138
+ outs = self.p_sample_plms_sampling(
1139
+ img,
1140
+ cond1,
1141
+ cond2,
1142
+ ts,
1143
+ index=index,
1144
+ use_original_steps=ddim_use_original_steps,
1145
+ quantize_denoised=quantize_denoised,
1146
+ temperature=temperature,
1147
+ noise_dropout=noise_dropout,
1148
+ score_corrector=score_corrector,
1149
+ corrector_kwargs=corrector_kwargs,
1150
+ unconditional_guidance_scale=unconditional_guidance_scale,
1151
+ unconditional_conditioning=unconditional_conditioning,
1152
+ old_eps=old_eps,
1153
+ t_next=ts_next,
1154
+ input_image=input_image,
1155
+ optimizing_weight=weighting_parameter[i],
1156
+ noise_save_path=noise_saved_path,
1157
+ )
1158
+ img, pred_x0, e_t = outs
1159
+ old_eps.append(e_t)
1160
+ if len(old_eps) >= 4:
1161
+ old_eps.pop(0)
1162
+ img_temp = self.model.decode_first_stage(img)
1163
+ img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
1164
+ img_temp_ddim = img_temp_ddim.cpu()
1165
+
1166
+ # save image
1167
+ # with torch.no_grad():
1168
+ # x_sample = 255.0 * rearrange(
1169
+ # img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
1170
+ # )
1171
+ # imgsave = Image.fromarray(x_sample.astype(np.uint8))
1172
+ # imgsave.save(image_save_path + "/%d.png" % (epoch))
1173
+
1174
+ loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image)
1175
+ loss2 = DCLIPLoss()(
1176
+ readed_image, img_temp_ddim[0].float().cuda(), otext, new_text
1177
+ )
1178
+ loss = 0.05 * loss1 + loss2
1179
+ optimizer.zero_grad()
1180
+ loss.backward()
1181
+ optimizer.step()
1182
+ # torch.save(
1183
+ # weighting_parameter, lambda_save_path + "/weightingParam%d.pt" % (epoch)
1184
+ # )
1185
+ if epoch < 9:
1186
+ del img
1187
+ else:
1188
+ # save image
1189
+ with torch.no_grad():
1190
+ x_sample = 255.0 * rearrange(
1191
+ img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
1192
+ )
1193
+ imgsave = Image.fromarray(x_sample.astype(np.uint8))
1194
+ imgsave.save(image_save_path + "/final.png")
1195
+ torch.save(
1196
+ weighting_parameter, lambda_save_path + "/weightingParam_final.pt"
1197
+ )
1198
+
1199
+ torch.cuda.empty_cache()
1200
+ # shutil.rmtree("noise")
1201
+ return None
1202
+
1203
+ ################ Edit Image End ######################
1204
+
1205
+ ################ Disentangle #########################
1206
+
1207
+ def sample_optimize_intrinsic(
1208
+ self,
1209
+ S,
1210
+ batch_size,
1211
+ shape,
1212
+ conditioning1=None,
1213
+ conditioning2=None,
1214
+ callback=None,
1215
+ normals_sequence=None,
1216
+ img_callback=None,
1217
+ quantize_x0=False,
1218
+ eta=0.0,
1219
+ mask=None,
1220
+ x0=None,
1221
+ temperature=1.0,
1222
+ noise_dropout=0.0,
1223
+ score_corrector=None,
1224
+ corrector_kwargs=None,
1225
+ verbose=True,
1226
+ x_T=None,
1227
+ log_every_t=100,
1228
+ unconditional_guidance_scale=1.0,
1229
+ unconditional_conditioning=None,
1230
+ input_image=None,
1231
+ noise_save_path=None,
1232
+ lambda_t=None,
1233
+ lambda_save_path=None,
1234
+ image_save_path=None,
1235
+ original_text=None,
1236
+ new_text=None,
1237
+ otext=None,
1238
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
1239
+ **kwargs,
1240
+ ):
1241
+ assert conditioning1 is not None
1242
+ assert conditioning2 is not None
1243
+
1244
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
1245
+ # sampling
1246
+ C, H, W = shape
1247
+ size = (batch_size, C, H, W)
1248
+ print(f"Data shape for PLMS sampling is {size}")
1249
+
1250
+ self.plms_sampling_optimize_intrinsic(
1251
+ conditioning1,
1252
+ conditioning2,
1253
+ size,
1254
+ callback=callback,
1255
+ img_callback=img_callback,
1256
+ quantize_denoised=quantize_x0,
1257
+ mask=mask,
1258
+ x0=x0,
1259
+ ddim_use_original_steps=False,
1260
+ noise_dropout=noise_dropout,
1261
+ temperature=temperature,
1262
+ score_corrector=score_corrector,
1263
+ corrector_kwargs=corrector_kwargs,
1264
+ x_T=x_T,
1265
+ log_every_t=log_every_t,
1266
+ unconditional_guidance_scale=unconditional_guidance_scale,
1267
+ unconditional_conditioning=unconditional_conditioning,
1268
+ input_image=input_image,
1269
+ noise_save_path=noise_save_path,
1270
+ lambda_t=lambda_t,
1271
+ lambda_save_path=lambda_save_path,
1272
+ image_save_path=image_save_path,
1273
+ original_text=original_text,
1274
+ new_text=new_text,
1275
+ otext=otext,
1276
+ )
1277
+ return None
1278
+
1279
+ def plms_sampling_optimize_intrinsic(
1280
+ self,
1281
+ cond1,
1282
+ cond2,
1283
+ shape,
1284
+ x_T=None,
1285
+ ddim_use_original_steps=False,
1286
+ callback=None,
1287
+ timesteps=None,
1288
+ quantize_denoised=False,
1289
+ mask=None,
1290
+ x0=None,
1291
+ img_callback=None,
1292
+ log_every_t=100,
1293
+ temperature=1.0,
1294
+ noise_dropout=0.0,
1295
+ score_corrector=None,
1296
+ corrector_kwargs=None,
1297
+ unconditional_guidance_scale=1.0,
1298
+ unconditional_conditioning=None,
1299
+ input_image=None,
1300
+ noise_save_path=None,
1301
+ lambda_t=None,
1302
+ lambda_save_path=None,
1303
+ image_save_path=None,
1304
+ original_text=None,
1305
+ new_text=None,
1306
+ otext=None,
1307
+ ):
1308
+ device = self.model.betas.device
1309
+
1310
+ b = shape[0]
1311
+ if x_T is None:
1312
+ img = torch.randn(shape, device=device)
1313
+ else:
1314
+ img = x_T
1315
+ img_clone = img.clone()
1316
+
1317
+ if timesteps is None:
1318
+ timesteps = (
1319
+ self.ddpm_num_timesteps
1320
+ if ddim_use_original_steps
1321
+ else self.ddim_timesteps
1322
+ )
1323
+ elif timesteps is not None and not ddim_use_original_steps:
1324
+ subset_end = (
1325
+ int(
1326
+ min(timesteps / self.ddim_timesteps.shape[0], 1)
1327
+ * self.ddim_timesteps.shape[0]
1328
+ )
1329
+ - 1
1330
+ )
1331
+ timesteps = self.ddim_timesteps[:subset_end]
1332
+
1333
+ time_range = (
1334
+ list(reversed(range(0, timesteps)))
1335
+ if ddim_use_original_steps
1336
+ else np.flip(timesteps)
1337
+ )
1338
+ weighting_parameter = lambda_t
1339
+ weighting_parameter.requires_grad = True
1340
+ from torch import optim
1341
+
1342
+ optimizer = optim.Adam([weighting_parameter], lr=0.05)
1343
+
1344
+ print("Original image")
1345
+ with torch.no_grad():
1346
+ img = img_clone.clone()
1347
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
1348
+ iterator = time_range
1349
+ old_eps = []
1350
+
1351
+ for i, step in enumerate(iterator):
1352
+ index = total_steps - i - 1
1353
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
1354
+ ts_next = torch.full(
1355
+ (b,),
1356
+ time_range[min(i + 1, len(time_range) - 1)],
1357
+ device=device,
1358
+ dtype=torch.long,
1359
+ )
1360
+
1361
+ outs = self.p_sample_plms_sampling(
1362
+ img,
1363
+ cond1,
1364
+ cond2,
1365
+ ts,
1366
+ index=index,
1367
+ use_original_steps=ddim_use_original_steps,
1368
+ quantize_denoised=quantize_denoised,
1369
+ temperature=temperature,
1370
+ noise_dropout=noise_dropout,
1371
+ score_corrector=score_corrector,
1372
+ corrector_kwargs=corrector_kwargs,
1373
+ unconditional_guidance_scale=unconditional_guidance_scale,
1374
+ unconditional_conditioning=unconditional_conditioning,
1375
+ old_eps=old_eps,
1376
+ t_next=ts_next,
1377
+ input_image=input_image,
1378
+ optimizing_weight=torch.ones(50)[i],
1379
+ noise_save_path=noise_save_path,
1380
+ )
1381
+ img, pred_x0, e_t = outs
1382
+ old_eps.append(e_t)
1383
+ if len(old_eps) >= 4:
1384
+ old_eps.pop(0)
1385
+ img_temp = self.model.decode_first_stage(img)
1386
+ del img
1387
+ img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
1388
+ img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2)
1389
+ # save image
1390
+ with torch.no_grad():
1391
+ x_sample = 255.0 * rearrange(
1392
+ img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
1393
+ )
1394
+ imgsave = Image.fromarray(x_sample.astype(np.uint8))
1395
+ imgsave.save(image_save_path + "original.png")
1396
+
1397
+ readed_image = (
1398
+ torchvision.io.read_image(image_save_path + "original.png").float()
1399
+ / 255
1400
+ )
1401
+
1402
+ print("Optimizing start")
1403
+ for epoch in tqdm(range(10)):
1404
+ img = img_clone.clone()
1405
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
1406
+ iterator = time_range
1407
+ old_eps = []
1408
+
1409
+ for i, step in enumerate(iterator):
1410
+ index = total_steps - i - 1
1411
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
1412
+ ts_next = torch.full(
1413
+ (b,),
1414
+ time_range[min(i + 1, len(time_range) - 1)],
1415
+ device=device,
1416
+ dtype=torch.long,
1417
+ )
1418
+
1419
+ outs = self.p_sample_plms_sampling(
1420
+ img,
1421
+ cond1,
1422
+ cond2,
1423
+ ts,
1424
+ index=index,
1425
+ use_original_steps=ddim_use_original_steps,
1426
+ quantize_denoised=quantize_denoised,
1427
+ temperature=temperature,
1428
+ noise_dropout=noise_dropout,
1429
+ score_corrector=score_corrector,
1430
+ corrector_kwargs=corrector_kwargs,
1431
+ unconditional_guidance_scale=unconditional_guidance_scale,
1432
+ unconditional_conditioning=unconditional_conditioning,
1433
+ old_eps=old_eps,
1434
+ t_next=ts_next,
1435
+ input_image=input_image,
1436
+ optimizing_weight=weighting_parameter[i],
1437
+ noise_save_path=noise_save_path,
1438
+ )
1439
+ img, _, e_t = outs
1440
+ old_eps.append(e_t)
1441
+ if len(old_eps) >= 4:
1442
+ old_eps.pop(0)
1443
+ img_temp = self.model.decode_first_stage(img)
1444
+ del img
1445
+ img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
1446
+ img_temp_ddim = img_temp_ddim.cpu()
1447
+
1448
+ # # save image
1449
+ # with torch.no_grad():
1450
+ # x_sample = 255. * rearrange(img_temp_ddim[0].detach().cpu().numpy(), 'c h w -> h w c')
1451
+ # imgsave = Image.fromarray(x_sample.astype(np.uint8))
1452
+ # imgsave.save(image_save_path + "/%d.png"%(epoch))
1453
+
1454
+ loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image)
1455
+ loss2 = DCLIPLoss()(
1456
+ readed_image, img_temp_ddim[0].float().cuda(), otext, new_text
1457
+ )
1458
+ loss = (
1459
+ 0.05 * loss1 + loss2
1460
+ ) # 0.05 or 0.03. Adjust according to attributes on scenes or people.
1461
+ optimizer.zero_grad()
1462
+ loss.backward()
1463
+ optimizer.step()
1464
+ # torch.save(weighting_parameter, lambda_save_path+"/weightingParam%d.pt"%(epoch))
1465
+ with torch.no_grad():
1466
+ if epoch == 9:
1467
+ # save image
1468
+ x_sample = 255.0 * rearrange(
1469
+ img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
1470
+ )
1471
+ imgsave = Image.fromarray(x_sample.astype(np.uint8))
1472
+ imgsave.save(image_save_path + "/final.png")
1473
+ torch.save(
1474
+ weighting_parameter,
1475
+ lambda_save_path + "/weightingParam_final.pt",
1476
+ )
1477
+ torch.cuda.empty_cache()
1478
+ return None
1479
+
1480
+
1481
+ ################ Disentangle End #########################
cldm/tmp.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ use kornia and albumentations for transformations
5
+ @author: Tu Bui @University of Surrey
6
+ """
7
+ import os
8
+ from . import utils
9
+ import torch
10
+ import numpy as np
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+ from PIL import Image
14
+ import kornia as ko
15
+ import albumentations as ab
16
+
17
+
18
+ class IdentityAugment(nn.Module):
19
+ def __init__(self):
20
+ super().__init__()
21
+
22
+ def forward(self, x, **kwargs):
23
+ return x
24
+
25
+
26
+ class RandomCompress(nn.Module):
27
+ def __init__(self, severity='medium', p=0.5):
28
+ super().__init__()
29
+ self.p = p
30
+ if severity == 'low':
31
+ self.jpeg_quality = 70
32
+ elif severity == 'medium':
33
+ self.jpeg_quality = 50
34
+ elif severity == 'high':
35
+ self.jpeg_quality = 40
36
+
37
+ def forward(self, x, ramp=1.):
38
+ # x (B, C, H, W) in range [0, 1]
39
+ # ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
40
+ if torch.rand(1)[0] >= self.p:
41
+ return x
42
+ jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
43
+ x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
44
+ return x
45
+
46
+
47
+ class RandomBoxBlur(nn.Module):
48
+ def __init__(self, severity='medium', border_type='reflect', normalize=True, p=0.5):
49
+ super().__init__()
50
+ self.p = p
51
+ if severity == 'low':
52
+ kernel_size = 3
53
+ elif severity == 'medium':
54
+ kernel_size = 5
55
+ elif severity == 'high':
56
+ kernel_size = 7
57
+
58
+ self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalize=normalize, p=self.p)
59
+
60
+ def forward(self, x, **kwargs):
61
+ return self.tform(x)
62
+
63
+ class RandomMedianBlur(nn.Module):
64
+ def __init__(self, severity='medium', p=0.5):
65
+ super().__init__()
66
+ self.p = p
67
+ self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
68
+
69
+ def forward(self, x, **kwargs):
70
+ return self.tform(x)
71
+
72
+
73
+ class RandomBrightness(nn.Module):
74
+ def __init__(self, severity='medium', p=0.5):
75
+ super().__init__()
76
+ self.p = p
77
+ if severity == 'low':
78
+ brightness = (0.9, 1.1)
79
+ elif severity == 'medium':
80
+ brightness = (0.75, 1.25)
81
+ elif severity == 'high':
82
+ brightness = (0.5, 1.5)
83
+ self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
84
+
85
+ def forward(self, x, **kwargs):
86
+ return self.tform(x)
87
+
88
+
89
+ class RandomContrast(nn.Module):
90
+ def __init__(self, severity='medium', p=0.5):
91
+ super().__init__()
92
+ self.p = p
93
+ if severity == 'low':
94
+ contrast = (0.9, 1.1)
95
+ elif severity == 'medium':
96
+ contrast = (0.75, 1.25)
97
+ elif severity == 'high':
98
+ contrast = (0.5, 1.5)
99
+ self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
100
+
101
+ def forward(self, x, **kwargs):
102
+ return self.tform(x)
103
+
104
+
105
+ class RandomSaturation(nn.Module):
106
+ def __init__(self, severity='medium', p=0.5):
107
+ super().__init__()
108
+ self.p = p
109
+ if severity == 'low':
110
+ sat = (0.9, 1.1)
111
+ elif severity == 'medium':
112
+ sat = (0.75, 1.25)
113
+ elif severity == 'high':
114
+ sat = (0.5, 1.5)
115
+ self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
116
+
117
+ def forward(self, x, **kwargs):
118
+ return self.tform(x)
119
+
120
+ class RandomSharpness(nn.Module):
121
+ def __init__(self, severity='medium', p=0.5):
122
+ super().__init__()
123
+ self.p = p
124
+ if severity == 'low':
125
+ sharpness = 0.5
126
+ elif severity == 'medium':
127
+ sharpness = 1.0
128
+ elif severity == 'high':
129
+ sharpness = 2.5
130
+ self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
131
+
132
+ def forward(self, x, **kwargs):
133
+ return self.tform(x)
134
+
135
+ class RandomColorJiggle(nn.Module):
136
+ def __init__(self, severity='medium', p=0.5):
137
+ super().__init__()
138
+ self.p = p
139
+ if severity == 'low':
140
+ factor = (0.05, 0.05, 0.05, 0.01)
141
+ elif severity == 'medium':
142
+ factor = (0.1, 0.1, 0.1, 0.02)
143
+ elif severity == 'high':
144
+ factor = (0.1, 0.1, 0.1, 0.05)
145
+ self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
146
+
147
+ def forward(self, x, **kwargs):
148
+ return self.tform(x)
149
+
150
+ class RandomHue(nn.Module):
151
+ def __init__(self, severity='medium', p=0.5):
152
+ super().__init__()
153
+ self.p = p
154
+ if severity == 'low':
155
+ hue = 0.01
156
+ elif severity == 'medium':
157
+ hue = 0.02
158
+ elif severity == 'high':
159
+ hue = 0.05
160
+ self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
161
+
162
+ def forward(self, x, **kwargs):
163
+ return self.tform(x)
164
+
165
+ class RandomGamma(nn.Module):
166
+ def __init__(self, severity='medium', p=0.5):
167
+ super().__init__()
168
+ self.p = p
169
+ if severity == 'low':
170
+ gamma, gain = (0.9, 1.1), (0.9,1.1)
171
+ elif severity == 'medium':
172
+ gamma, gain = (0.75, 1.25), (0.75,1.25)
173
+ elif severity == 'high':
174
+ gamma, gain = (0.5, 1.5), (0.5,1.5)
175
+ self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
176
+
177
+ def forward(self, x, **kwargs):
178
+ return self.tform(x)
179
+
180
+ class RandomGaussianBlur(nn.Module):
181
+ def __init__(self, severity='medium', p=0.5):
182
+ super().__init__()
183
+ self.p = p
184
+ if severity == 'low':
185
+ kernel_size, sigma = 3, (0.1, 1.0)
186
+ elif severity == 'medium':
187
+ kernel_size, sigma = 5, (0.1, 1.5)
188
+ elif severity == 'high':
189
+ kernel_size, sigma = 7, (0.1, 2.0)
190
+ self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
191
+
192
+ def forward(self, x, **kwargs):
193
+ return self.tform(x)
194
+
195
+ class RandomGaussianNoise(nn.Module):
196
+ def __init__(self, severity='medium', p=0.5):
197
+ super().__init__()
198
+ self.p = p
199
+ if severity == 'low':
200
+ std = 0.02
201
+ elif severity == 'medium':
202
+ std = 0.04
203
+ elif severity == 'high':
204
+ std = 0.08
205
+ self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
206
+
207
+ def forward(self, x, **kwargs):
208
+ return self.tform(x)
209
+
210
+ class RandomMotionBlur(nn.Module):
211
+ def __init__(self, severity='medium', p=0.5):
212
+ super().__init__()
213
+ self.p = p
214
+ if severity == 'low':
215
+ kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
216
+ elif severity == 'medium':
217
+ kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
218
+ elif severity == 'high':
219
+ kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
220
+ self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
221
+
222
+ def forward(self, x, **kwargs):
223
+ return self.tform(x)
224
+
225
+ class RandomPosterize(nn.Module):
226
+ def __init__(self, severity='medium', p=0.5):
227
+ super().__init__()
228
+ self.p = p
229
+ if severity == 'low':
230
+ bits = 5
231
+ elif severity == 'medium':
232
+ bits = 4
233
+ elif severity == 'high':
234
+ bits = 3
235
+ self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
236
+
237
+ def forward(self, x, **kwargs):
238
+ return self.tform(x)
239
+
240
+ class RandomRGBShift(nn.Module):
241
+ def __init__(self, severity='medium', p=0.5):
242
+ super().__init__()
243
+ self.p = p
244
+ if severity == 'low':
245
+ rgb = 0.02
246
+ elif severity == 'medium':
247
+ rgb = 0.05
248
+ elif severity == 'high':
249
+ rgb = 0.1
250
+ self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
251
+
252
+ def forward(self, x, **kwargs):
253
+ return self.tform(x)
254
+
255
+
256
+
257
+ class TransformNet(nn.Module):
258
+ def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=True, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
259
+ super().__init__()
260
+ self.n_optional = n_optional
261
+ self.p = p
262
+ p_flip = 0.5 if flip else 0
263
+ rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
264
+ self.ramp = ramp
265
+ self.register_buffer('step0', torch.tensor(0))
266
+
267
+ assert crop_mode in ['random_crop', 'resized_crop']
268
+ if crop_mode == 'random_crop':
269
+ rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
270
+ elif crop_mode == 'resized_crop':
271
+ rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
272
+
273
+ self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
274
+ self.optional_transforms = []
275
+ if compress:
276
+ self.optional_transforms.append(RandomCompress(severity, p=p))
277
+ if brightness:
278
+ self.optional_transforms.append(RandomBrightness(severity, p=p))
279
+ if contrast:
280
+ self.optional_transforms.append(RandomContrast(severity, p=p))
281
+ if color_jiggle:
282
+ self.optional_transforms.append(RandomColorJiggle(severity, p=p))
283
+ if gamma:
284
+ self.optional_transforms.append(RandomGamma(severity, p=p))
285
+ if grayscale:
286
+ self.optional_transforms.append(ko.augmentation.RandomGrayscale(p=p/4))
287
+ if gaussian_blur:
288
+ self.optional_transforms.append(RandomGaussianBlur(severity, p=p))
289
+ if gaussian_noise:
290
+ self.optional_transforms.append(RandomGaussianNoise(severity, p=p))
291
+ if hue:
292
+ self.optional_transforms.append(RandomHue(severity, p=p))
293
+ if motion_blur:
294
+ self.optional_transforms.append(RandomMotionBlur(severity, p=p))
295
+ if posterize:
296
+ self.optional_transforms.append(RandomPosterize(severity, p=p))
297
+ if rgb_shift:
298
+ self.optional_transforms.append(RandomRGBShift(severity, p=p))
299
+ if saturation:
300
+ self.optional_transforms.append(RandomSaturation(severity, p=p))
301
+ if sharpness:
302
+ self.optional_transforms.append(RandomSharpness(severity, p=p))
303
+ if median_blur:
304
+ self.optional_transforms.append(RandomMedianBlur(severity, p=p))
305
+
306
+ def activate(self, global_step):
307
+ if self.step0 == 0:
308
+ print(f'[TRAINING] Activating TransformNet at step {global_step}')
309
+ self.step0 = torch.tensor(global_step)
310
+
311
+ def is_activated(self):
312
+ return self.step0 > 0
313
+
314
+ def forward(self, x, global_step, p=0.9):
315
+ # x: [batch_size, 3, H, W] in range [-1, 1]
316
+ x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
317
+ # fixed transforms
318
+ for tform in self.fixed_transforms:
319
+ x = tform(x)
320
+ if isinstance(x, tuple):
321
+ x = x[0]
322
+
323
+ # optional transforms
324
+ ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
325
+ try:
326
+ if len(self.optional_transforms) > 0:
327
+ tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
328
+ for tform_id in tform_ids:
329
+ tform = self.optional_transforms[tform_id]
330
+ x = tform(x, ramp=ramp)
331
+ if isinstance(x, tuple):
332
+ x = x[0]
333
+ except Exception as e:
334
+ print(tform_id, ramp)
335
+ import pdb; pdb.set_trace()
336
+ return x * 2 - 1 # [0, 1] -> [-1, 1]
337
+
338
+
339
+ if __name__ == '__main__':
340
+ pass
cldm/transformations.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from . import utils
3
+ import torch
4
+ import numpy as np
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from tools.augment_imagenetc import RandomImagenetC
8
+ from PIL import Image
9
+ import kornia as ko
10
+ # from kornia.augmentation import RandomHorizontalFlip, RandomCrop
11
+
12
+
13
+ class TransformNet(nn.Module):
14
+ def __init__(self, rnd_bri=0.3, rnd_hue=0.1, do_jpeg=False, jpeg_quality=50, rnd_noise=0.02, rnd_sat=1.0, rnd_trans=0.1,contrast=[0.5, 1.5], rnd_flip=False, ramp=1000, imagenetc_level=0, crop_mode='crop') -> None:
15
+ super().__init__()
16
+ self.rnd_bri = rnd_bri
17
+ self.rnd_hue = rnd_hue
18
+ self.jpeg_quality = jpeg_quality
19
+ self.rnd_noise = rnd_noise
20
+ self.rnd_sat = rnd_sat
21
+ self.rnd_trans = rnd_trans
22
+ self.contrast_low, self.contrast_high = contrast
23
+ self.do_jpeg = do_jpeg
24
+ p_flip = 0.5 if rnd_flip else 0
25
+ self.rnd_flip = ko.augmentation.RandomHorizontalFlip(p_flip)
26
+ self.ramp = ramp
27
+ self.register_buffer('step0', torch.tensor(0)) # large number
28
+ assert crop_mode in ['crop', 'resized_crop']
29
+ if crop_mode == 'crop':
30
+ self.rnd_crop = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
31
+ elif crop_mode == 'resized_crop':
32
+ self.rnd_crop = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
33
+ if imagenetc_level > 0:
34
+ self.imagenetc = ImagenetCTransform(max_severity=imagenetc_level)
35
+
36
+ def activate(self, global_step):
37
+ if self.step0 == 0:
38
+ print(f'[TRAINING] Activating TransformNet at step {global_step}')
39
+ self.step0 = torch.tensor(global_step)
40
+
41
+ def is_activated(self):
42
+ return self.step0 > 0
43
+
44
+ def forward(self, x, global_step, p=0.9):
45
+ # x: [batch_size, 3, H, W] in range [-1, 1]
46
+ x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
47
+
48
+ # flip
49
+ x = self.rnd_flip(x)
50
+ # random crop
51
+ x = self.rnd_crop(x)
52
+ if isinstance(x, tuple):
53
+ x = x[0] # weird bug in kornia 0.6.0 that returns transform matrix occasionally
54
+
55
+ if torch.rand(1)[0] >= p:
56
+ return x * 2 - 1 # [0, 1] -> [-1, 1]
57
+ if hasattr(self, 'imagenetc') and torch.rand(1)[0] < 0.5:
58
+ x = self.imagenetc(x * 2 - 1) # [0, 1] -> [-1, 1])
59
+ return x
60
+
61
+ batch_size, sh, device = x.shape[0], x.size(), x.device
62
+ # x0 = x.clone().detach()
63
+ ramp_fn = lambda ramp: np.min([(global_step-self.step0.cpu().item()) / ramp, 1.])
64
+
65
+ rnd_bri = ramp_fn(self.ramp) * self.rnd_bri
66
+ rnd_hue = ramp_fn(self.ramp) * self.rnd_hue
67
+ rnd_brightness = utils.get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size).to(device) # [batch_size, 3, 1, 1]
68
+ rnd_noise = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_noise
69
+
70
+ contrast_low = 1. - (1. - self.contrast_low) * ramp_fn(self.ramp)
71
+ contrast_high = 1. + (self.contrast_high - 1.) * ramp_fn(self.ramp)
72
+ contrast_params = [contrast_low, contrast_high]
73
+
74
+ # blur
75
+ N_blur = 7
76
+ f = utils.random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.],
77
+ wmin_line=3).to(device)
78
+ x = F.conv2d(x, f, bias=None, padding=int((N_blur - 1) / 2))
79
+
80
+ # noise
81
+ noise = torch.normal(mean=0, std=rnd_noise, size=x.size(), dtype=torch.float32).to(device)
82
+ x = x + noise
83
+ x = torch.clamp(x, 0, 1)
84
+
85
+ # contrast & brightness
86
+ contrast_scale = torch.Tensor(x.size()[0]).uniform_(contrast_params[0], contrast_params[1])
87
+ contrast_scale = contrast_scale.reshape(x.size()[0], 1, 1, 1).to(device)
88
+ x = x * contrast_scale
89
+ x = x + rnd_brightness
90
+ x = torch.clamp(x, 0, 1)
91
+
92
+ # saturation
93
+ # rnd_sat = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_sat
94
+ # sat_weight = torch.FloatTensor([.3, .6, .1]).reshape(1, 3, 1, 1).to(device)
95
+ # encoded_image_lum = torch.mean(x * sat_weight, dim=1).unsqueeze_(1)
96
+ # x = (1 - rnd_sat) * x + rnd_sat * encoded_image_lum
97
+ rnd_sat = (torch.rand(1)[0]*2.0 - 1.0)*ramp_fn(self.ramp) * self.rnd_sat + 1.0
98
+ x = ko.enhance.adjust.adjust_saturation(x, rnd_sat)
99
+
100
+ # jpeg
101
+ x = x.reshape(sh)
102
+ if self.do_jpeg:
103
+ jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.ramp) * (100. - self.jpeg_quality)
104
+ x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
105
+
106
+ x = x * 2 - 1 # [0, 1] -> [-1, 1]
107
+ return x
108
+
109
+
110
+ class ImagenetCTransform(nn.Module):
111
+ def __init__(self, max_severity=5) -> None:
112
+ super().__init__()
113
+ self.max_severity = max_severity
114
+ self.tform = RandomImagenetC(max_severity=max_severity, phase='train')
115
+
116
+ def forward(self, x):
117
+ # x: [batch_size, 3, H, W] in range [-1, 1]
118
+ img0 = x.detach().cpu().numpy()
119
+ img = img0 * 127.5 + 127.5 # [-1, 1] -> [0, 255]
120
+ img = img.transpose(0, 2, 3, 1).astype(np.uint8)
121
+ img = [Image.fromarray(i) for i in img]
122
+ img = [self.tform(i) for i in img]
123
+ img = np.array([np.array(i) for i in img], dtype=np.float32)
124
+ img = img.transpose(0, 3, 1, 2) / 127.5 - 1. # [0, 255] -> [-1, 1]
125
+ residual = torch.from_numpy(img - img0).to(x.device)
126
+ x = x + residual
127
+ return x
cldm/transformations2.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ use kornia and albumentations for transformations
5
+ @author: Tu Bui @University of Surrey
6
+ """
7
+ import os
8
+ from . import utils
9
+ import torch
10
+ import numpy as np
11
+ from torch import nn
12
+ import torch.nn.functional as thf
13
+ from PIL import Image
14
+ import kornia as ko
15
+ import albumentations as ab
16
+ from torchvision import transforms
17
+
18
+
19
+ class IdentityAugment(nn.Module):
20
+ def __init__(self):
21
+ super().__init__()
22
+
23
+ def forward(self, x, **kwargs):
24
+ return x
25
+
26
+
27
+ class RandomCompress(nn.Module):
28
+ def __init__(self, severity='medium', p=0.5):
29
+ super().__init__()
30
+ self.p = p
31
+ if severity == 'low':
32
+ self.jpeg_quality = 70
33
+ elif severity == 'medium':
34
+ self.jpeg_quality = 50
35
+ elif severity == 'high':
36
+ self.jpeg_quality = 40
37
+
38
+ def forward(self, x, ramp=1.):
39
+ # x (B, C, H, W) in range [0, 1]
40
+ # ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
41
+ if torch.rand(1)[0] >= self.p:
42
+ return x
43
+ jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
44
+ x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
45
+ return x
46
+
47
+
48
+ class RandomBoxBlur(nn.Module):
49
+ def __init__(self, severity='medium', border_type='reflect', normalized=True, p=0.5):
50
+ super().__init__()
51
+ self.p = p
52
+ if severity == 'low':
53
+ kernel_size = 3
54
+ elif severity == 'medium':
55
+ kernel_size = 5
56
+ elif severity == 'high':
57
+ kernel_size = 7
58
+
59
+ self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalized=normalized, p=self.p)
60
+
61
+ def forward(self, x, **kwargs):
62
+ return self.tform(x)
63
+
64
+ class RandomMedianBlur(nn.Module):
65
+ def __init__(self, severity='medium', p=0.5):
66
+ super().__init__()
67
+ self.p = p
68
+ self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
69
+
70
+ def forward(self, x, **kwargs):
71
+ return self.tform(x)
72
+
73
+
74
+ class RandomBrightness(nn.Module):
75
+ def __init__(self, severity='medium', p=0.5):
76
+ super().__init__()
77
+ self.p = p
78
+ if severity == 'low':
79
+ brightness = (0.9, 1.1)
80
+ elif severity == 'medium':
81
+ brightness = (0.75, 1.25)
82
+ elif severity == 'high':
83
+ brightness = (0.5, 1.5)
84
+ self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
85
+
86
+ def forward(self, x, **kwargs):
87
+ return self.tform(x)
88
+
89
+
90
+ class RandomContrast(nn.Module):
91
+ def __init__(self, severity='medium', p=0.5):
92
+ super().__init__()
93
+ self.p = p
94
+ if severity == 'low':
95
+ contrast = (0.9, 1.1)
96
+ elif severity == 'medium':
97
+ contrast = (0.75, 1.25)
98
+ elif severity == 'high':
99
+ contrast = (0.5, 1.5)
100
+ self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
101
+
102
+ def forward(self, x, **kwargs):
103
+ return self.tform(x)
104
+
105
+
106
+ class RandomSaturation(nn.Module):
107
+ def __init__(self, severity='medium', p=0.5):
108
+ super().__init__()
109
+ self.p = p
110
+ if severity == 'low':
111
+ sat = (0.9, 1.1)
112
+ elif severity == 'medium':
113
+ sat = (0.75, 1.25)
114
+ elif severity == 'high':
115
+ sat = (0.5, 1.5)
116
+ self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
117
+
118
+ def forward(self, x, **kwargs):
119
+ return self.tform(x)
120
+
121
+ class RandomSharpness(nn.Module):
122
+ def __init__(self, severity='medium', p=0.5):
123
+ super().__init__()
124
+ self.p = p
125
+ if severity == 'low':
126
+ sharpness = 0.5
127
+ elif severity == 'medium':
128
+ sharpness = 1.0
129
+ elif severity == 'high':
130
+ sharpness = 2.5
131
+ self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
132
+
133
+ def forward(self, x, **kwargs):
134
+ return self.tform(x)
135
+
136
+ class RandomColorJiggle(nn.Module):
137
+ def __init__(self, severity='medium', p=0.5):
138
+ super().__init__()
139
+ self.p = p
140
+ if severity == 'low':
141
+ factor = (0.05, 0.05, 0.05, 0.01)
142
+ elif severity == 'medium':
143
+ factor = (0.1, 0.1, 0.1, 0.02)
144
+ elif severity == 'high':
145
+ factor = (0.1, 0.1, 0.1, 0.05)
146
+ self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
147
+
148
+ def forward(self, x, **kwargs):
149
+ return self.tform(x)
150
+
151
+ class RandomHue(nn.Module):
152
+ def __init__(self, severity='medium', p=0.5):
153
+ super().__init__()
154
+ self.p = p
155
+ if severity == 'low':
156
+ hue = 0.01
157
+ elif severity == 'medium':
158
+ hue = 0.02
159
+ elif severity == 'high':
160
+ hue = 0.05
161
+ self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
162
+
163
+ def forward(self, x, **kwargs):
164
+ return self.tform(x)
165
+
166
+ class RandomGamma(nn.Module):
167
+ def __init__(self, severity='medium', p=0.5):
168
+ super().__init__()
169
+ self.p = p
170
+ if severity == 'low':
171
+ gamma, gain = (0.9, 1.1), (0.9,1.1)
172
+ elif severity == 'medium':
173
+ gamma, gain = (0.75, 1.25), (0.75,1.25)
174
+ elif severity == 'high':
175
+ gamma, gain = (0.5, 1.5), (0.5,1.5)
176
+ self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
177
+
178
+ def forward(self, x, **kwargs):
179
+ return self.tform(x)
180
+
181
+ class RandomGaussianBlur(nn.Module):
182
+ def __init__(self, severity='medium', p=0.5):
183
+ super().__init__()
184
+ self.p = p
185
+ if severity == 'low':
186
+ kernel_size, sigma = 3, (0.1, 1.0)
187
+ elif severity == 'medium':
188
+ kernel_size, sigma = 5, (0.1, 1.5)
189
+ elif severity == 'high':
190
+ kernel_size, sigma = 7, (0.1, 2.0)
191
+ self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
192
+
193
+ def forward(self, x, **kwargs):
194
+ return self.tform(x)
195
+
196
+ class RandomGaussianNoise(nn.Module):
197
+ def __init__(self, severity='medium', p=0.5):
198
+ super().__init__()
199
+ self.p = p
200
+ if severity == 'low':
201
+ std = 0.02
202
+ elif severity == 'medium':
203
+ std = 0.04
204
+ elif severity == 'high':
205
+ std = 0.08
206
+ self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
207
+
208
+ def forward(self, x, **kwargs):
209
+ return self.tform(x)
210
+
211
+ class RandomMotionBlur(nn.Module):
212
+ def __init__(self, severity='medium', p=0.5):
213
+ super().__init__()
214
+ self.p = p
215
+ if severity == 'low':
216
+ kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
217
+ elif severity == 'medium':
218
+ kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
219
+ elif severity == 'high':
220
+ kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
221
+ self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
222
+
223
+ def forward(self, x, **kwargs):
224
+ return self.tform(x)
225
+
226
+ class RandomPosterize(nn.Module):
227
+ def __init__(self, severity='medium', p=0.5):
228
+ super().__init__()
229
+ self.p = p
230
+ if severity == 'low':
231
+ bits = 5
232
+ elif severity == 'medium':
233
+ bits = 4
234
+ elif severity == 'high':
235
+ bits = 3
236
+ self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
237
+
238
+ def forward(self, x, **kwargs):
239
+ return self.tform(x)
240
+
241
+ class RandomRGBShift(nn.Module):
242
+ def __init__(self, severity='medium', p=0.5):
243
+ super().__init__()
244
+ self.p = p
245
+ if severity == 'low':
246
+ rgb = 0.02
247
+ elif severity == 'medium':
248
+ rgb = 0.05
249
+ elif severity == 'high':
250
+ rgb = 0.1
251
+ self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
252
+
253
+ def forward(self, x, **kwargs):
254
+ return self.tform(x)
255
+
256
+
257
+
258
+ class TransformNet(nn.Module):
259
+ def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=False, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, box_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
260
+ super().__init__()
261
+ self.n_optional = n_optional
262
+ self.p = p
263
+ p_flip = 0.5 if flip else 0
264
+ rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
265
+ self.ramp = ramp
266
+ self.register_buffer('step0', torch.tensor(0))
267
+
268
+ self.crop_mode = crop_mode
269
+ assert crop_mode in ['random_crop', 'resized_crop']
270
+ if crop_mode == 'random_crop':
271
+ rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
272
+ elif crop_mode == 'resized_crop':
273
+ rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
274
+
275
+ self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
276
+ if compress:
277
+ self.register(RandomCompress(severity, p=p), 'Random Compress')
278
+ if brightness:
279
+ self.register(RandomBrightness(severity, p=p), 'Random Brightness')
280
+ if contrast:
281
+ self.register(RandomContrast(severity, p=p), 'Random Contrast')
282
+ if color_jiggle:
283
+ self.register(RandomColorJiggle(severity, p=p), 'Random Color')
284
+ if gamma:
285
+ self.register(RandomGamma(severity, p=p), 'Random Gamma')
286
+ if grayscale:
287
+ self.register(ko.augmentation.RandomGrayscale(p=p), 'Grayscale')
288
+ if gaussian_blur:
289
+ self.register(RandomGaussianBlur(severity, p=p), 'Random Gaussian Blur')
290
+ if gaussian_noise:
291
+ self.register(RandomGaussianNoise(severity, p=p), 'Random Gaussian Noise')
292
+ if hue:
293
+ self.register(RandomHue(severity, p=p), 'Random Hue')
294
+ if motion_blur:
295
+ self.register(RandomMotionBlur(severity, p=p), 'Random Motion Blur')
296
+ if posterize:
297
+ self.register(RandomPosterize(severity, p=p), 'Random Posterize')
298
+ if rgb_shift:
299
+ self.register(RandomRGBShift(severity, p=p), 'Random RGB Shift')
300
+ if saturation:
301
+ self.register(RandomSaturation(severity, p=p), 'Random Saturation')
302
+ if sharpness:
303
+ self.register(RandomSharpness(severity, p=p), 'Random Sharpness')
304
+ if median_blur:
305
+ self.register(RandomMedianBlur(severity, p=p), 'Random Median Blur')
306
+ if box_blur:
307
+ self.register(RandomBoxBlur(severity, p=p), 'Random Box Blur')
308
+
309
+ def register(self, tform, name):
310
+ # register a new (optional) transform
311
+ if not hasattr(self, 'optional_transforms'):
312
+ self.optional_transforms = []
313
+ self.optional_names = []
314
+ self.optional_transforms.append(tform)
315
+ self.optional_names.append(name)
316
+
317
+ def activate(self, global_step):
318
+ if self.step0 == 0:
319
+ print(f'[TRAINING] Activating TransformNet at step {global_step}')
320
+ self.step0 = torch.tensor(global_step)
321
+
322
+ def is_activated(self):
323
+ return self.step0 > 0
324
+
325
+ def forward(self, x, global_step, p=0.9):
326
+ # x: [batch_size, 3, H, W] in range [-1, 1]
327
+ x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
328
+ # fixed transforms
329
+ for tform in self.fixed_transforms:
330
+ x = tform(x)
331
+ if isinstance(x, tuple):
332
+ x = x[0]
333
+
334
+ # optional transforms
335
+ ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
336
+ if len(self.optional_transforms) > 0:
337
+ tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
338
+ for tform_id in tform_ids:
339
+ tform = self.optional_transforms[tform_id]
340
+ x = tform(x, ramp=ramp)
341
+ if isinstance(x, tuple):
342
+ x = x[0]
343
+
344
+ return x * 2 - 1 # [0, 1] -> [-1, 1]
345
+
346
+ def transform_by_id(self, x, tform_id):
347
+ # x: [batch_size, 3, H, W] in range [-1, 1]
348
+ x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
349
+ # fixed transforms
350
+ for tform in self.fixed_transforms:
351
+ x = tform(x)
352
+ if isinstance(x, tuple):
353
+ x = x[0]
354
+
355
+ # optional transforms
356
+ tform = self.optional_transforms[tform_id]
357
+ x = tform(x)
358
+ if isinstance(x, tuple):
359
+ x = x[0]
360
+ return x * 2 - 1 # [0, 1] -> [-1, 1]
361
+
362
+ def transform_by_name(self, x, tform_name):
363
+ assert tform_name in self.optional_names
364
+ tform_id = self.optional_names.index(tform_name)
365
+ return self.transform_by_id(x, tform_id)
366
+
367
+ def apply_transform_on_pil_image(self, x, tform_name):
368
+ # x: PIL image
369
+ # return: PIL image
370
+ assert tform_name in self.optional_names + ['Random Crop', 'Random Flip']
371
+ # if tform_name == 'Random Crop': # the only transform dependent on image size
372
+ # # crop equivalent to 224/256
373
+ # w, h = x.size
374
+ # new_w, new_h = int(224 / 256 * w), int(224 / 256 * h)
375
+ # x = transforms.RandomCrop((new_h, new_w))(x)
376
+ # return x
377
+
378
+ # x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
379
+ # x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
380
+ # if tform_name == 'Random Flip':
381
+ # x = self.fixed_transforms[0](x)
382
+ # else:
383
+ # tform_id = self.optional_names.index(tform_name)
384
+ # tform = self.optional_transforms[tform_id]
385
+ # x = tform(x)
386
+ # if isinstance(x, tuple):
387
+ # x = x[0]
388
+ # x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
389
+ # return Image.fromarray(x.astype(np.uint8))
390
+
391
+ w, h = x.size
392
+ x = x.resize((256, 256), Image.BILINEAR)
393
+ x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
394
+ x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
395
+ if tform_name == 'Random Flip':
396
+ x = self.fixed_transforms[0](x)
397
+ elif tform_name == 'Random Crop':
398
+ x = self.fixed_transforms[1](x)
399
+ else:
400
+ tform_id = self.optional_names.index(tform_name)
401
+ tform = self.optional_transforms[tform_id]
402
+ x = tform(x)
403
+ if isinstance(x, tuple):
404
+ x = x[0]
405
+ x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
406
+ x = Image.fromarray(x.astype(np.uint8))
407
+ if (tform_name == 'Random Crop') and (self.crop_mode == 'random_crop'):
408
+ w, h = int(224 / 256 * w), int(224 / 256 * h)
409
+ x = x.resize((w, h), Image.BILINEAR)
410
+ return x
411
+
412
+
413
+ if __name__ == '__main__':
414
+ pass
cldm/utils.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import itertools
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torch.nn as nn
8
+
9
+ from PIL import Image, ImageOps
10
+ import matplotlib.pyplot as plt
11
+
12
+ def random_blur_kernel(probs, N_blur, sigrange_gauss, sigrange_line, wmin_line):
13
+ N = N_blur
14
+ coords = torch.from_numpy(np.stack(np.meshgrid(range(N_blur), range(N_blur), indexing='ij'), axis=-1)) - (0.5 * (N-1)) # (7,7,2)
15
+ manhat = torch.sum(torch.abs(coords), dim=-1) # (7, 7)
16
+
17
+ # nothing, default
18
+ vals_nothing = (manhat < 0.5).float() # (7, 7)
19
+
20
+ # gauss
21
+ sig_gauss = torch.rand(1)[0] * (sigrange_gauss[1] - sigrange_gauss[0]) + sigrange_gauss[0]
22
+ vals_gauss = torch.exp(-torch.sum(coords ** 2, dim=-1) /2. / sig_gauss ** 2)
23
+
24
+ # line
25
+ theta = torch.rand(1)[0] * 2.* np.pi
26
+ v = torch.FloatTensor([torch.cos(theta), torch.sin(theta)]) # (2)
27
+ dists = torch.sum(coords * v, dim=-1) # (7, 7)
28
+
29
+ sig_line = torch.rand(1)[0] * (sigrange_line[1] - sigrange_line[0]) + sigrange_line[0]
30
+ w_line = torch.rand(1)[0] * (0.5 * (N-1) + 0.1 - wmin_line) + wmin_line
31
+
32
+ vals_line = torch.exp(-dists ** 2 / 2. / sig_line ** 2) * (manhat < w_line) # (7, 7)
33
+
34
+ t = torch.rand(1)[0]
35
+ vals = vals_nothing
36
+ if t < (probs[0] + probs[1]):
37
+ vals = vals_line
38
+ else:
39
+ vals = vals
40
+ if t < probs[0]:
41
+ vals = vals_gauss
42
+ else:
43
+ vals = vals
44
+
45
+ v = vals / torch.sum(vals) # 归一化 (7, 7)
46
+ z = torch.zeros_like(v)
47
+ f = torch.stack([v,z,z, z,v,z, z,z,v], dim=0).reshape([3, 3, N, N])
48
+ return f
49
+
50
+
51
+ def get_rand_transform_matrix(image_size, d, batch_size):
52
+ Ms = np.zeros((batch_size, 2, 3, 3))
53
+ for i in range(batch_size):
54
+ tl_x = random.uniform(-d, d) # Top left corner, top
55
+ tl_y = random.uniform(-d, d) # Top left corner, left
56
+ bl_x = random.uniform(-d, d) # Bot left corner, bot
57
+ bl_y = random.uniform(-d, d) # Bot left corner, left
58
+ tr_x = random.uniform(-d, d) # Top right corner, top
59
+ tr_y = random.uniform(-d, d) # Top right corner, right
60
+ br_x = random.uniform(-d, d) # Bot right corner, bot
61
+ br_y = random.uniform(-d, d) # Bot right corner, right
62
+
63
+ rect = np.array([
64
+ [tl_x, tl_y],
65
+ [tr_x + image_size, tr_y],
66
+ [br_x + image_size, br_y + image_size],
67
+ [bl_x, bl_y + image_size]], dtype = "float32")
68
+
69
+ dst = np.array([
70
+ [0, 0],
71
+ [image_size, 0],
72
+ [image_size, image_size],
73
+ [0, image_size]], dtype = "float32")
74
+
75
+ M = cv2.getPerspectiveTransform(rect, dst)
76
+ M_inv = np.linalg.inv(M)
77
+ Ms[i, 0, :, :] = M_inv
78
+ Ms[i, 1, :, :] = M
79
+ Ms = torch.from_numpy(Ms).float()
80
+
81
+ return Ms
82
+
83
+
84
+ def get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size):
85
+ rnd_hue = torch.FloatTensor(batch_size, 3, 1, 1).uniform_(-rnd_hue, rnd_hue)
86
+ rnd_brightness = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(-rnd_bri, rnd_bri)
87
+ return rnd_hue + rnd_brightness
88
+
89
+
90
+ # reference: https://github.com/mlomnitz/DiffJPEG.git
91
+ y_table = np.array(
92
+ [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
93
+ 55], [14, 13, 16, 24, 40, 57, 69, 56],
94
+ [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
95
+ 77], [24, 35, 55, 64, 81, 104, 113, 92],
96
+ [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
97
+ dtype=np.float32).T
98
+
99
+ y_table = nn.Parameter(torch.from_numpy(y_table))
100
+ c_table = np.empty((8, 8), dtype=np.float32)
101
+ c_table.fill(99)
102
+ c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
103
+ [24, 26, 56, 99], [47, 66, 99, 99]]).T
104
+ c_table = nn.Parameter(torch.from_numpy(c_table))
105
+
106
+ # 1. RGB -> YCbCr
107
+ class rgb_to_ycbcr_jpeg(nn.Module):
108
+ """ Converts RGB image to YCbCr
109
+ Input:
110
+ image(tensor): batch x 3 x height x width
111
+ Outpput:
112
+ result(tensor): batch x height x width x 3
113
+ """
114
+ def __init__(self):
115
+ super(rgb_to_ycbcr_jpeg, self).__init__()
116
+ matrix = np.array(
117
+ [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5],
118
+ [0.5, -0.418688, -0.081312]], dtype=np.float32).T
119
+ self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
120
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
121
+
122
+ def forward(self, image):
123
+ image = image.permute(0, 2, 3, 1)
124
+ result = torch.tensordot(image, self.matrix, dims=1) + self.shift
125
+ result.view(image.shape)
126
+ return result
127
+
128
+ # 2. Chroma subsampling
129
+ class chroma_subsampling(nn.Module):
130
+ """ Chroma subsampling on CbCv channels
131
+ Input:
132
+ image(tensor): batch x height x width x 3
133
+ Output:
134
+ y(tensor): batch x height x width
135
+ cb(tensor): batch x height/2 x width/2
136
+ cr(tensor): batch x height/2 x width/2
137
+ """
138
+ def __init__(self):
139
+ super(chroma_subsampling, self).__init__()
140
+
141
+ def forward(self, image):
142
+ image_2 = image.permute(0, 3, 1, 2).clone()
143
+ avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
144
+ count_include_pad=False)
145
+ cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
146
+ cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
147
+ cb = cb.permute(0, 2, 3, 1)
148
+ cr = cr.permute(0, 2, 3, 1)
149
+ return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
150
+
151
+ # 3. Block splitting
152
+ class block_splitting(nn.Module):
153
+ """ Splitting image into patches
154
+ Input:
155
+ image(tensor): batch x height x width
156
+ Output:
157
+ patch(tensor): batch x h*w/64 x h x w
158
+ """
159
+ def __init__(self):
160
+ super(block_splitting, self).__init__()
161
+ self.k = 8
162
+
163
+ def forward(self, image):
164
+ height, width = image.shape[1:3]
165
+ batch_size = image.shape[0]
166
+ image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
167
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
168
+ return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
169
+
170
+ # 4. DCT
171
+ class dct_8x8(nn.Module):
172
+ """ Discrete Cosine Transformation
173
+ Input:
174
+ image(tensor): batch x height x width
175
+ Output:
176
+ dcp(tensor): batch x height x width
177
+ """
178
+ def __init__(self):
179
+ super(dct_8x8, self).__init__()
180
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
181
+ for x, y, u, v in itertools.product(range(8), repeat=4):
182
+ tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
183
+ (2 * y + 1) * v * np.pi / 16)
184
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
185
+ #
186
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
187
+ self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() )
188
+
189
+ def forward(self, image):
190
+ image = image - 128
191
+ result = self.scale * torch.tensordot(image, self.tensor, dims=2)
192
+ result.view(image.shape)
193
+ return result
194
+
195
+ # 5. Quantization
196
+ class y_quantize(nn.Module):
197
+ """ JPEG Quantization for Y channel
198
+ Input:
199
+ image(tensor): batch x height x width
200
+ rounding(function): rounding function to use
201
+ factor(float): Degree of compression
202
+ Output:
203
+ image(tensor): batch x height x width
204
+ """
205
+ def __init__(self, rounding, factor=1):
206
+ super(y_quantize, self).__init__()
207
+ self.rounding = rounding
208
+ self.factor = factor
209
+ self.y_table = y_table
210
+
211
+ def forward(self, image):
212
+ image = image.float() / (self.y_table * self.factor)
213
+ image = self.rounding(image)
214
+ return image
215
+
216
+
217
+ class c_quantize(nn.Module):
218
+ """ JPEG Quantization for CrCb channels
219
+ Input:
220
+ image(tensor): batch x height x width
221
+ rounding(function): rounding function to use
222
+ factor(float): Degree of compression
223
+ Output:
224
+ image(tensor): batch x height x width
225
+ """
226
+ def __init__(self, rounding, factor=1):
227
+ super(c_quantize, self).__init__()
228
+ self.rounding = rounding
229
+ self.factor = factor
230
+ self.c_table = c_table
231
+
232
+ def forward(self, image):
233
+ image = image.float() / (self.c_table * self.factor)
234
+ image = self.rounding(image)
235
+ return image
236
+
237
+
238
+ class compress_jpeg(nn.Module):
239
+ """ Full JPEG compression algortihm
240
+ Input:
241
+ imgs(tensor): batch x 3 x height x width
242
+ rounding(function): rounding function to use
243
+ factor(float): Compression factor
244
+ Ouput:
245
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
246
+ """
247
+ def __init__(self, rounding=torch.round, factor=1):
248
+ super(compress_jpeg, self).__init__()
249
+ self.l1 = nn.Sequential(
250
+ rgb_to_ycbcr_jpeg(),
251
+ chroma_subsampling()
252
+ )
253
+ self.l2 = nn.Sequential(
254
+ block_splitting(),
255
+ dct_8x8()
256
+ )
257
+ self.c_quantize = c_quantize(rounding=rounding, factor=factor)
258
+ self.y_quantize = y_quantize(rounding=rounding, factor=factor)
259
+
260
+ def forward(self, image):
261
+ y, cb, cr = self.l1(image*255)
262
+ components = {'y': y, 'cb': cb, 'cr': cr}
263
+ for k in components.keys():
264
+ comp = self.l2(components[k])
265
+ if k in ('cb', 'cr'):
266
+ comp = self.c_quantize(comp)
267
+ else:
268
+ comp = self.y_quantize(comp)
269
+
270
+ components[k] = comp
271
+
272
+ return components['y'], components['cb'], components['cr']
273
+
274
+ # -5. Dequantization
275
+ class y_dequantize(nn.Module):
276
+ """ Dequantize Y channel
277
+ Inputs:
278
+ image(tensor): batch x height x width
279
+ factor(float): compression factor
280
+ Outputs:
281
+ image(tensor): batch x height x width
282
+ """
283
+ def __init__(self, factor=1):
284
+ super(y_dequantize, self).__init__()
285
+ self.y_table = y_table
286
+ self.factor = factor
287
+
288
+ def forward(self, image):
289
+ return image * (self.y_table * self.factor)
290
+
291
+
292
+ class c_dequantize(nn.Module):
293
+ """ Dequantize CbCr channel
294
+ Inputs:
295
+ image(tensor): batch x height x width
296
+ factor(float): compression factor
297
+ Outputs:
298
+ image(tensor): batch x height x width
299
+ """
300
+ def __init__(self, factor=1):
301
+ super(c_dequantize, self).__init__()
302
+ self.factor = factor
303
+ self.c_table = c_table
304
+
305
+ def forward(self, image):
306
+ return image * (self.c_table * self.factor)
307
+
308
+ # -4. Inverse DCT
309
+ class idct_8x8(nn.Module):
310
+ """ Inverse discrete Cosine Transformation
311
+ Input:
312
+ dcp(tensor): batch x height x width
313
+ Output:
314
+ image(tensor): batch x height x width
315
+ """
316
+ def __init__(self):
317
+ super(idct_8x8, self).__init__()
318
+ alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
319
+ self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
320
+ tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
321
+ for x, y, u, v in itertools.product(range(8), repeat=4):
322
+ tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
323
+ (2 * v + 1) * y * np.pi / 16)
324
+ self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
325
+
326
+ def forward(self, image):
327
+ image = image * self.alpha
328
+ result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
329
+ result.view(image.shape)
330
+ return result
331
+
332
+ # -3. Block joining
333
+ class block_merging(nn.Module):
334
+ """ Merge pathces into image
335
+ Inputs:
336
+ patches(tensor) batch x height*width/64, height x width
337
+ height(int)
338
+ width(int)
339
+ Output:
340
+ image(tensor): batch x height x width
341
+ """
342
+ def __init__(self):
343
+ super(block_merging, self).__init__()
344
+
345
+ def forward(self, patches, height, width):
346
+ k = 8
347
+ batch_size = patches.shape[0]
348
+ image_reshaped = patches.view(batch_size, height//k, width//k, k, k)
349
+ image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
350
+ return image_transposed.contiguous().view(batch_size, height, width)
351
+
352
+ # -2. Chroma upsampling
353
+ class chroma_upsampling(nn.Module):
354
+ """ Upsample chroma layers
355
+ Input:
356
+ y(tensor): y channel image
357
+ cb(tensor): cb channel
358
+ cr(tensor): cr channel
359
+ Ouput:
360
+ image(tensor): batch x height x width x 3
361
+ """
362
+ def __init__(self):
363
+ super(chroma_upsampling, self).__init__()
364
+
365
+ def forward(self, y, cb, cr):
366
+ def repeat(x, k=2):
367
+ height, width = x.shape[1:3]
368
+ x = x.unsqueeze(-1)
369
+ x = x.repeat(1, 1, k, k)
370
+ x = x.view(-1, height * k, width * k)
371
+ return x
372
+
373
+ cb = repeat(cb)
374
+ cr = repeat(cr)
375
+
376
+ return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
377
+
378
+ # -1: YCbCr -> RGB
379
+ class ycbcr_to_rgb_jpeg(nn.Module):
380
+ """ Converts YCbCr image to RGB JPEG
381
+ Input:
382
+ image(tensor): batch x height x width x 3
383
+ Outpput:
384
+ result(tensor): batch x 3 x height x width
385
+ """
386
+ def __init__(self):
387
+ super(ycbcr_to_rgb_jpeg, self).__init__()
388
+
389
+ matrix = np.array(
390
+ [[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
391
+ dtype=np.float32).T
392
+ self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
393
+ self.matrix = nn.Parameter(torch.from_numpy(matrix))
394
+
395
+ def forward(self, image):
396
+ result = torch.tensordot(image + self.shift, self.matrix, dims=1)
397
+ result.view(image.shape)
398
+ return result.permute(0, 3, 1, 2)
399
+
400
+
401
+ class decompress_jpeg(nn.Module):
402
+ """ Full JPEG decompression algortihm
403
+ Input:
404
+ compressed(dict(tensor)): batch x h*w/64 x 8 x 8
405
+ rounding(function): rounding function to use
406
+ factor(float): Compression factor
407
+ Ouput:
408
+ image(tensor): batch x 3 x height x width
409
+ """
410
+ def __init__(self, height, width, rounding=torch.round, factor=1):
411
+ super(decompress_jpeg, self).__init__()
412
+ self.c_dequantize = c_dequantize(factor=factor)
413
+ self.y_dequantize = y_dequantize(factor=factor)
414
+ self.idct = idct_8x8()
415
+ self.merging = block_merging()
416
+ self.chroma = chroma_upsampling()
417
+ self.colors = ycbcr_to_rgb_jpeg()
418
+
419
+ self.height, self.width = height, width
420
+
421
+ def forward(self, y, cb, cr):
422
+ components = {'y': y, 'cb': cb, 'cr': cr}
423
+ for k in components.keys():
424
+ if k in ('cb', 'cr'):
425
+ comp = self.c_dequantize(components[k])
426
+ height, width = int(self.height/2), int(self.width/2)
427
+ else:
428
+ comp = self.y_dequantize(components[k])
429
+ height, width = self.height, self.width
430
+ comp = self.idct(comp)
431
+ components[k] = self.merging(comp, height, width)
432
+ #
433
+ image = self.chroma(components['y'], components['cb'], components['cr'])
434
+ image = self.colors(image)
435
+
436
+ image = torch.min(255*torch.ones_like(image),
437
+ torch.max(torch.zeros_like(image), image))
438
+ return image/255
439
+
440
+ def diff_round(x):
441
+ """ Differentiable rounding function
442
+ Input:
443
+ x(tensor)
444
+ Output:
445
+ x(tensor)
446
+ """
447
+ return torch.round(x) + (x - torch.round(x))**3
448
+
449
+ def round_only_at_0(x):
450
+ cond = (torch.abs(x) < 0.5).float()
451
+ return cond * (x ** 3) + (1 - cond) * x
452
+
453
+ def quality_to_factor(quality):
454
+ """ Calculate factor corresponding to quality
455
+ Input:
456
+ quality(float): Quality for jpeg compression
457
+ Output:
458
+ factor(float): Compression factor
459
+ """
460
+ if quality < 50:
461
+ quality = 5000. / quality
462
+ else:
463
+ quality = 200. - quality*2
464
+ return quality / 100.
465
+
466
+ def jpeg_compress_decompress(image,
467
+ # downsample_c=True,
468
+ rounding=round_only_at_0,
469
+ quality=80):
470
+ # image_r = image * 255
471
+ height, width = image.shape[2:4]
472
+ # orig_height, orig_width = height, width
473
+ # if height % 16 != 0 or width % 16 != 0:
474
+ # # Round up to next multiple of 16
475
+ # height = ((height - 1) // 16 + 1) * 16
476
+ # width = ((width - 1) // 16 + 1) * 16
477
+
478
+ # vpad = height - orig_height
479
+ # wpad = width - orig_width
480
+ # top = vpad // 2
481
+ # bottom = vpad - top
482
+ # left = wpad // 2
483
+ # right = wpad - left
484
+ # #image = tf.pad(image, [[0, 0], [top, bottom], [left, right], [0, 0]], 'SYMMETRIC')
485
+ # image = torch.pad(image, [[0, 0], [0, vpad], [0, wpad], [0, 0]], 'reflect')
486
+
487
+ factor = quality_to_factor(quality)
488
+
489
+ compress = compress_jpeg(rounding=rounding, factor=factor).to(image.device)
490
+ decompress = decompress_jpeg(height, width, rounding=rounding, factor=factor).to(image.device)
491
+
492
+ y, cb, cr = compress(image)
493
+ recovered = decompress(y, cb, cr)
494
+
495
+ return recovered.contiguous()
496
+
497
+
498
+ if __name__ == '__main__':
499
+ ''' test JPEG compress and decompress'''
500
+ # img = Image.open('house.jpg')
501
+ # img = np.array(img) / 255.
502
+ # img_r = np.transpose(img, [2, 0, 1])
503
+ # img_tensor = torch.from_numpy(img_r).unsqueeze(0).float()
504
+
505
+ # recover = jpeg_compress_decompress(img_tensor)
506
+
507
+ # recover_arr = recover.detach().squeeze(0).numpy()
508
+ # recover_arr = np.transpose(recover_arr, [1, 2, 0])
509
+
510
+ # plt.subplot(121)
511
+ # plt.imshow(img)
512
+ # plt.subplot(122)
513
+ # plt.imshow(recover_arr)
514
+ # plt.show()
515
+
516
+ ''' test blur '''
517
+ # blur
518
+
519
+ img = Image.open('house.jpg')
520
+ img = np.array(img) / 255.
521
+ img_r = np.transpose(img, [2, 0, 1])
522
+ img_tensor = torch.from_numpy(img_r).unsqueeze(0).float()
523
+ print(img_tensor.shape)
524
+
525
+ N_blur=7
526
+ f = random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.], wmin_line=3)
527
+ # print(f.shape)
528
+ # print(type(f))
529
+ encoded_image = F.conv2d(img_tensor, f, bias=None, padding=int((N_blur-1)/2))
530
+
531
+ encoded_image = encoded_image.detach().squeeze(0).numpy()
532
+ encoded_image = np.transpose(encoded_image, [1, 2, 0])
533
+
534
+ plt.subplot(121)
535
+ plt.imshow(img)
536
+ plt.subplot(122)
537
+ plt.imshow(encoded_image)
538
+ plt.show()
539
+
flae/models.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as thf
5
+ import pytorch_lightning as pl
6
+ from ldm.util import instantiate_from_config
7
+ import einops
8
+ import kornia
9
+ import numpy as np
10
+ import torchvision
11
+ from contextlib import contextmanager
12
+ from ldm.modules.ema import LitEma
13
+
14
+
15
+ class FlAE(pl.LightningModule):
16
+ def __init__(self,
17
+ cover_key,
18
+ secret_key,
19
+ secret_len,
20
+ resolution,
21
+ secret_encoder_config,
22
+ secret_decoder_config,
23
+ loss_config,
24
+ noise_config='__none__',
25
+ ckpt_path="__none__",
26
+ use_ema=False
27
+ ):
28
+ super().__init__()
29
+ self.cover_key = cover_key
30
+ self.secret_key = secret_key
31
+ secret_encoder_config.params.secret_len = secret_len
32
+ secret_decoder_config.params.secret_len = secret_len
33
+ secret_encoder_config.params.resolution = resolution
34
+ secret_decoder_config.params.resolution = 224
35
+ self.encoder = instantiate_from_config(secret_encoder_config)
36
+ self.decoder = instantiate_from_config(secret_decoder_config)
37
+ self.loss_layer = instantiate_from_config(loss_config)
38
+ if noise_config != '__none__':
39
+ print('Using noise')
40
+ self.noise = instantiate_from_config(noise_config)
41
+
42
+ self.use_ema = use_ema
43
+ if self.use_ema:
44
+ print('Using EMA')
45
+ self.encoder_ema = LitEma(self.encoder)
46
+ self.decoder_ema = LitEma(self.decoder)
47
+ print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
48
+
49
+ if ckpt_path != "__none__":
50
+ self.init_from_ckpt(ckpt_path, ignore_keys=[])
51
+
52
+ # early training phase
53
+ self.fixed_img = None
54
+ self.fixed_secret = None
55
+ self.register_buffer("fixed_input", torch.tensor(True))
56
+ self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
57
+
58
+ def init_from_ckpt(self, path, ignore_keys=list()):
59
+ sd = torch.load(path, map_location="cpu")["state_dict"]
60
+ keys = list(sd.keys())
61
+ for k in keys:
62
+ for ik in ignore_keys:
63
+ if k.startswith(ik):
64
+ print("Deleting key {} from state_dict.".format(k))
65
+ del sd[k]
66
+ self.load_state_dict(sd, strict=False)
67
+ print(f"Restored from {path}")
68
+
69
+ @contextmanager
70
+ def ema_scope(self, context=None):
71
+ if self.use_ema:
72
+ self.encoder_ema.store(self.encoder.parameters())
73
+ self.decoder_ema.store(self.decoder.parameters())
74
+ self.encoder_ema.copy_to(self.encoder)
75
+ self.decoder_ema.copy_to(self.decoder)
76
+ if context is not None:
77
+ print(f"{context}: Switched to EMA weights")
78
+ try:
79
+ yield None
80
+ finally:
81
+ if self.use_ema:
82
+ self.encoder_ema.restore(self.encoder.parameters())
83
+ self.decoder_ema.restore(self.decoder.parameters())
84
+ if context is not None:
85
+ print(f"{context}: Restored training weights")
86
+
87
+ def on_train_batch_end(self, *args, **kwargs):
88
+ if self.use_ema:
89
+ self.encoder_ema(self.encoder)
90
+ self.decoder_ema(self.decoder)
91
+
92
+ @torch.no_grad()
93
+ def get_input(self, batch, bs=None):
94
+ image = batch[self.cover_key]
95
+ secret = batch[self.secret_key]
96
+ if bs is not None:
97
+ image = image[:bs]
98
+ secret = secret[:bs]
99
+ else:
100
+ bs = image.shape[0]
101
+ # encode image 1st stage
102
+ image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
103
+
104
+ # check if using fixed input (early training phase)
105
+ # if self.training and self.fixed_input:
106
+ if self.fixed_input:
107
+ if self.fixed_img is None: # first iteration
108
+ print('[TRAINING] Warmup - using fixed input image for now!')
109
+ self.fixed_img = image.detach().clone()[:bs]
110
+ self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only
111
+ image = self.fixed_img
112
+ new_bs = min(secret.shape[0], image.shape[0])
113
+ image, secret = image[:new_bs], secret[:new_bs]
114
+
115
+ out = [image, secret]
116
+ return out
117
+
118
+ def forward(self, cover, secret):
119
+ # return a tuple (stego, residual)
120
+ enc_out = self.encoder(cover, secret)
121
+ if self.encoder.return_residual:
122
+ return cover + enc_out, enc_out
123
+ else:
124
+ return enc_out, enc_out - cover
125
+
126
+ def shared_step(self, batch):
127
+ x, s = self.get_input(batch)
128
+ stego, residual = self(x, s)
129
+ if hasattr(self, "noise") and self.noise.is_activated():
130
+ stego_noised = self.noise(stego, self.global_step, p=0.9)
131
+ else:
132
+ stego_noised = self.crop(stego)
133
+ stego_noised = torch.clamp(stego_noised, -1, 1)
134
+ spred = self.decoder(stego_noised)
135
+
136
+ loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step)
137
+ bit_acc = loss_dict["bit_acc"]
138
+
139
+ bit_acc_ = bit_acc.item()
140
+
141
+ if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
142
+ self.loss_layer.activate_ramp(self.global_step)
143
+
144
+ if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
145
+ if hasattr(self, 'noise') and (not self.noise.is_activated()):
146
+ self.noise.activate(self.global_step)
147
+
148
+ if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
149
+ print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
150
+ self.fixed_input = ~self.fixed_input
151
+ return loss, loss_dict
152
+
153
+ def training_step(self, batch, batch_idx):
154
+ loss, loss_dict = self.shared_step(batch)
155
+ loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
156
+ self.log_dict(loss_dict, prog_bar=True,
157
+ logger=True, on_step=True, on_epoch=True)
158
+
159
+ self.log("global_step", self.global_step,
160
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
161
+ # if self.use_scheduler:
162
+ # lr = self.optimizers().param_groups[0]['lr']
163
+ # self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
164
+
165
+ return loss
166
+
167
+ @torch.no_grad()
168
+ def validation_step(self, batch, batch_idx):
169
+ _, loss_dict_no_ema = self.shared_step(batch)
170
+ loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
171
+ with self.ema_scope():
172
+ _, loss_dict_ema = self.shared_step(batch)
173
+ loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
174
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
175
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
176
+
177
+ @torch.no_grad()
178
+ def log_images(self, batch, fixed_input=False, **kwargs):
179
+ log = dict()
180
+ if fixed_input and self.fixed_img is not None:
181
+ x, s = self.fixed_img, self.fixed_secret
182
+ else:
183
+ x, s = self.get_input(batch)
184
+ stego, residual = self(x, s)
185
+ if hasattr(self, 'noise') and self.noise.is_activated():
186
+ img_noise = self.noise(stego, self.global_step, p=1.0)
187
+ log['noised'] = img_noise
188
+ log['input'] = x
189
+ log['stego'] = stego
190
+ log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
191
+ return log
192
+
193
+ def configure_optimizers(self):
194
+ lr = self.learning_rate
195
+ params = list(self.encoder.parameters()) + list(self.decoder.parameters())
196
+ optimizer = torch.optim.AdamW(params, lr=lr)
197
+ return optimizer
198
+
199
+
200
+
201
+
202
+ class SecretEncoder(nn.Module):
203
+ def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None:
204
+ super().__init__()
205
+ self.secret_len = secret_len
206
+ self.return_residual = return_residual
207
+ self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0
208
+ self.secret_dense = nn.Linear(secret_len, 16*16*3)
209
+ log_resolution = int(math.log(resolution, 2))
210
+ assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
211
+ self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
212
+ self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1)
213
+ self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
214
+ self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
215
+ self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
216
+ self.conv5 = nn.Conv2d(128, 256, 3, 2, 1)
217
+ self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
218
+ self.up6 = nn.Conv2d(256, 128, 2, 1)
219
+ self.upsample6 = nn.Upsample(scale_factor=(2, 2))
220
+ self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1)
221
+ self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
222
+ self.up7 = nn.Conv2d(128, 64, 2, 1)
223
+ self.upsample7 = nn.Upsample(scale_factor=(2, 2))
224
+ self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1)
225
+ self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
226
+ self.up8 = nn.Conv2d(64, 32, 2, 1)
227
+ self.upsample8 = nn.Upsample(scale_factor=(2, 2))
228
+ self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1)
229
+ self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
230
+ self.up9 = nn.Conv2d(32, 32, 2, 1)
231
+ self.upsample9 = nn.Upsample(scale_factor=(2, 2))
232
+ self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1)
233
+ self.conv10 = nn.Conv2d(32, 32, 3, 1, 1)
234
+ self.residual = nn.Conv2d(32, 3, 1)
235
+
236
+ def forward(self, image, secret):
237
+ fingerprint = thf.relu(self.secret_dense(secret))
238
+ fingerprint = fingerprint.view((-1, 3, 16, 16))
239
+ fingerprint_enlarged = self.secret_upsample(fingerprint)
240
+ # try:
241
+ inputs = torch.cat([fingerprint_enlarged, image], dim=1)
242
+ # except:
243
+ # print(fingerprint_enlarged.shape, image.shape, fingerprint.shape)
244
+ # import pdb; pdb.set_trace()
245
+ conv1 = thf.relu(self.conv1(inputs))
246
+ conv2 = thf.relu(self.conv2(conv1))
247
+ conv3 = thf.relu(self.conv3(conv2))
248
+ conv4 = thf.relu(self.conv4(conv3))
249
+ conv5 = thf.relu(self.conv5(conv4))
250
+ up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5))))
251
+ merge6 = torch.cat([conv4, up6], dim=1)
252
+ conv6 = thf.relu(self.conv6(merge6))
253
+ up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6))))
254
+ merge7 = torch.cat([conv3, up7], dim=1)
255
+ conv7 = thf.relu(self.conv7(merge7))
256
+ up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7))))
257
+ merge8 = torch.cat([conv2, up8], dim=1)
258
+ conv8 = thf.relu(self.conv8(merge8))
259
+ up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8))))
260
+ merge9 = torch.cat([conv1, up9, inputs], dim=1)
261
+ conv9 = thf.relu(self.conv9(merge9))
262
+ conv10 = thf.relu(self.conv10(conv9))
263
+ residual = self.residual(conv10)
264
+ residual = self.act_fn(residual)
265
+ return residual
266
+
267
+
268
+ class SecretEncoder1(nn.Module):
269
+ def __init__(self, resolution=256, secret_len=100) -> None:
270
+ pass
271
+
272
+ class SecretDecoder(nn.Module):
273
+ def __init__(self, arch='resnet18', resolution=224, secret_len=100):
274
+ super().__init__()
275
+ self.resolution = resolution
276
+ self.arch = arch
277
+ if arch == 'resnet18':
278
+ self.decoder = torchvision.models.resnet18(pretrained=True, progress=False)
279
+ self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
280
+ elif arch == 'resnet50':
281
+ self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
282
+ self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
283
+ elif arch == 'simple':
284
+ self.decoder = SimpleCNN(resolution, secret_len)
285
+ else:
286
+ raise ValueError('Unknown architecture')
287
+
288
+ def forward(self, image):
289
+ if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution:
290
+ image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False)
291
+ x = self.decoder(image)
292
+ return x
293
+
294
+
295
+ class SimpleCNN(nn.Module):
296
+ def __init__(self, resolution=224, secret_len=100):
297
+ super().__init__()
298
+ self.resolution = resolution
299
+ self.IMAGE_CHANNELS = 3
300
+ self.decoder = nn.Sequential(
301
+ nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), # resolution / 2
302
+ nn.ReLU(),
303
+ nn.Conv2d(32, 32, 3, 1, 1),
304
+ nn.ReLU(),
305
+ nn.Conv2d(32, 64, 3, 2, 1), # resolution / 4
306
+ nn.ReLU(),
307
+ nn.Conv2d(64, 64, 3, 1, 1),
308
+ nn.ReLU(),
309
+ nn.Conv2d(64, 64, 3, 2, 1), # resolution / 8
310
+ nn.ReLU(),
311
+ nn.Conv2d(64, 128, 3, 2, 1), # resolution / 16
312
+ nn.ReLU(),
313
+ nn.Conv2d(128, 128, (3, 3), 2, 1), # resolution / 32
314
+ nn.ReLU(),
315
+ )
316
+ self.dense = nn.Sequential(
317
+ nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
318
+ nn.ReLU(),
319
+ nn.Linear(512, secret_len),
320
+ )
321
+
322
+ def forward(self, image):
323
+ x = self.decoder(image)
324
+ x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
325
+ return self.dense(x)
flae/munit.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+ from torch import nn
6
+ from torch.autograd import Variable
7
+ import torch
8
+ import torch.nn.functional as F
9
+ try:
10
+ from itertools import izip as zip
11
+ except ImportError: # will be 3.x series
12
+ pass
13
+
14
+ ##################################################################################
15
+ # Discriminator
16
+ ##################################################################################
17
+
18
+ class MsImageDis(nn.Module):
19
+ # Multi-scale discriminator architecture
20
+ def __init__(self, input_dim, params):
21
+ super(MsImageDis, self).__init__()
22
+ self.n_layer = params['n_layer']
23
+ self.gan_type = params['gan_type']
24
+ self.dim = params['dim']
25
+ self.norm = params['norm']
26
+ self.activ = params['activ']
27
+ self.num_scales = params['num_scales']
28
+ self.pad_type = params['pad_type']
29
+ self.input_dim = input_dim
30
+ self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
31
+ self.cnns = nn.ModuleList()
32
+ for _ in range(self.num_scales):
33
+ self.cnns.append(self._make_net())
34
+
35
+ def _make_net(self):
36
+ dim = self.dim
37
+ cnn_x = []
38
+ cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
39
+ for i in range(self.n_layer - 1):
40
+ cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
41
+ dim *= 2
42
+ cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
43
+ cnn_x = nn.Sequential(*cnn_x)
44
+ return cnn_x
45
+
46
+ def forward(self, x):
47
+ outputs = []
48
+ for model in self.cnns:
49
+ outputs.append(model(x))
50
+ x = self.downsample(x)
51
+ return outputs
52
+
53
+ def calc_dis_loss(self, input_fake, input_real):
54
+ # calculate the loss to train D
55
+ outs0 = self.forward(input_fake)
56
+ outs1 = self.forward(input_real)
57
+ loss = 0
58
+
59
+ for it, (out0, out1) in enumerate(zip(outs0, outs1)):
60
+ if self.gan_type == 'lsgan':
61
+ loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
62
+ elif self.gan_type == 'nsgan':
63
+ all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
64
+ all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
65
+ loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
66
+ F.binary_cross_entropy(F.sigmoid(out1), all1))
67
+ else:
68
+ assert 0, "Unsupported GAN type: {}".format(self.gan_type)
69
+ return loss
70
+
71
+ def calc_gen_loss(self, input_fake):
72
+ # calculate the loss to train G
73
+ outs0 = self.forward(input_fake)
74
+ loss = 0
75
+ for it, (out0) in enumerate(outs0):
76
+ if self.gan_type == 'lsgan':
77
+ loss += torch.mean((out0 - 1)**2) # LSGAN
78
+ elif self.gan_type == 'nsgan':
79
+ all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
80
+ loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
81
+ else:
82
+ assert 0, "Unsupported GAN type: {}".format(self.gan_type)
83
+ return loss
84
+
85
+ ##################################################################################
86
+ # Generator
87
+ ##################################################################################
88
+
89
+ class AdaINGen(nn.Module):
90
+ # AdaIN auto-encoder architecture
91
+ def __init__(self, input_dim, params):
92
+ super(AdaINGen, self).__init__()
93
+ dim = params['dim']
94
+ style_dim = params['style_dim']
95
+ n_downsample = params['n_downsample']
96
+ n_res = params['n_res']
97
+ activ = params['activ']
98
+ pad_type = params['pad_type']
99
+ mlp_dim = params['mlp_dim']
100
+
101
+ # style encoder
102
+ self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
103
+
104
+ # content encoder
105
+ self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
106
+ self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
107
+
108
+ # MLP to generate AdaIN parameters
109
+ self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
110
+
111
+ def forward(self, images):
112
+ # reconstruct an image
113
+ content, style_fake = self.encode(images)
114
+ images_recon = self.decode(content, style_fake)
115
+ return images_recon
116
+
117
+ def encode(self, images):
118
+ # encode an image to its content and style codes
119
+ style_fake = self.enc_style(images)
120
+ content = self.enc_content(images)
121
+ return content, style_fake
122
+
123
+ def decode(self, content, style):
124
+ # decode content and style codes to an image
125
+ adain_params = self.mlp(style)
126
+ self.assign_adain_params(adain_params, self.dec)
127
+ images = self.dec(content)
128
+ return images
129
+
130
+ def assign_adain_params(self, adain_params, model):
131
+ # assign the adain_params to the AdaIN layers in model
132
+ for m in model.modules():
133
+ if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
134
+ mean = adain_params[:, :m.num_features]
135
+ std = adain_params[:, m.num_features:2*m.num_features]
136
+ m.bias = mean.contiguous().view(-1)
137
+ m.weight = std.contiguous().view(-1)
138
+ if adain_params.size(1) > 2*m.num_features:
139
+ adain_params = adain_params[:, 2*m.num_features:]
140
+
141
+ def get_num_adain_params(self, model):
142
+ # return the number of AdaIN parameters needed by the model
143
+ num_adain_params = 0
144
+ for m in model.modules():
145
+ if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
146
+ num_adain_params += 2*m.num_features
147
+ return num_adain_params
148
+
149
+
150
+ class VAEGen(nn.Module):
151
+ # VAE architecture
152
+ def __init__(self, input_dim, params):
153
+ super(VAEGen, self).__init__()
154
+ dim = params['dim']
155
+ n_downsample = params['n_downsample']
156
+ n_res = params['n_res']
157
+ activ = params['activ']
158
+ pad_type = params['pad_type']
159
+
160
+ # content encoder
161
+ self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
162
+ self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type)
163
+
164
+ def forward(self, images):
165
+ # This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
166
+ hiddens = self.encode(images)
167
+ if self.training == True:
168
+ noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
169
+ images_recon = self.decode(hiddens + noise)
170
+ else:
171
+ images_recon = self.decode(hiddens)
172
+ return images_recon, hiddens
173
+
174
+ def encode(self, images):
175
+ hiddens = self.enc(images)
176
+ noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
177
+ return hiddens, noise
178
+
179
+ def decode(self, hiddens):
180
+ images = self.dec(hiddens)
181
+ return images
182
+
183
+
184
+ ##################################################################################
185
+ # Encoder and Decoders
186
+ ##################################################################################
187
+
188
+ class StyleEncoder(nn.Module):
189
+ def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
190
+ super(StyleEncoder, self).__init__()
191
+ self.model = []
192
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
193
+ for i in range(2):
194
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
195
+ dim *= 2
196
+ for i in range(n_downsample - 2):
197
+ self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
198
+ self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
199
+ self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
200
+ self.model = nn.Sequential(*self.model)
201
+ self.output_dim = dim
202
+
203
+ def forward(self, x):
204
+ return self.model(x)
205
+
206
+ class ContentEncoder(nn.Module):
207
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
208
+ super(ContentEncoder, self).__init__()
209
+ self.model = []
210
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
211
+ # downsampling blocks
212
+ for i in range(n_downsample):
213
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
214
+ dim *= 2
215
+ # residual blocks
216
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
217
+ self.model = nn.Sequential(*self.model)
218
+ self.output_dim = dim
219
+
220
+ def forward(self, x):
221
+ return self.model(x)
222
+
223
+ class Decoder(nn.Module):
224
+ def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
225
+ super(Decoder, self).__init__()
226
+
227
+ self.model = []
228
+ # AdaIN residual blocks
229
+ self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
230
+ # upsampling blocks
231
+ for i in range(n_upsample):
232
+ self.model += [nn.Upsample(scale_factor=2),
233
+ Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
234
+ dim //= 2
235
+ # use reflection padding in the last conv layer
236
+ self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
237
+ self.model = nn.Sequential(*self.model)
238
+
239
+ def forward(self, x):
240
+ return self.model(x)
241
+
242
+ ##################################################################################
243
+ # Sequential Models
244
+ ##################################################################################
245
+ class ResBlocks(nn.Module):
246
+ def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
247
+ super(ResBlocks, self).__init__()
248
+ self.model = []
249
+ for i in range(num_blocks):
250
+ self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
251
+ self.model = nn.Sequential(*self.model)
252
+
253
+ def forward(self, x):
254
+ return self.model(x)
255
+
256
+ class MLP(nn.Module):
257
+ def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
258
+
259
+ super(MLP, self).__init__()
260
+ self.model = []
261
+ self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
262
+ for i in range(n_blk - 2):
263
+ self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
264
+ self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
265
+ self.model = nn.Sequential(*self.model)
266
+
267
+ def forward(self, x):
268
+ return self.model(x.view(x.size(0), -1))
269
+
270
+ ##################################################################################
271
+ # Basic Blocks
272
+ ##################################################################################
273
+ class ResBlock(nn.Module):
274
+ def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
275
+ super(ResBlock, self).__init__()
276
+
277
+ model = []
278
+ model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
279
+ model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
280
+ self.model = nn.Sequential(*model)
281
+
282
+ def forward(self, x):
283
+ residual = x
284
+ out = self.model(x)
285
+ out += residual
286
+ return out
287
+
288
+ class Conv2dBlock(nn.Module):
289
+ def __init__(self, input_dim ,output_dim, kernel_size, stride,
290
+ padding=0, norm='none', activation='relu', pad_type='zero'):
291
+ super(Conv2dBlock, self).__init__()
292
+ self.use_bias = True
293
+ # initialize padding
294
+ if pad_type == 'reflect':
295
+ self.pad = nn.ReflectionPad2d(padding)
296
+ elif pad_type == 'replicate':
297
+ self.pad = nn.ReplicationPad2d(padding)
298
+ elif pad_type == 'zero':
299
+ self.pad = nn.ZeroPad2d(padding)
300
+ else:
301
+ assert 0, "Unsupported padding type: {}".format(pad_type)
302
+
303
+ # initialize normalization
304
+ norm_dim = output_dim
305
+ if norm == 'bn':
306
+ self.norm = nn.BatchNorm2d(norm_dim)
307
+ elif norm == 'in':
308
+ #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
309
+ self.norm = nn.InstanceNorm2d(norm_dim)
310
+ elif norm == 'ln':
311
+ self.norm = LayerNorm(norm_dim)
312
+ elif norm == 'adain':
313
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
314
+ elif norm == 'none' or norm == 'sn':
315
+ self.norm = None
316
+ else:
317
+ assert 0, "Unsupported normalization: {}".format(norm)
318
+
319
+ # initialize activation
320
+ if activation == 'relu':
321
+ self.activation = nn.ReLU(inplace=True)
322
+ elif activation == 'lrelu':
323
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
324
+ elif activation == 'prelu':
325
+ self.activation = nn.PReLU()
326
+ elif activation == 'selu':
327
+ self.activation = nn.SELU(inplace=True)
328
+ elif activation == 'tanh':
329
+ self.activation = nn.Tanh()
330
+ elif activation == 'none':
331
+ self.activation = None
332
+ else:
333
+ assert 0, "Unsupported activation: {}".format(activation)
334
+
335
+ # initialize convolution
336
+ if norm == 'sn':
337
+ self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
338
+ else:
339
+ self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
340
+
341
+ def forward(self, x):
342
+ x = self.conv(self.pad(x))
343
+ if self.norm:
344
+ x = self.norm(x)
345
+ if self.activation:
346
+ x = self.activation(x)
347
+ return x
348
+
349
+ class LinearBlock(nn.Module):
350
+ def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
351
+ super(LinearBlock, self).__init__()
352
+ use_bias = True
353
+ # initialize fully connected layer
354
+ if norm == 'sn':
355
+ self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
356
+ else:
357
+ self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
358
+
359
+ # initialize normalization
360
+ norm_dim = output_dim
361
+ if norm == 'bn':
362
+ self.norm = nn.BatchNorm1d(norm_dim)
363
+ elif norm == 'in':
364
+ self.norm = nn.InstanceNorm1d(norm_dim)
365
+ elif norm == 'ln':
366
+ self.norm = LayerNorm(norm_dim)
367
+ elif norm == 'none' or norm == 'sn':
368
+ self.norm = None
369
+ else:
370
+ assert 0, "Unsupported normalization: {}".format(norm)
371
+
372
+ # initialize activation
373
+ if activation == 'relu':
374
+ self.activation = nn.ReLU(inplace=True)
375
+ elif activation == 'lrelu':
376
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
377
+ elif activation == 'prelu':
378
+ self.activation = nn.PReLU()
379
+ elif activation == 'selu':
380
+ self.activation = nn.SELU(inplace=True)
381
+ elif activation == 'tanh':
382
+ self.activation = nn.Tanh()
383
+ elif activation == 'none':
384
+ self.activation = None
385
+ else:
386
+ assert 0, "Unsupported activation: {}".format(activation)
387
+
388
+ def forward(self, x):
389
+ out = self.fc(x)
390
+ if self.norm:
391
+ out = self.norm(out)
392
+ if self.activation:
393
+ out = self.activation(out)
394
+ return out
395
+
396
+ ##################################################################################
397
+ # VGG network definition
398
+ ##################################################################################
399
+ class Vgg16(nn.Module):
400
+ def __init__(self):
401
+ super(Vgg16, self).__init__()
402
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
403
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
404
+
405
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
406
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
407
+
408
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
409
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
410
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
411
+
412
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
413
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
414
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
415
+
416
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
417
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
418
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
419
+
420
+ def forward(self, X):
421
+ h = F.relu(self.conv1_1(X), inplace=True)
422
+ h = F.relu(self.conv1_2(h), inplace=True)
423
+ # relu1_2 = h
424
+ h = F.max_pool2d(h, kernel_size=2, stride=2)
425
+
426
+ h = F.relu(self.conv2_1(h), inplace=True)
427
+ h = F.relu(self.conv2_2(h), inplace=True)
428
+ # relu2_2 = h
429
+ h = F.max_pool2d(h, kernel_size=2, stride=2)
430
+
431
+ h = F.relu(self.conv3_1(h), inplace=True)
432
+ h = F.relu(self.conv3_2(h), inplace=True)
433
+ h = F.relu(self.conv3_3(h), inplace=True)
434
+ # relu3_3 = h
435
+ h = F.max_pool2d(h, kernel_size=2, stride=2)
436
+
437
+ h = F.relu(self.conv4_1(h), inplace=True)
438
+ h = F.relu(self.conv4_2(h), inplace=True)
439
+ h = F.relu(self.conv4_3(h), inplace=True)
440
+ # relu4_3 = h
441
+
442
+ h = F.relu(self.conv5_1(h), inplace=True)
443
+ h = F.relu(self.conv5_2(h), inplace=True)
444
+ h = F.relu(self.conv5_3(h), inplace=True)
445
+ relu5_3 = h
446
+
447
+ return relu5_3
448
+ # return [relu1_2, relu2_2, relu3_3, relu4_3]
449
+
450
+ ##################################################################################
451
+ # Normalization layers
452
+ ##################################################################################
453
+ class AdaptiveInstanceNorm2d(nn.Module):
454
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
455
+ super(AdaptiveInstanceNorm2d, self).__init__()
456
+ self.num_features = num_features
457
+ self.eps = eps
458
+ self.momentum = momentum
459
+ # weight and bias are dynamically assigned
460
+ self.weight = None
461
+ self.bias = None
462
+ # just dummy buffers, not used
463
+ self.register_buffer('running_mean', torch.zeros(num_features))
464
+ self.register_buffer('running_var', torch.ones(num_features))
465
+
466
+ def forward(self, x):
467
+ assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
468
+ b, c = x.size(0), x.size(1)
469
+ running_mean = self.running_mean.repeat(b)
470
+ running_var = self.running_var.repeat(b)
471
+
472
+ # Apply instance norm
473
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
474
+
475
+ out = F.batch_norm(
476
+ x_reshaped, running_mean, running_var, self.weight, self.bias,
477
+ True, self.momentum, self.eps)
478
+
479
+ return out.view(b, c, *x.size()[2:])
480
+
481
+ def __repr__(self):
482
+ return self.__class__.__name__ + '(' + str(self.num_features) + ')'
483
+
484
+
485
+ class LayerNorm(nn.Module):
486
+ def __init__(self, num_features, eps=1e-5, affine=True):
487
+ super(LayerNorm, self).__init__()
488
+ self.num_features = num_features
489
+ self.affine = affine
490
+ self.eps = eps
491
+
492
+ if self.affine:
493
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
494
+ self.beta = nn.Parameter(torch.zeros(num_features))
495
+
496
+ def forward(self, x):
497
+ shape = [-1] + [1] * (x.dim() - 1)
498
+ # print(x.size())
499
+ if x.size(0) == 1:
500
+ # These two lines run much faster in pytorch 0.4 than the two lines listed below.
501
+ mean = x.view(-1).mean().view(*shape)
502
+ std = x.view(-1).std().view(*shape)
503
+ else:
504
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
505
+ std = x.view(x.size(0), -1).std(1).view(*shape)
506
+
507
+ x = (x - mean) / (std + self.eps)
508
+
509
+ if self.affine:
510
+ shape = [1, -1] + [1] * (x.dim() - 2)
511
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
512
+ return x
513
+
514
+ def l2normalize(v, eps=1e-12):
515
+ return v / (v.norm() + eps)
516
+
517
+
518
+ class SpectralNorm(nn.Module):
519
+ """
520
+ Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
521
+ and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
522
+ """
523
+ def __init__(self, module, name='weight', power_iterations=1):
524
+ super(SpectralNorm, self).__init__()
525
+ self.module = module
526
+ self.name = name
527
+ self.power_iterations = power_iterations
528
+ if not self._made_params():
529
+ self._make_params()
530
+
531
+ def _update_u_v(self):
532
+ u = getattr(self.module, self.name + "_u")
533
+ v = getattr(self.module, self.name + "_v")
534
+ w = getattr(self.module, self.name + "_bar")
535
+
536
+ height = w.data.shape[0]
537
+ for _ in range(self.power_iterations):
538
+ v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
539
+ u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
540
+
541
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
542
+ sigma = u.dot(w.view(height, -1).mv(v))
543
+ setattr(self.module, self.name, w / sigma.expand_as(w))
544
+
545
+ def _made_params(self):
546
+ try:
547
+ u = getattr(self.module, self.name + "_u")
548
+ v = getattr(self.module, self.name + "_v")
549
+ w = getattr(self.module, self.name + "_bar")
550
+ return True
551
+ except AttributeError:
552
+ return False
553
+
554
+
555
+ def _make_params(self):
556
+ w = getattr(self.module, self.name)
557
+
558
+ height = w.data.shape[0]
559
+ width = w.view(height, -1).data.shape[1]
560
+
561
+ u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
562
+ v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
563
+ u.data = l2normalize(u.data)
564
+ v.data = l2normalize(v.data)
565
+ w_bar = nn.Parameter(w.data)
566
+
567
+ del self.module._parameters[self.name]
568
+
569
+ self.module.register_parameter(self.name + "_u", u)
570
+ self.module.register_parameter(self.name + "_v", v)
571
+ self.module.register_parameter(self.name + "_bar", w_bar)
572
+
573
+
574
+ def forward(self, *args):
575
+ self._update_u_v()
576
+ return self.module.forward(*args)
flae/unet.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ from torch.autograd import Variable
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from .munit import ResBlocks, Conv2dBlock
6
+ import math
7
+
8
+
9
+ class Unet(nn.Module):
10
+ def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None:
11
+ super().__init__()
12
+ self.secret_len = secret_len
13
+ self.return_residual = return_residual
14
+ self.secret_dense = nn.Linear(secret_len, 16*16*3)
15
+ log_resolution = int(math.log(resolution, 2))
16
+ assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
17
+ self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
18
+
19
+ self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect')
20
+ self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect')
21
+
22
+ def forward(self, image, secret):
23
+ # import pdb; pdb.set_trace()
24
+ fingerprint = F.relu(self.secret_dense(secret))
25
+ fingerprint = fingerprint.view((-1, 3, 16, 16))
26
+ fingerprint_enlarged = self.secret_upsample(fingerprint)
27
+ inputs = torch.cat([fingerprint_enlarged, image], dim=1)
28
+ emb = self.enc(inputs)
29
+ # import pdb; pdb.set_trace()
30
+ out = self.dec(emb)
31
+ return out
32
+
33
+ class Encoder(nn.Module):
34
+ def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
35
+ super().__init__()
36
+ self.model = []
37
+ self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
38
+ # downsampling blocks
39
+ for i in range(n_downsample):
40
+ self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
41
+ dim *= 2
42
+ # residual blocks
43
+ self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
44
+ # self.model = nn.(*self.model)
45
+ self.model = nn.ModuleList(self.model)
46
+ self.output_dim = dim
47
+
48
+ def forward(self, x):
49
+ out = []
50
+ for block in self.model:
51
+ x = block(x)
52
+ out.append(x)
53
+ # print(x.shape)
54
+ return out
55
+
56
+
57
+ class Decoder(nn.Module):
58
+ def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
59
+ super(Decoder, self).__init__()
60
+
61
+ self.model = []
62
+ # AdaIN residual blocks
63
+ self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)]
64
+ # upsampling blocks
65
+ for i in range(n_upsample):
66
+ self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type)
67
+ ]
68
+ dim //= 2
69
+ # use reflection padding in the last conv layer
70
+ self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
71
+ # self.model = nn.Sequential(*self.model)
72
+ self.model = nn.ModuleList(self.model)
73
+
74
+ def forward(self, x):
75
+ x1 = x.pop()
76
+ for block in self.model:
77
+ x2 = x.pop()
78
+ # print(x1.shape, x2.shape)
79
+ x1 = block(x1, x2)
80
+ x1 = self.output_layer(x1)
81
+ return x1
82
+
83
+
84
+ class Merge(nn.Module):
85
+ def __init__(self, dim, activation='relu'):
86
+ super().__init__()
87
+ self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1)
88
+ # initialize activation
89
+ if activation == 'relu':
90
+ self.activation = nn.ReLU(inplace=True)
91
+ elif activation == 'lrelu':
92
+ self.activation = nn.LeakyReLU(0.2, inplace=True)
93
+ elif activation == 'prelu':
94
+ self.activation = nn.PReLU()
95
+ elif activation == 'selu':
96
+ self.activation = nn.SELU(inplace=True)
97
+ elif activation == 'tanh':
98
+ self.activation = nn.Tanh()
99
+ elif activation == 'none':
100
+ self.activation = None
101
+ else:
102
+ assert 0, "Unsupported activation: {}".format(activation)
103
+ def forward(self, x1, x2):
104
+ x = torch.cat([x1, x2], dim=1) # 2xdim
105
+ x = self.conv(x) # B,dim,H,W
106
+ x = self.activation(x)
107
+ return x
108
+
109
+ class DecoderBlock(nn.Module):
110
+ def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'):
111
+ super().__init__()
112
+ assert block_type in ['resblock', 'upsample']
113
+ if block_type == 'resblock':
114
+ self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type)
115
+ else:
116
+ assert out_dim == in_dim//2
117
+ self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2),
118
+ Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type))
119
+ self.merge = Merge(out_dim, activ)
120
+
121
+ def forward(self, x1, x2):
122
+ x1 = self.core_layer(x1)
123
+ return self.merge(x1, x2)
ldm/util.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+
7
+ from inspect import isfunction
8
+ from PIL import Image, ImageDraw, ImageFont
9
+
10
+
11
+ def log_txt_as_img(wh, xc, size=10):
12
+ # wh a tuple of (width, height)
13
+ # xc a list of captions to plot
14
+ b = len(xc)
15
+ txts = list()
16
+ for bi in range(b):
17
+ txt = Image.new("RGB", wh, color="white")
18
+ draw = ImageDraw.Draw(txt)
19
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
20
+ nc = int(40 * (wh[0] / 256))
21
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
22
+
23
+ try:
24
+ draw.text((0, 0), lines, fill="black", font=font)
25
+ except UnicodeEncodeError:
26
+ print("Cant encode string for logging. Skipping.")
27
+
28
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
29
+ txts.append(txt)
30
+ txts = np.stack(txts)
31
+ txts = torch.tensor(txts)
32
+ return txts
33
+
34
+
35
+ def ismap(x):
36
+ if not isinstance(x, torch.Tensor):
37
+ return False
38
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
39
+
40
+
41
+ def isimage(x):
42
+ if not isinstance(x,torch.Tensor):
43
+ return False
44
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
45
+
46
+
47
+ def exists(x):
48
+ return x is not None
49
+
50
+
51
+ def default(val, d):
52
+ if exists(val):
53
+ return val
54
+ return d() if isfunction(d) else d
55
+
56
+
57
+ def mean_flat(tensor):
58
+ """
59
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
60
+ Take the mean over all non-batch dimensions.
61
+ """
62
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
63
+
64
+
65
+ def count_params(model, verbose=False):
66
+ total_params = sum(p.numel() for p in model.parameters())
67
+ if verbose:
68
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
69
+ return total_params
70
+
71
+
72
+ def instantiate_from_config(config):
73
+ if not "target" in config:
74
+ if config == '__is_first_stage__':
75
+ return None
76
+ elif config == "__is_unconditional__":
77
+ return None
78
+ raise KeyError("Expected key `target` to instantiate.")
79
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
80
+
81
+
82
+ def get_obj_from_str(string, reload=False):
83
+ module, cls = string.rsplit(".", 1)
84
+ if reload:
85
+ module_imp = importlib.import_module(module)
86
+ importlib.reload(module_imp)
87
+ return getattr(importlib.import_module(module, package=None), cls)
88
+
89
+
90
+ class AdamWwithEMAandWings(optim.Optimizer):
91
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
92
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
93
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
94
+ ema_power=1., param_names=()):
95
+ """AdamW that saves EMA versions of the parameters."""
96
+ if not 0.0 <= lr:
97
+ raise ValueError("Invalid learning rate: {}".format(lr))
98
+ if not 0.0 <= eps:
99
+ raise ValueError("Invalid epsilon value: {}".format(eps))
100
+ if not 0.0 <= betas[0] < 1.0:
101
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
102
+ if not 0.0 <= betas[1] < 1.0:
103
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
104
+ if not 0.0 <= weight_decay:
105
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
106
+ if not 0.0 <= ema_decay <= 1.0:
107
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
108
+ defaults = dict(lr=lr, betas=betas, eps=eps,
109
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
110
+ ema_power=ema_power, param_names=param_names)
111
+ super().__init__(params, defaults)
112
+
113
+ def __setstate__(self, state):
114
+ super().__setstate__(state)
115
+ for group in self.param_groups:
116
+ group.setdefault('amsgrad', False)
117
+
118
+ @torch.no_grad()
119
+ def step(self, closure=None):
120
+ """Performs a single optimization step.
121
+ Args:
122
+ closure (callable, optional): A closure that reevaluates the model
123
+ and returns the loss.
124
+ """
125
+ loss = None
126
+ if closure is not None:
127
+ with torch.enable_grad():
128
+ loss = closure()
129
+
130
+ for group in self.param_groups:
131
+ params_with_grad = []
132
+ grads = []
133
+ exp_avgs = []
134
+ exp_avg_sqs = []
135
+ ema_params_with_grad = []
136
+ state_sums = []
137
+ max_exp_avg_sqs = []
138
+ state_steps = []
139
+ amsgrad = group['amsgrad']
140
+ beta1, beta2 = group['betas']
141
+ ema_decay = group['ema_decay']
142
+ ema_power = group['ema_power']
143
+
144
+ for p in group['params']:
145
+ if p.grad is None:
146
+ continue
147
+ params_with_grad.append(p)
148
+ if p.grad.is_sparse:
149
+ raise RuntimeError('AdamW does not support sparse gradients')
150
+ grads.append(p.grad)
151
+
152
+ state = self.state[p]
153
+
154
+ # State initialization
155
+ if len(state) == 0:
156
+ state['step'] = 0
157
+ # Exponential moving average of gradient values
158
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
159
+ # Exponential moving average of squared gradient values
160
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
161
+ if amsgrad:
162
+ # Maintains max of all exp. moving avg. of sq. grad. values
163
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
164
+ # Exponential moving average of parameter values
165
+ state['param_exp_avg'] = p.detach().float().clone()
166
+
167
+ exp_avgs.append(state['exp_avg'])
168
+ exp_avg_sqs.append(state['exp_avg_sq'])
169
+ ema_params_with_grad.append(state['param_exp_avg'])
170
+
171
+ if amsgrad:
172
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
173
+
174
+ # update the steps for each param group update
175
+ state['step'] += 1
176
+ # record the step after step update
177
+ state_steps.append(state['step'])
178
+
179
+ optim._functional.adamw(params_with_grad,
180
+ grads,
181
+ exp_avgs,
182
+ exp_avg_sqs,
183
+ max_exp_avg_sqs,
184
+ state_steps,
185
+ amsgrad=amsgrad,
186
+ beta1=beta1,
187
+ beta2=beta2,
188
+ lr=group['lr'],
189
+ weight_decay=group['weight_decay'],
190
+ eps=group['eps'],
191
+ maximize=False)
192
+
193
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
194
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
195
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
196
+
197
+ return loss
pages/Extract_Secret.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ streamlit app demo
5
+ how to run:
6
+ streamlit run app.py --server.port 8501
7
+
8
+ @author: Tu Bui @surrey.ac.uk
9
+ """
10
+ import os, sys, torch
11
+ import inspect
12
+ cdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
13
+ sys.path.insert(1, os.path.join(cdir, '../'))
14
+ import argparse
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import pickle
18
+ import pytorch_lightning as pl
19
+ from torchvision import transforms
20
+ import argparse
21
+ from ldm.util import instantiate_from_config
22
+ from omegaconf import OmegaConf
23
+ from PIL import Image
24
+ from tools.augment_imagenetc import RandomImagenetC
25
+ from cldm.transformations2 import TransformNet
26
+ from io import BytesIO
27
+ from tools.helpers import welcome_message
28
+ from tools.ecc import BCH, RSC
29
+ import streamlit as st
30
+ from Embed_Secret import load_ecc, load_model, decode_secret, to_bytes, model_names, SECRET_LEN
31
+
32
+
33
+ # model_names = ['RoSteALS', 'UNet']
34
+ # SECRET_LEN = 100
35
+
36
+ def app():
37
+ st.title('Watermarking Demo')
38
+ # setup model
39
+ model_name = st.selectbox("Choose the model", model_names)
40
+ model, tform_emb, tform_det = load_model(model_name)
41
+ display_width = 300
42
+ ecc = load_ecc('BCH')
43
+ noise = TransformNet(p=1.0, crop_mode='resized_crop')
44
+ noise_names = noise.optional_names
45
+
46
+ # setup st
47
+ st.subheader("Input")
48
+ image_file = None
49
+ image_file = st.file_uploader("Upload stego image", type=["png","jpg","jpeg"])
50
+ if image_file is not None:
51
+ im = Image.open(image_file).convert('RGB')
52
+ ext = image_file.name.split('.')[-1]
53
+ st.image(im, width=display_width)
54
+
55
+
56
+ # add crop
57
+ st.subheader("Corruptions")
58
+ crop_button = st.button('Regenerate Crop', key='crop')
59
+ if image_file is not None:
60
+ im_crop = noise.apply_transform_on_pil_image(im, 'Random Crop')
61
+ if crop_button:
62
+ im_crop = noise.apply_transform_on_pil_image(im, 'Random Crop')
63
+ # st.image(im_crop, width=display_width)
64
+
65
+ # add noise source 1
66
+ corrupt_method1 = st.selectbox("Choose noise source #1", ['None'] + noise_names, key='noise1')
67
+ if image_file is not None:
68
+ if corrupt_method1=='None':
69
+ im_noise1 = im_crop
70
+ else:
71
+ im_noise1 = noise.apply_transform_on_pil_image(im_crop, corrupt_method1)
72
+ # st.image(im_noise1, width=display_width)
73
+
74
+ # add noise source 2
75
+ corrupt_method2 = st.selectbox("Choose noise source #2", ['None'] + noise_names, key='noise2')
76
+ if image_file is not None:
77
+ if corrupt_method2=='None':
78
+ im_noise2 = im_noise1
79
+ else:
80
+ im_noise2 = noise.apply_transform_on_pil_image(im_noise1, corrupt_method2)
81
+
82
+ st.subheader("Output")
83
+ if image_file is not None:
84
+ st.image(im_noise2, width=display_width)
85
+ mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
86
+ im_noise2_bytes = to_bytes(np.uint8(im_noise2), mime)
87
+ st.download_button(label='Download image', data=im_noise2_bytes, file_name=f'corrupted.{ext}', mime=mime)
88
+
89
+ # prediction
90
+ st.subheader('Extract Secret From Output')
91
+ status = st.empty()
92
+ if image_file is not None:
93
+ secret_pred = decode_secret(model_name, model, im_noise2, tform_det)
94
+ secret_decoded = ecc.decode_text(secret_pred)[0]
95
+ status.markdown(f'Predicted secret: **{secret_decoded}**', unsafe_allow_html=True)
96
+
97
+ # bit acc
98
+ st.subheader('Accuracy')
99
+ secret_text = st.text_input('Input groundtruth secret')
100
+ bit_acc_status = st.empty()
101
+ if image_file is not None and secret_text:
102
+ secret = ecc.encode_text([secret_text]) # (1, 100)
103
+ bit_acc = (secret_pred == secret).mean()
104
+ # bit_acc_status.markdown('**Bit Accuracy**: {:.2f}%'.format(bit_acc*100), unsafe_allow_html=True)
105
+ word_acc = int(secret_decoded == secret_text)
106
+ bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
107
+
108
+ if __name__ == '__main__':
109
+ app()
110
+
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .helpers import *
2
+ from .hparams import HParams
3
+ from .slack_bot import Notifier
tools/augment_imagenetc.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ wrapper for imagenet-c transformations
5
+ @author: Tu Bui @surrey.ac.uk
6
+ """
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+ import os
11
+ import sys
12
+ import random
13
+ import numpy as np
14
+ from PIL import Image
15
+ from imagenet_c import corrupt, corruption_dict
16
+
17
+
18
+ class IdentityAugment(object):
19
+ def __call__(self, x):
20
+ return x
21
+
22
+ def __repr__(self):
23
+ s = f'()'
24
+ return self.__class__.__name__ + s
25
+
26
+ class RandomImagenetC(object):
27
+ # transform id 5 (motion blur) and 7 (snow) requires WandImage which is not fork-safe, while id 4 (glass blur) and 6 (zoom blur) are super slow thus we move it to validation (unseen), 12 (elastic transform) is non realistic
28
+ methods = {'train': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18]),#np.arange(15),
29
+ 'val': np.array([4, 5, 6, 7, 12]),
30
+ 'test': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18])
31
+ }
32
+ method_names = list(corruption_dict.keys())
33
+ def __init__(self, min_severity=1, max_severity=5, phase='all', p=1.0,n=19):
34
+ assert phase in ['train', 'val', 'test', 'all'], ValueError(f'{phase} not recognised. Must be one of [train, val, all]')
35
+ if phase == 'all':
36
+ self.corrupt_ids = np.concatenate(list(self.methods.values()))
37
+ else:
38
+ self.corrupt_ids = self.methods[phase]
39
+ self.corrupt_ids = self.corrupt_ids[:n] # first n tforms
40
+ self.phase = phase
41
+ self.severity = np.arange(min_severity, max_severity+1)
42
+ self.p = p # probability to apply a transformation
43
+
44
+ def __call__(self, x, corrupt_id=None, corrupt_strength=None):
45
+ # input: x PIL image
46
+ if corrupt_id is None:
47
+ if len(self.corrupt_ids)==0: # do nothing
48
+ return x
49
+ corrupt_id = np.random.choice(self.corrupt_ids)
50
+ else:
51
+ assert corrupt_id in range(19)
52
+
53
+ severity = np.random.choice(self.severity) if corrupt_strength is None else corrupt_strength
54
+ assert severity in self.severity, f'Error! Corrupt strength {severity} isnt supported.'
55
+
56
+ if np.random.rand() < self.p:
57
+ org_size = x.size
58
+ x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
59
+ x = corrupt(x, severity, corruption_number=corrupt_id)
60
+ x = Image.fromarray(x[:,:,::-1])
61
+ if x.size != org_size:
62
+ x = x.resize(org_size, Image.BILINEAR)
63
+ return x
64
+
65
+ def transform_with_fixed_severity(self, x, severity, corrupt_id=None):
66
+ if corrupt_id is None:
67
+ corrupt_id = np.random.choice(self.corrupt_ids)
68
+ else:
69
+ assert corrupt_id in self.corrupt_ids
70
+ assert severity > 0 and severity < 6
71
+ org_size = x.size
72
+ x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
73
+ x = corrupt(x, severity, corruption_number=corrupt_id)
74
+ x = Image.fromarray(x[:,:,::-1])
75
+ if x.size != org_size:
76
+ x = x.resize(org_size, Image.BILINEAR)
77
+ return x
78
+
79
+ def __repr__(self):
80
+ s = f'(severity={self.severity}, phase={self.phase}, p={self.p},ids={self.corrupt_ids})'
81
+ return self.__class__.__name__ + s
82
+
83
+
84
+ class NoiseResidual(object):
85
+ def __init__(self, k=16):
86
+ self.k = k
87
+ def __call__(self, x):
88
+ h, w = x.height, x.width
89
+ x1 = x.resize((w//self.k,h//self.k), Image.BILINEAR).resize((w, h), Image.BILINEAR)
90
+ x1 = np.abs(np.array(x).astype(np.float32) - np.array(x1).astype(np.float32))
91
+ x1 = (x1 - x1.min())/(x1.max() - x1.min() + np.finfo(np.float32).eps)
92
+ x1 = Image.fromarray((x1*255).astype(np.uint8))
93
+ return x1
94
+ def __repr__(self):
95
+ s = f'(k={self.k}'
96
+ return self.__class__.__name__ + s
97
+
98
+
99
+ def get_transforms(img_mean=[0.5, 0.5, 0.5], img_std=[0.5, 0.5, 0.5], rsize=256, csize=224, pertubation=True, dct=False, residual=False, max_c=19):
100
+ from torchvision import transforms
101
+ prep = transforms.Compose([
102
+ transforms.Resize(rsize),
103
+ transforms.RandomHorizontalFlip(),
104
+ transforms.RandomCrop(csize)])
105
+ if pertubation:
106
+ pertubation_train = RandomImagenetC(max_severity=5, phase='train', p=0.95,n=max_c)
107
+ pertubation_val = RandomImagenetC(max_severity=5, phase='train', p=1.0,n=max_c)
108
+ pertubation_test = RandomImagenetC(max_severity=5, phase='val', p=1.0,n=max_c)
109
+ else:
110
+ pertubation_train = pertubation_val = pertubation_test = IdentityAugment()
111
+ if dct:
112
+ from .image_tools import DCT
113
+ norm = [
114
+ DCT(),
115
+ transforms.ToTensor(),
116
+ transforms.Normalize(mean=img_mean, std=img_std)]
117
+ else:
118
+ norm = [
119
+ transforms.ToTensor(),
120
+ transforms.Normalize(mean=img_mean, std=img_std)]
121
+ if residual:
122
+ norm.insert(0, NoiseResidual())
123
+
124
+ preprocess = {
125
+ 'train': [prep, pertubation_train, transforms.Compose(norm)],
126
+
127
+ 'val': [prep, pertubation_val, transforms.Compose(norm)],
128
+
129
+ 'test_unseen': [prep, pertubation_test, transforms.Compose(norm)],
130
+
131
+ 'clean': transforms.Compose([transforms.Resize(csize)] + norm)
132
+ }
133
+ return preprocess
134
+
135
+
136
+ # ## example
137
+ # from PIL import Image
138
+ # import numpy as np
139
+ # import time
140
+ # from imagenet_c import corrupt, corruption_dict
141
+ # im = Image.open('/vol/research/tubui1/projects/gan_prov/gan_models/stargan2/test.jpg').convert('RGB').resize((224,224), Image.BILINEAR)
142
+ # im.save('original.jpg')
143
+ # im = np.array(im)[:,:,::-1] # BRG
144
+ # t = np.zeros(19)
145
+ # for i, key in enumerate(corruption_dict.keys()):
146
+ # begin = time.time()
147
+ # for j in range(10):
148
+ # out = corrupt(im, 5, corruption_number=i)
149
+ # end = time.time()
150
+ # t[i] = end-begin
151
+ # # Image.fromarray(out[:,:,::-1]).save(f'imc_{key}.jpg')
152
+ # print(f'{i} - {key}: {end-begin}')
153
+
154
+ # for i,k in enumerate(corruption_dict.keys()):
155
+ # print(i, k, t[i])
tools/base_lmdb.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+ from pathlib import Path
3
+ import os
4
+ import io
5
+ import lmdb
6
+ import pickle
7
+ import gzip
8
+ import bz2
9
+ import lzma
10
+ import shutil
11
+ from tqdm import tqdm
12
+ import pandas as pd
13
+ import numpy as np
14
+ from numpy import ndarray
15
+ import time
16
+ import torch
17
+ from torch import Tensor
18
+ from distutils.dir_util import copy_tree
19
+ from PIL import Image
20
+ from PIL import ImageFile
21
+
22
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
23
+
24
+
25
+ def _default_encode(data: Any, protocol: int) -> bytes:
26
+ return pickle.dumps(data, protocol=protocol)
27
+
28
+
29
+ def _ascii_encode(data: str) -> bytes:
30
+ return data.encode("ascii")
31
+
32
+
33
+ def _default_decode(data: bytes) -> Any:
34
+ return pickle.loads(data)
35
+
36
+
37
+ def _default_decompress(data: bytes) -> bytes:
38
+ return data
39
+
40
+
41
+ def _decompress(compression: Optional[str]):
42
+ if compression is None:
43
+ _decompress = _default_decompress
44
+ elif compression == "gzip":
45
+ _decompress = gzip.decompress
46
+ elif compression == "bz2":
47
+ _decompress = bz2.decompress
48
+ elif compression == "lzma":
49
+ _decompress = lzma.decompress
50
+ else:
51
+ raise ValueError(f"Unknown compression algorithm: {compression}")
52
+
53
+ return _decompress
54
+
55
+
56
+ class BaseLMDB(object):
57
+ _database = None
58
+ _protocol = None
59
+ _length = None
60
+
61
+ def __init__(
62
+ self,
63
+ path: Union[str, Path],
64
+ readahead: bool = False,
65
+ pre_open: bool = False,
66
+ compression: Optional[str] = None
67
+ ):
68
+ """
69
+ Base class for LMDB-backed databases.
70
+
71
+ :param path: Path to the database.
72
+ :param readahead: Enables the filesystem readahead mechanism.
73
+ :param pre_open: If set to True, the first iterations will be faster, but it will raise error when doing multi-gpu training. If set to False, the database will open when you will retrieve the first item.
74
+ """
75
+ if not isinstance(path, str):
76
+ path = str(path)
77
+
78
+ self.path = path
79
+ self.readahead = readahead
80
+ self.pre_open = pre_open
81
+ self._decompress = _decompress(compression)
82
+ self._has_fetched_an_item = False
83
+
84
+ @property
85
+ def database(self):
86
+ if self._database is None:
87
+ self._database = lmdb.open(
88
+ path=self.path,
89
+ readonly=True,
90
+ readahead=self.readahead,
91
+ max_spare_txns=256,
92
+ lock=False,
93
+ )
94
+ return self._database
95
+
96
+ @database.deleter
97
+ def database(self):
98
+ if self._database is not None:
99
+ self._database.close()
100
+ self._database = None
101
+
102
+ @property
103
+ def protocol(self):
104
+ """
105
+ Read the pickle protocol contained in the database.
106
+
107
+ :return: The set of available keys.
108
+ """
109
+ if self._protocol is None:
110
+ self._protocol = self._get(
111
+ item="protocol",
112
+ encode_key=_ascii_encode,
113
+ decompress_value=_default_decompress,
114
+ decode_value=_default_decode,
115
+ )
116
+ return self._protocol
117
+
118
+ @property
119
+ def keys(self):
120
+ """
121
+ Read the keys contained in the database.
122
+
123
+ :return: The set of available keys.
124
+ """
125
+ protocol = self.protocol
126
+ keys = self._get(
127
+ item="keys",
128
+ encode_key=lambda key: _default_encode(key, protocol=protocol),
129
+ decompress_value=_default_decompress,
130
+ decode_value=_default_decode,
131
+ )
132
+ return keys
133
+
134
+ def __len__(self):
135
+ """
136
+ Returns the number of keys available in the database.
137
+
138
+ :return: The number of keys.
139
+ """
140
+ if self._length is None:
141
+ self._length = len(self.keys)
142
+ return self._length
143
+
144
+ def __getitem__(self, item):
145
+ """
146
+ Retrieves an item or a list of items from the database.
147
+
148
+ :param item: A key or a list of keys.
149
+ :return: A value or a list of values.
150
+ """
151
+ self._has_fetched_an_item = True
152
+ if not isinstance(item, list):
153
+ item = self._get(
154
+ item=item,
155
+ encode_key=self._encode_key,
156
+ decompress_value=self._decompress_value,
157
+ decode_value=self._decode_value,
158
+ )
159
+ else:
160
+ item = self._gets(
161
+ items=item,
162
+ encode_keys=self._encode_keys,
163
+ decompress_values=self._decompress_values,
164
+ decode_values=self._decode_values,
165
+ )
166
+ return item
167
+
168
+ def _get(self, item, encode_key, decompress_value, decode_value):
169
+ """
170
+ Instantiates a transaction and its associated cursor to fetch an item.
171
+
172
+ :param item: A key.
173
+ :param encode_key:
174
+ :param decode_value:
175
+ :return:
176
+ """
177
+ with self.database.begin() as txn:
178
+ with txn.cursor() as cursor:
179
+ item = self._fetch(
180
+ cursor=cursor,
181
+ key=item,
182
+ encode_key=encode_key,
183
+ decompress_value=decompress_value,
184
+ decode_value=decode_value,
185
+ )
186
+ self._keep_database()
187
+ return item
188
+
189
+ def _gets(self, items, encode_keys, decompress_values, decode_values):
190
+ """
191
+ Instantiates a transaction and its associated cursor to fetch a list of items.
192
+
193
+ :param items: A list of keys.
194
+ :param encode_keys:
195
+ :param decode_values:
196
+ :return:
197
+ """
198
+ with self.database.begin() as txn:
199
+ with txn.cursor() as cursor:
200
+ items = self._fetchs(
201
+ cursor=cursor,
202
+ keys=items,
203
+ encode_keys=encode_keys,
204
+ decompress_values=decompress_values,
205
+ decode_values=decode_values,
206
+ )
207
+ self._keep_database()
208
+ return items
209
+
210
+ def _fetch(self, cursor, key, encode_key, decompress_value, decode_value):
211
+ """
212
+ Retrieve a value given a key.
213
+
214
+ :param cursor:
215
+ :param key: A key.
216
+ :param encode_key:
217
+ :param decode_value:
218
+ :return: A value.
219
+ """
220
+ key = encode_key(key)
221
+ value = cursor.get(key)
222
+ value = decompress_value(value)
223
+ value = decode_value(value)
224
+ return value
225
+
226
+ def _fetchs(self, cursor, keys, encode_keys, decompress_values, decode_values):
227
+ """
228
+ Retrieve a list of values given a list of keys.
229
+
230
+ :param cursor:
231
+ :param keys: A list of keys.
232
+ :param encode_keys:
233
+ :param decode_values:
234
+ :return: A list of values.
235
+ """
236
+ keys = encode_keys(keys)
237
+ _, values = list(zip(*cursor.getmulti(keys)))
238
+ values = decompress_values(values)
239
+ values = decode_values(values)
240
+ return values
241
+
242
+ def _encode_key(self, key: Any) -> bytes:
243
+ """
244
+ Converts a key into a byte key.
245
+
246
+ :param key: A key.
247
+ :return: A byte key.
248
+ """
249
+ return pickle.dumps(key, protocol=self.protocol)
250
+
251
+ def _encode_keys(self, keys: list) -> list:
252
+ """
253
+ Converts keys into byte keys.
254
+
255
+ :param keys: A list of keys.
256
+ :return: A list of byte keys.
257
+ """
258
+ return [self._encode_key(key=key) for key in keys]
259
+
260
+ def _decompress_value(self, value: bytes) -> bytes:
261
+ return self._decompress(value)
262
+
263
+ def _decompress_values(self, values: list) -> list:
264
+ return [self._decompress_value(value=value) for value in values]
265
+
266
+ def _decode_value(self, value: bytes) -> Any:
267
+ """
268
+ Converts a byte value back into a value.
269
+
270
+ :param value: A byte value.
271
+ :return: A value
272
+ """
273
+ return pickle.loads(value)
274
+
275
+ def _decode_values(self, values: list) -> list:
276
+ """
277
+ Converts bytes values back into values.
278
+
279
+ :param values: A list of byte values.
280
+ :return: A list of values.
281
+ """
282
+ return [self._decode_value(value=value) for value in values]
283
+
284
+ def _keep_database(self):
285
+ """
286
+ Checks if the database must be deleted.
287
+
288
+ :return:
289
+ """
290
+ if not self.pre_open and not self._has_fetched_an_item:
291
+ del self.database
292
+
293
+ def __iter__(self):
294
+ """
295
+ Provides an iterator over the keys when iterating over the database.
296
+
297
+ :return: An iterator on the keys.
298
+ """
299
+ return iter(self.keys)
300
+
301
+ def __del__(self):
302
+ """
303
+ Closes the database properly.
304
+ """
305
+ del self.database
306
+
307
+ @staticmethod
308
+ def write(data_lst, indir, outdir):
309
+ raise NotImplementedError
310
+
311
+
312
+ class PILlmdb(BaseLMDB):
313
+ def __init__(
314
+ self,
315
+ lmdb_dir: Union[str, Path],
316
+ image_list: Union[str, Path, pd.DataFrame]=None,
317
+ index_key='id',
318
+ **kwargs
319
+ ):
320
+ super().__init__(path=lmdb_dir, **kwargs)
321
+ if image_list is None:
322
+ self.ids = list(range(len(self.keys)))
323
+ self.labels = list(range(len(self.ids)))
324
+ else:
325
+ df = pd.read_csv(str(image_list))
326
+ assert index_key in df, f'[PILlmdb] Error! {image_list} must have id keys.'
327
+ self.ids = df[index_key].tolist()
328
+ assert max(self.ids) < len(self.keys)
329
+ if 'label' in df:
330
+ self.labels = df['label'].tolist()
331
+ else: # all numeric keys other than 'id' are labels
332
+ keys = [key for key in df if (key!=index_key and type(df[key][0]) in [int, np.int64])]
333
+ # df = df.drop('id', axis=1)
334
+ self.labels = df[keys].to_numpy()
335
+ self._length = len(self.ids)
336
+
337
+ def __len__(self):
338
+ return self._length
339
+
340
+ def __iter__(self):
341
+ return iter([self.keys[i] for i in self.ids])
342
+
343
+ def __getitem__(self, index):
344
+ key = self.keys[self.ids[index]]
345
+ return super().__getitem__(key)
346
+
347
+ def set_ids(self, ids):
348
+ self.ids = [self.ids[i] for i in ids]
349
+ self.labels = [self.labels[i] for i in ids]
350
+ self._length = len(self.ids)
351
+
352
+ def _decode_value(self, value: bytes):
353
+ """
354
+ Converts a byte image back into a PIL Image.
355
+
356
+ :param value: A byte image.
357
+ :return: A PIL Image image.
358
+ """
359
+ return Image.open(io.BytesIO(value))
360
+
361
+ @staticmethod
362
+ def write(indir, outdir, data_lst=None, transform=None):
363
+ """
364
+ create lmdb given data directory and list of image paths; or an iterator
365
+ :param data_lst None or csv file containing 'path' key to store relative paths to the images
366
+ :param indir root directory of the images
367
+ :param outdir output lmdb, data.mdb and lock.mdb will be written here
368
+ """
369
+
370
+ outdir = Path(outdir)
371
+ outdir.mkdir(parents=True, exist_ok=True)
372
+ tmp_dir = Path("/tmp") / f"TEMP_{time.time()}"
373
+ tmp_dir.mkdir(parents=True, exist_ok=True)
374
+ dtype = {'str': False, 'pil': False}
375
+ if isinstance(indir, str) or isinstance(indir, Path):
376
+ indir = Path(indir)
377
+ if data_lst is None: # grab all images in this dir
378
+ lst = list(indir.glob('**/*.jpg')) + list(indir.glob('**/*.png'))
379
+ else:
380
+ lst = pd.read_csv(data_lst)['path'].tolist()
381
+ lst = [indir/p for p in lst]
382
+ assert len(lst) > 0, f'Couldnt find any image in {indir} (Support only .jpg and .png) or list (must have path field).'
383
+ n = len(lst)
384
+ dtype['str'] = True
385
+ else: # iterator
386
+ n = len(indir)
387
+ lst = iter(indir)
388
+ dtype['pil'] = True
389
+
390
+ with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env:
391
+ # Add the protocol to the database.
392
+ with env.begin(write=True) as txn:
393
+ key = "protocol".encode("ascii")
394
+ value = pickle.dumps(pickle.DEFAULT_PROTOCOL)
395
+ txn.put(key=key, value=value, dupdata=False)
396
+ # Add the keys to the database.
397
+ with env.begin(write=True) as txn:
398
+ key = pickle.dumps("keys")
399
+ value = pickle.dumps(list(range(n)))
400
+ txn.put(key=key, value=value, dupdata=False)
401
+ # Add the images to the database.
402
+ for key, value in tqdm(enumerate(lst), total=n, miniters=n//100, mininterval=300):
403
+ with env.begin(write=True) as txn:
404
+ key = pickle.dumps(key)
405
+ if dtype['str']:
406
+ with value.open("rb") as file:
407
+ byteimg = file.read()
408
+ else: # PIL
409
+ data = io.BytesIO()
410
+ value.save(data, 'png')
411
+ byteimg = data.getvalue()
412
+
413
+ if transform is not None:
414
+ im = Image.open(io.BytesIO(byteimg))
415
+ im = transform(im)
416
+ data = io.BytesIO()
417
+ im.save(data, 'png')
418
+ byteimg = data.getvalue()
419
+ txn.put(key=key, value=byteimg, dupdata=False)
420
+
421
+ # Move the database to its destination.
422
+ copy_tree(str(tmp_dir), str(outdir))
423
+ shutil.rmtree(str(tmp_dir))
424
+
425
+
426
+
427
+ class MaskDatabase(PILlmdb):
428
+ def _decode_value(self, value: bytes):
429
+ """
430
+ Converts a byte image back into a PIL Image.
431
+
432
+ :param value: A byte image.
433
+ :return: A PIL Image image.
434
+ """
435
+ return Image.open(io.BytesIO(value)).convert("1")
436
+
437
+
438
+ class LabelDatabase(BaseLMDB):
439
+ pass
440
+
441
+
442
+ class ArrayDatabase(BaseLMDB):
443
+ _dtype = None
444
+ _shape = None
445
+
446
+ def __init__(
447
+ self,
448
+ lmdb_dir: Union[str, Path],
449
+ image_list: Union[str, Path, pd.DataFrame]=None,
450
+ **kwargs
451
+ ):
452
+ super().__init__(path=lmdb_dir, **kwargs)
453
+ if image_list is None:
454
+ self.ids = list(range(len(self.keys)))
455
+ self.labels = list(range(len(self.ids)))
456
+ else:
457
+ df = pd.read_csv(str(image_list))
458
+ assert 'id' in df, f'[ArrayDatabase] Error! {image_list} must have id keys.'
459
+ self.ids = df['id'].tolist()
460
+ assert max(self.ids) < len(self.keys)
461
+ if 'label' in df:
462
+ self.labels = df['label'].tolist()
463
+ else: # all numeric keys other than 'id' are labels
464
+ keys = [key for key in df if (key!='id' and type(df[key][0]) in [int, np.int64])]
465
+ # df = df.drop('id', axis=1)
466
+ self.labels = df[keys].to_numpy()
467
+ self._length = len(self.ids)
468
+
469
+ def set_ids(self, ids):
470
+ self.ids = [self.ids[i] for i in ids]
471
+ self.labels = [self.labels[i] for i in ids]
472
+ self._length = len(self.ids)
473
+
474
+ def __len__(self):
475
+ return self._length
476
+
477
+ def __iter__(self):
478
+ return iter([self.keys[i] for i in self.ids])
479
+
480
+ def __getitem__(self, index):
481
+ key = self.keys[self.ids[index]]
482
+ return super().__getitem__(key)
483
+
484
+ @property
485
+ def dtype(self):
486
+ if self._dtype is None:
487
+ protocol = self.protocol
488
+ self._dtype = self._get(
489
+ item="dtype",
490
+ encode_key=lambda key: _default_encode(key, protocol=protocol),
491
+ decompress_value=_default_decompress,
492
+ decode_value=_default_decode,
493
+ )
494
+ return self._dtype
495
+
496
+ @property
497
+ def shape(self):
498
+ if self._shape is None:
499
+ protocol = self.protocol
500
+ self._shape = self._get(
501
+ item="shape",
502
+ encode_key=lambda key: _default_encode(key, protocol=protocol),
503
+ decompress_value=_default_decompress,
504
+ decode_value=_default_decode,
505
+ )
506
+ return self._shape
507
+
508
+ def _decode_value(self, value: bytes) -> ndarray:
509
+ value = super()._decode_value(value)
510
+ return np.frombuffer(value, dtype=self.dtype).reshape(self.shape)
511
+
512
+ def _decode_values(self, values: list) -> ndarray:
513
+ shape = (len(values),) + self.shape
514
+ return np.frombuffer(b"".join(values), dtype=self.dtype).reshape(shape)
515
+
516
+ @staticmethod
517
+ def write(diter, outdir):
518
+ """
519
+ diter is an iterator that has __len__ method
520
+ class Myiter():
521
+ def __init__(self, data):
522
+ self.data = data
523
+ def __iter__(self):
524
+ self.counter = 0
525
+ return self
526
+ def __len__(self):
527
+ return len(self.data)
528
+ def __next__(self):
529
+ if self.counter < len(self):
530
+ out = self.data[self.counter]
531
+ self.counter+=1
532
+ return out
533
+ else:
534
+ raise StopIteration
535
+ a = iter(Myiter([1,2,3]))
536
+ for i in a:
537
+ print(i)
538
+ """
539
+ outdir = Path(outdir)
540
+ outdir.mkdir(parents=True, exist_ok=True)
541
+ tmp_dir = Path("/tmp") / f"TEMP_{time.time()}"
542
+ tmp_dir.mkdir(parents=True, exist_ok=True)
543
+ # Create the database.
544
+ n = len(diter)
545
+ with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env:
546
+ # Add the protocol to the database.
547
+ with env.begin(write=True) as txn:
548
+ key = "protocol".encode("ascii")
549
+ value = pickle.dumps(pickle.DEFAULT_PROTOCOL)
550
+ txn.put(key=key, value=value, dupdata=False)
551
+ # Add the keys to the database.
552
+ with env.begin(write=True) as txn:
553
+ key = pickle.dumps("keys")
554
+ value = pickle.dumps(list(range(n)))
555
+ txn.put(key=key, value=value, dupdata=False)
556
+ # Extract the shape and dtype of the values.
557
+ value = next(iter(diter))
558
+ shape = value.shape
559
+ dtype = value.dtype
560
+ # Add the shape to the database.
561
+ with env.begin(write=True) as txn:
562
+ key = pickle.dumps("shape")
563
+ value = pickle.dumps(shape)
564
+ txn.put(key=key, value=value, dupdata=False)
565
+ # Add the dtype to the database.
566
+ with env.begin(write=True) as txn:
567
+ key = pickle.dumps("dtype")
568
+ value = pickle.dumps(dtype)
569
+ txn.put(key=key, value=value, dupdata=False)
570
+ # Add the values to the database.
571
+ with env.begin(write=True) as txn:
572
+ for key, value in tqdm(enumerate(iter(diter)), total=n, miniters=n//100, mininterval=300):
573
+ key = pickle.dumps(key)
574
+ value = pickle.dumps(value)
575
+ txn.put(key=key, value=value, dupdata=False)
576
+
577
+ # Move the database to its destination.
578
+ copy_tree(str(tmp_dir), str(outdir))
579
+ shutil.rmtree(str(tmp_dir))
580
+
581
+
582
+
583
+ class TensorDatabase(ArrayDatabase):
584
+ def _decode_value(self, value: bytes) -> Tensor:
585
+ return torch.from_numpy(super(TensorDatabase, self)._decode_value(value))
586
+
587
+ def _decode_values(self, values: list) -> Tensor:
588
+ return torch.from_numpy(super(TensorDatabase, self)._decode_values(values))
tools/ecc.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bchlib
2
+ import numpy as np
3
+ from typing import List, Tuple
4
+ import random
5
+ from copy import deepcopy
6
+
7
+ class RSC(object):
8
+ def __init__(self, data_bytes=16, ecc_bytes=4, verbose=False, **kwargs):
9
+ from reedsolo import RSCodec
10
+ self.rs = RSCodec(ecc_bytes)
11
+ if verbose:
12
+ print(f'Reed-Solomon ECC len: {ecc_bytes*8} bits')
13
+ self.data_len = data_bytes
14
+ self.dlen = data_bytes * 8 # data length in bits
15
+ self.ecc_len = ecc_bytes * 8 # ecc length in bits
16
+
17
+ def get_total_len(self):
18
+ return self.dlen + self.ecc_len
19
+
20
+ def encode_text(self, text: List[str]):
21
+ return np.array([self._encode_text(t) for t in text])
22
+
23
+ def _encode_text(self, text: str):
24
+ text = text + ' ' * (self.dlen // 8 - len(text))
25
+ out = self.rs.encode(text.encode('utf-8')) # bytearray
26
+ out = ''.join(format(x, '08b') for x in out) # bit string
27
+ out = np.array([int(x) for x in out], dtype=np.float32)
28
+ return out
29
+
30
+ def decode_text(self, data: np.array):
31
+ assert len(data.shape)==2
32
+ return [self._decode_text(d) for d in data]
33
+
34
+ def _decode_text(self, data: np.array):
35
+ assert len(data.shape)==1
36
+ data = ''.join([str(int(bit)) for bit in data])
37
+ data = bytes(int(data[i: i + 8], 2) for i in range(0, len(data), 8))
38
+ data = bytearray(data)
39
+ try:
40
+ data = self.rs.decode(data)[0]
41
+ data = data.decode('utf-8').strip()
42
+ except:
43
+ print('Error: Decode failed')
44
+ data = get_random_unicode(self.get_total_len()//8)
45
+
46
+ return data
47
+
48
+ def get_random_unicode(length):
49
+ # Update this to include code point ranges to be sampled
50
+ include_ranges = [
51
+ ( 0x0021, 0x0021 ),
52
+ ( 0x0023, 0x0026 ),
53
+ ( 0x0028, 0x007E ),
54
+ ( 0x00A1, 0x00AC ),
55
+ ( 0x00AE, 0x00FF ),
56
+ ( 0x0100, 0x017F ),
57
+ ( 0x0180, 0x024F ),
58
+ ( 0x2C60, 0x2C7F ),
59
+ ( 0x16A0, 0x16F0 ),
60
+ ( 0x0370, 0x0377 ),
61
+ ( 0x037A, 0x037E ),
62
+ ( 0x0384, 0x038A ),
63
+ ( 0x038C, 0x038C ),
64
+ ]
65
+ alphabet = [
66
+ chr(code_point) for current_range in include_ranges
67
+ for code_point in range(current_range[0], current_range[1] + 1)
68
+ ]
69
+ return ''.join(random.choice(alphabet) for i in range(length))
70
+
71
+
72
+ class BCH(object):
73
+ def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, payload_len=100, verbose=True,**kwargs):
74
+ self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
75
+ self.payload_len = payload_len # in bits
76
+ self.data_len = (self.payload_len - self.bch.ecc_bytes*8)//7 # in ascii characters
77
+ assert self.data_len*7+self.bch.ecc_bytes*8 <= self.bch.n, f'Error! BCH with poly {BCH_POLYNOMIAL} and bits {BCH_BITS} can only encode max {self.bch.n//8} bytes of total payload'
78
+ if verbose:
79
+ print(f'BCH: POLYNOMIAL={BCH_POLYNOMIAL}, protected bits={BCH_BITS}, payload_len={payload_len} bits, data_len={self.data_len*7} bits ({self.data_len} ascii chars), ecc len={self.bch.ecc_bytes*8} bits')
80
+
81
+ def get_total_len(self):
82
+ return self.payload_len
83
+
84
+ def encode_text(self, text: List[str]):
85
+ return np.array([self._encode_text(t) for t in text])
86
+
87
+ def _encode_text(self, text: str):
88
+ text = text + ' ' * (self.data_len - len(text))
89
+ # data = text.encode('utf-8') # bytearray
90
+ data = encode_text_ascii(text) # bytearray
91
+ ecc = self.bch.encode(data) # bytearray
92
+ packet = data + ecc # payload in bytearray
93
+ packet = ''.join(format(x, '08b') for x in packet)
94
+ packet = [int(x) for x in packet]
95
+ packet.extend([0]*(self.payload_len - len(packet)))
96
+ packet = np.array(packet, dtype=np.float32)
97
+ return packet
98
+
99
+ def decode_text(self, data: np.array):
100
+ assert len(data.shape)==2
101
+ return [self._decode_text(d) for d in data]
102
+
103
+ def _decode_text(self, packet: np.array):
104
+ assert len(packet.shape)==1
105
+ packet = ''.join([str(int(bit)) for bit in packet]) # bit string
106
+ packet = packet[:(len(packet)//8*8)] # trim to multiple of 8 bits
107
+ packet = bytes(int(packet[i: i + 8], 2) for i in range(0, len(packet), 8))
108
+ packet = bytearray(packet)
109
+ # assert len(packet) == self.data_len + self.bch.ecc_bytes
110
+ data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
111
+ data0 = decode_text_ascii(deepcopy(data)).strip()
112
+ bitflips = self.bch.decode_inplace(data, ecc)
113
+ if bitflips == -1: # error, return random text
114
+ data = data0
115
+ else:
116
+ # data = data.decode('utf-8').strip()
117
+ data = decode_text_ascii(data).strip()
118
+ return data
119
+
120
+
121
+ def encode_text_ascii(text: str):
122
+ # encode text to 7-bit ascii
123
+ # input: text, str
124
+ # output: encoded text, bytearray
125
+ text_int7 = [ord(t) & 127 for t in text]
126
+ text_bitstr = ''.join(format(t,'07b') for t in text_int7)
127
+ if len(text_bitstr) % 8 != 0:
128
+ text_bitstr = '0'*(8-len(text_bitstr)%8) + text_bitstr # pad to multiple of 8
129
+ text_int8 = [int(text_bitstr[i:i+8], 2) for i in range(0, len(text_bitstr), 8)]
130
+ return bytearray(text_int8)
131
+
132
+
133
+ def decode_text_ascii(text: bytearray):
134
+ # decode text from 7-bit ascii
135
+ # input: text, bytearray
136
+ # output: decoded text, str
137
+ text_bitstr = ''.join(format(t,'08b') for t in text) # bit string
138
+ pad = len(text_bitstr) % 7
139
+ if pad != 0: # has padding, remove
140
+ text_bitstr = text_bitstr[pad:]
141
+ text_int7 = [int(text_bitstr[i:i+7], 2) for i in range(0, len(text_bitstr), 7)]
142
+ text_bytes = bytes(text_int7)
143
+ return text_bytes.decode('utf-8')
144
+
145
+
146
+ class ECC(object):
147
+ def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, **kwargs):
148
+ self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
149
+
150
+ def get_total_len(self):
151
+ return 100
152
+
153
+ def _encode(self, x):
154
+ # x: 56 bits, {0, 1}, np.array
155
+ # return: 100 bits, {0, 1}, np.array
156
+ dlen = len(x)
157
+ data_str = ''.join(str(x) for x in x.astype(int))
158
+ packet = bytes(int(data_str[i: i + 8], 2) for i in range(0, dlen, 8))
159
+ packet = bytearray(packet)
160
+ ecc = self.bch.encode(packet)
161
+ packet = packet + ecc # 96 bits
162
+ packet = ''.join(format(x, '08b') for x in packet)
163
+ packet = [int(x) for x in packet]
164
+ packet.extend([0, 0, 0, 0])
165
+ packet = np.array(packet, dtype=np.float32) # 100
166
+ return packet
167
+
168
+ def _decode(self, x):
169
+ # x: 100 bits, {0, 1}, np.array
170
+ # return: 56 bits, {0, 1}, np.array
171
+ packet_binary = "".join([str(int(bit)) for bit in x])
172
+ packet = bytes(int(packet_binary[i: i + 8], 2) for i in range(0, len(packet_binary), 8))
173
+ packet = bytearray(packet)
174
+
175
+ data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
176
+ bitflips = self.bch.decode_inplace(data, ecc)
177
+ if bitflips == -1: # error, return random data
178
+ data = np.random.binomial(1, .5, 56)
179
+ else:
180
+ data = ''.join(format(x, '08b') for x in data)
181
+ data = np.array([int(x) for x in data], dtype=np.float32)
182
+ return data # 56 bits
183
+
184
+ def _generate(self):
185
+ dlen = 56
186
+ data= np.random.binomial(1, .5, dlen)
187
+ packet = self._encode(data)
188
+ return packet, data
189
+
190
+ def generate(self, nsamples=1):
191
+ # generate random 56 bit secret
192
+ data = [self._generate() for _ in range(nsamples)]
193
+ data = (np.array([d[0] for d in data]), np.array([d[1] for d in data]))
194
+ return data # data with ecc, data org
195
+
196
+ def _to_text(self, data):
197
+ # data: {0, 1}, np.array
198
+ # return: str
199
+ data = ''.join([str(int(bit)) for bit in data])
200
+ all_bytes = [ data[i: i+8] for i in range(0, len(data), 8) ]
201
+ text = ''.join([chr(int(byte, 2)) for byte in all_bytes])
202
+ return text.strip()
203
+
204
+ def _to_binary(self, s):
205
+ if isinstance(s, str):
206
+ out = ''.join([ format(ord(i), "08b") for i in s ])
207
+ elif isinstance(s, bytes):
208
+ out = ''.join([ format(i, "08b") for i in s ])
209
+ elif isinstance(s, np.ndarray) and s.dtype is np.dtype(bool):
210
+ out = ''.join([chr(int(i)) for i in s])
211
+ elif isinstance(s, int) or isinstance(s, np.uint8):
212
+ out = format(s, "08b")
213
+ elif isinstance(s, np.ndarray):
214
+ out = [ format(i, "08b") for i in s ]
215
+ else:
216
+ raise TypeError("Type not supported.")
217
+
218
+ return np.array([float(i) for i in out], dtype=np.float32)
219
+
220
+ def _encode_text(self, s):
221
+ s = s + ' '*(7-len(s)) # 7 chars
222
+ s = self._to_binary(s) # 56 bits
223
+ packet = self._encode(s) # 100 bits
224
+ return packet, s
225
+
226
+ def encode_text(self, secret_list, return_pre_ecc=False):
227
+ """encode secret with BCH ECC.
228
+ Input: secret (list of strings)
229
+ Output: secret (np array) with shape (B, 100) type float23, val {0,1}"""
230
+ assert np.all(np.array([len(s) for s in secret_list]) <= 7), 'Error! all strings must be less than 7 characters'
231
+ secret_list = [self._encode_text(s) for s in secret_list]
232
+ ecc = np.array([s[0] for s in secret_list], dtype=np.float32)
233
+ if return_pre_ecc:
234
+ return ecc, np.array([s[1] for s in secret_list], dtype=np.float32)
235
+ return ecc
236
+
237
+ def decode_text(self, data):
238
+ """Decode secret with BCH ECC and convert to string.
239
+ Input: secret (torch.tensor) with shape (B, 100) type bool
240
+ Output: secret (B, 56)"""
241
+ data = self.decode(data)
242
+ data = [self._to_text(d) for d in data]
243
+ return data
244
+
245
+ def decode(self, data):
246
+ """Decode secret with BCH ECC and convert to string.
247
+ Input: secret (torch.tensor) with shape (B, 100) type bool
248
+ Output: secret (B, 56)"""
249
+ data = data[:, :96]
250
+ data = [self._decode(d) for d in data]
251
+ return np.array(data)
252
+
253
+ def test_ecc():
254
+ ecc = ECC()
255
+ batch_size = 10
256
+ secret_ecc, secret_org = ecc.generate(batch_size) # 10x100 ecc secret, 10x56 org secret
257
+ # modify secret_ecc
258
+ secret_pred = secret_ecc.copy()
259
+ secret_pred[:,3:6] = 1 - secret_pred[:,3:6]
260
+ # pass secret_ecc to model and get predicted as secret_pred
261
+ secret_pred_org = ecc.decode(secret_pred) # 10x56
262
+ assert np.all(secret_pred_org == secret_org) # 10
263
+
264
+
265
+ def test_bch():
266
+ # test 100 bit
267
+ def check(text, poly, k, l):
268
+ bch = BCH(poly, k, l)
269
+ # text = 'secrets'
270
+ encode = bch.encode_text([text])
271
+ for ind in np.random.choice(l, k):
272
+ encode[0, ind] = 1 - encode[0, ind]
273
+ text_recon = bch.decode_text(encode)[0]
274
+ assert text==text_recon
275
+
276
+ check('secrets', 137, 5, 100)
277
+ check('some secret', 285, 10, 160)
278
+
279
+ if __name__ == '__main__':
280
+ test_ecc()
281
+ test_bch()
tools/eval_metrics.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import skimage.metrics
4
+ import lpips
5
+ from PIL import Image
6
+ from .sifid import SIFID
7
+
8
+
9
+ def resize_array(x, size=256):
10
+ """
11
+ Resize image array to given size.
12
+ Args:
13
+ x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
14
+ size (int): Size of output image.
15
+ Returns:
16
+ (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
17
+ """
18
+ if x.shape[1] != size:
19
+ x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])]
20
+ x = np.array([np.array(i) for i in x])
21
+ return x
22
+
23
+
24
+ def resize_tensor(x, size=256):
25
+ """
26
+ Resize image tensor to given size.
27
+ Args:
28
+ x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
29
+ size (int): Size of output image.
30
+ Returns:
31
+ (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
32
+ """
33
+ if x.shape[2] != size:
34
+ x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
35
+ return x
36
+
37
+
38
+ def normalise(x):
39
+ """
40
+ Normalise image array to range [-1, 1] and tensor.
41
+ Args:
42
+ x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
43
+ Returns:
44
+ (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
45
+ """
46
+ x = x.astype(np.float32)
47
+ x = x / 255
48
+ x = (x - 0.5) / 0.5
49
+ x = torch.from_numpy(x)
50
+ x = x.permute(0, 3, 1, 2)
51
+ return x
52
+
53
+
54
+ def unormalise(x, vrange=[-1, 1]):
55
+ """
56
+ Unormalise image tensor to range [0, 255] and RGB array.
57
+ Args:
58
+ x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
59
+ Returns:
60
+ (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
61
+ """
62
+ x = (x - vrange[0])/(vrange[1] - vrange[0])
63
+ x = x * 255
64
+ x = x.permute(0, 2, 3, 1)
65
+ x = x.cpu().numpy().astype(np.uint8)
66
+ return x
67
+
68
+
69
+ def compute_mse(x, y):
70
+ """
71
+ Compute mean squared error between two image arrays.
72
+ Args:
73
+ x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
74
+ y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
75
+ Returns:
76
+ (1darray): Mean squared error.
77
+ """
78
+ return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1)
79
+
80
+
81
+ def compute_psnr(x, y):
82
+ """
83
+ Compute peak signal-to-noise ratio between two images.
84
+ Args:
85
+ x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
86
+ y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
87
+ Returns:
88
+ (float): Peak signal-to-noise ratio.
89
+ """
90
+ return 10 * np.log10(255 ** 2 / compute_mse(x, y))
91
+
92
+
93
+ def compute_ssim(x, y):
94
+ """
95
+ Compute structural similarity index between two images.
96
+ Args:
97
+ x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
98
+ y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
99
+ Returns:
100
+ (float): Structural similarity index.
101
+ """
102
+ return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)])
103
+
104
+
105
+ def compute_lpips(x, y, net='alex'):
106
+ """
107
+ Compute LPIPS between two images.
108
+ Args:
109
+ x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
110
+ y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
111
+ Returns:
112
+ (float): LPIPS.
113
+ """
114
+ lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
115
+ x, y = x.cuda(), y.cuda()
116
+ return lpips_fn(x, y).detach().cpu().numpy().squeeze()
117
+
118
+
119
+ def compute_sifid(x, y, net=None):
120
+ """
121
+ Compute SIFID between two images.
122
+ Args:
123
+ x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
124
+ y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
125
+ Returns:
126
+ (float): SIFID.
127
+ """
128
+ fn = SIFID() if net is None else net
129
+ out = [fn(xi, yi) for xi, yi in zip(x, y)]
130
+ return np.array(out)
tools/fid.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torchvision.transforms as TF
41
+ from PIL import Image
42
+ from scipy import linalg
43
+ from torch.nn.functional import adaptive_avg_pool2d
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ import torchvision
47
+
48
+ try:
49
+ from tqdm import tqdm
50
+ except ImportError:
51
+ # If tqdm is not available, provide a mock version of it
52
+ def tqdm(x):
53
+ return x
54
+
55
+
56
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
57
+ 'tif', 'tiff', 'webp'}
58
+
59
+
60
+ try:
61
+ from torchvision.models.utils import load_state_dict_from_url
62
+ except ImportError:
63
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
64
+
65
+ # Inception weights ported to Pytorch from
66
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
67
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
68
+
69
+
70
+ class InceptionV3(nn.Module):
71
+ """Pretrained InceptionV3 network returning feature maps"""
72
+
73
+ # Index of default block of inception to return,
74
+ # corresponds to output of final average pooling
75
+ DEFAULT_BLOCK_INDEX = 3
76
+
77
+ # Maps feature dimensionality to their output blocks indices
78
+ BLOCK_INDEX_BY_DIM = {
79
+ 64: 0, # First max pooling features
80
+ 192: 1, # Second max pooling featurs
81
+ 768: 2, # Pre-aux classifier features
82
+ 2048: 3 # Final average pooling features
83
+ }
84
+
85
+ def __init__(self,
86
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
87
+ resize_input=True,
88
+ normalize_input=True,
89
+ requires_grad=False,
90
+ use_fid_inception=True):
91
+ """Build pretrained InceptionV3
92
+
93
+ Parameters
94
+ ----------
95
+ output_blocks : list of int
96
+ Indices of blocks to return features of. Possible values are:
97
+ - 0: corresponds to output of first max pooling
98
+ - 1: corresponds to output of second max pooling
99
+ - 2: corresponds to output which is fed to aux classifier
100
+ - 3: corresponds to output of final average pooling
101
+ resize_input : bool
102
+ If true, bilinearly resizes input to width and height 299 before
103
+ feeding input to model. As the network without fully connected
104
+ layers is fully convolutional, it should be able to handle inputs
105
+ of arbitrary size, so resizing might not be strictly needed
106
+ normalize_input : bool
107
+ If true, scales the input from range (0, 1) to the range the
108
+ pretrained Inception network expects, namely (-1, 1)
109
+ requires_grad : bool
110
+ If true, parameters of the model require gradients. Possibly useful
111
+ for finetuning the network
112
+ use_fid_inception : bool
113
+ If true, uses the pretrained Inception model used in Tensorflow's
114
+ FID implementation. If false, uses the pretrained Inception model
115
+ available in torchvision. The FID Inception model has different
116
+ weights and a slightly different structure from torchvision's
117
+ Inception model. If you want to compute FID scores, you are
118
+ strongly advised to set this parameter to true to get comparable
119
+ results.
120
+ """
121
+ super(InceptionV3, self).__init__()
122
+
123
+ self.resize_input = resize_input
124
+ self.normalize_input = normalize_input
125
+ self.output_blocks = sorted(output_blocks)
126
+ self.last_needed_block = max(output_blocks)
127
+
128
+ assert self.last_needed_block <= 3, \
129
+ 'Last possible output block index is 3'
130
+
131
+ self.blocks = nn.ModuleList()
132
+
133
+ if use_fid_inception:
134
+ inception = fid_inception_v3()
135
+ else:
136
+ inception = _inception_v3(weights='DEFAULT')
137
+
138
+ # Block 0: input to maxpool1
139
+ block0 = [
140
+ inception.Conv2d_1a_3x3,
141
+ inception.Conv2d_2a_3x3,
142
+ inception.Conv2d_2b_3x3,
143
+ nn.MaxPool2d(kernel_size=3, stride=2)
144
+ ]
145
+ self.blocks.append(nn.Sequential(*block0))
146
+
147
+ # Block 1: maxpool1 to maxpool2
148
+ if self.last_needed_block >= 1:
149
+ block1 = [
150
+ inception.Conv2d_3b_1x1,
151
+ inception.Conv2d_4a_3x3,
152
+ nn.MaxPool2d(kernel_size=3, stride=2)
153
+ ]
154
+ self.blocks.append(nn.Sequential(*block1))
155
+
156
+ # Block 2: maxpool2 to aux classifier
157
+ if self.last_needed_block >= 2:
158
+ block2 = [
159
+ inception.Mixed_5b,
160
+ inception.Mixed_5c,
161
+ inception.Mixed_5d,
162
+ inception.Mixed_6a,
163
+ inception.Mixed_6b,
164
+ inception.Mixed_6c,
165
+ inception.Mixed_6d,
166
+ inception.Mixed_6e,
167
+ ]
168
+ self.blocks.append(nn.Sequential(*block2))
169
+
170
+ # Block 3: aux classifier to final avgpool
171
+ if self.last_needed_block >= 3:
172
+ block3 = [
173
+ inception.Mixed_7a,
174
+ inception.Mixed_7b,
175
+ inception.Mixed_7c,
176
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
177
+ ]
178
+ self.blocks.append(nn.Sequential(*block3))
179
+
180
+ for param in self.parameters():
181
+ param.requires_grad = requires_grad
182
+
183
+ def forward(self, inp):
184
+ """Get Inception feature maps
185
+
186
+ Parameters
187
+ ----------
188
+ inp : torch.autograd.Variable
189
+ Input tensor of shape Bx3xHxW. Values are expected to be in
190
+ range (0, 1)
191
+
192
+ Returns
193
+ -------
194
+ List of torch.autograd.Variable, corresponding to the selected output
195
+ block, sorted ascending by index
196
+ """
197
+ outp = []
198
+ x = inp
199
+
200
+ if self.resize_input:
201
+ x = F.interpolate(x,
202
+ size=(299, 299),
203
+ mode='bilinear',
204
+ align_corners=False)
205
+
206
+ if self.normalize_input:
207
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
208
+
209
+ for idx, block in enumerate(self.blocks):
210
+ x = block(x)
211
+ if idx in self.output_blocks:
212
+ outp.append(x)
213
+
214
+ if idx == self.last_needed_block:
215
+ break
216
+
217
+ return outp
218
+
219
+
220
+ def _inception_v3(*args, **kwargs):
221
+ """Wraps `torchvision.models.inception_v3`"""
222
+ try:
223
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
224
+ except ValueError:
225
+ # Just a caution against weird version strings
226
+ version = (0,)
227
+
228
+ # Skips default weight inititialization if supported by torchvision
229
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
230
+ if version >= (0, 6):
231
+ kwargs['init_weights'] = False
232
+
233
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
234
+ # argument prior to version 0.13.
235
+ if version < (0, 13) and 'weights' in kwargs:
236
+ if kwargs['weights'] == 'DEFAULT':
237
+ kwargs['pretrained'] = True
238
+ elif kwargs['weights'] is None:
239
+ kwargs['pretrained'] = False
240
+ else:
241
+ raise ValueError(
242
+ 'weights=={} not supported in torchvision {}'.format(
243
+ kwargs['weights'], torchvision.__version__
244
+ )
245
+ )
246
+ del kwargs['weights']
247
+
248
+ return torchvision.models.inception_v3(*args, **kwargs)
249
+
250
+
251
+ def fid_inception_v3():
252
+ """Build pretrained Inception model for FID computation
253
+
254
+ The Inception model for FID computation uses a different set of weights
255
+ and has a slightly different structure than torchvision's Inception.
256
+
257
+ This method first constructs torchvision's Inception and then patches the
258
+ necessary parts that are different in the FID Inception model.
259
+ """
260
+ inception = _inception_v3(num_classes=1008,
261
+ aux_logits=False,
262
+ weights=None)
263
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
264
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
265
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
266
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
267
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
268
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
269
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
270
+ inception.Mixed_7b = FIDInceptionE_1(1280)
271
+ inception.Mixed_7c = FIDInceptionE_2(2048)
272
+
273
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
274
+ inception.load_state_dict(state_dict)
275
+ return inception
276
+
277
+
278
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
279
+ """InceptionA block patched for FID computation"""
280
+ def __init__(self, in_channels, pool_features):
281
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
282
+
283
+ def forward(self, x):
284
+ branch1x1 = self.branch1x1(x)
285
+
286
+ branch5x5 = self.branch5x5_1(x)
287
+ branch5x5 = self.branch5x5_2(branch5x5)
288
+
289
+ branch3x3dbl = self.branch3x3dbl_1(x)
290
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
291
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
292
+
293
+ # Patch: Tensorflow's average pool does not use the padded zero's in
294
+ # its average calculation
295
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
296
+ count_include_pad=False)
297
+ branch_pool = self.branch_pool(branch_pool)
298
+
299
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
300
+ return torch.cat(outputs, 1)
301
+
302
+
303
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
304
+ """InceptionC block patched for FID computation"""
305
+ def __init__(self, in_channels, channels_7x7):
306
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
307
+
308
+ def forward(self, x):
309
+ branch1x1 = self.branch1x1(x)
310
+
311
+ branch7x7 = self.branch7x7_1(x)
312
+ branch7x7 = self.branch7x7_2(branch7x7)
313
+ branch7x7 = self.branch7x7_3(branch7x7)
314
+
315
+ branch7x7dbl = self.branch7x7dbl_1(x)
316
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
317
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
318
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
319
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
320
+
321
+ # Patch: Tensorflow's average pool does not use the padded zero's in
322
+ # its average calculation
323
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
324
+ count_include_pad=False)
325
+ branch_pool = self.branch_pool(branch_pool)
326
+
327
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
328
+ return torch.cat(outputs, 1)
329
+
330
+
331
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
332
+ """First InceptionE block patched for FID computation"""
333
+ def __init__(self, in_channels):
334
+ super(FIDInceptionE_1, self).__init__(in_channels)
335
+
336
+ def forward(self, x):
337
+ branch1x1 = self.branch1x1(x)
338
+
339
+ branch3x3 = self.branch3x3_1(x)
340
+ branch3x3 = [
341
+ self.branch3x3_2a(branch3x3),
342
+ self.branch3x3_2b(branch3x3),
343
+ ]
344
+ branch3x3 = torch.cat(branch3x3, 1)
345
+
346
+ branch3x3dbl = self.branch3x3dbl_1(x)
347
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
348
+ branch3x3dbl = [
349
+ self.branch3x3dbl_3a(branch3x3dbl),
350
+ self.branch3x3dbl_3b(branch3x3dbl),
351
+ ]
352
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
353
+
354
+ # Patch: Tensorflow's average pool does not use the padded zero's in
355
+ # its average calculation
356
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
357
+ count_include_pad=False)
358
+ branch_pool = self.branch_pool(branch_pool)
359
+
360
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
361
+ return torch.cat(outputs, 1)
362
+
363
+
364
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
365
+ """Second InceptionE block patched for FID computation"""
366
+ def __init__(self, in_channels):
367
+ super(FIDInceptionE_2, self).__init__(in_channels)
368
+
369
+ def forward(self, x):
370
+ branch1x1 = self.branch1x1(x)
371
+
372
+ branch3x3 = self.branch3x3_1(x)
373
+ branch3x3 = [
374
+ self.branch3x3_2a(branch3x3),
375
+ self.branch3x3_2b(branch3x3),
376
+ ]
377
+ branch3x3 = torch.cat(branch3x3, 1)
378
+
379
+ branch3x3dbl = self.branch3x3dbl_1(x)
380
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
381
+ branch3x3dbl = [
382
+ self.branch3x3dbl_3a(branch3x3dbl),
383
+ self.branch3x3dbl_3b(branch3x3dbl),
384
+ ]
385
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
386
+
387
+ # Patch: The FID Inception model uses max pooling instead of average
388
+ # pooling. This is likely an error in this specific Inception
389
+ # implementation, as other Inception models use average pooling here
390
+ # (which matches the description in the paper).
391
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
392
+ branch_pool = self.branch_pool(branch_pool)
393
+
394
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
395
+ return torch.cat(outputs, 1)
396
+
397
+ class ImagePathDataset(torch.utils.data.Dataset):
398
+ def __init__(self, files, transforms=None):
399
+ self.files = files
400
+ self.transforms = transforms
401
+
402
+ def __len__(self):
403
+ return len(self.files)
404
+
405
+ def __getitem__(self, i):
406
+ path = self.files[i]
407
+ img = Image.open(path).convert('RGB')
408
+ if self.transforms is not None:
409
+ img = self.transforms(img)
410
+ return img
411
+
412
+
413
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
414
+ num_workers=1, resize=0):
415
+ """Calculates the activations of the pool_3 layer for all images.
416
+
417
+ Params:
418
+ -- files : List of image files paths
419
+ -- model : Instance of inception model
420
+ -- batch_size : Batch size of images for the model to process at once.
421
+ Make sure that the number of samples is a multiple of
422
+ the batch size, otherwise some samples are ignored. This
423
+ behavior is retained to match the original FID score
424
+ implementation.
425
+ -- dims : Dimensionality of features returned by Inception
426
+ -- device : Device to run calculations
427
+ -- num_workers : Number of parallel dataloader workers
428
+
429
+ Returns:
430
+ -- A numpy array of dimension (num images, dims) that contains the
431
+ activations of the given tensor when feeding inception with the
432
+ query tensor.
433
+ """
434
+ model.eval()
435
+
436
+ if batch_size > len(files):
437
+ print(('Warning: batch size is bigger than the data size. '
438
+ 'Setting batch size to data size'))
439
+ batch_size = len(files)
440
+ if resize > 0:
441
+ tform = TF.Compose([TF.Resize((resize, resize)), TF.ToTensor()])
442
+ else:
443
+ tform = TF.ToTensor()
444
+ dataset = ImagePathDataset(files, transforms=tform)
445
+ dataloader = torch.utils.data.DataLoader(dataset,
446
+ batch_size=batch_size,
447
+ shuffle=False,
448
+ drop_last=False,
449
+ num_workers=num_workers)
450
+
451
+ pred_arr = np.empty((len(files), dims))
452
+
453
+ start_idx = 0
454
+
455
+ for batch in tqdm(dataloader):
456
+ batch = batch.to(device)
457
+
458
+ with torch.no_grad():
459
+ pred = model(batch)[0]
460
+
461
+ # If model output is not scalar, apply global spatial average pooling.
462
+ # This happens if you choose a dimensionality not equal 2048.
463
+ if pred.size(2) != 1 or pred.size(3) != 1:
464
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
465
+
466
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
467
+
468
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
469
+
470
+ start_idx = start_idx + pred.shape[0]
471
+
472
+ return pred_arr
473
+
474
+
475
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
476
+ """Numpy implementation of the Frechet Distance.
477
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
478
+ and X_2 ~ N(mu_2, C_2) is
479
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
480
+
481
+ Stable version by Dougal J. Sutherland.
482
+
483
+ Params:
484
+ -- mu1 : Numpy array containing the activations of a layer of the
485
+ inception net (like returned by the function 'get_predictions')
486
+ for generated samples.
487
+ -- mu2 : The sample mean over activations, precalculated on an
488
+ representative data set.
489
+ -- sigma1: The covariance matrix over activations for generated samples.
490
+ -- sigma2: The covariance matrix over activations, precalculated on an
491
+ representative data set.
492
+
493
+ Returns:
494
+ -- : The Frechet Distance.
495
+ """
496
+
497
+ mu1 = np.atleast_1d(mu1)
498
+ mu2 = np.atleast_1d(mu2)
499
+
500
+ sigma1 = np.atleast_2d(sigma1)
501
+ sigma2 = np.atleast_2d(sigma2)
502
+
503
+ assert mu1.shape == mu2.shape, \
504
+ 'Training and test mean vectors have different lengths'
505
+ assert sigma1.shape == sigma2.shape, \
506
+ 'Training and test covariances have different dimensions'
507
+
508
+ diff = mu1 - mu2
509
+
510
+ # Product might be almost singular
511
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
512
+ if not np.isfinite(covmean).all():
513
+ msg = ('fid calculation produces singular product; '
514
+ 'adding %s to diagonal of cov estimates') % eps
515
+ print(msg)
516
+ offset = np.eye(sigma1.shape[0]) * eps
517
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
518
+
519
+ # Numerical error might give slight imaginary component
520
+ if np.iscomplexobj(covmean):
521
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
522
+ m = np.max(np.abs(covmean.imag))
523
+ raise ValueError('Imaginary component {}'.format(m))
524
+ covmean = covmean.real
525
+
526
+ tr_covmean = np.trace(covmean)
527
+
528
+ return (diff.dot(diff) + np.trace(sigma1)
529
+ + np.trace(sigma2) - 2 * tr_covmean)
530
+
531
+
532
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
533
+ device='cpu', num_workers=1, resize=0):
534
+ """Calculation of the statistics used by the FID.
535
+ Params:
536
+ -- files : List of image files paths
537
+ -- model : Instance of inception model
538
+ -- batch_size : The images numpy array is split into batches with
539
+ batch size batch_size. A reasonable batch size
540
+ depends on the hardware.
541
+ -- dims : Dimensionality of features returned by Inception
542
+ -- device : Device to run calculations
543
+ -- num_workers : Number of parallel dataloader workers
544
+
545
+ Returns:
546
+ -- mu : The mean over samples of the activations of the pool_3 layer of
547
+ the inception model.
548
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
549
+ the inception model.
550
+ """
551
+ act = get_activations(files, model, batch_size, dims, device, num_workers, resize)
552
+ mu = np.mean(act, axis=0)
553
+ sigma = np.cov(act, rowvar=False)
554
+ return mu, sigma
555
+
556
+
557
+ def compute_statistics_of_path(path, model, batch_size, dims, device,
558
+ num_workers=1, nimages=None, resize=0):
559
+ if path.endswith('.npz'):
560
+ with np.load(path) as f:
561
+ m, s = f['mu'][:], f['sigma'][:]
562
+ else:
563
+ path = pathlib.Path(path)
564
+
565
+ files = sorted([file for ext in IMAGE_EXTENSIONS
566
+ for file in path.glob('**/*.{}'.format(ext))])
567
+ nfiles = len(files)
568
+ n = nfiles if nimages is None else min(nimages, nfiles)
569
+ print(f'Found {nfiles} images. Computing FID with {n} images.')
570
+ files = files[:n]
571
+ m, s = calculate_activation_statistics(files, model, batch_size,
572
+ dims, device, num_workers, resize)
573
+
574
+ return m, s
575
+
576
+
577
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
578
+ """Calculates the FID of two paths"""
579
+ for p in paths:
580
+ if not os.path.exists(p):
581
+ raise RuntimeError('Invalid path: %s' % p)
582
+
583
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
584
+
585
+ model = InceptionV3([block_idx]).to(device)
586
+
587
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
588
+ dims, device, num_workers, nimages, resize)
589
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
590
+ dims, device, num_workers, nimages, resize)
591
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
592
+
593
+ return fid_value
594
+
595
+
596
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
597
+ """Calculates the FID of two paths"""
598
+ if not os.path.exists(paths[0]):
599
+ raise RuntimeError('Invalid path: %s' % paths[0])
600
+
601
+ if os.path.exists(paths[1]):
602
+ raise RuntimeError('Existing output file: %s' % paths[1])
603
+
604
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
605
+
606
+ model = InceptionV3([block_idx]).to(device)
607
+
608
+ print(f"Saving statistics for {paths[0]}")
609
+
610
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
611
+ dims, device, num_workers, nimages, resize=0)
612
+
613
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
614
+
615
+
616
+ def main():
617
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
618
+ parser.add_argument('--batch-size', type=int, default=20,
619
+ help='Batch size to use')
620
+ parser.add_argument('--num-workers', type=int,
621
+ help=('Number of processes to use for data loading. '
622
+ 'Defaults to `min(8, num_cpus)`'))
623
+ parser.add_argument('--device', type=str, default='cuda:0',
624
+ help='Device to use. Like cuda, cuda:0 or cpu')
625
+ parser.add_argument('--dims', type=int, default=2048,
626
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
627
+ help=('Dimensionality of Inception features to use. '
628
+ 'By default, uses pool3 features'))
629
+ parser.add_argument('--nimages', type=int, default=50000, help='max number of images to use')
630
+ parser.add_argument('--resize', type=int, default=0, help='resize images to this size, 0 mean keep original size')
631
+ parser.add_argument('--save-stats', action='store_true',
632
+ help=('Generate an npz archive from a directory of samples. '
633
+ 'The first path is used as input and the second as output.'))
634
+ parser.add_argument('path', type=str, nargs=2,
635
+ help=('Paths to the generated images or '
636
+ 'to .npz statistic files'))
637
+ args = parser.parse_args()
638
+
639
+ if args.device is None:
640
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
641
+ else:
642
+ device = torch.device(args.device)
643
+
644
+ if args.num_workers is None:
645
+ try:
646
+ num_cpus = len(os.sched_getaffinity(0))
647
+ except AttributeError:
648
+ # os.sched_getaffinity is not available under Windows, use
649
+ # os.cpu_count instead (which may not return the *available* number
650
+ # of CPUs).
651
+ num_cpus = os.cpu_count()
652
+
653
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
654
+ else:
655
+ num_workers = args.num_workers
656
+
657
+ if args.save_stats:
658
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers, args.nimages, args.resize)
659
+ return
660
+
661
+ fid_value = calculate_fid_given_paths(args.path,
662
+ args.batch_size,
663
+ device,
664
+ args.dims,
665
+ num_workers,
666
+ args.nimages,
667
+ args.resize)
668
+ print('FID: ', fid_value)
669
+
670
+
671
+ if __name__ == '__main__':
672
+ main()
tools/fid_lmdb.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2
+
3
+ The FID metric calculates the distance between two distributions of images.
4
+ Typically, we have summary statistics (mean & covariance matrix) of one
5
+ of these distributions, while the 2nd distribution is given by a GAN.
6
+
7
+ When run as a stand-alone program, it compares the distribution of
8
+ images that are stored as PNG/JPEG at a specified location with a
9
+ distribution given by summary statistics (in pickle format).
10
+
11
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
12
+ the pool_3 layer of the inception net for generated samples and real world
13
+ samples respectively.
14
+
15
+ See --help to see further details.
16
+
17
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
18
+ of Tensorflow
19
+
20
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
21
+
22
+ Licensed under the Apache License, Version 2.0 (the "License");
23
+ you may not use this file except in compliance with the License.
24
+ You may obtain a copy of the License at
25
+
26
+ http://www.apache.org/licenses/LICENSE-2.0
27
+
28
+ Unless required by applicable law or agreed to in writing, software
29
+ distributed under the License is distributed on an "AS IS" BASIS,
30
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
31
+ See the License for the specific language governing permissions and
32
+ limitations under the License.
33
+ """
34
+ import os
35
+ import pathlib
36
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
37
+
38
+ import numpy as np
39
+ import torch
40
+ import torchvision.transforms as TF
41
+ from PIL import Image
42
+ from scipy import linalg
43
+ from torch.nn.functional import adaptive_avg_pool2d
44
+ import torch.nn as nn
45
+ import torch.nn.functional as F
46
+ import torchvision
47
+ import sys
48
+ sys.path.insert(1, '/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet')
49
+ from tools.image_dataset import ImageDataset
50
+ try:
51
+ from tqdm import tqdm
52
+ except ImportError:
53
+ # If tqdm is not available, provide a mock version of it
54
+ def tqdm(x):
55
+ return x
56
+
57
+
58
+ IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
59
+ 'tif', 'tiff', 'webp'}
60
+
61
+
62
+ try:
63
+ from torchvision.models.utils import load_state_dict_from_url
64
+ except ImportError:
65
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
66
+
67
+ # Inception weights ported to Pytorch from
68
+ # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
69
+ FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
70
+
71
+
72
+ class InceptionV3(nn.Module):
73
+ """Pretrained InceptionV3 network returning feature maps"""
74
+
75
+ # Index of default block of inception to return,
76
+ # corresponds to output of final average pooling
77
+ DEFAULT_BLOCK_INDEX = 3
78
+
79
+ # Maps feature dimensionality to their output blocks indices
80
+ BLOCK_INDEX_BY_DIM = {
81
+ 64: 0, # First max pooling features
82
+ 192: 1, # Second max pooling featurs
83
+ 768: 2, # Pre-aux classifier features
84
+ 2048: 3 # Final average pooling features
85
+ }
86
+
87
+ def __init__(self,
88
+ output_blocks=(DEFAULT_BLOCK_INDEX,),
89
+ resize_input=True,
90
+ normalize_input=True,
91
+ requires_grad=False,
92
+ use_fid_inception=True):
93
+ """Build pretrained InceptionV3
94
+
95
+ Parameters
96
+ ----------
97
+ output_blocks : list of int
98
+ Indices of blocks to return features of. Possible values are:
99
+ - 0: corresponds to output of first max pooling
100
+ - 1: corresponds to output of second max pooling
101
+ - 2: corresponds to output which is fed to aux classifier
102
+ - 3: corresponds to output of final average pooling
103
+ resize_input : bool
104
+ If true, bilinearly resizes input to width and height 299 before
105
+ feeding input to model. As the network without fully connected
106
+ layers is fully convolutional, it should be able to handle inputs
107
+ of arbitrary size, so resizing might not be strictly needed
108
+ normalize_input : bool
109
+ If true, scales the input from range (0, 1) to the range the
110
+ pretrained Inception network expects, namely (-1, 1)
111
+ requires_grad : bool
112
+ If true, parameters of the model require gradients. Possibly useful
113
+ for finetuning the network
114
+ use_fid_inception : bool
115
+ If true, uses the pretrained Inception model used in Tensorflow's
116
+ FID implementation. If false, uses the pretrained Inception model
117
+ available in torchvision. The FID Inception model has different
118
+ weights and a slightly different structure from torchvision's
119
+ Inception model. If you want to compute FID scores, you are
120
+ strongly advised to set this parameter to true to get comparable
121
+ results.
122
+ """
123
+ super(InceptionV3, self).__init__()
124
+
125
+ self.resize_input = resize_input
126
+ self.normalize_input = normalize_input
127
+ self.output_blocks = sorted(output_blocks)
128
+ self.last_needed_block = max(output_blocks)
129
+
130
+ assert self.last_needed_block <= 3, \
131
+ 'Last possible output block index is 3'
132
+
133
+ self.blocks = nn.ModuleList()
134
+
135
+ if use_fid_inception:
136
+ inception = fid_inception_v3()
137
+ else:
138
+ inception = _inception_v3(weights='DEFAULT')
139
+
140
+ # Block 0: input to maxpool1
141
+ block0 = [
142
+ inception.Conv2d_1a_3x3,
143
+ inception.Conv2d_2a_3x3,
144
+ inception.Conv2d_2b_3x3,
145
+ nn.MaxPool2d(kernel_size=3, stride=2)
146
+ ]
147
+ self.blocks.append(nn.Sequential(*block0))
148
+
149
+ # Block 1: maxpool1 to maxpool2
150
+ if self.last_needed_block >= 1:
151
+ block1 = [
152
+ inception.Conv2d_3b_1x1,
153
+ inception.Conv2d_4a_3x3,
154
+ nn.MaxPool2d(kernel_size=3, stride=2)
155
+ ]
156
+ self.blocks.append(nn.Sequential(*block1))
157
+
158
+ # Block 2: maxpool2 to aux classifier
159
+ if self.last_needed_block >= 2:
160
+ block2 = [
161
+ inception.Mixed_5b,
162
+ inception.Mixed_5c,
163
+ inception.Mixed_5d,
164
+ inception.Mixed_6a,
165
+ inception.Mixed_6b,
166
+ inception.Mixed_6c,
167
+ inception.Mixed_6d,
168
+ inception.Mixed_6e,
169
+ ]
170
+ self.blocks.append(nn.Sequential(*block2))
171
+
172
+ # Block 3: aux classifier to final avgpool
173
+ if self.last_needed_block >= 3:
174
+ block3 = [
175
+ inception.Mixed_7a,
176
+ inception.Mixed_7b,
177
+ inception.Mixed_7c,
178
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
179
+ ]
180
+ self.blocks.append(nn.Sequential(*block3))
181
+
182
+ for param in self.parameters():
183
+ param.requires_grad = requires_grad
184
+
185
+ def forward(self, inp):
186
+ """Get Inception feature maps
187
+
188
+ Parameters
189
+ ----------
190
+ inp : torch.autograd.Variable
191
+ Input tensor of shape Bx3xHxW. Values are expected to be in
192
+ range (0, 1)
193
+
194
+ Returns
195
+ -------
196
+ List of torch.autograd.Variable, corresponding to the selected output
197
+ block, sorted ascending by index
198
+ """
199
+ outp = []
200
+ x = inp
201
+
202
+ if self.resize_input:
203
+ x = F.interpolate(x,
204
+ size=(299, 299),
205
+ mode='bilinear',
206
+ align_corners=False)
207
+
208
+ if self.normalize_input:
209
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
210
+
211
+ for idx, block in enumerate(self.blocks):
212
+ x = block(x)
213
+ if idx in self.output_blocks:
214
+ outp.append(x)
215
+
216
+ if idx == self.last_needed_block:
217
+ break
218
+
219
+ return outp
220
+
221
+
222
+ def _inception_v3(*args, **kwargs):
223
+ """Wraps `torchvision.models.inception_v3`"""
224
+ try:
225
+ version = tuple(map(int, torchvision.__version__.split('.')[:2]))
226
+ except ValueError:
227
+ # Just a caution against weird version strings
228
+ version = (0,)
229
+
230
+ # Skips default weight inititialization if supported by torchvision
231
+ # version. See https://github.com/mseitzer/pytorch-fid/issues/28.
232
+ if version >= (0, 6):
233
+ kwargs['init_weights'] = False
234
+
235
+ # Backwards compatibility: `weights` argument was handled by `pretrained`
236
+ # argument prior to version 0.13.
237
+ if version < (0, 13) and 'weights' in kwargs:
238
+ if kwargs['weights'] == 'DEFAULT':
239
+ kwargs['pretrained'] = True
240
+ elif kwargs['weights'] is None:
241
+ kwargs['pretrained'] = False
242
+ else:
243
+ raise ValueError(
244
+ 'weights=={} not supported in torchvision {}'.format(
245
+ kwargs['weights'], torchvision.__version__
246
+ )
247
+ )
248
+ del kwargs['weights']
249
+
250
+ return torchvision.models.inception_v3(*args, **kwargs)
251
+
252
+
253
+ def fid_inception_v3():
254
+ """Build pretrained Inception model for FID computation
255
+
256
+ The Inception model for FID computation uses a different set of weights
257
+ and has a slightly different structure than torchvision's Inception.
258
+
259
+ This method first constructs torchvision's Inception and then patches the
260
+ necessary parts that are different in the FID Inception model.
261
+ """
262
+ inception = _inception_v3(num_classes=1008,
263
+ aux_logits=False,
264
+ weights=None)
265
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
266
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
267
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
268
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
269
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
270
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
271
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
272
+ inception.Mixed_7b = FIDInceptionE_1(1280)
273
+ inception.Mixed_7c = FIDInceptionE_2(2048)
274
+
275
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
276
+ inception.load_state_dict(state_dict)
277
+ return inception
278
+
279
+
280
+ class FIDInceptionA(torchvision.models.inception.InceptionA):
281
+ """InceptionA block patched for FID computation"""
282
+ def __init__(self, in_channels, pool_features):
283
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
284
+
285
+ def forward(self, x):
286
+ branch1x1 = self.branch1x1(x)
287
+
288
+ branch5x5 = self.branch5x5_1(x)
289
+ branch5x5 = self.branch5x5_2(branch5x5)
290
+
291
+ branch3x3dbl = self.branch3x3dbl_1(x)
292
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
293
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
294
+
295
+ # Patch: Tensorflow's average pool does not use the padded zero's in
296
+ # its average calculation
297
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
298
+ count_include_pad=False)
299
+ branch_pool = self.branch_pool(branch_pool)
300
+
301
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
302
+ return torch.cat(outputs, 1)
303
+
304
+
305
+ class FIDInceptionC(torchvision.models.inception.InceptionC):
306
+ """InceptionC block patched for FID computation"""
307
+ def __init__(self, in_channels, channels_7x7):
308
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
309
+
310
+ def forward(self, x):
311
+ branch1x1 = self.branch1x1(x)
312
+
313
+ branch7x7 = self.branch7x7_1(x)
314
+ branch7x7 = self.branch7x7_2(branch7x7)
315
+ branch7x7 = self.branch7x7_3(branch7x7)
316
+
317
+ branch7x7dbl = self.branch7x7dbl_1(x)
318
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
319
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
320
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
321
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
322
+
323
+ # Patch: Tensorflow's average pool does not use the padded zero's in
324
+ # its average calculation
325
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
326
+ count_include_pad=False)
327
+ branch_pool = self.branch_pool(branch_pool)
328
+
329
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
330
+ return torch.cat(outputs, 1)
331
+
332
+
333
+ class FIDInceptionE_1(torchvision.models.inception.InceptionE):
334
+ """First InceptionE block patched for FID computation"""
335
+ def __init__(self, in_channels):
336
+ super(FIDInceptionE_1, self).__init__(in_channels)
337
+
338
+ def forward(self, x):
339
+ branch1x1 = self.branch1x1(x)
340
+
341
+ branch3x3 = self.branch3x3_1(x)
342
+ branch3x3 = [
343
+ self.branch3x3_2a(branch3x3),
344
+ self.branch3x3_2b(branch3x3),
345
+ ]
346
+ branch3x3 = torch.cat(branch3x3, 1)
347
+
348
+ branch3x3dbl = self.branch3x3dbl_1(x)
349
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
350
+ branch3x3dbl = [
351
+ self.branch3x3dbl_3a(branch3x3dbl),
352
+ self.branch3x3dbl_3b(branch3x3dbl),
353
+ ]
354
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
355
+
356
+ # Patch: Tensorflow's average pool does not use the padded zero's in
357
+ # its average calculation
358
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
359
+ count_include_pad=False)
360
+ branch_pool = self.branch_pool(branch_pool)
361
+
362
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
363
+ return torch.cat(outputs, 1)
364
+
365
+
366
+ class FIDInceptionE_2(torchvision.models.inception.InceptionE):
367
+ """Second InceptionE block patched for FID computation"""
368
+ def __init__(self, in_channels):
369
+ super(FIDInceptionE_2, self).__init__(in_channels)
370
+
371
+ def forward(self, x):
372
+ branch1x1 = self.branch1x1(x)
373
+
374
+ branch3x3 = self.branch3x3_1(x)
375
+ branch3x3 = [
376
+ self.branch3x3_2a(branch3x3),
377
+ self.branch3x3_2b(branch3x3),
378
+ ]
379
+ branch3x3 = torch.cat(branch3x3, 1)
380
+
381
+ branch3x3dbl = self.branch3x3dbl_1(x)
382
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
383
+ branch3x3dbl = [
384
+ self.branch3x3dbl_3a(branch3x3dbl),
385
+ self.branch3x3dbl_3b(branch3x3dbl),
386
+ ]
387
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
388
+
389
+ # Patch: The FID Inception model uses max pooling instead of average
390
+ # pooling. This is likely an error in this specific Inception
391
+ # implementation, as other Inception models use average pooling here
392
+ # (which matches the description in the paper).
393
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
394
+ branch_pool = self.branch_pool(branch_pool)
395
+
396
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
397
+ return torch.cat(outputs, 1)
398
+
399
+ class ImagePathDataset(torch.utils.data.Dataset):
400
+ def __init__(self, files, transforms=None):
401
+ self.files = files
402
+ self.transforms = transforms
403
+
404
+ def __len__(self):
405
+ return len(self.files)
406
+
407
+ def __getitem__(self, i):
408
+ path = self.files[i]
409
+ img = Image.open(path).convert('RGB')
410
+ if self.transforms is not None:
411
+ img = self.transforms(img)
412
+ return img
413
+
414
+
415
+ def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
416
+ num_workers=1, resize=0):
417
+ """Calculates the activations of the pool_3 layer for all images.
418
+
419
+ Params:
420
+ -- files : List of image files paths
421
+ -- model : Instance of inception model
422
+ -- batch_size : Batch size of images for the model to process at once.
423
+ Make sure that the number of samples is a multiple of
424
+ the batch size, otherwise some samples are ignored. This
425
+ behavior is retained to match the original FID score
426
+ implementation.
427
+ -- dims : Dimensionality of features returned by Inception
428
+ -- device : Device to run calculations
429
+ -- num_workers : Number of parallel dataloader workers
430
+
431
+ Returns:
432
+ -- A numpy array of dimension (num images, dims) that contains the
433
+ activations of the given tensor when feeding inception with the
434
+ query tensor.
435
+ """
436
+ model.eval()
437
+
438
+ if batch_size > len(files):
439
+ print(('Warning: batch size is bigger than the data size. '
440
+ 'Setting batch size to data size'))
441
+ batch_size = len(files)
442
+ if resize > 0:
443
+ tform = TF.Compose([TF.Resize((resize, resize)), TF.ToTensor()])
444
+ else:
445
+ tform = TF.ToTensor()
446
+ if isinstance(files, list):
447
+ dataset = ImagePathDataset(files, transforms=tform)
448
+ else:
449
+ files.set_transform(tform)
450
+ dataset = files
451
+ dataloader = torch.utils.data.DataLoader(dataset,
452
+ batch_size=batch_size,
453
+ shuffle=False,
454
+ drop_last=False,
455
+ num_workers=num_workers)
456
+
457
+ pred_arr = np.empty((len(files), dims))
458
+
459
+ start_idx = 0
460
+
461
+ for batch in tqdm(dataloader):
462
+ batch = batch['image'].to(device)
463
+
464
+ with torch.no_grad():
465
+ pred = model(batch)[0]
466
+
467
+ # If model output is not scalar, apply global spatial average pooling.
468
+ # This happens if you choose a dimensionality not equal 2048.
469
+ if pred.size(2) != 1 or pred.size(3) != 1:
470
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
471
+
472
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
473
+
474
+ pred_arr[start_idx:start_idx + pred.shape[0]] = pred
475
+
476
+ start_idx = start_idx + pred.shape[0]
477
+
478
+ return pred_arr
479
+
480
+
481
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
482
+ """Numpy implementation of the Frechet Distance.
483
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
484
+ and X_2 ~ N(mu_2, C_2) is
485
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
486
+
487
+ Stable version by Dougal J. Sutherland.
488
+
489
+ Params:
490
+ -- mu1 : Numpy array containing the activations of a layer of the
491
+ inception net (like returned by the function 'get_predictions')
492
+ for generated samples.
493
+ -- mu2 : The sample mean over activations, precalculated on an
494
+ representative data set.
495
+ -- sigma1: The covariance matrix over activations for generated samples.
496
+ -- sigma2: The covariance matrix over activations, precalculated on an
497
+ representative data set.
498
+
499
+ Returns:
500
+ -- : The Frechet Distance.
501
+ """
502
+
503
+ mu1 = np.atleast_1d(mu1)
504
+ mu2 = np.atleast_1d(mu2)
505
+
506
+ sigma1 = np.atleast_2d(sigma1)
507
+ sigma2 = np.atleast_2d(sigma2)
508
+
509
+ assert mu1.shape == mu2.shape, \
510
+ 'Training and test mean vectors have different lengths'
511
+ assert sigma1.shape == sigma2.shape, \
512
+ 'Training and test covariances have different dimensions'
513
+
514
+ diff = mu1 - mu2
515
+
516
+ # Product might be almost singular
517
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
518
+ if not np.isfinite(covmean).all():
519
+ msg = ('fid calculation produces singular product; '
520
+ 'adding %s to diagonal of cov estimates') % eps
521
+ print(msg)
522
+ offset = np.eye(sigma1.shape[0]) * eps
523
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
524
+
525
+ # Numerical error might give slight imaginary component
526
+ if np.iscomplexobj(covmean):
527
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
528
+ m = np.max(np.abs(covmean.imag))
529
+ raise ValueError('Imaginary component {}'.format(m))
530
+ covmean = covmean.real
531
+
532
+ tr_covmean = np.trace(covmean)
533
+
534
+ return (diff.dot(diff) + np.trace(sigma1)
535
+ + np.trace(sigma2) - 2 * tr_covmean)
536
+
537
+
538
+ def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
539
+ device='cpu', num_workers=1, resize=0):
540
+ """Calculation of the statistics used by the FID.
541
+ Params:
542
+ -- files : List of image files paths
543
+ -- model : Instance of inception model
544
+ -- batch_size : The images numpy array is split into batches with
545
+ batch size batch_size. A reasonable batch size
546
+ depends on the hardware.
547
+ -- dims : Dimensionality of features returned by Inception
548
+ -- device : Device to run calculations
549
+ -- num_workers : Number of parallel dataloader workers
550
+
551
+ Returns:
552
+ -- mu : The mean over samples of the activations of the pool_3 layer of
553
+ the inception model.
554
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
555
+ the inception model.
556
+ """
557
+ act = get_activations(files, model, batch_size, dims, device, num_workers, resize)
558
+ mu = np.mean(act, axis=0)
559
+ sigma = np.cov(act, rowvar=False)
560
+ return mu, sigma
561
+
562
+
563
+ def compute_statistics_of_path(path, model, batch_size, dims, device,
564
+ num_workers=1, nimages=None, resize=0):
565
+ if path.endswith('.npz'):
566
+ with np.load(path) as f:
567
+ m, s = f['mu'][:], f['sigma'][:]
568
+ else:
569
+ path = pathlib.Path(path)
570
+ if (path/'data.mdb').exists():
571
+ files = ImageDataset(path, None)
572
+ nfiles = len(files)
573
+ n = nfiles if nimages is None else min(nimages, nfiles)
574
+ files.set_ids(range(n))
575
+ else:
576
+ files = sorted([file for ext in IMAGE_EXTENSIONS
577
+ for file in path.glob('**/*.{}'.format(ext))])
578
+ nfiles = len(files)
579
+ n = nfiles if nimages is None else min(nimages, nfiles)
580
+ files = files[:n]
581
+ print(f'Found {nfiles} images. Computing FID with {n} images.')
582
+ m, s = calculate_activation_statistics(files, model, batch_size,
583
+ dims, device, num_workers, resize)
584
+
585
+ return m, s
586
+
587
+
588
+ def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
589
+ """Calculates the FID of two paths"""
590
+ for p in paths:
591
+ if not os.path.exists(p):
592
+ raise RuntimeError('Invalid path: %s' % p)
593
+
594
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
595
+
596
+ model = InceptionV3([block_idx]).to(device)
597
+
598
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
599
+ dims, device, num_workers, nimages, resize)
600
+ m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
601
+ dims, device, num_workers, nimages, resize)
602
+ fid_value = calculate_frechet_distance(m1, s1, m2, s2)
603
+
604
+ return fid_value
605
+
606
+
607
+ def save_fid_stats(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
608
+ """Calculates the FID of two paths"""
609
+ if not os.path.exists(paths[0]):
610
+ raise RuntimeError('Invalid path: %s' % paths[0])
611
+
612
+ if os.path.exists(paths[1]):
613
+ raise RuntimeError('Existing output file: %s' % paths[1])
614
+
615
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
616
+
617
+ model = InceptionV3([block_idx]).to(device)
618
+
619
+ print(f"Saving statistics for {paths[0]}")
620
+
621
+ m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
622
+ dims, device, num_workers, nimages, resize=0)
623
+
624
+ np.savez_compressed(paths[1], mu=m1, sigma=s1)
625
+
626
+
627
+ def main():
628
+ parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
629
+ parser.add_argument('--batch-size', type=int, default=20,
630
+ help='Batch size to use')
631
+ parser.add_argument('--num-workers', type=int,
632
+ help=('Number of processes to use for data loading. '
633
+ 'Defaults to `min(8, num_cpus)`'))
634
+ parser.add_argument('--device', type=str, default='cuda:0',
635
+ help='Device to use. Like cuda, cuda:0 or cpu')
636
+ parser.add_argument('--dims', type=int, default=2048,
637
+ choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
638
+ help=('Dimensionality of Inception features to use. '
639
+ 'By default, uses pool3 features'))
640
+ parser.add_argument('--nimages', type=int, default=50000, help='max number of images to use')
641
+ parser.add_argument('--resize', type=int, default=0, help='resize images to this size, 0 mean keep original size')
642
+ parser.add_argument('--save-stats', action='store_true',
643
+ help=('Generate an npz archive from a directory of samples. '
644
+ 'The first path is used as input and the second as output.'))
645
+ parser.add_argument('path', type=str, nargs=2,
646
+ help=('Paths to the generated images or '
647
+ 'to .npz statistic files'))
648
+ args = parser.parse_args()
649
+
650
+ if args.device is None:
651
+ device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
652
+ else:
653
+ device = torch.device(args.device)
654
+
655
+ if args.num_workers is None:
656
+ try:
657
+ num_cpus = len(os.sched_getaffinity(0))
658
+ except AttributeError:
659
+ # os.sched_getaffinity is not available under Windows, use
660
+ # os.cpu_count instead (which may not return the *available* number
661
+ # of CPUs).
662
+ num_cpus = os.cpu_count()
663
+
664
+ num_workers = min(num_cpus, 8) if num_cpus is not None else 0
665
+ else:
666
+ num_workers = args.num_workers
667
+
668
+ if args.save_stats:
669
+ save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers, args.nimages, args.resize)
670
+ return
671
+
672
+ fid_value = calculate_fid_given_paths(args.path,
673
+ args.batch_size,
674
+ device,
675
+ args.dims,
676
+ num_workers,
677
+ args.nimages,
678
+ args.resize)
679
+ print('FID: ', fid_value)
680
+
681
+
682
+ if __name__ == '__main__':
683
+ main()
tools/gradcam.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ gradcam visualisation for each GAN class
5
+ @author: Tu Bui @surrey.ac.uk
6
+ """
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+ import os
11
+ import sys
12
+ import inspect
13
+ import argparse
14
+ import torch
15
+ import numpy as np
16
+ import matplotlib
17
+ matplotlib.use('Agg')
18
+ import matplotlib.pyplot as plt
19
+ import cv2
20
+ from PIL import Image, ImageDraw, ImageFont
21
+ import torch
22
+ import torchvision
23
+ from torch.autograd import Function
24
+ import torch.nn.functional as F
25
+
26
+
27
+ def show_cam_on_image(img, cam, cmap='jet'):
28
+ """
29
+ Args:
30
+ img PIL image (H,W,3)
31
+ cam heatmap (H, W), range [0,1]
32
+ Returns:
33
+ PIL image with heatmap applied.
34
+ """
35
+ cm = plt.get_cmap(cmap)
36
+ cam = cm(cam)[...,:3] # RGB [0,1]
37
+ cam = np.array(img, dtype=np.float32)/255. + cam
38
+ cam /= cam.max()
39
+ cam = np.uint8(cam*255)
40
+ return Image.fromarray(cam)
41
+
42
+
43
+ class HookedModel(object):
44
+ def __init__(self, model, feature_layer_name):
45
+ self.model = model
46
+ self.feature_trees = feature_layer_name.split('.')
47
+
48
+ def __call__(self, x):
49
+ x = feedforward(x, self.model, self.feature_trees)
50
+ return x
51
+
52
+
53
+ def feedforward(x, module, layer_names):
54
+ for name, submodule in module._modules.items():
55
+ # print(f'Forwarding {name} ...')
56
+ if name == layer_names[0]:
57
+ if len(layer_names) == 1: # leaf node reached
58
+ # print(f' Hook {name}')
59
+ x = submodule(x)
60
+ x.register_hook(save_gradients)
61
+ save_features(x)
62
+ else:
63
+ # print(f' Stepping into {name}:')
64
+ x = feedforward(x, submodule, layer_names[1:])
65
+ else:
66
+ x = submodule(x)
67
+ if name == 'avgpool': # specific for resnet50
68
+ x = x.view(x.size(0), -1)
69
+ return x
70
+
71
+
72
+ basket = dict(grads=[], feature_maps=[]) # global variable to hold the gradients and output features of the layers of interest
73
+
74
+ def empty_basket():
75
+ basket = dict(grads=[], feature_maps=[])
76
+
77
+ def save_gradients(grad):
78
+ basket['grads'].append(grad)
79
+
80
+ def save_features(feat):
81
+ basket['feature_maps'].append(feat)
82
+
83
+
84
+ class GradCam(object):
85
+ def __init__(self, model, feature_layer_name, use_cuda=True):
86
+ self.model = model
87
+ self.hooked_model = HookedModel(model, feature_layer_name)
88
+ self.cuda = use_cuda
89
+ if self.cuda:
90
+ self.model = model.cuda()
91
+ self.model.eval()
92
+
93
+ def __call__(self, x, target, act=None):
94
+ empty_basket()
95
+ target = torch.as_tensor(target, dtype=torch.float)
96
+ if self.cuda:
97
+ x = x.cuda()
98
+ target = target.cuda()
99
+ z = self.hooked_model(x)
100
+ if act is not None:
101
+ z = act(z)
102
+ criteria = F.cosine_similarity(z, target)
103
+ self.model.zero_grad()
104
+ criteria.backward(retain_graph=True)
105
+ gradients = [grad.cpu().data.numpy() for grad in basket['grads'][::-1]] # gradients appear in reversed order
106
+ feature_maps = [feat.cpu().data.numpy() for feat in basket['feature_maps']]
107
+ cams = []
108
+ for feat, grad in zip(feature_maps, gradients):
109
+ # feat and grad have shape (1, C, H, W)
110
+ weight = np.mean(grad, axis=(2,3), keepdims=True)[0] # (C,1,1)
111
+ cam = np.sum(weight * feat[0], axis=0) # (H,w)
112
+ cam = cv2.resize(cam, x.shape[2:])
113
+ cam = cam - np.min(cam)
114
+ cam = cam / (np.max(cam) + np.finfo(np.float32).eps)
115
+ cams.append(cam)
116
+ cams = np.array(cams).mean(axis=0) # (H,W)
117
+ return cams
118
+
119
+
120
+ def gradcam_demo():
121
+ from torchvision import transforms
122
+ model = torchvision.models.resnet50(pretrained=True)
123
+ model.eval()
124
+ gradcam = GradCam(model, 'layer4.2', True)
125
+ tform = [
126
+ transforms.Resize((224, 224)),
127
+ # transforms.CenterCrop(224),
128
+ transforms.ToTensor(),
129
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
130
+ ]
131
+ preprocess = transforms.Compose(tform)
132
+ im0 = Image.open('/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet/examples/catdog.jpg').convert('RGB')
133
+ im = preprocess(im0).unsqueeze(0)
134
+ target = np.zeros((1,1000), dtype=np.float32)
135
+ target[0, 285] = 1 # cat
136
+ cam = gradcam(im, target)
137
+
138
+ im0 = tform[0](im0)
139
+ out = show_cam_on_image(im0, cam)
140
+ out.save('test.jpg')
141
+ print('done')
142
+
143
+
144
+ def make_target_vector(nclass, target_class_id):
145
+ out = np.zeros((1, nclass), dtype=np.float32)
146
+ out[0, target_class_id] = 1
147
+ return out
148
+
149
+
150
+
151
+ if __name__ == '__main__':
152
+ gradcam_demo()
tools/helpers.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 12 11:05:57 2016
4
+ some help functions to perform basic tasks
5
+ @author: tb00083
6
+ """
7
+ import os
8
+ import sys
9
+ import csv
10
+ import socket
11
+ import numpy as np
12
+ import json
13
+ import pickle # python3.x
14
+ import time
15
+ from datetime import timedelta, datetime
16
+ from typing import Any, List, Tuple, Union
17
+ import subprocess
18
+ import struct
19
+ import errno
20
+ from pprint import pprint
21
+ import glob
22
+ from threading import Thread
23
+
24
+
25
+ def welcome_message():
26
+ """
27
+ get welcome message including hostname and command line arguments
28
+ """
29
+ hostname = socket.gethostname()
30
+ all_args = ' '.join(sys.argv)
31
+ out_text = 'On server {}: {}\n'.format(hostname, all_args)
32
+ return out_text
33
+
34
+
35
+ class EasyDict(dict):
36
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
37
+ def __init__(self, dict_to_convert=None):
38
+ if dict_to_convert is not None:
39
+ for key, val in dict_to_convert.items():
40
+ self[key] = val
41
+
42
+ def __getattr__(self, name: str) -> Any:
43
+ try:
44
+ return self[name]
45
+ except KeyError:
46
+ raise AttributeError(name)
47
+
48
+ def __setattr__(self, name: str, value: Any) -> None:
49
+ self[name] = value
50
+
51
+ def __delattr__(self, name: str) -> None:
52
+ del self[name]
53
+
54
+
55
+ def get_time_id_str():
56
+ """
57
+ returns a string with DDHHM format, where M is the minutes cut to the tenths
58
+ """
59
+ now = datetime.now()
60
+ time_str = "{:02d}{:02d}{:02d}".format(now.day, now.hour, now.minute)
61
+ time_str = time_str[:-1]
62
+ return time_str
63
+
64
+
65
+ def time_format(t):
66
+ m, s = divmod(t, 60)
67
+ h, m = divmod(m, 60)
68
+ m, h, s = int(m), int(h), int(s)
69
+
70
+ if m == 0 and h == 0:
71
+ return "{}s".format(s)
72
+ elif h == 0:
73
+ return "{}m{}s".format(m, s)
74
+ else:
75
+ return "{}h{}m{}s".format(h, m, s)
76
+
77
+
78
+ def get_all_files(dir_path, trim=0, extension=''):
79
+ """
80
+ Recursively get list of all files in the given directory
81
+ trim = 1 : trim the dir_path from results, 0 otherwise
82
+ extension: get files with specific format
83
+ """
84
+ file_paths = [] # List which will store all of the full filepaths.
85
+
86
+ # Walk the tree.
87
+ for root, directories, files in os.walk(dir_path):
88
+ for filename in files:
89
+ # Join the two strings in order to form the full filepath.
90
+ filepath = os.path.join(root, filename)
91
+ file_paths.append(filepath) # Add it to the list.
92
+
93
+ if trim == 1: # trim dir_path from results
94
+ if dir_path[-1] != os.sep:
95
+ dir_path += os.sep
96
+ trim_len = len(dir_path)
97
+ file_paths = [x[trim_len:] for x in file_paths]
98
+
99
+ if extension: # select only file with specific extension
100
+ extension = extension.lower()
101
+ tlen = len(extension)
102
+ file_paths = [x for x in file_paths if x[-tlen:] == extension]
103
+
104
+ return file_paths # Self-explanatory.
105
+
106
+
107
+ def get_all_dirs(dir_path, trim=0):
108
+ """
109
+ Recursively get list of all directories in the given directory
110
+ excluding the '.' and '..' directories
111
+ trim = 1 : trim the dir_path from results, 0 otherwise
112
+ """
113
+ out = []
114
+ # Walk the tree.
115
+ for root, directories, files in os.walk(dir_path):
116
+ for dirname in directories:
117
+ # Join the two strings in order to form the full filepath.
118
+ dir_full = os.path.join(root, dirname)
119
+ out.append(dir_full) # Add it to the list.
120
+
121
+ if trim == 1: # trim dir_path from results
122
+ if dir_path[-1] != os.sep:
123
+ dir_path += os.sep
124
+ trim_len = len(dir_path)
125
+ out = [x[trim_len:] for x in out]
126
+
127
+ return out
128
+
129
+
130
+ def read_list(file_path, delimeter=' ', keep_original=True):
131
+ """
132
+ read list column wise
133
+ deprecated, should use pandas instead
134
+ """
135
+ out = []
136
+ with open(file_path, 'r') as f:
137
+ reader = csv.reader(f, delimiter=delimeter)
138
+ for row in reader:
139
+ out.append(row)
140
+ out = zip(*out)
141
+
142
+ if not keep_original:
143
+ for col in range(len(out)):
144
+ if out[col][0].isdigit(): # attempt to convert to numerical array
145
+ out[col] = np.array(out[col]).astype(np.int64)
146
+
147
+ return out
148
+
149
+
150
+ def save_pickle2(file_path, **kwargs):
151
+ """
152
+ save variables to file (using pickle)
153
+ """
154
+ # check if any variable is a dict
155
+ var_count = 0
156
+ for key in kwargs:
157
+ var_count += 1
158
+ if isinstance(kwargs[key], dict):
159
+ sys.stderr.write('Opps! Cannot write a dictionary into pickle')
160
+ sys.exit(1)
161
+ with open(file_path, 'wb') as f:
162
+ pickler = pickle.Pickler(f, -1)
163
+ pickler.dump(var_count)
164
+ for key in kwargs:
165
+ pickler.dump(key)
166
+ pickler.dump(kwargs[key])
167
+
168
+
169
+ def load_pickle2(file_path, varnum=0):
170
+ """
171
+ load variables that previously saved using self.save()
172
+ varnum : number of variables u want to load (0 mean it will load all)
173
+ Note: if you are loading class instance(s), you must have it defined in advance
174
+ """
175
+ with open(file_path, 'rb') as f:
176
+ pickler = pickle.Unpickler(f)
177
+ var_count = pickler.load()
178
+ if varnum:
179
+ var_count = min([var_count, varnum])
180
+ out = {}
181
+ for i in range(var_count):
182
+ key = pickler.load()
183
+ out[key] = pickler.load()
184
+
185
+ return out
186
+
187
+
188
+ def save_pickle(path, obj):
189
+ """
190
+ simple method to save a picklable object
191
+ :param path: path to save
192
+ :param obj: a picklable object
193
+ :return: None
194
+ """
195
+ with open(path, 'wb') as f:
196
+ pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
197
+
198
+
199
+ def load_pickle(path):
200
+ """
201
+ load a pickled object
202
+ :param path: .pkl path
203
+ :return: the pickled object
204
+ """
205
+ with open(path, 'rb') as f:
206
+ return pickle.load(f)
207
+
208
+
209
+ def make_new_dir(dir_path, remove_existing=False, mode=511):
210
+ """note: default mode in ubuntu is 511"""
211
+ if not os.path.exists(dir_path):
212
+ try:
213
+ if mode == 777:
214
+ oldmask = os.umask(000)
215
+ os.makedirs(dir_path, 0o777)
216
+ os.umask(oldmask)
217
+ else:
218
+ os.makedirs(dir_path, mode)
219
+ except OSError as exc: # Python >2.5
220
+ if exc.errno == errno.EEXIST and os.path.isdir(dir_path):
221
+ pass
222
+ else:
223
+ raise
224
+ if remove_existing:
225
+ for file_obj in os.listdir(dir_path):
226
+ file_path = os.path.join(dir_path, file_obj)
227
+ if os.path.isfile(file_path):
228
+ os.unlink(file_path)
229
+
230
+
231
+ def get_latest_file(root, pattern):
232
+ """
233
+ get the latest file in a directory that match the provided pattern
234
+ useful for getting the last checkpoint
235
+ :param root: search directory
236
+ :param pattern: search pattern containing 1 wild card representing a number e.g. 'ckpt_*.tar'
237
+ :return: full path of the file with largest number in wild card, None if not found
238
+ """
239
+ out = None
240
+ parts = pattern.split('*')
241
+ max_id = - np.inf
242
+ for path in glob.glob(os.path.join(root, pattern)):
243
+ id_ = os.path.basename(path)
244
+ for part in parts:
245
+ id_ = id_.replace(part, '')
246
+ try:
247
+ id_ = int(id_)
248
+ if id_ > max_id:
249
+ max_id = id_
250
+ out = path
251
+ except:
252
+ continue
253
+ return out
254
+
255
+
256
+ class Locker(object):
257
+ """place a lock file in specified location
258
+ useful for distributed computing"""
259
+
260
+ def __init__(self, name='lock.txt', mode=511):
261
+ """INPUT: name default file name to be created as a lock
262
+ mode if a directory has to be created, set its permission to mode"""
263
+ self.name = name
264
+ self.mode = mode
265
+
266
+ def lock(self, path):
267
+ make_new_dir(path, False, self.mode)
268
+ with open(os.path.join(path, self.name), 'w') as f:
269
+ f.write('progress')
270
+
271
+ def finish(self, path):
272
+ make_new_dir(path, False, self.mode)
273
+ with open(os.path.join(path, self.name), 'w') as f:
274
+ f.write('finish')
275
+
276
+ def customise(self, path, text):
277
+ make_new_dir(path, False, self.mode)
278
+ with open(os.path.join(path, self.name), 'w') as f:
279
+ f.write(text)
280
+
281
+ def is_locked(self, path):
282
+ out = False
283
+ check_path = os.path.join(path, self.name)
284
+ if os.path.exists(check_path):
285
+ text = open(check_path, 'r').readline().strip()
286
+ out = True if text == 'progress' else False
287
+ return out
288
+
289
+ def is_finished(self, path):
290
+ out = False
291
+ check_path = os.path.join(path, self.name)
292
+ if os.path.exists(check_path):
293
+ text = open(check_path, 'r').readline().strip()
294
+ out = True if text == 'finish' else False
295
+ return out
296
+
297
+ def is_locked_or_finished(self, path):
298
+ return self.is_locked(path) | self.is_finished(path)
299
+
300
+ def clean(self, path):
301
+ check_path = os.path.join(path, self.name)
302
+ if os.path.exists(check_path):
303
+ try:
304
+ os.remove(check_path)
305
+ except Exception as e:
306
+ print('Unable to remove %s: %s.' % (check_path, e))
307
+
308
+
309
+ class ProgressBar(object):
310
+ """show progress"""
311
+
312
+ def __init__(self, total, increment=5):
313
+ self.total = total
314
+ self.point = self.total / 100.0
315
+ self.increment = increment
316
+ self.interval = int(self.total * self.increment / 100)
317
+ self.milestones = list(range(0, total, self.interval)) + [self.total, ]
318
+ self.id = 0
319
+
320
+ def show_progress(self, i):
321
+ if i >= self.milestones[self.id]:
322
+ while i >= self.milestones[self.id]:
323
+ self.id += 1
324
+ sys.stdout.write("\r[" + "=" * int(i / self.interval) +
325
+ " " * int((self.total - i) / self.interval) + "]" + str(int((i + 1) / self.point)) + "%")
326
+ sys.stdout.flush()
327
+
328
+
329
+ class Timer(object):
330
+
331
+ def __init__(self):
332
+ self.start_t = time.time()
333
+ self.last_t = self.start_t
334
+
335
+ def time(self, lap=False):
336
+ end_t = time.time()
337
+ if lap:
338
+ out = timedelta(seconds=int(end_t - self.last_t)) # count from last stop point
339
+ else:
340
+ out = timedelta(seconds=int(end_t - self.start_t)) # count from beginning
341
+ self.last_t = end_t
342
+ return out
343
+
344
+
345
+ class ExThread(Thread):
346
+ def run(self):
347
+ self.exc = None
348
+ try:
349
+ if hasattr(self, '_Thread__target'):
350
+ # Thread uses name mangling prior to Python 3.
351
+ self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
352
+ else:
353
+ self.ret = self._target(*self._args, **self._kwargs)
354
+ except BaseException as e:
355
+ self.exc = e
356
+
357
+ def join(self):
358
+ super(ExThread, self).join()
359
+ if self.exc:
360
+ raise RuntimeError('Exception in thread.') from self.exc
361
+ return self.ret
362
+
363
+
364
+ def get_gpu_free_mem():
365
+ """return a list of free GPU memory"""
366
+ sp = subprocess.Popen(['nvidia-smi', '-q'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
367
+ out_str = sp.communicate()
368
+ out_list = out_str[0].decode("utf-8") .split('\n')
369
+
370
+ out = []
371
+ for i in range(len(out_list)):
372
+ item = out_list[i]
373
+ if item.strip() == 'FB Memory Usage':
374
+ free_mem = int(out_list[i + 3].split(':')[1].strip().split(' ')[0])
375
+ out.append(free_mem)
376
+ return out
377
+
378
+
379
+ def float2hex(x):
380
+ """
381
+ x: a vector
382
+ return: x in hex
383
+ """
384
+ f = np.float32(x)
385
+ out = ''
386
+ if f.size == 1: # just a single number
387
+ f = [f, ]
388
+ for e in f:
389
+ h = hex(struct.unpack('<I', struct.pack('<f', e))[0])
390
+ out += h[2:].zfill(8)
391
+ return out
392
+
393
+
394
+ def hex2float(x):
395
+ """
396
+ x: a string with len divided by 8
397
+ return x as array of float32
398
+ """
399
+ assert len(x) % 8 == 0, 'Error! string len = {} not divided by 8'.format(len(x))
400
+ l = len(x) / 8
401
+ out = np.empty(l, dtype=np.float32)
402
+ x = [x[i:i + 8] for i in range(0, len(x), 8)]
403
+ for i, e in enumerate(x):
404
+ out[i] = struct.unpack('!f', e.decode('hex'))[0]
405
+ return out
406
+
407
+
408
+ def nice_print(inputs, stream=sys.stdout):
409
+ """print a list of string to file stream"""
410
+ if type(inputs) is not list:
411
+ tstrings = inputs.split('\n')
412
+ pprint(tstrings, stream=stream)
413
+ else:
414
+ for string in inputs:
415
+ nice_print(string, stream=stream)
416
+ stream.flush()
tools/hparams.py ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 The Tensor2Tensor Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # source: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/hparam.py
16
+ # Forked with minor changes from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
17
+ """Hyperparameter values."""
18
+ from __future__ import absolute_import
19
+ from __future__ import division
20
+ from __future__ import print_function
21
+
22
+ import json
23
+ import numbers
24
+ import re
25
+ import six
26
+ import numpy as np
27
+
28
+ # Define the regular expression for parsing a single clause of the input
29
+ # (delimited by commas). A legal clause looks like:
30
+ # <variable name>[<index>]? = <rhs>
31
+ # where <rhs> is either a single token or [] enclosed list of tokens.
32
+ # For example: "var[1] = a" or "x = [1,2,3]"
33
+ PARAM_RE = re.compile(r"""
34
+ (?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
35
+ (\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
36
+ \s*=\s*
37
+ ((?P<val>[^,\[]*) # single value: "a" or None
38
+ |
39
+ \[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
40
+ ($|,\s*)""", re.VERBOSE)
41
+
42
+
43
+ def copy_hparams(hparams):
44
+ """Return a copy of an HParams instance."""
45
+ return HParams(**hparams.values())
46
+
47
+
48
+ def print_config(hps):
49
+ for key, val in six.iteritems(hps.values()):
50
+ print('%s = %s' % (key, str(val)))
51
+
52
+
53
+ def save_config(output_file, hps, verbose=True):
54
+ def convert(o): # json cannot serialize integer in np.int64 format
55
+ if isinstance(o, np.int64):
56
+ return int(o)
57
+ raise TypeError
58
+ if verbose:
59
+ print_config(hps)
60
+ with open(output_file, 'w') as f:
61
+ json.dump(hps.values(), f, indent=True, default=convert)
62
+
63
+
64
+ def load_config(hps, config_file, verbose=True):
65
+ """
66
+ parse hparams from config file
67
+ :param hps: hparams object whose values to be updated
68
+ :param config_file: json config file
69
+ :param verbose: print out values
70
+ """
71
+ try:
72
+ with open(config_file, 'r') as fin:
73
+ hps.parse_json(fin.read())
74
+ if verbose:
75
+ print_config(hps)
76
+ except Exception as e:
77
+ print('Error reading config file %s: %s.\nConfig will not be updated.' % (config_file, e))
78
+ # return hps
79
+
80
+
81
+ def _parse_fail(name, var_type, value, values):
82
+ """Helper function for raising a value error for bad assignment."""
83
+ raise ValueError(
84
+ 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
85
+ (name, var_type.__name__, value, values))
86
+
87
+
88
+ def _reuse_fail(name, values):
89
+ """Helper function for raising a value error for reuse of name."""
90
+ raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
91
+ values))
92
+
93
+
94
+ def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
95
+ results_dictionary):
96
+ """Update results_dictionary with a scalar value.
97
+
98
+ Used to update the results_dictionary to be returned by parse_values when
99
+ encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
100
+
101
+ Mutates results_dictionary.
102
+
103
+ Args:
104
+ name: Name of variable in assignment ("s" or "arr").
105
+ parse_fn: Function for parsing the actual value.
106
+ var_type: Type of named variable.
107
+ m_dict: Dictionary constructed from regex parsing.
108
+ m_dict['val']: RHS value (scalar)
109
+ m_dict['index']: List index value (or None)
110
+ values: Full expression being parsed
111
+ results_dictionary: The dictionary being updated for return by the parsing
112
+ function.
113
+
114
+ Raises:
115
+ ValueError: If the name has already been used.
116
+ """
117
+ try:
118
+ parsed_value = parse_fn(m_dict['val'])
119
+ except ValueError:
120
+ _parse_fail(name, var_type, m_dict['val'], values)
121
+
122
+ # If no index is provided
123
+ if not m_dict['index']:
124
+ if name in results_dictionary:
125
+ _reuse_fail(name, values)
126
+ results_dictionary[name] = parsed_value
127
+ else:
128
+ if name in results_dictionary:
129
+ # The name has already been used as a scalar, then it
130
+ # will be in this dictionary and map to a non-dictionary.
131
+ if not isinstance(results_dictionary.get(name), dict):
132
+ _reuse_fail(name, values)
133
+ else:
134
+ results_dictionary[name] = {}
135
+
136
+ index = int(m_dict['index'])
137
+ # Make sure the index position hasn't already been assigned a value.
138
+ if index in results_dictionary[name]:
139
+ _reuse_fail('{}[{}]'.format(name, index), values)
140
+ results_dictionary[name][index] = parsed_value
141
+
142
+
143
+ def _process_list_value(name, parse_fn, var_type, m_dict, values,
144
+ results_dictionary):
145
+ """Update results_dictionary from a list of values.
146
+
147
+ Used to update results_dictionary to be returned by parse_values when
148
+ encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
149
+
150
+ Mutates results_dictionary.
151
+
152
+ Args:
153
+ name: Name of variable in assignment ("arr").
154
+ parse_fn: Function for parsing individual values.
155
+ var_type: Type of named variable.
156
+ m_dict: Dictionary constructed from regex parsing.
157
+ m_dict['val']: RHS value (scalar)
158
+ values: Full expression being parsed
159
+ results_dictionary: The dictionary being updated for return by the parsing
160
+ function.
161
+
162
+ Raises:
163
+ ValueError: If the name has an index or the values cannot be parsed.
164
+ """
165
+ if m_dict['index'] is not None:
166
+ raise ValueError('Assignment of a list to a list index.')
167
+ elements = filter(None, re.split('[ ,]', m_dict['vals']))
168
+ # Make sure the name hasn't already been assigned a value
169
+ if name in results_dictionary:
170
+ raise _reuse_fail(name, values)
171
+ try:
172
+ results_dictionary[name] = [parse_fn(e) for e in elements]
173
+ except ValueError:
174
+ _parse_fail(name, var_type, m_dict['vals'], values)
175
+
176
+
177
+ def _cast_to_type_if_compatible(name, param_type, value):
178
+ """Cast hparam to the provided type, if compatible.
179
+
180
+ Args:
181
+ name: Name of the hparam to be cast.
182
+ param_type: The type of the hparam.
183
+ value: The value to be cast, if compatible.
184
+
185
+ Returns:
186
+ The result of casting `value` to `param_type`.
187
+
188
+ Raises:
189
+ ValueError: If the type of `value` is not compatible with param_type.
190
+ * If `param_type` is a string type, but `value` is not.
191
+ * If `param_type` is a boolean, but `value` is not, or vice versa.
192
+ * If `param_type` is an integer type, but `value` is not.
193
+ * If `param_type` is a float type, but `value` is not a numeric type.
194
+ """
195
+ fail_msg = (
196
+ "Could not cast hparam '%s' of type '%s' from value %r" %
197
+ (name, param_type, value))
198
+
199
+ # Some callers use None, for which we can't do any casting/checking. :(
200
+ if issubclass(param_type, type(None)):
201
+ return value
202
+
203
+ # Avoid converting a non-string type to a string.
204
+ if (issubclass(param_type, (six.string_types, six.binary_type)) and
205
+ not isinstance(value, (six.string_types, six.binary_type))):
206
+ raise ValueError(fail_msg)
207
+
208
+ # Avoid converting a number or string type to a boolean or vice versa.
209
+ if issubclass(param_type, bool) != isinstance(value, bool):
210
+ raise ValueError(fail_msg)
211
+
212
+ # Avoid converting float to an integer (the reverse is fine).
213
+ if (issubclass(param_type, numbers.Integral) and
214
+ not isinstance(value, numbers.Integral)):
215
+ raise ValueError(fail_msg)
216
+
217
+ # Avoid converting a non-numeric type to a numeric type.
218
+ if (issubclass(param_type, numbers.Number) and
219
+ not isinstance(value, numbers.Number)):
220
+ raise ValueError(fail_msg)
221
+
222
+ return param_type(value)
223
+
224
+
225
+ def parse_values(values, type_map, ignore_unknown=False):
226
+ """Parses hyperparameter values from a string into a python map.
227
+
228
+ `values` is a string containing comma-separated `name=value` pairs.
229
+ For each pair, the value of the hyperparameter named `name` is set to
230
+ `value`.
231
+
232
+ If a hyperparameter name appears multiple times in `values`, a ValueError
233
+ is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
234
+
235
+ If a hyperparameter name in both an index assignment and scalar assignment,
236
+ a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
237
+
238
+ The hyperparameter name may contain '.' symbols, which will result in an
239
+ attribute name that is only accessible through the getattr and setattr
240
+ functions. (And must be first explicit added through add_hparam.)
241
+
242
+ WARNING: Use of '.' in your variable names is allowed, but is not well
243
+ supported and not recommended.
244
+
245
+ The `value` in `name=value` must follows the syntax according to the
246
+ type of the parameter:
247
+
248
+ * Scalar integer: A Python-parsable integer point value. E.g.: 1,
249
+ 100, -12.
250
+ * Scalar float: A Python-parsable floating point value. E.g.: 1.0,
251
+ -.54e89.
252
+ * Boolean: Either true or false.
253
+ * Scalar string: A non-empty sequence of characters, excluding comma,
254
+ spaces, and square brackets. E.g.: foo, bar_1.
255
+ * List: A comma separated list of scalar values of the parameter type
256
+ enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
257
+
258
+ When index assignment is used, the corresponding type_map key should be the
259
+ list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
260
+ "arr[1]").
261
+
262
+ Args:
263
+ values: String. Comma separated list of `name=value` pairs where
264
+ 'value' must follow the syntax described above.
265
+ type_map: A dictionary mapping hyperparameter names to types. Note every
266
+ parameter name in values must be a key in type_map. The values must
267
+ conform to the types indicated, where a value V is said to conform to a
268
+ type T if either V has type T, or V is a list of elements of type T.
269
+ Hence, for a multidimensional parameter 'x' taking float values,
270
+ 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
271
+ ignore_unknown: Bool. Whether values that are missing a type in type_map
272
+ should be ignored. If set to True, a ValueError will not be raised for
273
+ unknown hyperparameter type.
274
+
275
+ Returns:
276
+ A python map mapping each name to either:
277
+ * A scalar value.
278
+ * A list of scalar values.
279
+ * A dictionary mapping index numbers to scalar values.
280
+ (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
281
+
282
+ Raises:
283
+ ValueError: If there is a problem with input.
284
+ * If `values` cannot be parsed.
285
+ * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
286
+ * If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
287
+ 'a[1]=1,a[1]=2', or 'a=1,a=[1]')
288
+ """
289
+ results_dictionary = {}
290
+ pos = 0
291
+ while pos < len(values):
292
+ m = PARAM_RE.match(values, pos)
293
+ if not m:
294
+ raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
295
+ # Check that there is a comma between parameters and move past it.
296
+ pos = m.end()
297
+ # Parse the values.
298
+ m_dict = m.groupdict()
299
+ name = m_dict['name']
300
+ if name not in type_map:
301
+ if ignore_unknown:
302
+ continue
303
+ raise ValueError('Unknown hyperparameter type for %s' % name)
304
+ type_ = type_map[name]
305
+
306
+ # Set up correct parsing function (depending on whether type_ is a bool)
307
+ if type_ == bool:
308
+ def parse_bool(value):
309
+ if value in ['true', 'True']:
310
+ return True
311
+ elif value in ['false', 'False']:
312
+ return False
313
+ else:
314
+ try:
315
+ return bool(int(value))
316
+ except ValueError:
317
+ _parse_fail(name, type_, value, values)
318
+
319
+ parse = parse_bool
320
+ else:
321
+ parse = type_
322
+
323
+ # If a singe value is provided
324
+ if m_dict['val'] is not None:
325
+ _process_scalar_value(name, parse, type_, m_dict, values,
326
+ results_dictionary)
327
+
328
+ # If the assigned value is a list:
329
+ elif m_dict['vals'] is not None:
330
+ _process_list_value(name, parse, type_, m_dict, values,
331
+ results_dictionary)
332
+
333
+ else: # Not assigned a list or value
334
+ _parse_fail(name, type_, '', values)
335
+
336
+ return results_dictionary
337
+
338
+
339
+ class HParams(object):
340
+ """Class to hold a set of hyperparameters as name-value pairs.
341
+
342
+ A `HParams` object holds hyperparameters used to build and train a model,
343
+ such as the number of hidden units in a neural net layer or the learning rate
344
+ to use when training.
345
+
346
+ You first create a `HParams` object by specifying the names and values of the
347
+ hyperparameters.
348
+
349
+ To make them easily accessible the parameter names are added as direct
350
+ attributes of the class. A typical usage is as follows:
351
+
352
+ ```python
353
+ # Create a HParams object specifying names and values of the model
354
+ # hyperparameters:
355
+ hparams = HParams(learning_rate=0.1, num_hidden_units=100)
356
+
357
+ # The hyperparameter are available as attributes of the HParams object:
358
+ hparams.learning_rate ==> 0.1
359
+ hparams.num_hidden_units ==> 100
360
+ ```
361
+
362
+ Hyperparameters have type, which is inferred from the type of their value
363
+ passed at construction type. The currently supported types are: integer,
364
+ float, boolean, string, and list of integer, float, boolean, or string.
365
+
366
+ You can override hyperparameter values by calling the
367
+ [`parse()`](#HParams.parse) method, passing a string of comma separated
368
+ `name=value` pairs. This is intended to make it possible to override
369
+ any hyperparameter values from a single command-line flag to which
370
+ the user passes 'hyper-param=value' pairs. It avoids having to define
371
+ one flag for each hyperparameter.
372
+
373
+ The syntax expected for each value depends on the type of the parameter.
374
+ See `parse()` for a description of the syntax.
375
+
376
+ Example:
377
+
378
+ ```python
379
+ # Define a command line flag to pass name=value pairs.
380
+ # For example using argparse:
381
+ import argparse
382
+ parser = argparse.ArgumentParser(description='Train my model.')
383
+ parser.add_argument('--hparams', type=str,
384
+ help='Comma separated list of "name=value" pairs.')
385
+ args = parser.parse_args()
386
+ ...
387
+ def my_program():
388
+ # Create a HParams object specifying the names and values of the
389
+ # model hyperparameters:
390
+ hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
391
+ activations=['relu', 'tanh'])
392
+
393
+ # Override hyperparameters values by parsing the command line
394
+ hparams.parse(args.hparams)
395
+
396
+ # If the user passed `--hparams=learning_rate=0.3` on the command line
397
+ # then 'hparams' has the following attributes:
398
+ hparams.learning_rate ==> 0.3
399
+ hparams.num_hidden_units ==> 100
400
+ hparams.activations ==> ['relu', 'tanh']
401
+
402
+ # If the hyperparameters are in json format use parse_json:
403
+ hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
404
+ ```
405
+ """
406
+
407
+ _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
408
+
409
+ def __init__(self, model_structure=None, **kwargs):
410
+ """Create an instance of `HParams` from keyword arguments.
411
+
412
+ The keyword arguments specify name-values pairs for the hyperparameters.
413
+ The parameter types are inferred from the type of the values passed.
414
+
415
+ The parameter names are added as attributes of `HParams` object, so they
416
+ can be accessed directly with the dot notation `hparams._name_`.
417
+
418
+ Example:
419
+
420
+ ```python
421
+ # Define 3 hyperparameters: 'learning_rate' is a float parameter,
422
+ # 'num_hidden_units' an integer parameter, and 'activation' a string
423
+ # parameter.
424
+ hparams = tf.HParams(
425
+ learning_rate=0.1, num_hidden_units=100, activation='relu')
426
+
427
+ hparams.activation ==> 'relu'
428
+ ```
429
+
430
+ Note that a few names are reserved and cannot be used as hyperparameter
431
+ names. If you use one of the reserved name the constructor raises a
432
+ `ValueError`.
433
+
434
+ Args:
435
+ model_structure: An instance of ModelStructure, defining the feature
436
+ crosses to be used in the Trial.
437
+ **kwargs: Key-value pairs where the key is the hyperparameter name and
438
+ the value is the value for the parameter.
439
+
440
+ Raises:
441
+ ValueError: If both `hparam_def` and initialization values are provided,
442
+ or if one of the arguments is invalid.
443
+
444
+ """
445
+ # Register the hyperparameters and their type in _hparam_types.
446
+ # This simplifies the implementation of parse().
447
+ # _hparam_types maps the parameter name to a tuple (type, bool).
448
+ # The type value is the type of the parameter for scalar hyperparameters,
449
+ # or the type of the list elements for multidimensional hyperparameters.
450
+ # The bool value is True if the value is a list, False otherwise.
451
+ self._hparam_types = {}
452
+ self._model_structure = model_structure
453
+ for name, value in six.iteritems(kwargs):
454
+ self.add_hparam(name, value)
455
+
456
+ def __add__(self, other):
457
+ """
458
+ addition operation keeping key order
459
+ """
460
+ out = HParams()
461
+ for key in self._hparam_types.keys():
462
+ out.add_hparam(key, getattr(self, key))
463
+ for key in other._hparam_types.keys():
464
+ if getattr(out, key, None) is None: # add new param
465
+ out.add_hparam(key, getattr(other, key))
466
+ else: # update existing param
467
+ out.set_hparam(key, getattr(other, key))
468
+ return out
469
+
470
+ def __str__(self):
471
+ s = 'HParams(\n'
472
+ for key, val in six.iteritems(self.values()):
473
+ s += f'\t{key} = {val}\n'
474
+ # print('%s = %s' % (key, str(val)))
475
+ s += ')'
476
+ return s
477
+
478
+ def __repr__(self):
479
+ return self.__str__()
480
+
481
+ def add_hparam(self, name, value):
482
+ """Adds {name, value} pair to hyperparameters.
483
+
484
+ Args:
485
+ name: Name of the hyperparameter.
486
+ value: Value of the hyperparameter. Can be one of the following types:
487
+ int, float, string, int list, float list, or string list.
488
+
489
+ Raises:
490
+ ValueError: if one of the arguments is invalid.
491
+ """
492
+ # Keys in kwargs are unique, but 'name' could the name of a pre-existing
493
+ # attribute of this object. In that case we refuse to use it as a
494
+ # hyperparameter name.
495
+ if getattr(self, name, None) is not None:
496
+ raise ValueError('Hyperparameter name is reserved: %s' % name)
497
+ if isinstance(value, (list, tuple)):
498
+ if not value:
499
+ raise ValueError(
500
+ 'Multi-valued hyperparameters cannot be empty: %s' % name)
501
+ self._hparam_types[name] = (type(value[0]), True)
502
+ else:
503
+ self._hparam_types[name] = (type(value), False)
504
+ setattr(self, name, value)
505
+
506
+ def set_hparam(self, name, value):
507
+ """Set the value of an existing hyperparameter.
508
+
509
+ This function verifies that the type of the value matches the type of the
510
+ existing hyperparameter.
511
+
512
+ Args:
513
+ name: Name of the hyperparameter.
514
+ value: New value of the hyperparameter.
515
+
516
+ Raises:
517
+ KeyError: If the hyperparameter doesn't exist.
518
+ ValueError: If there is a type mismatch.
519
+ """
520
+ param_type, is_list = self._hparam_types[name]
521
+ if isinstance(value, list):
522
+ if not is_list:
523
+ raise ValueError(
524
+ 'Must not pass a list for single-valued parameter: %s' % name)
525
+ setattr(self, name, [
526
+ _cast_to_type_if_compatible(name, param_type, v) for v in value])
527
+ else:
528
+ if is_list:
529
+ raise ValueError(
530
+ 'Must pass a list for multi-valued parameter: %s.' % name)
531
+ setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
532
+
533
+ def del_hparam(self, name):
534
+ """Removes the hyperparameter with key 'name'.
535
+
536
+ Does nothing if it isn't present.
537
+
538
+ Args:
539
+ name: Name of the hyperparameter.
540
+ """
541
+ if hasattr(self, name):
542
+ delattr(self, name)
543
+ del self._hparam_types[name]
544
+
545
+ def parse(self, values):
546
+ """Override existing hyperparameter values, parsing new values from a string.
547
+
548
+ See parse_values for more detail on the allowed format for values.
549
+
550
+ Args:
551
+ values: String. Comma separated list of `name=value` pairs where 'value'
552
+ must follow the syntax described above.
553
+
554
+ Returns:
555
+ The `HParams` instance.
556
+
557
+ Raises:
558
+ ValueError: If `values` cannot be parsed or a hyperparameter in `values`
559
+ doesn't exist.
560
+ """
561
+ type_map = {}
562
+ for name, t in self._hparam_types.items():
563
+ param_type, _ = t
564
+ type_map[name] = param_type
565
+
566
+ values_map = parse_values(values, type_map)
567
+ return self.override_from_dict(values_map)
568
+
569
+ def override_from_dict(self, values_dict):
570
+ """Override existing hyperparameter values, parsing new values from a dictionary.
571
+
572
+ Args:
573
+ values_dict: Dictionary of name:value pairs.
574
+
575
+ Returns:
576
+ The `HParams` instance.
577
+
578
+ Raises:
579
+ KeyError: If a hyperparameter in `values_dict` doesn't exist.
580
+ ValueError: If `values_dict` cannot be parsed.
581
+ """
582
+ for name, value in values_dict.items():
583
+ self.set_hparam(name, value)
584
+ return self
585
+
586
+ def set_model_structure(self, model_structure):
587
+ self._model_structure = model_structure
588
+
589
+ def get_model_structure(self):
590
+ return self._model_structure
591
+
592
+ def to_json(self, indent=None, separators=None, sort_keys=False):
593
+ """Serializes the hyperparameters into JSON.
594
+
595
+ Args:
596
+ indent: If a non-negative integer, JSON array elements and object members
597
+ will be pretty-printed with that indent level. An indent level of 0, or
598
+ negative, will only insert newlines. `None` (the default) selects the
599
+ most compact representation.
600
+ separators: Optional `(item_separator, key_separator)` tuple. Default is
601
+ `(', ', ': ')`.
602
+ sort_keys: If `True`, the output dictionaries will be sorted by key.
603
+
604
+ Returns:
605
+ A JSON string.
606
+ """
607
+ def remove_callables(x):
608
+ """Omit callable elements from input with arbitrary nesting."""
609
+ if isinstance(x, dict):
610
+ return {k: remove_callables(v) for k, v in six.iteritems(x)
611
+ if not callable(v)}
612
+ elif isinstance(x, list):
613
+ return [remove_callables(i) for i in x if not callable(i)]
614
+ return x
615
+ return json.dumps(
616
+ remove_callables(self.values()),
617
+ indent=indent,
618
+ separators=separators,
619
+ sort_keys=sort_keys)
620
+
621
+ def parse_json(self, values_json):
622
+ """Override existing hyperparameter values, parsing new values from a json object.
623
+
624
+ Args:
625
+ values_json: String containing a json object of name:value pairs.
626
+
627
+ Returns:
628
+ The `HParams` instance.
629
+
630
+ Raises:
631
+ KeyError: If a hyperparameter in `values_json` doesn't exist.
632
+ ValueError: If `values_json` cannot be parsed.
633
+ """
634
+ values_map = json.loads(values_json)
635
+ return self.override_from_dict(values_map)
636
+
637
+ def values(self):
638
+ """Return the hyperparameter values as a Python dictionary.
639
+
640
+ Returns:
641
+ A dictionary with hyperparameter names as keys. The values are the
642
+ hyperparameter values.
643
+ """
644
+ return {n: getattr(self, n) for n in self._hparam_types.keys()}
645
+
646
+ def get(self, key, default=None):
647
+ """Returns the value of `key` if it exists, else `default`."""
648
+ if key in self._hparam_types:
649
+ # Ensure that default is compatible with the parameter type.
650
+ if default is not None:
651
+ param_type, is_param_list = self._hparam_types[key]
652
+ type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
653
+ fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
654
+ 'default=%s' % (key, type_str, default))
655
+
656
+ is_default_list = isinstance(default, list)
657
+ if is_param_list != is_default_list:
658
+ raise ValueError(fail_msg)
659
+
660
+ try:
661
+ if is_default_list:
662
+ for value in default:
663
+ _cast_to_type_if_compatible(key, param_type, value)
664
+ else:
665
+ _cast_to_type_if_compatible(key, param_type, default)
666
+ except ValueError as e:
667
+ raise ValueError('%s. %s' % (fail_msg, e))
668
+
669
+ return getattr(self, key)
670
+
671
+ return default
672
+
673
+ def __contains__(self, key):
674
+ return key in self._hparam_types
675
+
676
+ @staticmethod
677
+ def _get_kind_name(param_type, is_list):
678
+ """Returns the field name given parameter type and is_list.
679
+
680
+ Args:
681
+ param_type: Data type of the hparam.
682
+ is_list: Whether this is a list.
683
+
684
+ Returns:
685
+ A string representation of the field name.
686
+
687
+ Raises:
688
+ ValueError: If parameter type is not recognized.
689
+ """
690
+ if issubclass(param_type, bool):
691
+ # This check must happen before issubclass(param_type, six.integer_types),
692
+ # since Python considers bool to be a subclass of int.
693
+ typename = 'bool'
694
+ elif issubclass(param_type, six.integer_types):
695
+ # Setting 'int' and 'long' types to be 'int64' to ensure the type is
696
+ # compatible with both Python2 and Python3.
697
+ typename = 'int64'
698
+ elif issubclass(param_type, (six.string_types, six.binary_type)):
699
+ # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
700
+ # compatible with both Python2 and Python3.
701
+ typename = 'bytes'
702
+ elif issubclass(param_type, float):
703
+ typename = 'float'
704
+ else:
705
+ raise ValueError('Unsupported parameter type: %s' % str(param_type))
706
+
707
+ suffix = 'list' if is_list else 'value'
708
+ return '_'.join([typename, suffix])
709
+
710
+ @staticmethod
711
+ def save_config(self, output_file, verbose=True):
712
+ def convert(o): # json cannot serialize integer in np.int64 format
713
+ if isinstance(o, np.int64):
714
+ return int(o)
715
+ raise TypeError
716
+ if verbose:
717
+ print(self)
718
+ with open(output_file, 'w') as f:
719
+ json.dump(self.values(), f, indent=True, default=convert)
720
+
721
+ @staticmethod
722
+ def load_config(config_file, verbose=True):
723
+ """
724
+ parse hparams from config file
725
+ :param config_file: json config file
726
+ :param verbose: print out values
727
+ """
728
+ try:
729
+ with open(config_file, 'r') as fin:
730
+ json_dict = json.loads(fin.read())
731
+ hps = HParams(**json_dict)
732
+ if verbose:
733
+ print_config(hps)
734
+ except Exception as e:
735
+ print('Error reading config file %s: %s.\nConfig will not be updated.' % (config_file, e))
736
+ return hps
737
+
738
+ @staticmethod
739
+ def clone(self):
740
+ """
741
+ return a deep copy of this object
742
+ """
743
+ return HParams(**self.values)
tools/image_dataset.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ imagefolder loader
5
+ inspired from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
6
+ @author: Tu Bui @surrey.ac.uk
7
+ """
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+ import os
12
+ import sys
13
+ import io
14
+ import time
15
+ import pandas as pd
16
+ import numpy as np
17
+ import random
18
+ from PIL import Image
19
+ from typing import Any, Callable, List, Optional, Tuple
20
+ import torch
21
+ from torchvision import transforms
22
+ from .base_lmdb import PILlmdb, ArrayDatabase
23
+ # from . import debug
24
+
25
+
26
+ def worker_init_fn(worker_id):
27
+ # to be passed to torch.utils.data.DataLoader to fix the
28
+ # random seed issue with numpy in multi-worker settings
29
+ torch_seed = torch.initial_seed()
30
+ random.seed(torch_seed + worker_id)
31
+ if torch_seed >= 2**30: # make sure torch_seed + workder_id < 2**32
32
+ torch_seed = torch_seed % 2**30
33
+ np.random.seed(torch_seed + worker_id)
34
+
35
+
36
+ def pil_loader(path: str) -> Image.Image:
37
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
38
+ with open(path, 'rb') as f:
39
+ img = Image.open(f)
40
+ return img.convert('RGB')
41
+
42
+
43
+ def dataset_wrapper(data_dir, data_list, **kwargs):
44
+ if os.path.exists(os.path.join(data_dir, 'data.mdb')):
45
+ return ImageDataset(data_dir, data_list, **kwargs)
46
+ else:
47
+ return ImageFolder(data_dir, data_list, **kwargs)
48
+
49
+
50
+ class ImageFolder(torch.utils.data.Dataset):
51
+ _repr_indent = 4
52
+ def __init__(self, data_dir, data_list, secret_len=100, resize=256, transform=None, **kwargs):
53
+ super().__init__()
54
+ self.transform = transforms.RandomResizedCrop((resize, resize), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333)) if transform is None else transform
55
+ self.build_data(data_dir, data_list, **kwargs)
56
+ self.kwargs = kwargs
57
+ self.secret_len = secret_len
58
+
59
+ def build_data(self, data_dir, data_list, **kwargs):
60
+ self.data_dir = data_dir
61
+ if isinstance(data_list, list):
62
+ self.data_list = data_list
63
+ elif isinstance(data_list, str):
64
+ self.data_list = pd.read_csv(data_list)['path'].tolist()
65
+ elif isinstance(data_list, pd.DataFrame):
66
+ self.data_list = data_list['path'].tolist()
67
+ else:
68
+ raise ValueError('data_list must be a list, str or pd.DataFrame')
69
+ self.N = len(self.data_list)
70
+
71
+ def __getitem__(self, index):
72
+ path = self.data_list[index]
73
+ img = pil_loader(os.path.join(self.data_dir, path))
74
+ img = self.transform(img)
75
+ img = np.array(img, dtype=np.float32)/127.5-1. # [-1, 1]
76
+ secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
77
+ return {'image': img, 'secret': secret} # {'img': x, 'index': index}
78
+
79
+ def __len__(self) -> int:
80
+ # raise NotImplementedError
81
+ return self.N
82
+
83
+ class ImageDataset(torch.utils.data.Dataset):
84
+ r"""
85
+ Customised Image Folder class for pytorch.
86
+ Accept lmdb and a csv list as the input.
87
+ Usage:
88
+ dataset = ImageDataset(img_dir, img_list)
89
+ dataset.set_transform(some_pytorch_transforms)
90
+ loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True,
91
+ num_workers=4, worker_init_fn=worker_init_fn)
92
+ for x,y in loader:
93
+ # x and y is input and target (dict), the keys can be customised.
94
+ """
95
+ _repr_indent = 4
96
+ def __init__(self, data_dir, data_list, secret_len=100, resize=None, transform=None, target_transform=None, **kwargs):
97
+ super().__init__()
98
+ if resize is not None:
99
+ self.resize = transforms.Resize((resize, resize))
100
+ self.set_transform(transform, target_transform)
101
+ self.build_data(data_dir, data_list, **kwargs)
102
+ self.secret_len = secret_len
103
+ self.kwargs = kwargs
104
+
105
+ def set_transform(self, transform, target_transform=None):
106
+ self.transform, self.target_transform = transform, target_transform
107
+
108
+ def build_data(self, data_dir, data_list, **kwargs):
109
+ """
110
+ Args:
111
+ data_list (text file) must have at least 3 fields: id, path and label
112
+
113
+ This method must create an attribute self.samples containing ID, input and target samples; and another attribute N storing the dataset size
114
+
115
+ Optional attributes: classes (list of unique classes), group (useful for
116
+ metric learning)
117
+ """
118
+ self.data_dir, self.list = data_dir, data_list
119
+ if ('dtype' in kwargs) and (kwargs['dtype'].lower() == 'array'):
120
+ data = ArrayDatabase(data_dir, data_list)
121
+ else:
122
+ data = PILlmdb(data_dir, data_list, **kwargs)
123
+ self.N = len(data)
124
+ self.classes = np.unique(data.labels)
125
+ self.samples = {'x': data, 'y': data.labels}
126
+ # assert isinstance(data_list, str) or isinstance(data_list, pd.DataFrame)
127
+ # df = pd.read_csv(data_list) if isinstance(data_list, str) else data_list
128
+ # assert 'id' in df and 'label' in df, f'[DATA] Error! {data_list} must contains "id" and "label".'
129
+ # ids = df['id'].tolist()
130
+ # labels = np.array(df['label'].tolist())
131
+ # data = PILlmdb(data_dir)
132
+ # assert set(ids).issubset(set(data.keys)) # ids should exist in lmdb
133
+ # self.N = len(ids)
134
+ # self.classes, inds = np.unique(labels, return_index=True)
135
+ # self.samples = {'id': ids, 'x': data, 'y': labels}
136
+
137
+ def set_ids(self, ids):
138
+ self.samples['x'].set_ids(ids)
139
+ self.samples['y'] = [self.samples['y'][i] for i in ids]
140
+ self.N = len(self.samples['x'])
141
+
142
+ def __getitem__(self, index: int) -> Any:
143
+ """
144
+ Args:
145
+ index (int): Index
146
+ Returns:
147
+ dict: (x: sample, y: target, **kwargs)
148
+ """
149
+ x, y = self.samples['x'][index], self.samples['y'][index]
150
+ if hasattr(self, 'resize'):
151
+ x = self.resize(x)
152
+ if self.transform is not None:
153
+ x = self.transform(x)
154
+ if self.target_transform is not None:
155
+ y = self.target_transform(y)
156
+ x = np.array(x, dtype=np.float32)/127.5-1.
157
+ secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
158
+ return {'image': x, 'secret': secret} # {'img': x, 'index': index}
159
+
160
+ def __len__(self) -> int:
161
+ # raise NotImplementedError
162
+ return self.N
163
+
164
+ def __repr__(self) -> str:
165
+ head = "\nDataset " + self.__class__.__name__
166
+ body = ["Number of datapoints: {}".format(self.__len__())]
167
+ if hasattr(self, 'data_dir') and self.data_dir is not None:
168
+ body.append("data_dir location: {}".format(self.data_dir))
169
+ if hasattr(self, 'kwargs'):
170
+ body.append(f'kwargs: {self.kwargs}')
171
+ body += self.extra_repr().splitlines()
172
+ if hasattr(self, "transform") and self.transform is not None:
173
+ body += [repr(self.transform)]
174
+ lines = [head] + [" " * self._repr_indent + line for line in body]
175
+ return '\n'.join(lines)
176
+
177
+ def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
178
+ lines = transform.__repr__().splitlines()
179
+ return (["{}{}".format(head, lines[0])] +
180
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
181
+
182
+ def extra_repr(self) -> str:
183
+ return ""
184
+
tools/image_dataset_generic.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ imagefolder loader
5
+ inspired from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
6
+ @author: Tu Bui @surrey.ac.uk
7
+ """
8
+ from __future__ import absolute_import
9
+ from __future__ import division
10
+ from __future__ import print_function
11
+ import os
12
+ import sys
13
+ import io
14
+ import time
15
+ import pandas as pd
16
+ import numpy as np
17
+ import random
18
+ from PIL import Image
19
+ from typing import Any, Callable, List, Optional, Tuple
20
+ import torch
21
+ from .base_lmdb import PILlmdb, ArrayDatabase
22
+ from torchvision import transforms
23
+ # from . import debug
24
+
25
+
26
+ def worker_init_fn(worker_id):
27
+ # to be passed to torch.utils.data.DataLoader to fix the
28
+ # random seed issue with numpy in multi-worker settings
29
+ torch_seed = torch.initial_seed()
30
+ random.seed(torch_seed + worker_id)
31
+ if torch_seed >= 2**30: # make sure torch_seed + workder_id < 2**32
32
+ torch_seed = torch_seed % 2**30
33
+ np.random.seed(torch_seed + worker_id)
34
+
35
+
36
+ def pil_loader(path: str) -> Image.Image:
37
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
38
+ with open(path, 'rb') as f:
39
+ img = Image.open(f)
40
+ return img.convert('RGB')
41
+
42
+
43
+ class ImageDataset(torch.utils.data.Dataset):
44
+ r"""
45
+ Customised Image Folder class for pytorch.
46
+ Accept lmdb and a csv list as the input.
47
+ Usage:
48
+ dataset = ImageDataset(img_dir, img_list)
49
+ dataset.set_transform(some_pytorch_transforms)
50
+ loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True,
51
+ num_workers=4, worker_init_fn=worker_init_fn)
52
+ for x,y in loader:
53
+ # x and y is input and target (dict), the keys can be customised.
54
+ """
55
+ _repr_indent = 4
56
+ def __init__(self, data_dir, data_list, secret_len=100, transform=None, target_transform=None, **kwargs):
57
+ super().__init__()
58
+ self.set_transform(transform, target_transform)
59
+ self.build_data(data_dir, data_list, **kwargs)
60
+ self.secret_len = secret_len
61
+ self.kwargs = kwargs
62
+
63
+ def set_transform(self, transform, target_transform=None):
64
+ self.transform, self.target_transform = transform, target_transform
65
+
66
+ def build_data(self, data_dir, data_list, **kwargs):
67
+ """
68
+ Args:
69
+ data_list (text file) must have at least 3 fields: id, path and label
70
+
71
+ This method must create an attribute self.samples containing ID, input and target samples; and another attribute N storing the dataset size
72
+
73
+ Optional attributes: classes (list of unique classes), group (useful for
74
+ metric learning)
75
+ """
76
+ self.data_dir, self.list = data_dir, data_list
77
+ if ('dtype' in kwargs) and (kwargs['dtype'].lower() == 'array'):
78
+ data = ArrayDatabase(data_dir, data_list)
79
+ else:
80
+ data = PILlmdb(data_dir, data_list, **kwargs)
81
+ self.N = len(data)
82
+ self.classes = np.unique(data.labels)
83
+ self.samples = {'x': data, 'y': data.labels}
84
+
85
+ def __getitem__(self, index: int) -> Any:
86
+ """
87
+ Args:
88
+ index (int): Index
89
+ Returns:
90
+ dict: (x: sample, y: target, **kwargs)
91
+ """
92
+ x, y = self.samples['x'][index], self.samples['y'][index]
93
+ if self.transform is not None:
94
+ x = self.transform(x)
95
+ if self.target_transform is not None:
96
+ y = self.target_transform(y)
97
+ x = np.array(x, dtype=np.float32)/127.5-1.
98
+ secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
99
+ return {'image': x, 'secret': secret} # {'img': x, 'index': index}
100
+
101
+ def __len__(self) -> int:
102
+ # raise NotImplementedError
103
+ return self.N
104
+
105
+ def __repr__(self) -> str:
106
+ head = "\nDataset " + self.__class__.__name__
107
+ body = ["Number of datapoints: {}".format(self.__len__())]
108
+ if hasattr(self, 'data_dir') and self.data_dir is not None:
109
+ body.append("data_dir location: {}".format(self.data_dir))
110
+ if hasattr(self, 'kwargs'):
111
+ body.append(f'kwargs: {self.kwargs}')
112
+ body += self.extra_repr().splitlines()
113
+ if hasattr(self, "transform") and self.transform is not None:
114
+ body += [repr(self.transform)]
115
+ lines = [head] + [" " * self._repr_indent + line for line in body]
116
+ return '\n'.join(lines)
117
+
118
+ def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
119
+ lines = transform.__repr__().splitlines()
120
+ return (["{}{}".format(head, lines[0])] +
121
+ ["{}{}".format(" " * len(head), line) for line in lines[1:]])
122
+
123
+ def extra_repr(self) -> str:
124
+ return ""
125
+
126
+ class ImageFolder(torch.utils.data.Dataset):
127
+ _repr_indent = 4
128
+ def __init__(self, data_dir, data_list, secret_len=100, resize=256, transform=None, **kwargs):
129
+ super().__init__()
130
+ self.transform = transforms.Resize((resize, resize)) if transform is None else transform
131
+ self.build_data(data_dir, data_list, **kwargs)
132
+ self.kwargs = kwargs
133
+ self.secret_len = secret_len
134
+
135
+ def build_data(self, data_dir, data_list, **kwargs):
136
+ self.data_dir = data_dir
137
+ if isinstance(data_list, list):
138
+ self.data_list = data_list
139
+ elif isinstance(data_list, str):
140
+ self.data_list = pd.read_csv(data_list)['path'].tolist()
141
+ elif isinstance(data_list, pd.DataFrame):
142
+ self.data_list = data_list['path'].tolist()
143
+ else:
144
+ raise ValueError('data_list must be a list, str or pd.DataFrame')
145
+ self.N = len(self.data_list)
146
+
147
+ def __getitem__(self, index):
148
+ path = self.data_list[index]
149
+ img = pil_loader(os.path.join(self.data_dir, path))
150
+ img = self.transform(img)
151
+ img = np.array(img, dtype=np.float32)/127.5-1. # [-1, 1]
152
+ secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2) # not used
153
+ return {'image': img, 'secret': secret} # {'img': x, 'index': index}
154
+
155
+ def __len__(self) -> int:
156
+ # raise NotImplementedError
157
+ return self.N
tools/image_tools.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+
5
+ @author: Tu Bui @surrey.ac.uk
6
+ """
7
+ from __future__ import absolute_import
8
+ from __future__ import division
9
+ from __future__ import print_function
10
+ from scipy import fftpack
11
+ import sys, os
12
+ from pathlib import Path
13
+ import numpy as np
14
+ import random
15
+ import glob
16
+ import json
17
+ import time
18
+ import importlib
19
+ import pandas as pd
20
+ from tqdm import tqdm
21
+ # from IPython.display import display
22
+ # import seaborn as sns
23
+ import matplotlib
24
+ # matplotlib.use('Agg') # headless run
25
+ import matplotlib.pyplot as plt
26
+ import matplotlib.patches as mpatches
27
+ from PIL import Image, ImageDraw, ImageFont
28
+ cmap = plt.get_cmap("tab10") # cmap as function
29
+ cmap = plt.rcParams['axes.prop_cycle'].by_key()['color'] # cmap
30
+
31
+ FONT = '/vol/research/tubui1/_base/utils/FreeSans.ttf'
32
+
33
+ # def imshow(im):
34
+ # if type(im) is np.ndarray:
35
+ # im = Image.fromarray(im)
36
+ # display(im)
37
+
38
+ def make_grid(array_list, gsize=(3,3)):
39
+ """
40
+ make a grid image from a list of image array (RGB)
41
+ return: array RGB
42
+ """
43
+ assert len(gsize)==2 and gsize[0]*gsize[1]==len(array_list)
44
+ h,w,c = array_list[0].shape
45
+ out = np.array(array_list).reshape(gsize[0], gsize[1], h, w, c).transpose(0, 2, 1, 3, 4).reshape(gsize[0]*h, gsize[1]*w, c)
46
+ return out
47
+
48
+ def collage(im_list, size=None, pad=0, color=255):
49
+ """
50
+ generalised function of make_grid()
51
+ work on PIL/numpy images of arbitrary size
52
+ """
53
+ if size is None:
54
+ size=(1, len(im_list))
55
+ assert len(size)==2
56
+ if isinstance(im_list[0], np.ndarray):
57
+ im_list = [Image.fromarray(im) for im in im_list]
58
+ h, w = size
59
+ n = len(im_list)
60
+ canvas = []
61
+ for i in range(h):
62
+ start, end = i*w, min((i+1)*w, n)
63
+ row = combine_horz(im_list[start:end], pad, color)
64
+ canvas.append(row)
65
+ canvas = combine_vert(canvas, pad, color)
66
+ return canvas
67
+
68
+ def combine_horz(pil_ims, pad=0, c=255):
69
+ """
70
+ Combines multiple pil_ims into a single side-by-side PIL image object.
71
+ """
72
+ widths, heights = zip(*(i.size for i in pil_ims))
73
+ total_width = sum(widths) + (len(pil_ims)-1) * pad
74
+ max_height = max(heights)
75
+ color = (c,c,c)
76
+ new_im = Image.new('RGB', (total_width, max_height), color)
77
+ x_offset = 0
78
+ for im in pil_ims:
79
+ new_im.paste(im, (x_offset,0))
80
+ x_offset += (im.size[0] + pad)
81
+ return new_im
82
+
83
+
84
+ def combine_vert(pil_ims, pad=0, c=255):
85
+ """
86
+ Combines multiple pil_ims into a single vertical PIL image object.
87
+ """
88
+ widths, heights = zip(*(i.size for i in pil_ims))
89
+ max_width = max(widths)
90
+ total_height = sum(heights) + (len(pil_ims)-1)*pad
91
+ color = (c,c,c)
92
+ new_im = Image.new('RGB', (max_width, total_height), color)
93
+ y_offset = 0
94
+ for im in pil_ims:
95
+ new_im.paste(im, (0,y_offset))
96
+ y_offset += (im.size[1] + pad)
97
+ return new_im
98
+
99
+ def make_text_image(img_shape=(100,20), text='hello', font_path=FONT, offset=(0,0), font_size=16):
100
+ """
101
+ make a text image with given width/height and font size
102
+ Args:
103
+ img_shape, offset tuple (width, height)
104
+ font_path path to font file (TrueType)
105
+ font_size max font size, actual may smaller
106
+
107
+ Return:
108
+ pil image
109
+ """
110
+ im = Image.new('RGB', tuple(img_shape), (255,255,255))
111
+ draw = ImageDraw.Draw(im)
112
+
113
+ def get_font_size(max_font_size):
114
+ font = ImageFont.truetype(font_path, max_font_size)
115
+ text_size = font.getsize(text) # (w,h)
116
+ start_w = int((img_shape[0] - text_size[0]) / 2)
117
+ start_h = int((img_shape[1] - text_size[1])/2)
118
+ if start_h <0 or start_w < 0:
119
+ return get_font_size(max_font_size-2)
120
+ else:
121
+ return font, (start_w, start_h)
122
+ font, pos = get_font_size(font_size)
123
+ pos = (pos[0]+offset[0], pos[1]+offset[1])
124
+ draw.text(pos, text, font=font, fill=0)
125
+ return im
126
+
127
+
128
+ def log_scale(array, epsilon=1e-12):
129
+ """Log scale the input array.
130
+ """
131
+ array = np.abs(array)
132
+ array += epsilon # no zero in log
133
+ array = np.log(array)
134
+ return array
135
+
136
+ def dct2(array):
137
+ """2D DCT"""
138
+ array = fftpack.dct(array, type=2, norm="ortho", axis=0)
139
+ array = fftpack.dct(array, type=2, norm="ortho", axis=1)
140
+ return array
141
+
142
+ def idct2(array):
143
+ """inverse 2D DCT"""
144
+ array = fftpack.idct(array, type=2, norm="ortho", axis=0)
145
+ array = fftpack.idct(array, type=2, norm="ortho", axis=1)
146
+ return array
147
+
148
+
149
+ class DCT(object):
150
+ def __init__(self, log=True):
151
+ self.log = log
152
+
153
+ def __call__(self, x):
154
+ x = np.array(x)
155
+ x = dct2(x)
156
+ if self.log:
157
+ x = log_scale(x)
158
+ # normalize
159
+ x = np.clip((x - x.min())/(x.max() - x.min()) * 255, 0, 255).astype(np.uint8)
160
+ return Image.fromarray(x)
161
+
162
+ def __repr__(self):
163
+ s = f'(Discrete Cosine Transform, logarithm={self.log})'
164
+ return self.__class__.__name__ + s
tools/imgcap_dataset.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Dataset class for image-caption
5
+ @author: Tu Bui @University of Surrey
6
+ """
7
+ import json
8
+ from PIL import Image
9
+ import numpy as np
10
+ from pathlib import Path
11
+ import torch
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from functools import partial
14
+ import pytorch_lightning as pl
15
+ from ldm.util import instantiate_from_config
16
+ import pandas as pd
17
+
18
+
19
+ def worker_init_fn(_):
20
+ worker_info = torch.utils.data.get_worker_info()
21
+ worker_id = worker_info.id
22
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
23
+
24
+
25
+ class WrappedDataset(Dataset):
26
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
27
+
28
+ def __init__(self, dataset):
29
+ self.data = dataset
30
+
31
+ def __len__(self):
32
+ return len(self.data)
33
+
34
+ def __getitem__(self, idx):
35
+ return self.data[idx]
36
+
37
+
38
+ class DataModuleFromConfig(pl.LightningDataModule):
39
+ def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
40
+ shuffle_val_dataloader=False):
41
+ super().__init__()
42
+ self.batch_size = batch_size
43
+ self.dataset_configs = dict()
44
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
45
+ self.use_worker_init_fn = use_worker_init_fn
46
+ if train is not None:
47
+ self.dataset_configs["train"] = train
48
+ self.train_dataloader = self._train_dataloader
49
+ if validation is not None:
50
+ self.dataset_configs["validation"] = validation
51
+ self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
52
+ if test is not None:
53
+ self.dataset_configs["test"] = test
54
+ self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
55
+ if predict is not None:
56
+ self.dataset_configs["predict"] = predict
57
+ self.predict_dataloader = self._predict_dataloader
58
+ self.wrap = wrap
59
+
60
+ def prepare_data(self):
61
+ for data_cfg in self.dataset_configs.values():
62
+ instantiate_from_config(data_cfg)
63
+
64
+ def setup(self, stage=None):
65
+ self.datasets = dict(
66
+ (k, instantiate_from_config(self.dataset_configs[k]))
67
+ for k in self.dataset_configs)
68
+ if self.wrap:
69
+ for k in self.datasets:
70
+ self.datasets[k] = WrappedDataset(self.datasets[k])
71
+
72
+ def _train_dataloader(self):
73
+ if self.use_worker_init_fn:
74
+ init_fn = worker_init_fn
75
+ else:
76
+ init_fn = None
77
+ return DataLoader(self.datasets["train"], batch_size=self.batch_size,
78
+ num_workers=self.num_workers, shuffle=True,
79
+ worker_init_fn=init_fn)
80
+
81
+ def _val_dataloader(self, shuffle=False):
82
+ if self.use_worker_init_fn:
83
+ init_fn = worker_init_fn
84
+ else:
85
+ init_fn = None
86
+ return DataLoader(self.datasets["validation"],
87
+ batch_size=self.batch_size,
88
+ num_workers=self.num_workers,
89
+ worker_init_fn=init_fn,
90
+ shuffle=shuffle)
91
+
92
+ def _test_dataloader(self, shuffle=False):
93
+ if self.use_worker_init_fn:
94
+ init_fn = worker_init_fn
95
+ else:
96
+ init_fn = None
97
+
98
+ return DataLoader(self.datasets["test"], batch_size=self.batch_size,
99
+ num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
100
+
101
+ def _predict_dataloader(self, shuffle=False):
102
+ if self.use_worker_init_fn:
103
+ init_fn = worker_init_fn
104
+ else:
105
+ init_fn = None
106
+ return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
107
+ num_workers=self.num_workers, worker_init_fn=init_fn)
108
+
109
+
110
+ class ImageCaptionRaw(Dataset):
111
+ def __init__(self, image_dir, caption_file, secret_len=100, transform=None):
112
+ super().__init__()
113
+ self.image_dir = Path(image_dir)
114
+ self.data = []
115
+ with open(caption_file, 'rt') as f:
116
+ for line in f:
117
+ self.data.append(json.loads(line))
118
+ self.secret_len = secret_len
119
+ self.transform = transform
120
+
121
+ def __len__(self):
122
+ return len(self.data)
123
+
124
+ def __getitem__(self, idx):
125
+ item = self.data[idx]
126
+ image = Image.open(self.image_dir/item['image']).convert('RGB').resize((512,512))
127
+ caption = item['captions']
128
+ cid = torch.randint(0, len(caption), (1,)).item()
129
+ caption = caption[cid]
130
+ if self.transform is not None:
131
+ image = self.transform(image)
132
+
133
+ image = np.array(image, dtype=np.float32)/ 255.0 # normalize to [0, 1]
134
+ target = image * 2.0 - 1.0 # normalize to [-1, 1]
135
+ secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
136
+ return dict(image=image, caption=caption, target=target, secret=secret)
137
+
138
+
139
+ class BAMFG(Dataset):
140
+ def __init__(self, style_dir, gt_dir, data_list, transform=None):
141
+ super().__init__()
142
+ self.style_dir = Path(style_dir)
143
+ self.gt_dir = Path(gt_dir)
144
+ self.data = pd.read_csv(data_list)
145
+ self.transform = transform
146
+
147
+ def __len__(self):
148
+ return len(self.data)
149
+
150
+ def __getitem__(self, idx):
151
+ item = self.data.iloc[idx]
152
+ gt_img = Image.open(self.gt_dir/item['gt_img']).convert('RGB').resize((512,512))
153
+ style_img = Image.open(self.style_dir/item['style_img']).convert('RGB').resize((512,512))
154
+ txt = item['prompt']
155
+ if self.transform is not None:
156
+ gt_img = self.transform(gt_img)
157
+ style_img = self.transform(style_img)
158
+
159
+ gt_img = np.array(gt_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
160
+ style_img = np.array(style_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
161
+ target = gt_img * 2.0 - 1.0 # normalize to [-1, 1]
162
+
163
+ return dict(image=gt_img, txt=txt, hint=style_img)
tools/sifid.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from scipy import linalg
4
+ import torchvision
5
+ from torchvision import transforms
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from PIL import Image
9
+
10
+
11
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
12
+ """Numpy implementation of the Frechet Distance.
13
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
14
+ and X_2 ~ N(mu_2, C_2) is
15
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
16
+ Stable version by Dougal J. Sutherland.
17
+ Params:
18
+ -- mu1 : Numpy array containing the activations of a layer of the
19
+ inception net (like returned by the function 'get_predictions')
20
+ for generated samples.
21
+ -- mu2 : The sample mean over activations, precalculated on an
22
+ representative data set.
23
+ -- sigma1: The covariance matrix over activations for generated samples.
24
+ -- sigma2: The covariance matrix over activations, precalculated on an
25
+ representative data set.
26
+ Returns:
27
+ -- : The Frechet Distance.
28
+ """
29
+
30
+ mu1 = np.atleast_1d(mu1)
31
+ mu2 = np.atleast_1d(mu2)
32
+
33
+ sigma1 = np.atleast_2d(sigma1)
34
+ sigma2 = np.atleast_2d(sigma2)
35
+
36
+ assert mu1.shape == mu2.shape, \
37
+ 'Training and test mean vectors have different lengths'
38
+ assert sigma1.shape == sigma2.shape, \
39
+ 'Training and test covariances have different dimensions'
40
+
41
+ diff = mu1 - mu2
42
+
43
+ # Product might be almost singular
44
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
45
+ if not np.isfinite(covmean).all():
46
+ msg = ('fid calculation produces singular product; '
47
+ 'adding %s to diagonal of cov estimates') % eps
48
+ print(msg)
49
+ offset = np.eye(sigma1.shape[0]) * eps
50
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
51
+
52
+ # Numerical error might give slight imaginary component
53
+ if np.iscomplexobj(covmean):
54
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
55
+ m = np.max(np.abs(covmean.imag))
56
+ raise ValueError('Imaginary component {}'.format(m))
57
+ covmean = covmean.real
58
+
59
+ tr_covmean = np.trace(covmean)
60
+
61
+ return (diff.dot(diff) + np.trace(sigma1) +
62
+ np.trace(sigma2) - 2 * tr_covmean)
63
+
64
+
65
+ class SIFID(object):
66
+ def __init__(self, dims=64) -> None:
67
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
68
+ self.model = InceptionV3([block_idx]).cuda()
69
+ self.model.eval()
70
+ self.dims = dims
71
+
72
+ def calculate_activation_statistics(self, x):
73
+ act = self.get_activations(x)
74
+ mu = np.mean(act, axis=0)
75
+ sigma = np.cov(act, rowvar=False)
76
+ return mu, sigma
77
+
78
+ def get_activations(self, x):
79
+ # x tensor (B, C, H, W) in range [0, 1]
80
+ batch_size = x.shape[0]
81
+ with torch.no_grad():
82
+ pred = self.model(x)[0]
83
+ pred = pred.cpu().numpy()
84
+ pred = pred.transpose(0, 2, 3, 1).reshape(batch_size*pred.shape[2]*pred.shape[3],-1)
85
+ return pred
86
+
87
+ def __call__(self, x1, x2):
88
+ # x1, x2 tensor (B, C, H, W) in range [-1, 1]
89
+ x1, x2 = (x1 + 1.)/2, (x2 + 1.)/2 # [-1, 1] -> [0, 1]
90
+ m1, s1 = self.calculate_activation_statistics(x1.unsqueeze(0).cuda())
91
+ m2, s2 = self.calculate_activation_statistics(x2.unsqueeze(0).cuda())
92
+ return calculate_frechet_distance(m1, s1, m2, s2)
93
+
94
+
95
+ class InceptionV3(nn.Module):
96
+ """Pretrained InceptionV3 network returning feature maps"""
97
+
98
+ # Index of default block of inception to return,
99
+ # corresponds to output of final average pooling
100
+ DEFAULT_BLOCK_INDEX = 3
101
+
102
+ # Maps feature dimensionality to their output blocks indices
103
+ BLOCK_INDEX_BY_DIM = {
104
+ 64: 0, # First max pooling features
105
+ 192: 1, # Second max pooling featurs
106
+ 768: 2, # Pre-aux classifier features
107
+ 2048: 3 # Final average pooling features
108
+ }
109
+
110
+ def __init__(self,
111
+ output_blocks=[DEFAULT_BLOCK_INDEX],
112
+ resize_input=False,
113
+ normalize_input=True,
114
+ requires_grad=False):
115
+ """Build pretrained InceptionV3
116
+ Parameters
117
+ ----------
118
+ output_blocks : list of int
119
+ Indices of blocks to return features of. Possible values are:
120
+ - 0: corresponds to output of first max pooling
121
+ - 1: corresponds to output of second max pooling
122
+ - 2: corresponds to output which is fed to aux classifier
123
+ - 3: corresponds to output of final average pooling
124
+ resize_input : bool
125
+ If true, bilinearly resizes input to width and height 299 before
126
+ feeding input to model. As the network without fully connected
127
+ layers is fully convolutional, it should be able to handle inputs
128
+ of arbitrary size, so resizing might not be strictly needed
129
+ normalize_input : bool
130
+ If true, scales the input from range (0, 1) to the range the
131
+ pretrained Inception network expects, namely (-1, 1)
132
+ requires_grad : bool
133
+ If true, parameters of the model require gradient. Possibly useful
134
+ for finetuning the network
135
+ """
136
+ super(InceptionV3, self).__init__()
137
+
138
+ self.resize_input = resize_input
139
+ self.normalize_input = normalize_input
140
+ self.output_blocks = sorted(output_blocks)
141
+ self.last_needed_block = max(output_blocks)
142
+
143
+ assert self.last_needed_block <= 3, \
144
+ 'Last possible output block index is 3'
145
+
146
+ self.blocks = nn.ModuleList()
147
+
148
+ inception = torchvision.models.inception_v3(pretrained=True)
149
+
150
+ # Block 0: input to maxpool1
151
+ block0 = [
152
+ inception.Conv2d_1a_3x3,
153
+ inception.Conv2d_2a_3x3,
154
+ inception.Conv2d_2b_3x3,
155
+ ]
156
+
157
+
158
+ self.blocks.append(nn.Sequential(*block0))
159
+
160
+ # Block 1: maxpool1 to maxpool2
161
+ if self.last_needed_block >= 1:
162
+ block1 = [
163
+ nn.MaxPool2d(kernel_size=3, stride=2),
164
+ inception.Conv2d_3b_1x1,
165
+ inception.Conv2d_4a_3x3,
166
+ ]
167
+ self.blocks.append(nn.Sequential(*block1))
168
+
169
+ # Block 2: maxpool2 to aux classifier
170
+ if self.last_needed_block >= 2:
171
+ block2 = [
172
+ nn.MaxPool2d(kernel_size=3, stride=2),
173
+ inception.Mixed_5b,
174
+ inception.Mixed_5c,
175
+ inception.Mixed_5d,
176
+ inception.Mixed_6a,
177
+ inception.Mixed_6b,
178
+ inception.Mixed_6c,
179
+ inception.Mixed_6d,
180
+ inception.Mixed_6e,
181
+ ]
182
+ self.blocks.append(nn.Sequential(*block2))
183
+
184
+ # Block 3: aux classifier to final avgpool
185
+ if self.last_needed_block >= 3:
186
+ block3 = [
187
+ inception.Mixed_7a,
188
+ inception.Mixed_7b,
189
+ inception.Mixed_7c,
190
+ ]
191
+ self.blocks.append(nn.Sequential(*block3))
192
+
193
+ if self.last_needed_block >= 4:
194
+ block4 = [
195
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
196
+ ]
197
+ self.blocks.append(nn.Sequential(*block4))
198
+
199
+ for param in self.parameters():
200
+ param.requires_grad = requires_grad
201
+
202
+ def forward(self, inp):
203
+ """Get Inception feature maps
204
+ Parameters
205
+ ----------
206
+ inp : torch.autograd.Variable
207
+ Input tensor of shape Bx3xHxW. Values are expected to be in
208
+ range (0, 1)
209
+ Returns
210
+ -------
211
+ List of torch.autograd.Variable, corresponding to the selected output
212
+ block, sorted ascending by index
213
+ """
214
+ outp = []
215
+ x = inp
216
+
217
+ if self.resize_input:
218
+ x = F.upsample(x,
219
+ size=(299, 299),
220
+ mode='bilinear',
221
+ align_corners=False)
222
+
223
+ if self.normalize_input:
224
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
225
+
226
+ for idx, block in enumerate(self.blocks):
227
+ x = block(x)
228
+ if idx in self.output_blocks:
229
+ outp.append(x)
230
+
231
+ if idx == self.last_needed_block:
232
+ break
233
+
234
+ return outp
235
+
236
+ if __name__ == '__main__':
237
+ tform = transforms.Compose([transforms.Resize((256,256)),
238
+ transforms.ToTensor(),
239
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
240
+ im1 = Image.open('test1.jpg')
241
+ im2 = Image.open('test2.jpg')
242
+ im1 = tform(im1) # 3xHxW in [-1,]
243
+ im2 = tform(im2)
244
+ sifid_model = SIFID()
245
+ sifid_score = sifid_model(im1, im2)
246
+ print(sifid_score)
tools/slack_bot.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ slack_bot.py
5
+ Created on May 02 2020 11:02
6
+ a bot to send message/image during program run
7
+ @author: Tu Bui [email protected]
8
+ """
9
+
10
+ from __future__ import absolute_import
11
+ from __future__ import division
12
+ from __future__ import print_function
13
+ import os
14
+ import sys
15
+ import requests
16
+ import socket
17
+ from slack import WebClient
18
+ from slack.errors import SlackApiError
19
+ import threading
20
+
21
+
22
+ SLACK_MAX_PRINT_ERROR = 3
23
+ SLACK_ERROR_CODE = {'not_active': 1,
24
+ 'API': 2}
25
+
26
+
27
+ def welcome_message():
28
+ hostname = socket.gethostname()
29
+ all_args = ' '.join(sys.argv)
30
+ out_text = 'On server {}: {}\n'.format(hostname, all_args)
31
+ return out_text
32
+
33
+
34
+ class Notifier(object):
35
+ """
36
+ A slack bot to send text/image to a given workspace channel.
37
+ This class initializes with a text file as input, the text file should contain 2 lines:
38
+ slack token
39
+ slack channel
40
+
41
+ Usage:
42
+ msg = Notifier(token_file)
43
+ msg.send_initial_text(' '.join(sys.argv))
44
+ msg.send_text('hi, this text is inside slack thread')
45
+ msg.send_file(your_file, 'file title')
46
+ """
47
+ def __init__(self, token_file):
48
+ """
49
+ setup slack
50
+ :param token_file: path to slack token file
51
+ """
52
+ self.active = True
53
+ self.thread_id = None
54
+ self.counter = 0 # count number of errors during Web API call
55
+ if not os.path.exists(token_file):
56
+ print('[SLACK] token file not found. You will not be notified.')
57
+ self.active = False
58
+ else:
59
+ try:
60
+ with open(token_file, 'r') as f:
61
+ lines = f.readlines()
62
+ self.token = lines[0].strip()
63
+ self.channel = lines[1].strip()
64
+ except Exception as e:
65
+ print(e)
66
+ print('[SLACK] fail to read token file. You will not be notified.')
67
+ self.active = False
68
+
69
+ def _handel_error(self, e):
70
+ assert e.response["ok"] is False
71
+ assert e.response["error"] # str like 'invalid_auth', 'channel_not_found'
72
+ self.counter += 1
73
+ if self.counter <= SLACK_MAX_PRINT_ERROR:
74
+ print(f"Got the following error, you will not be notified: {e.response['error']}")
75
+
76
+ def send_init_text(self, text=None):
77
+ """
78
+ start a new thread with a main message and register the thread id
79
+ :param text: initial message for this thread
80
+ :return:
81
+ """
82
+ if not self.active:
83
+ return SLACK_ERROR_CODE['not_active']
84
+ try:
85
+ if text is None:
86
+ text = welcome_message()
87
+ sc = WebClient(self.token)
88
+ response = sc.chat_postMessage(channel=self.channel, text=text)
89
+ self.thread_id = response['ts']
90
+ except SlackApiError as e:
91
+ self._handel_error(e)
92
+ return SLACK_ERROR_CODE['API']
93
+ print('[SLACK] sent initial text. Chat ID %s. Message %s' % (self.thread_id, text))
94
+ return 0
95
+
96
+ def send_init_file(self, file_path, title=''):
97
+ """
98
+ start a new thread with a file and register thread id
99
+ :param file_path: path to file
100
+ :param title: title of this file
101
+ :return: 0 if success otherwise error code
102
+ """
103
+ if not self.active:
104
+ return SLACK_ERROR_CODE['not_active']
105
+ try:
106
+ response = sc.files_upload(title=title, channels=self.channel, file=file_path)
107
+ self.thread_id = response['ts']
108
+ except SlackApiError as e:
109
+ self._handel_error(e)
110
+ return SLACK_ERROR_CODE['API']
111
+ print('[SLACK] sent initial file. Chat ID %s.' % self.thread_id)
112
+ return 0
113
+
114
+ def send_text(self, text, reply_broadcast=False):
115
+ """
116
+ send text as a thread if one is registered in self.thread_id.
117
+ Otherwise send as a new message
118
+ :param text: message to send.
119
+ :return: 0 if success, error code otherwise
120
+ """
121
+ print(text)
122
+ if not self.active:
123
+ return SLACK_ERROR_CODE['not_active']
124
+ if self.thread_id is None:
125
+ self.send_init_text(text)
126
+ else:
127
+ try:
128
+ sc = WebClient(self.token)
129
+ response = sc.chat_postMessage(channel=self.channel, text=text,
130
+ thread_ts=self.thread_id, as_user=True,
131
+ reply_broadcast=reply_broadcast)
132
+ except SlackApiError as e:
133
+ self._handel_error(e)
134
+ return SLACK_ERROR_CODE['API']
135
+ return 0
136
+
137
+ def _send_file(self, file_path, title='', reply_broadcast=False):
138
+ """can be multithread target"""
139
+ try:
140
+ sc = WebClient(self.token)
141
+ sc.files_upload(title=title, channels=self.channel,
142
+ thread_ts=self.thread_id, file=file_path,
143
+ reply_broadcast=reply_broadcast)
144
+ except SlackApiError as e:
145
+ self._handel_error(e)
146
+ return SLACK_ERROR_CODE['API']
147
+ return 0
148
+
149
+ def send_file(self, file_path, title='', reply_broadcast=False):
150
+ if not self.active:
151
+ return SLACK_ERROR_CODE['not_active']
152
+ if self.thread_id is None:
153
+ return self.send_init_file(file_path, title)
154
+ else:
155
+ os_thread = threading.Thread(target=self._send_file, args=(file_path, title, reply_broadcast))
156
+ os_thread.start()
157
+ return 0 # may still have error if _send_file() fail