jadechoghari commited on
Commit
72fd365
·
verified ·
1 Parent(s): 4fc142b

add initial files

Browse files
Files changed (8) hide show
  1. checkpoint-last.pth +3 -0
  2. diffloss.py +248 -0
  3. diffusion.py +47 -0
  4. diffusion_utils.py +73 -0
  5. gaussian_diffusion.py +877 -0
  6. kl16.ckpt +3 -0
  7. mar.py +353 -0
  8. respace.py +129 -0
checkpoint-last.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e970a33bc90353e2fabe3498ed1f2d194dd8d17cd387665f80b2984dfca538c
3
+ size 1663614946
diffloss.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ import math
5
+
6
+ from diffusion import create_diffusion
7
+
8
+
9
+ class DiffLoss(nn.Module):
10
+ """Diffusion Loss"""
11
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12
+ super(DiffLoss, self).__init__()
13
+ self.in_channels = target_channels
14
+ self.net = SimpleMLPAdaLN(
15
+ in_channels=target_channels,
16
+ model_channels=width,
17
+ out_channels=target_channels * 2, # for vlb loss
18
+ z_channels=z_channels,
19
+ num_res_blocks=depth,
20
+ grad_checkpointing=grad_checkpointing
21
+ )
22
+
23
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25
+
26
+ def forward(self, target, z, mask=None):
27
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28
+ model_kwargs = dict(c=z)
29
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30
+ loss = loss_dict["loss"]
31
+ if mask is not None:
32
+ loss = (loss * mask).sum() / mask.sum()
33
+ return loss.mean()
34
+
35
+ def sample(self, z, temperature=1.0, cfg=1.0):
36
+ # diffusion loss sampling
37
+ if not cfg == 1.0:
38
+ noise = torch.randn(z.shape[0] // 2, self.in_channels)
39
+ noise = torch.cat([noise, noise], dim=0)
40
+ model_kwargs = dict(c=z, cfg_scale=cfg)
41
+ sample_fn = self.net.forward_with_cfg
42
+ else:
43
+ noise = torch.randn(z.shape[0], self.in_channels)
44
+ model_kwargs = dict(c=z)
45
+ sample_fn = self.net.forward
46
+
47
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
48
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49
+ temperature=temperature
50
+ )
51
+
52
+ return sampled_token_latent
53
+
54
+
55
+ def modulate(x, shift, scale):
56
+ return x * (1 + scale) + shift
57
+
58
+
59
+ class TimestepEmbedder(nn.Module):
60
+ """
61
+ Embeds scalar timesteps into vector representations.
62
+ """
63
+ def __init__(self, hidden_size, frequency_embedding_size=256):
64
+ super().__init__()
65
+ self.mlp = nn.Sequential(
66
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(hidden_size, hidden_size, bias=True),
69
+ )
70
+ self.frequency_embedding_size = frequency_embedding_size
71
+
72
+ @staticmethod
73
+ def timestep_embedding(t, dim, max_period=10000):
74
+ """
75
+ Create sinusoidal timestep embeddings.
76
+ :param t: a 1-D Tensor of N indices, one per batch element.
77
+ These may be fractional.
78
+ :param dim: the dimension of the output.
79
+ :param max_period: controls the minimum frequency of the embeddings.
80
+ :return: an (N, D) Tensor of positional embeddings.
81
+ """
82
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83
+ half = dim // 2
84
+ freqs = torch.exp(
85
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86
+ ).to(device=t.device)
87
+ args = t[:, None].float() * freqs[None]
88
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89
+ if dim % 2:
90
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91
+ return embedding
92
+
93
+ def forward(self, t):
94
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95
+ t_emb = self.mlp(t_freq)
96
+ return t_emb
97
+
98
+
99
+ class ResBlock(nn.Module):
100
+ """
101
+ A residual block that can optionally change the number of channels.
102
+ :param channels: the number of input channels.
103
+ """
104
+
105
+ def __init__(
106
+ self,
107
+ channels
108
+ ):
109
+ super().__init__()
110
+ self.channels = channels
111
+
112
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113
+ self.mlp = nn.Sequential(
114
+ nn.Linear(channels, channels, bias=True),
115
+ nn.SiLU(),
116
+ nn.Linear(channels, channels, bias=True),
117
+ )
118
+
119
+ self.adaLN_modulation = nn.Sequential(
120
+ nn.SiLU(),
121
+ nn.Linear(channels, 3 * channels, bias=True)
122
+ )
123
+
124
+ def forward(self, x, y):
125
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127
+ h = self.mlp(h)
128
+ return x + gate_mlp * h
129
+
130
+
131
+ class FinalLayer(nn.Module):
132
+ """
133
+ The final layer of DiT.
134
+ """
135
+ def __init__(self, model_channels, out_channels):
136
+ super().__init__()
137
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138
+ self.linear = nn.Linear(model_channels, out_channels, bias=True)
139
+ self.adaLN_modulation = nn.Sequential(
140
+ nn.SiLU(),
141
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
142
+ )
143
+
144
+ def forward(self, x, c):
145
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146
+ x = modulate(self.norm_final(x), shift, scale)
147
+ x = self.linear(x)
148
+ return x
149
+
150
+
151
+ class SimpleMLPAdaLN(nn.Module):
152
+ """
153
+ The MLP for Diffusion Loss.
154
+ :param in_channels: channels in the input Tensor.
155
+ :param model_channels: base channel count for the model.
156
+ :param out_channels: channels in the output Tensor.
157
+ :param z_channels: channels in the condition.
158
+ :param num_res_blocks: number of residual blocks per downsample.
159
+ """
160
+
161
+ def __init__(
162
+ self,
163
+ in_channels,
164
+ model_channels,
165
+ out_channels,
166
+ z_channels,
167
+ num_res_blocks,
168
+ grad_checkpointing=False
169
+ ):
170
+ super().__init__()
171
+
172
+ self.in_channels = in_channels
173
+ self.model_channels = model_channels
174
+ self.out_channels = out_channels
175
+ self.num_res_blocks = num_res_blocks
176
+ self.grad_checkpointing = grad_checkpointing
177
+
178
+ self.time_embed = TimestepEmbedder(model_channels)
179
+ self.cond_embed = nn.Linear(z_channels, model_channels)
180
+
181
+ self.input_proj = nn.Linear(in_channels, model_channels)
182
+
183
+ res_blocks = []
184
+ for i in range(num_res_blocks):
185
+ res_blocks.append(ResBlock(
186
+ model_channels,
187
+ ))
188
+
189
+ self.res_blocks = nn.ModuleList(res_blocks)
190
+ self.final_layer = FinalLayer(model_channels, out_channels)
191
+
192
+ self.initialize_weights()
193
+
194
+ def initialize_weights(self):
195
+ def _basic_init(module):
196
+ if isinstance(module, nn.Linear):
197
+ torch.nn.init.xavier_uniform_(module.weight)
198
+ if module.bias is not None:
199
+ nn.init.constant_(module.bias, 0)
200
+ self.apply(_basic_init)
201
+
202
+ # Initialize timestep embedding MLP
203
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205
+
206
+ # Zero-out adaLN modulation layers
207
+ for block in self.res_blocks:
208
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210
+
211
+ # Zero-out output layers
212
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214
+ nn.init.constant_(self.final_layer.linear.weight, 0)
215
+ nn.init.constant_(self.final_layer.linear.bias, 0)
216
+
217
+ def forward(self, x, t, c):
218
+ """
219
+ Apply the model to an input batch.
220
+ :param x: an [N x C x ...] Tensor of inputs.
221
+ :param t: a 1-D batch of timesteps.
222
+ :param c: conditioning from AR transformer.
223
+ :return: an [N x C x ...] Tensor of outputs.
224
+ """
225
+ x = self.input_proj(x)
226
+ t = self.time_embed(t)
227
+ c = self.cond_embed(c)
228
+
229
+ y = t + c
230
+
231
+ if self.grad_checkpointing and not torch.jit.is_scripting():
232
+ for block in self.res_blocks:
233
+ x = checkpoint(block, x, y)
234
+ else:
235
+ for block in self.res_blocks:
236
+ x = block(x, y)
237
+
238
+ return self.final_layer(x, y)
239
+
240
+ def forward_with_cfg(self, x, t, c, cfg_scale):
241
+ half = x[: len(x) // 2]
242
+ combined = torch.cat([half, half], dim=0)
243
+ model_out = self.forward(combined, t, c)
244
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
245
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
246
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
247
+ eps = torch.cat([half_eps, half_eps], dim=0)
248
+ return torch.cat([eps, rest], dim=1)
diffusion.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from DiT, which is modified from OpenAI's diffusion repos
2
+ # DiT: https://github.com/facebookresearch/DiT/diffusion
3
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6
+
7
+ import gaussian_diffusion as gd
8
+ from respace import SpacedDiffusion, space_timesteps
9
+
10
+
11
+ def create_diffusion(
12
+ timestep_respacing,
13
+ noise_schedule="linear",
14
+ use_kl=False,
15
+ sigma_small=False,
16
+ predict_xstart=False,
17
+ learn_sigma=True,
18
+ rescale_learned_sigmas=False,
19
+ diffusion_steps=1000
20
+ ):
21
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
+ if use_kl:
23
+ loss_type = gd.LossType.RESCALED_KL
24
+ elif rescale_learned_sigmas:
25
+ loss_type = gd.LossType.RESCALED_MSE
26
+ else:
27
+ loss_type = gd.LossType.MSE
28
+ if timestep_respacing is None or timestep_respacing == "":
29
+ timestep_respacing = [diffusion_steps]
30
+ return SpacedDiffusion(
31
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
+ betas=betas,
33
+ model_mean_type=(
34
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35
+ ),
36
+ model_var_type=(
37
+ (
38
+ gd.ModelVarType.FIXED_LARGE
39
+ if not sigma_small
40
+ else gd.ModelVarType.FIXED_SMALL
41
+ )
42
+ if not learn_sigma
43
+ else gd.ModelVarType.LEARNED_RANGE
44
+ ),
45
+ loss_type=loss_type
46
+ # rescale_timesteps=rescale_timesteps,
47
+ )
diffusion_utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
50
+ given image.
51
+ :param x: the target images. It is assumed that this was uint8 values,
52
+ rescaled to the range [-1, 1].
53
+ :param means: the Gaussian mean Tensor.
54
+ :param log_scales: the Gaussian log stddev Tensor.
55
+ :return: a tensor like x of log probabilities (in nats).
56
+ """
57
+ assert x.shape == means.shape == log_scales.shape
58
+ centered_x = x - means
59
+ inv_stdv = th.exp(-log_scales)
60
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
61
+ cdf_plus = approx_standard_normal_cdf(plus_in)
62
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
63
+ cdf_min = approx_standard_normal_cdf(min_in)
64
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
65
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
66
+ cdf_delta = cdf_plus - cdf_min
67
+ log_probs = th.where(
68
+ x < -0.999,
69
+ log_cdf_plus,
70
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
71
+ )
72
+ assert log_probs.shape == x.shape
73
+ return log_probs
gaussian_diffusion.py ADDED
@@ -0,0 +1,877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "cosine":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=True,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ temperature=1.0
386
+ ):
387
+ """
388
+ Sample x_{t-1} from the model at the given timestep.
389
+ :param model: the model to sample from.
390
+ :param x: the current tensor at x_{t-1}.
391
+ :param t: the value of t, starting at 0 for the first diffusion step.
392
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393
+ :param denoised_fn: if not None, a function which applies to the
394
+ x_start prediction before it is used to sample.
395
+ :param cond_fn: if not None, this is a gradient function that acts
396
+ similarly to the model.
397
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
398
+ pass to the model. This can be used for conditioning.
399
+ :param temperature: temperature scaling during Diff Loss sampling.
400
+ :return: a dict containing the following keys:
401
+ - 'sample': a random sample from the model.
402
+ - 'pred_xstart': a prediction of x_0.
403
+ """
404
+ out = self.p_mean_variance(
405
+ model,
406
+ x,
407
+ t,
408
+ clip_denoised=clip_denoised,
409
+ denoised_fn=denoised_fn,
410
+ model_kwargs=model_kwargs,
411
+ )
412
+ noise = th.randn_like(x)
413
+ nonzero_mask = (
414
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
415
+ ) # no noise when t == 0
416
+ if cond_fn is not None:
417
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
418
+ # scale the noise by temperature
419
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
420
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
421
+
422
+ def p_sample_loop(
423
+ self,
424
+ model,
425
+ shape,
426
+ noise=None,
427
+ clip_denoised=True,
428
+ denoised_fn=None,
429
+ cond_fn=None,
430
+ model_kwargs=None,
431
+ device=None,
432
+ progress=False,
433
+ temperature=1.0,
434
+ ):
435
+ """
436
+ Generate samples from the model.
437
+ :param model: the model module.
438
+ :param shape: the shape of the samples, (N, C, H, W).
439
+ :param noise: if specified, the noise from the encoder to sample.
440
+ Should be of the same shape as `shape`.
441
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
442
+ :param denoised_fn: if not None, a function which applies to the
443
+ x_start prediction before it is used to sample.
444
+ :param cond_fn: if not None, this is a gradient function that acts
445
+ similarly to the model.
446
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
447
+ pass to the model. This can be used for conditioning.
448
+ :param device: if specified, the device to create the samples on.
449
+ If not specified, use a model parameter's device.
450
+ :param progress: if True, show a tqdm progress bar.
451
+ :param temperature: temperature scaling during Diff Loss sampling.
452
+ :return: a non-differentiable batch of samples.
453
+ """
454
+ final = None
455
+ for sample in self.p_sample_loop_progressive(
456
+ model,
457
+ shape,
458
+ noise=noise,
459
+ clip_denoised=clip_denoised,
460
+ denoised_fn=denoised_fn,
461
+ cond_fn=cond_fn,
462
+ model_kwargs=model_kwargs,
463
+ device=device,
464
+ progress=progress,
465
+ temperature=temperature,
466
+ ):
467
+ final = sample
468
+ return final["sample"]
469
+
470
+ def p_sample_loop_progressive(
471
+ self,
472
+ model,
473
+ shape,
474
+ noise=None,
475
+ clip_denoised=True,
476
+ denoised_fn=None,
477
+ cond_fn=None,
478
+ model_kwargs=None,
479
+ device=None,
480
+ progress=False,
481
+ temperature=1.0,
482
+ ):
483
+ """
484
+ Generate samples from the model and yield intermediate samples from
485
+ each timestep of diffusion.
486
+ Arguments are the same as p_sample_loop().
487
+ Returns a generator over dicts, where each dict is the return value of
488
+ p_sample().
489
+ """
490
+ assert isinstance(shape, (tuple, list))
491
+ if noise is not None:
492
+ img = noise
493
+ else:
494
+ img = th.randn(*shape)
495
+ indices = list(range(self.num_timesteps))[::-1]
496
+
497
+ if progress:
498
+ # Lazy import so that we don't depend on tqdm.
499
+ from tqdm.auto import tqdm
500
+
501
+ indices = tqdm(indices)
502
+
503
+ for i in indices:
504
+ t = th.tensor([i] * shape[0])
505
+ with th.no_grad():
506
+ out = self.p_sample(
507
+ model,
508
+ img,
509
+ t,
510
+ clip_denoised=clip_denoised,
511
+ denoised_fn=denoised_fn,
512
+ cond_fn=cond_fn,
513
+ model_kwargs=model_kwargs,
514
+ temperature=temperature,
515
+ )
516
+ yield out
517
+ img = out["sample"]
518
+
519
+ def ddim_sample(
520
+ self,
521
+ model,
522
+ x,
523
+ t,
524
+ clip_denoised=True,
525
+ denoised_fn=None,
526
+ cond_fn=None,
527
+ model_kwargs=None,
528
+ eta=0.0,
529
+ ):
530
+ """
531
+ Sample x_{t-1} from the model using DDIM.
532
+ Same usage as p_sample().
533
+ """
534
+ out = self.p_mean_variance(
535
+ model,
536
+ x,
537
+ t,
538
+ clip_denoised=clip_denoised,
539
+ denoised_fn=denoised_fn,
540
+ model_kwargs=model_kwargs,
541
+ )
542
+ if cond_fn is not None:
543
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
544
+
545
+ # Usually our model outputs epsilon, but we re-derive it
546
+ # in case we used x_start or x_prev prediction.
547
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
548
+
549
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
550
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
551
+ sigma = (
552
+ eta
553
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
554
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
555
+ )
556
+ # Equation 12.
557
+ noise = th.randn_like(x)
558
+ mean_pred = (
559
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
560
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
561
+ )
562
+ nonzero_mask = (
563
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
564
+ ) # no noise when t == 0
565
+ sample = mean_pred + nonzero_mask * sigma * noise
566
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
567
+
568
+ def ddim_reverse_sample(
569
+ self,
570
+ model,
571
+ x,
572
+ t,
573
+ clip_denoised=True,
574
+ denoised_fn=None,
575
+ cond_fn=None,
576
+ model_kwargs=None,
577
+ eta=0.0,
578
+ ):
579
+ """
580
+ Sample x_{t+1} from the model using DDIM reverse ODE.
581
+ """
582
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
583
+ out = self.p_mean_variance(
584
+ model,
585
+ x,
586
+ t,
587
+ clip_denoised=clip_denoised,
588
+ denoised_fn=denoised_fn,
589
+ model_kwargs=model_kwargs,
590
+ )
591
+ if cond_fn is not None:
592
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
593
+ # Usually our model outputs epsilon, but we re-derive it
594
+ # in case we used x_start or x_prev prediction.
595
+ eps = (
596
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
597
+ - out["pred_xstart"]
598
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
599
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
600
+
601
+ # Equation 12. reversed
602
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
603
+
604
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
605
+
606
+ def ddim_sample_loop(
607
+ self,
608
+ model,
609
+ shape,
610
+ noise=None,
611
+ clip_denoised=True,
612
+ denoised_fn=None,
613
+ cond_fn=None,
614
+ model_kwargs=None,
615
+ device=None,
616
+ progress=False,
617
+ eta=0.0,
618
+ ):
619
+ """
620
+ Generate samples from the model using DDIM.
621
+ Same usage as p_sample_loop().
622
+ """
623
+ final = None
624
+ for sample in self.ddim_sample_loop_progressive(
625
+ model,
626
+ shape,
627
+ noise=noise,
628
+ clip_denoised=clip_denoised,
629
+ denoised_fn=denoised_fn,
630
+ cond_fn=cond_fn,
631
+ model_kwargs=model_kwargs,
632
+ device=device,
633
+ progress=progress,
634
+ eta=eta,
635
+ ):
636
+ final = sample
637
+ return final["sample"]
638
+
639
+ def ddim_sample_loop_progressive(
640
+ self,
641
+ model,
642
+ shape,
643
+ noise=None,
644
+ clip_denoised=True,
645
+ denoised_fn=None,
646
+ cond_fn=None,
647
+ model_kwargs=None,
648
+ device=None,
649
+ progress=False,
650
+ eta=0.0,
651
+ ):
652
+ """
653
+ Use DDIM to sample from the model and yield intermediate samples from
654
+ each timestep of DDIM.
655
+ Same usage as p_sample_loop_progressive().
656
+ """
657
+ assert isinstance(shape, (tuple, list))
658
+ if noise is not None:
659
+ img = noise
660
+ else:
661
+ img = th.randn(*shape)
662
+ indices = list(range(self.num_timesteps))[::-1]
663
+
664
+ if progress:
665
+ # Lazy import so that we don't depend on tqdm.
666
+ from tqdm.auto import tqdm
667
+
668
+ indices = tqdm(indices)
669
+
670
+ for i in indices:
671
+ t = th.tensor([i] * shape[0])
672
+ with th.no_grad():
673
+ out = self.ddim_sample(
674
+ model,
675
+ img,
676
+ t,
677
+ clip_denoised=clip_denoised,
678
+ denoised_fn=denoised_fn,
679
+ cond_fn=cond_fn,
680
+ model_kwargs=model_kwargs,
681
+ eta=eta,
682
+ )
683
+ yield out
684
+ img = out["sample"]
685
+
686
+ def _vb_terms_bpd(
687
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
688
+ ):
689
+ """
690
+ Get a term for the variational lower-bound.
691
+ The resulting units are bits (rather than nats, as one might expect).
692
+ This allows for comparison to other papers.
693
+ :return: a dict with the following keys:
694
+ - 'output': a shape [N] tensor of NLLs or KLs.
695
+ - 'pred_xstart': the x_0 predictions.
696
+ """
697
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
698
+ x_start=x_start, x_t=x_t, t=t
699
+ )
700
+ out = self.p_mean_variance(
701
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
702
+ )
703
+ kl = normal_kl(
704
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
705
+ )
706
+ kl = mean_flat(kl) / np.log(2.0)
707
+
708
+ decoder_nll = -discretized_gaussian_log_likelihood(
709
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
710
+ )
711
+ assert decoder_nll.shape == x_start.shape
712
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
713
+
714
+ # At the first timestep return the decoder NLL,
715
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
716
+ output = th.where((t == 0), decoder_nll, kl)
717
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
718
+
719
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
720
+ """
721
+ Compute training losses for a single timestep.
722
+ :param model: the model to evaluate loss on.
723
+ :param x_start: the [N x C x ...] tensor of inputs.
724
+ :param t: a batch of timestep indices.
725
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
726
+ pass to the model. This can be used for conditioning.
727
+ :param noise: if specified, the specific Gaussian noise to try to remove.
728
+ :return: a dict with the key "loss" containing a tensor of shape [N].
729
+ Some mean or variance settings may also have other keys.
730
+ """
731
+ if model_kwargs is None:
732
+ model_kwargs = {}
733
+ if noise is None:
734
+ noise = th.randn_like(x_start)
735
+ x_t = self.q_sample(x_start, t, noise=noise)
736
+
737
+ terms = {}
738
+
739
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
740
+ terms["loss"] = self._vb_terms_bpd(
741
+ model=model,
742
+ x_start=x_start,
743
+ x_t=x_t,
744
+ t=t,
745
+ clip_denoised=False,
746
+ model_kwargs=model_kwargs,
747
+ )["output"]
748
+ if self.loss_type == LossType.RESCALED_KL:
749
+ terms["loss"] *= self.num_timesteps
750
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
751
+ model_output = model(x_t, t, **model_kwargs)
752
+
753
+ if self.model_var_type in [
754
+ ModelVarType.LEARNED,
755
+ ModelVarType.LEARNED_RANGE,
756
+ ]:
757
+ B, C = x_t.shape[:2]
758
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
759
+ model_output, model_var_values = th.split(model_output, C, dim=1)
760
+ # Learn the variance using the variational bound, but don't let
761
+ # it affect our mean prediction.
762
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
763
+ terms["vb"] = self._vb_terms_bpd(
764
+ model=lambda *args, r=frozen_out: r,
765
+ x_start=x_start,
766
+ x_t=x_t,
767
+ t=t,
768
+ clip_denoised=False,
769
+ )["output"]
770
+ if self.loss_type == LossType.RESCALED_MSE:
771
+ # Divide by 1000 for equivalence with initial implementation.
772
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
773
+ terms["vb"] *= self.num_timesteps / 1000.0
774
+
775
+ target = {
776
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
777
+ x_start=x_start, x_t=x_t, t=t
778
+ )[0],
779
+ ModelMeanType.START_X: x_start,
780
+ ModelMeanType.EPSILON: noise,
781
+ }[self.model_mean_type]
782
+ assert model_output.shape == target.shape == x_start.shape
783
+ terms["mse"] = mean_flat((target - model_output) ** 2)
784
+ if "vb" in terms:
785
+ terms["loss"] = terms["mse"] + terms["vb"]
786
+ else:
787
+ terms["loss"] = terms["mse"]
788
+ else:
789
+ raise NotImplementedError(self.loss_type)
790
+
791
+ return terms
792
+
793
+ def _prior_bpd(self, x_start):
794
+ """
795
+ Get the prior KL term for the variational lower-bound, measured in
796
+ bits-per-dim.
797
+ This term can't be optimized, as it only depends on the encoder.
798
+ :param x_start: the [N x C x ...] tensor of inputs.
799
+ :return: a batch of [N] KL values (in bits), one per batch element.
800
+ """
801
+ batch_size = x_start.shape[0]
802
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
803
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
804
+ kl_prior = normal_kl(
805
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
806
+ )
807
+ return mean_flat(kl_prior) / np.log(2.0)
808
+
809
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
810
+ """
811
+ Compute the entire variational lower-bound, measured in bits-per-dim,
812
+ as well as other related quantities.
813
+ :param model: the model to evaluate loss on.
814
+ :param x_start: the [N x C x ...] tensor of inputs.
815
+ :param clip_denoised: if True, clip denoised samples.
816
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
817
+ pass to the model. This can be used for conditioning.
818
+ :return: a dict containing the following keys:
819
+ - total_bpd: the total variational lower-bound, per batch element.
820
+ - prior_bpd: the prior term in the lower-bound.
821
+ - vb: an [N x T] tensor of terms in the lower-bound.
822
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
823
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
824
+ """
825
+ device = x_start.device
826
+ batch_size = x_start.shape[0]
827
+
828
+ vb = []
829
+ xstart_mse = []
830
+ mse = []
831
+ for t in list(range(self.num_timesteps))[::-1]:
832
+ t_batch = th.tensor([t] * batch_size, device=device)
833
+ noise = th.randn_like(x_start)
834
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
835
+ # Calculate VLB term at the current timestep
836
+ with th.no_grad():
837
+ out = self._vb_terms_bpd(
838
+ model,
839
+ x_start=x_start,
840
+ x_t=x_t,
841
+ t=t_batch,
842
+ clip_denoised=clip_denoised,
843
+ model_kwargs=model_kwargs,
844
+ )
845
+ vb.append(out["output"])
846
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
847
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
848
+ mse.append(mean_flat((eps - noise) ** 2))
849
+
850
+ vb = th.stack(vb, dim=1)
851
+ xstart_mse = th.stack(xstart_mse, dim=1)
852
+ mse = th.stack(mse, dim=1)
853
+
854
+ prior_bpd = self._prior_bpd(x_start)
855
+ total_bpd = vb.sum(dim=1) + prior_bpd
856
+ return {
857
+ "total_bpd": total_bpd,
858
+ "prior_bpd": prior_bpd,
859
+ "vb": vb,
860
+ "xstart_mse": xstart_mse,
861
+ "mse": mse,
862
+ }
863
+
864
+
865
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
866
+ """
867
+ Extract values from a 1-D numpy array for a batch of indices.
868
+ :param arr: the 1-D numpy array.
869
+ :param timesteps: a tensor of indices into the array to extract.
870
+ :param broadcast_shape: a larger shape of K dimensions with the batch
871
+ dimension equal to the length of timesteps.
872
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
873
+ """
874
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
875
+ while len(res.shape) < len(broadcast_shape):
876
+ res = res[..., None]
877
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
kl16.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34ce001bcfffb7af67ec8af1e683a30d7bd45760855ddc7deedc1330f2cfd38f
3
+ size 265900046
mar.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import scipy.stats as stats
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.checkpoint import checkpoint
10
+
11
+ from timm.models.vision_transformer import Block
12
+
13
+ from diffloss import DiffLoss
14
+
15
+
16
+ def mask_by_order(mask_len, order, bsz, seq_len):
17
+ masking = torch.zeros(bsz, seq_len)
18
+ masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len)).bool()
19
+ return masking
20
+
21
+
22
+ class MAR(nn.Module):
23
+ """ Masked Autoencoder with VisionTransformer backbone
24
+ """
25
+ def __init__(self, img_size=256, vae_stride=16, patch_size=1,
26
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
27
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
28
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
29
+ vae_embed_dim=16,
30
+ mask_ratio_min=0.7,
31
+ label_drop_prob=0.1,
32
+ class_num=1000,
33
+ attn_dropout=0.1,
34
+ proj_dropout=0.1,
35
+ buffer_size=64,
36
+ diffloss_d=3,
37
+ diffloss_w=1024,
38
+ num_sampling_steps='100',
39
+ diffusion_batch_mul=4,
40
+ grad_checkpointing=False,
41
+ ):
42
+ super().__init__()
43
+
44
+ # --------------------------------------------------------------------------
45
+ # VAE and patchify specifics
46
+ self.vae_embed_dim = vae_embed_dim
47
+
48
+ self.img_size = img_size
49
+ self.vae_stride = vae_stride
50
+ self.patch_size = patch_size
51
+ self.seq_h = self.seq_w = img_size // vae_stride // patch_size
52
+ self.seq_len = self.seq_h * self.seq_w
53
+ self.token_embed_dim = vae_embed_dim * patch_size**2
54
+ self.grad_checkpointing = grad_checkpointing
55
+
56
+ # --------------------------------------------------------------------------
57
+ # Class Embedding
58
+ self.num_classes = class_num
59
+ self.class_emb = nn.Embedding(1000, encoder_embed_dim)
60
+ self.label_drop_prob = label_drop_prob
61
+ # Fake class embedding for CFG's unconditional generation
62
+ self.fake_latent = nn.Parameter(torch.zeros(1, encoder_embed_dim))
63
+
64
+ # --------------------------------------------------------------------------
65
+ # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
66
+ self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
67
+
68
+ # --------------------------------------------------------------------------
69
+ # MAR encoder specifics
70
+ self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True)
71
+ self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6)
72
+ self.buffer_size = buffer_size
73
+ self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim))
74
+
75
+ self.encoder_blocks = nn.ModuleList([
76
+ Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
77
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)])
78
+ self.encoder_norm = norm_layer(encoder_embed_dim)
79
+
80
+ # --------------------------------------------------------------------------
81
+ # MAR decoder specifics
82
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
83
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
84
+ self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim))
85
+
86
+ self.decoder_blocks = nn.ModuleList([
87
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
88
+ proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)])
89
+
90
+ self.decoder_norm = norm_layer(decoder_embed_dim)
91
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim))
92
+
93
+ self.initialize_weights()
94
+
95
+ # --------------------------------------------------------------------------
96
+ # Diffusion Loss
97
+ self.diffloss = DiffLoss(
98
+ target_channels=self.token_embed_dim,
99
+ z_channels=decoder_embed_dim,
100
+ width=diffloss_w,
101
+ depth=diffloss_d,
102
+ num_sampling_steps=num_sampling_steps,
103
+ grad_checkpointing=grad_checkpointing
104
+ )
105
+ self.diffusion_batch_mul = diffusion_batch_mul
106
+
107
+ def initialize_weights(self):
108
+ # parameters
109
+ torch.nn.init.normal_(self.class_emb.weight, std=.02)
110
+ torch.nn.init.normal_(self.fake_latent, std=.02)
111
+ torch.nn.init.normal_(self.mask_token, std=.02)
112
+ torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02)
113
+ torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02)
114
+ torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02)
115
+
116
+ # initialize nn.Linear and nn.LayerNorm
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ # we use xavier_uniform following official JAX ViT:
122
+ torch.nn.init.xavier_uniform_(m.weight)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ if m.bias is not None:
127
+ nn.init.constant_(m.bias, 0)
128
+ if m.weight is not None:
129
+ nn.init.constant_(m.weight, 1.0)
130
+
131
+ def patchify(self, x):
132
+ bsz, c, h, w = x.shape
133
+ p = self.patch_size
134
+ h_, w_ = h // p, w // p
135
+
136
+ x = x.reshape(bsz, c, h_, p, w_, p)
137
+ x = torch.einsum('nchpwq->nhwcpq', x)
138
+ x = x.reshape(bsz, h_ * w_, c * p ** 2)
139
+ return x # [n, l, d]
140
+
141
+ def unpatchify(self, x):
142
+ bsz = x.shape[0]
143
+ p = self.patch_size
144
+ c = self.vae_embed_dim
145
+ h_, w_ = self.seq_h, self.seq_w
146
+
147
+ x = x.reshape(bsz, h_, w_, c, p, p)
148
+ x = torch.einsum('nhwcpq->nchpwq', x)
149
+ x = x.reshape(bsz, c, h_ * p, w_ * p)
150
+ return x # [n, c, h, w]
151
+
152
+ def sample_orders(self, bsz):
153
+ # generate a batch of random generation orders
154
+ orders = []
155
+ for _ in range(bsz):
156
+ order = np.array(list(range(self.seq_len)))
157
+ np.random.shuffle(order)
158
+ orders.append(order)
159
+ orders = torch.Tensor(np.array(orders)).long()
160
+ return orders
161
+
162
+ def random_masking(self, x, orders):
163
+ # generate token mask
164
+ bsz, seq_len, embed_dim = x.shape
165
+ mask_rate = self.mask_ratio_generator.rvs(1)[0]
166
+ num_masked_tokens = int(np.ceil(seq_len * mask_rate))
167
+ mask = torch.zeros(bsz, seq_len, device=x.device)
168
+ mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens],
169
+ src=torch.ones(bsz, seq_len, device=x.device))
170
+ return mask
171
+
172
+ def forward_mae_encoder(self, x, mask, class_embedding):
173
+ x = self.z_proj(x)
174
+ bsz, seq_len, embed_dim = x.shape
175
+
176
+ # concat buffer
177
+ x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
178
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
179
+
180
+ # random drop class embedding during training
181
+ if self.training:
182
+ drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
183
+ drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(x.dtype)
184
+ class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
185
+
186
+ x[:, :self.buffer_size] = class_embedding.unsqueeze(1)
187
+
188
+ # encoder position embedding
189
+ x = x + self.encoder_pos_embed_learned
190
+ x = self.z_proj_ln(x)
191
+
192
+ # dropping
193
+ x = x[(1-mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
194
+
195
+ # apply Transformer blocks
196
+ if self.grad_checkpointing and not torch.jit.is_scripting():
197
+ for block in self.encoder_blocks:
198
+ x = checkpoint(block, x)
199
+ else:
200
+ for block in self.encoder_blocks:
201
+ x = block(x)
202
+ x = self.encoder_norm(x)
203
+
204
+ return x
205
+
206
+ def forward_mae_decoder(self, x, mask):
207
+
208
+ x = self.decoder_embed(x)
209
+ mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
210
+
211
+ # pad mask tokens
212
+ mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype)
213
+ x_after_pad = mask_tokens.clone()
214
+ x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2])
215
+
216
+ # decoder position embedding
217
+ x = x_after_pad + self.decoder_pos_embed_learned
218
+
219
+ # apply Transformer blocks
220
+ if self.grad_checkpointing and not torch.jit.is_scripting():
221
+ for block in self.decoder_blocks:
222
+ x = checkpoint(block, x)
223
+ else:
224
+ for block in self.decoder_blocks:
225
+ x = block(x)
226
+ x = self.decoder_norm(x)
227
+
228
+ x = x[:, self.buffer_size:]
229
+ x = x + self.diffusion_pos_embed_learned
230
+ return x
231
+
232
+ def forward_loss(self, z, target, mask):
233
+ bsz, seq_len, _ = target.shape
234
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
235
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
236
+ mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
237
+ loss = self.diffloss(z=z, target=target, mask=mask)
238
+ return loss
239
+
240
+ def forward(self, imgs, labels):
241
+
242
+ # class embed
243
+ class_embedding = self.class_emb(labels)
244
+
245
+ # patchify and mask (drop) tokens
246
+ x = self.patchify(imgs)
247
+ gt_latents = x.clone().detach()
248
+ orders = self.sample_orders(bsz=x.size(0))
249
+ mask = self.random_masking(x, orders)
250
+
251
+ # mae encoder
252
+ x = self.forward_mae_encoder(x, mask, class_embedding)
253
+
254
+ # mae decoder
255
+ z = self.forward_mae_decoder(x, mask)
256
+
257
+ # diffloss
258
+ loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
259
+
260
+ return loss
261
+
262
+ def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=False):
263
+
264
+ # init and sample generation orders
265
+ mask = torch.ones(bsz, self.seq_len)
266
+ tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim)
267
+ orders = self.sample_orders(bsz)
268
+
269
+ indices = list(range(num_iter))
270
+ if progress:
271
+ indices = tqdm(indices)
272
+ # generate latents
273
+ for step in indices:
274
+ cur_tokens = tokens.clone()
275
+
276
+ # class embedding and CFG
277
+ if labels is not None:
278
+ class_embedding = self.class_emb(labels)
279
+ else:
280
+ class_embedding = self.fake_latent.repeat(bsz, 1)
281
+ if not cfg == 1.0:
282
+ tokens = torch.cat([tokens, tokens], dim=0)
283
+ class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
284
+ mask = torch.cat([mask, mask], dim=0)
285
+
286
+ # mae encoder
287
+ x = self.forward_mae_encoder(tokens, mask, class_embedding)
288
+
289
+ # mae decoder
290
+ z = self.forward_mae_decoder(x, mask)
291
+
292
+ # mask ratio for the next round, following MaskGIT and MAGE.
293
+ mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter)
294
+ mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)])
295
+
296
+ # masks out at least one for the next iteration
297
+ mask_len = torch.maximum(torch.Tensor([1]),
298
+ torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len))
299
+
300
+ # get masking for next iteration and locations to be predicted in this iteration
301
+ mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
302
+ if step >= num_iter - 1:
303
+ mask_to_pred = mask[:bsz].bool()
304
+ else:
305
+ mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
306
+ mask = mask_next
307
+ if not cfg == 1.0:
308
+ mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
309
+
310
+ # sample token latents for this step
311
+ z = z[mask_to_pred.nonzero(as_tuple=True)]
312
+ # cfg schedule follow Muse
313
+ if cfg_schedule == "linear":
314
+ cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
315
+ elif cfg_schedule == "constant":
316
+ cfg_iter = cfg
317
+ else:
318
+ raise NotImplementedError
319
+ sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
320
+ if not cfg == 1.0:
321
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
322
+ mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
323
+
324
+ cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
325
+ tokens = cur_tokens.clone()
326
+
327
+ # unpatchify
328
+ tokens = self.unpatchify(tokens)
329
+ return tokens
330
+
331
+
332
+ def mar_base(**kwargs):
333
+ model = MAR(
334
+ encoder_embed_dim=768, encoder_depth=12, encoder_num_heads=12,
335
+ decoder_embed_dim=768, decoder_depth=12, decoder_num_heads=12,
336
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
337
+ return model
338
+
339
+
340
+ def mar_large(**kwargs):
341
+ model = MAR(
342
+ encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16,
343
+ decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16,
344
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
345
+ return model
346
+
347
+
348
+ def mar_huge(**kwargs):
349
+ model = MAR(
350
+ encoder_embed_dim=1280, encoder_depth=20, encoder_num_heads=16,
351
+ decoder_embed_dim=1280, decoder_depth=20, decoder_num_heads=16,
352
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
353
+ return model
respace.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)