ZYMPKU commited on
Commit
ed25868
·
1 Parent(s): 0b659a7
app.py CHANGED
@@ -8,10 +8,56 @@ from omegaconf import OmegaConf
8
  from contextlib import nullcontext
9
  from pytorch_lightning import seed_everything
10
  from os.path import join as ospj
 
 
 
11
 
12
  from util import *
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def predict(cfgs, model, sampler, batch):
16
 
17
  context = nullcontext if cfgs.aae_enabled else torch.no_grad
@@ -58,15 +104,8 @@ def demo_predict(input_blk, text, num_samples, steps, scale, seed, show_detail):
58
 
59
  image = input_blk["image"]
60
  mask = input_blk["mask"]
61
- image = cv2.resize(image, (cfgs.W, cfgs.H))
62
- mask = cv2.resize(mask, (cfgs.W, cfgs.H))
63
-
64
- mask = (mask == 0).astype(np.int32)
65
 
66
- image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0
67
- mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32).mean(dim=0, keepdim=True)
68
- masked = image * mask
69
- mask = 1 - mask
70
 
71
  seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text))))
72
 
@@ -131,6 +170,7 @@ if __name__ == "__main__":
131
  model = init_model(cfgs)
132
  sampler = init_sampling(cfgs)
133
  global_index = 0
 
134
 
135
  block = gr.Blocks().queue()
136
  with block:
@@ -161,6 +201,7 @@ if __name__ == "__main__":
161
  with gr.Column():
162
 
163
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
 
164
  text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
165
  run_button = gr.Button(variant="primary")
166
 
 
8
  from contextlib import nullcontext
9
  from pytorch_lightning import seed_everything
10
  from os.path import join as ospj
11
+ from random import randint
12
+ from torchvision.utils import save_image
13
+ from torchvision.transforms import Resize
14
 
15
  from util import *
16
 
17
 
18
+ def process(image, mask):
19
+
20
+ img_h, img_w = image.shape[:2]
21
+
22
+ mask = mask[...,:1]//255
23
+ contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
24
+ if len(contours) != 1: raise gr.Error("One masked area only!")
25
+
26
+ m_x, m_y, m_w, m_h = cv2.boundingRect(contours[0])
27
+ c_x, c_y = m_x + m_w//2, m_y + m_h//2
28
+
29
+ if img_w > img_h:
30
+ if m_w > img_h: raise gr.Error("Illegal mask area!")
31
+ if c_x < img_w - c_x:
32
+ c_l = max(0, c_x - img_h//2)
33
+ c_r = c_l + img_h
34
+ else:
35
+ c_r = min(img_w, c_x + img_h//2)
36
+ c_l = c_r - img_h
37
+ image = image[:,c_l:c_r,:]
38
+ mask = mask[:,c_l:c_r,:]
39
+ else:
40
+ if m_h > img_w: raise gr.Error("Illegal mask area!")
41
+ if c_y < img_h - c_y:
42
+ c_t = max(0, c_y - img_w//2)
43
+ c_b = c_t + img_w
44
+ else:
45
+ c_b = min(img_h, c_y + img_w//2)
46
+ c_t = c_b - img_w
47
+ image = image[c_t:c_b,:,:]
48
+ mask = mask[c_t:c_b,:,:]
49
+
50
+ image = torch.from_numpy(image.transpose(2,0,1)).to(dtype=torch.float32) / 127.5 - 1.0
51
+ mask = torch.from_numpy(mask.transpose(2,0,1)).to(dtype=torch.float32)
52
+
53
+ image = resize(image[None])[0]
54
+ mask = resize(mask[None])[0]
55
+ masked = image * (1 - mask)
56
+
57
+ return image, mask, masked
58
+
59
+
60
+
61
  def predict(cfgs, model, sampler, batch):
62
 
63
  context = nullcontext if cfgs.aae_enabled else torch.no_grad
 
104
 
105
  image = input_blk["image"]
106
  mask = input_blk["mask"]
 
 
 
 
107
 
108
+ image, mask, masked = process(image, mask)
 
 
 
109
 
110
  seg_mask = torch.cat((torch.ones(len(text)), torch.zeros(cfgs.seq_len-len(text))))
111
 
 
170
  model = init_model(cfgs)
171
  sampler = init_sampling(cfgs)
172
  global_index = 0
173
+ resize = Resize((cfgs.H, cfgs.W))
174
 
175
  block = gr.Blocks().queue()
176
  with block:
 
201
  with gr.Column():
202
 
203
  input_blk = gr.Image(source='upload', tool='sketch', type="numpy", label="Input", height=512)
204
+ gr.Markdown("Notice: please draw horizontally to indicate only **one** masked area.")
205
  text = gr.Textbox(label="Text to render: (1~12 characters)", info="the text you want to render at the masked region")
206
  run_button = gr.Button(variant="primary")
207
 
checkpoints/{st-step=100000+la-step=100000-v2.ckpt → st-step=100000+la-step=100000-v1.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b87a307ed6e240208b415166e88c0f3e6467ec9330836d70c6d662f423bfbc15
3
- size 4173692086
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edea71eb83b6be72c33ef787a7122a810a7b9257bf97a276ef322707d5769878
3
+ size 6148465904
configs/demo.yaml CHANGED
@@ -1,7 +1,7 @@
1
  type: "demo"
2
 
3
  # path
4
- load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-v2.ckpt"
5
  model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
 
7
  # param
@@ -15,7 +15,7 @@ channel: 4 # AE latent channel
15
  factor: 8 # AE downsample factor
16
  scale: [4.0, 0.0] # content scale, style scale
17
  noise_iters: 10
18
- force_uc_zero_embeddings: ["ref", "label"]
19
  aae_enabled: False
20
  detailed: False
21
 
 
1
  type: "demo"
2
 
3
  # path
4
+ load_ckpt_path: "./checkpoints/st-step=100000+la-step=100000-v1.ckpt"
5
  model_cfg_path: "./configs/test/textdesign_sd_2.yaml"
6
 
7
  # param
 
15
  factor: 8 # AE downsample factor
16
  scale: [4.0, 0.0] # content scale, style scale
17
  noise_iters: 10
18
+ force_uc_zero_embeddings: ["label"]
19
  aae_enabled: False
20
  detailed: False
21
 
configs/test/textdesign_sd_2.yaml CHANGED
@@ -1,8 +1,6 @@
1
  model:
2
  target: sgm.models.diffusion.DiffusionEngine
3
  params:
4
- opt_keys:
5
- - t_attn
6
  input_key: image
7
  scale_factor: 0.18215
8
  disable_first_stage_autocast: True
@@ -20,45 +18,54 @@ model:
20
  target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
21
 
22
  network_config:
23
- target: sgm.modules.diffusionmodules.openaimodel.UnifiedUNetModel
24
  params:
 
25
  in_channels: 9
26
  out_channels: 4
27
  ctrl_channels: 0
28
  model_channels: 320
29
  attention_resolutions: [4, 2, 1]
30
- save_attn_type: [t_attn]
31
- save_attn_layers: [output_blocks.6.1]
 
32
  num_res_blocks: 2
33
  channel_mult: [1, 2, 4, 4]
34
  num_head_channels: 64
 
35
  use_linear_in_transformer: True
36
  transformer_depth: 1
37
- t_context_dim: 2048
 
 
38
 
39
  conditioner_config:
40
  target: sgm.modules.GeneralConditioner
41
  params:
42
  emb_models:
43
- # textual crossattn cond
 
 
 
 
 
 
 
 
44
  - is_trainable: False
45
- emb_key: t_crossattn
46
- ucg_rate: 0.1
47
  input_key: label
48
  target: sgm.modules.encoders.modules.LabelEncoder
49
  params:
 
50
  max_len: 12
51
  emb_dim: 2048
52
  n_heads: 8
53
  n_trans_layers: 12
54
- ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
55
  # concat cond
56
  - is_trainable: False
57
  input_key: mask
58
- target: sgm.modules.encoders.modules.SpatialRescaler
59
- params:
60
- in_channels: 1
61
- multiplier: 0.125
62
  - is_trainable: False
63
  input_key: masked
64
  target: sgm.modules.encoders.modules.LatentEncoder
@@ -88,7 +95,6 @@ model:
88
  first_stage_config:
89
  target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
90
  params:
91
- ckpt_path: ./checkpoints/AEs/AE_inpainting_2.safetensors
92
  embed_dim: 4
93
  monitor: val/rec_loss
94
  ddconfig:
@@ -111,11 +117,16 @@ model:
111
  params:
112
  seq_len: 12
113
  kernel_size: 3
114
- gaussian_sigma: 1.0
115
  min_attn_size: 16
116
- lambda_local_loss: 0.01
117
  lambda_ocr_loss: 0.001
118
  ocr_enabled: False
 
 
 
 
 
119
 
120
  sigma_sampler_config:
121
  target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
 
1
  model:
2
  target: sgm.models.diffusion.DiffusionEngine
3
  params:
 
 
4
  input_key: image
5
  scale_factor: 0.18215
6
  disable_first_stage_autocast: True
 
18
  target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
19
 
20
  network_config:
21
+ target: sgm.modules.diffusionmodules.openaimodel.UNetAddModel
22
  params:
23
+ use_checkpoint: False
24
  in_channels: 9
25
  out_channels: 4
26
  ctrl_channels: 0
27
  model_channels: 320
28
  attention_resolutions: [4, 2, 1]
29
+ attn_type: add_attn
30
+ attn_layers:
31
+ - output_blocks.6.1
32
  num_res_blocks: 2
33
  channel_mult: [1, 2, 4, 4]
34
  num_head_channels: 64
35
+ use_spatial_transformer: True
36
  use_linear_in_transformer: True
37
  transformer_depth: 1
38
+ context_dim: 0
39
+ add_context_dim: 2048
40
+ legacy: False
41
 
42
  conditioner_config:
43
  target: sgm.modules.GeneralConditioner
44
  params:
45
  emb_models:
46
+ # crossattn cond
47
+ # - is_trainable: False
48
+ # input_key: txt
49
+ # target: sgm.modules.encoders.modules.FrozenOpenCLIPEmbedder
50
+ # params:
51
+ # arch: ViT-H-14
52
+ # version: ./checkpoints/encoders/OpenCLIP/ViT-H-14/open_clip_pytorch_model.bin
53
+ # layer: penultimate
54
+ # add crossattn cond
55
  - is_trainable: False
 
 
56
  input_key: label
57
  target: sgm.modules.encoders.modules.LabelEncoder
58
  params:
59
+ is_add_embedder: True
60
  max_len: 12
61
  emb_dim: 2048
62
  n_heads: 8
63
  n_trans_layers: 12
64
+ ckpt_path: ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt # ./checkpoints/encoders/LabelEncoder/epoch=19-step=7820.ckpt
65
  # concat cond
66
  - is_trainable: False
67
  input_key: mask
68
+ target: sgm.modules.encoders.modules.IdentityEncoder
 
 
 
69
  - is_trainable: False
70
  input_key: masked
71
  target: sgm.modules.encoders.modules.LatentEncoder
 
95
  first_stage_config:
96
  target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
97
  params:
 
98
  embed_dim: 4
99
  monitor: val/rec_loss
100
  ddconfig:
 
117
  params:
118
  seq_len: 12
119
  kernel_size: 3
120
+ gaussian_sigma: 0.5
121
  min_attn_size: 16
122
+ lambda_local_loss: 0.02
123
  lambda_ocr_loss: 0.001
124
  ocr_enabled: False
125
+
126
+ predictor_config:
127
+ target: sgm.modules.predictors.model.ParseqPredictor
128
+ params:
129
+ ckpt_path: "./checkpoints/predictors/parseq-bb5792a6.pt"
130
 
131
  sigma_sampler_config:
132
  target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
sgm/modules/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .encoders.modules import GeneralConditioner
2
 
3
  UNCONDITIONAL_CONFIG = {
4
  "target": "sgm.modules.GeneralConditioner",
 
1
+ from .encoders.modules import GeneralConditioner, DualConditioner
2
 
3
  UNCONDITIONAL_CONFIG = {
4
  "target": "sgm.modules.GeneralConditioner",
sgm/modules/attention.py CHANGED
@@ -5,15 +5,53 @@ from typing import Any, Optional
5
  import torch
6
  import torch.nn.functional as F
7
  from einops import rearrange, repeat
 
8
  from torch import nn, einsum
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
  import xformers
12
  import xformers.ops
 
13
  XFORMERS_IS_AVAILABLE = True
14
  except:
15
  XFORMERS_IS_AVAILABLE = False
16
- print("No module 'xformers'.")
 
 
17
 
18
 
19
  def exists(val):
@@ -108,6 +146,51 @@ class LinearAttention(nn.Module):
108
  return self.to_out(out)
109
 
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  class CrossAttention(nn.Module):
112
  def __init__(
113
  self,
@@ -115,7 +198,8 @@ class CrossAttention(nn.Module):
115
  context_dim=None,
116
  heads=8,
117
  dim_head=64,
118
- dropout=0.0
 
119
  ):
120
  super().__init__()
121
  inner_dim = dim_head * heads
@@ -128,38 +212,60 @@ class CrossAttention(nn.Module):
128
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
129
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
130
 
131
- self.to_out = zero_module(
132
- nn.Sequential(
133
- nn.Linear(inner_dim, query_dim),
134
- nn.Dropout(dropout)
135
- )
136
- )
137
 
138
  self.attn_map_cache = None
139
 
140
  def forward(
141
  self,
142
  x,
143
- context=None
 
 
 
144
  ):
145
  h = self.heads
146
 
 
 
 
 
 
 
147
  q = self.to_q(x)
148
  context = default(context, x)
149
  k = self.to_k(context)
150
  v = self.to_v(context)
151
 
 
 
 
 
 
 
 
 
 
 
 
152
  q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
153
 
154
  ## old
 
155
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
156
  del q, k
157
 
 
 
 
 
 
 
158
  # attention, what we cannot get enough of
159
- if sim.shape[-1] > 1:
160
- sim = sim.softmax(dim=-1) # softmax on token dim
161
- else:
162
- sim = sim.sigmoid() # sigmoid on pixel dim
163
 
164
  # save attn_map
165
  if self.attn_map_cache is not None:
@@ -170,7 +276,20 @@ class CrossAttention(nn.Module):
170
 
171
  out = einsum('b i j, b j d -> b i d', sim, v)
172
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
173
-
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  return self.to_out(out)
175
 
176
 
@@ -263,6 +382,10 @@ class MemoryEfficientCrossAttention(nn.Module):
263
 
264
 
265
  class BasicTransformerBlock(nn.Module):
 
 
 
 
266
 
267
  def __init__(
268
  self,
@@ -270,78 +393,169 @@ class BasicTransformerBlock(nn.Module):
270
  n_heads,
271
  d_head,
272
  dropout=0.0,
273
- t_context_dim=None,
274
- v_context_dim=None,
275
- gated_ff=True
 
 
 
 
276
  ):
277
  super().__init__()
278
-
279
- # self-attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  self.attn1 = MemoryEfficientCrossAttention(
281
  query_dim=dim,
282
  heads=n_heads,
283
  dim_head=d_head,
284
  dropout=dropout,
285
- context_dim=None
286
- )
287
-
288
- # textual cross-attention
289
- if t_context_dim is not None and t_context_dim > 0:
290
- self.t_attn = CrossAttention(
291
  query_dim=dim,
292
- context_dim=t_context_dim,
293
  heads=n_heads,
294
  dim_head=d_head,
295
- dropout=dropout
296
- )
297
- self.t_norm = nn.LayerNorm(dim)
298
-
299
- # visual cross-attention
300
- if v_context_dim is not None and v_context_dim > 0:
301
- self.v_attn = CrossAttention(
302
  query_dim=dim,
303
- context_dim=v_context_dim,
304
  heads=n_heads,
305
  dim_head=d_head,
306
- dropout=dropout
307
- )
308
- self.v_norm = nn.LayerNorm(dim)
309
-
310
  self.norm1 = nn.LayerNorm(dim)
 
311
  self.norm3 = nn.LayerNorm(dim)
312
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- def forward(self, x, t_context=None, v_context=None):
 
 
 
 
 
 
 
 
 
 
 
315
  x = (
316
  self.attn1(
317
  self.norm1(x),
318
- context=None
 
 
 
 
319
  )
320
  + x
321
  )
322
- if hasattr(self, "t_attn"):
323
  x = (
324
- self.t_attn(
325
- self.t_norm(x),
326
- context=t_context
327
  )
328
  + x
329
  )
330
- if hasattr(self, "v_attn"):
331
  x = (
332
- self.v_attn(
333
- self.v_norm(x),
334
- context=v_context
335
  )
336
  + x
337
  )
338
-
339
  x = self.ff(self.norm3(x)) + x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
 
 
 
 
 
 
 
 
341
  return x
342
 
343
 
344
- class SpatialTransformer(nn.Module):
345
  """
346
  Transformer block for image-like data.
347
  First, project the input (aka embedding)
@@ -358,12 +572,36 @@ class SpatialTransformer(nn.Module):
358
  d_head,
359
  depth=1,
360
  dropout=0.0,
361
- t_context_dim=None,
362
- v_context_dim=None,
363
- use_linear=False
 
 
 
 
 
364
  ):
365
  super().__init__()
366
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  self.in_channels = in_channels
368
  inner_dim = n_heads * d_head
369
  self.norm = Normalize(in_channels)
@@ -381,8 +619,12 @@ class SpatialTransformer(nn.Module):
381
  n_heads,
382
  d_head,
383
  dropout=dropout,
384
- t_context_dim=t_context_dim,
385
- v_context_dim=v_context_dim
 
 
 
 
386
  )
387
  for d in range(depth)
388
  ]
@@ -392,11 +634,14 @@ class SpatialTransformer(nn.Module):
392
  nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
393
  )
394
  else:
 
395
  self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
396
  self.use_linear = use_linear
397
 
398
- def forward(self, x, t_context=None, v_context=None):
399
-
 
 
400
  b, c, h, w = x.shape
401
  x_in = x
402
  x = self.norm(x)
@@ -406,11 +651,326 @@ class SpatialTransformer(nn.Module):
406
  if self.use_linear:
407
  x = self.proj_in(x)
408
  for i, block in enumerate(self.transformer_blocks):
409
- x = block(x, t_context=t_context, v_context=v_context)
 
 
410
  if self.use_linear:
411
  x = self.proj_out(x)
412
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
413
  if not self.use_linear:
414
  x = self.proj_out(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- return x + x_in
 
5
  import torch
6
  import torch.nn.functional as F
7
  from einops import rearrange, repeat
8
+ from packaging import version
9
  from torch import nn, einsum
10
 
11
+
12
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
13
+ SDP_IS_AVAILABLE = True
14
+ from torch.backends.cuda import SDPBackend, sdp_kernel
15
+
16
+ BACKEND_MAP = {
17
+ SDPBackend.MATH: {
18
+ "enable_math": True,
19
+ "enable_flash": False,
20
+ "enable_mem_efficient": False,
21
+ },
22
+ SDPBackend.FLASH_ATTENTION: {
23
+ "enable_math": False,
24
+ "enable_flash": True,
25
+ "enable_mem_efficient": False,
26
+ },
27
+ SDPBackend.EFFICIENT_ATTENTION: {
28
+ "enable_math": False,
29
+ "enable_flash": False,
30
+ "enable_mem_efficient": True,
31
+ },
32
+ None: {"enable_math": True, "enable_flash": True, "enable_mem_efficient": True},
33
+ }
34
+ else:
35
+ from contextlib import nullcontext
36
+
37
+ SDP_IS_AVAILABLE = False
38
+ sdp_kernel = nullcontext
39
+ BACKEND_MAP = {}
40
+ print(
41
+ f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
42
+ f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
43
+ )
44
+
45
  try:
46
  import xformers
47
  import xformers.ops
48
+
49
  XFORMERS_IS_AVAILABLE = True
50
  except:
51
  XFORMERS_IS_AVAILABLE = False
52
+ print("no module 'xformers'. Processing without...")
53
+
54
+ from .diffusionmodules.util import checkpoint
55
 
56
 
57
  def exists(val):
 
146
  return self.to_out(out)
147
 
148
 
149
+ class SpatialSelfAttention(nn.Module):
150
+ def __init__(self, in_channels):
151
+ super().__init__()
152
+ self.in_channels = in_channels
153
+
154
+ self.norm = Normalize(in_channels)
155
+ self.q = torch.nn.Conv2d(
156
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
157
+ )
158
+ self.k = torch.nn.Conv2d(
159
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
160
+ )
161
+ self.v = torch.nn.Conv2d(
162
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
163
+ )
164
+ self.proj_out = torch.nn.Conv2d(
165
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
166
+ )
167
+
168
+ def forward(self, x):
169
+ h_ = x
170
+ h_ = self.norm(h_)
171
+ q = self.q(h_)
172
+ k = self.k(h_)
173
+ v = self.v(h_)
174
+
175
+ # compute attention
176
+ b, c, h, w = q.shape
177
+ q = rearrange(q, "b c h w -> b (h w) c")
178
+ k = rearrange(k, "b c h w -> b c (h w)")
179
+ w_ = torch.einsum("bij,bjk->bik", q, k)
180
+
181
+ w_ = w_ * (int(c) ** (-0.5))
182
+ w_ = torch.nn.functional.softmax(w_, dim=2)
183
+
184
+ # attend to values
185
+ v = rearrange(v, "b c h w -> b c (h w)")
186
+ w_ = rearrange(w_, "b i j -> b j i")
187
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
188
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
189
+ h_ = self.proj_out(h_)
190
+
191
+ return x + h_
192
+
193
+
194
  class CrossAttention(nn.Module):
195
  def __init__(
196
  self,
 
198
  context_dim=None,
199
  heads=8,
200
  dim_head=64,
201
+ dropout=0.0,
202
+ backend=None,
203
  ):
204
  super().__init__()
205
  inner_dim = dim_head * heads
 
212
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
213
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
214
 
215
+ self.to_out = zero_module(nn.Sequential(
216
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
217
+ ))
218
+ self.backend = backend
 
 
219
 
220
  self.attn_map_cache = None
221
 
222
  def forward(
223
  self,
224
  x,
225
+ context=None,
226
+ mask=None,
227
+ additional_tokens=None,
228
+ n_times_crossframe_attn_in_self=0,
229
  ):
230
  h = self.heads
231
 
232
+ if additional_tokens is not None:
233
+ # get the number of masked tokens at the beginning of the output sequence
234
+ n_tokens_to_mask = additional_tokens.shape[1]
235
+ # add additional token
236
+ x = torch.cat([additional_tokens, x], dim=1)
237
+
238
  q = self.to_q(x)
239
  context = default(context, x)
240
  k = self.to_k(context)
241
  v = self.to_v(context)
242
 
243
+ if n_times_crossframe_attn_in_self:
244
+ # reprogramming cross-frame attention as in https://arxiv.org/abs/2303.13439
245
+ assert x.shape[0] % n_times_crossframe_attn_in_self == 0
246
+ n_cp = x.shape[0] // n_times_crossframe_attn_in_self
247
+ k = repeat(
248
+ k[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
249
+ )
250
+ v = repeat(
251
+ v[::n_times_crossframe_attn_in_self], "b ... -> (b n) ...", n=n_cp
252
+ )
253
+
254
  q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
255
 
256
  ## old
257
+
258
  sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
259
  del q, k
260
 
261
+ if exists(mask):
262
+ mask = rearrange(mask, 'b ... -> b (...)')
263
+ max_neg_value = -torch.finfo(sim.dtype).max
264
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
265
+ sim.masked_fill_(~mask, max_neg_value)
266
+
267
  # attention, what we cannot get enough of
268
+ sim = sim.softmax(dim=-1)
 
 
 
269
 
270
  # save attn_map
271
  if self.attn_map_cache is not None:
 
276
 
277
  out = einsum('b i j, b j d -> b i d', sim, v)
278
  out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
279
+
280
+ ## new
281
+ # with sdp_kernel(**BACKEND_MAP[self.backend]):
282
+ # # print("dispatching into backend", self.backend, "q/k/v shape: ", q.shape, k.shape, v.shape)
283
+ # out = F.scaled_dot_product_attention(
284
+ # q, k, v, attn_mask=mask
285
+ # ) # scale is dim_head ** -0.5 per default
286
+
287
+ # del q, k, v
288
+ # out = rearrange(out, "b h n d -> b n (h d)", h=h)
289
+
290
+ if additional_tokens is not None:
291
+ # remove additional token
292
+ out = out[:, n_tokens_to_mask:]
293
  return self.to_out(out)
294
 
295
 
 
382
 
383
 
384
  class BasicTransformerBlock(nn.Module):
385
+ ATTENTION_MODES = {
386
+ "softmax": CrossAttention, # vanilla attention
387
+ "softmax-xformers": MemoryEfficientCrossAttention, # ampere
388
+ }
389
 
390
  def __init__(
391
  self,
 
393
  n_heads,
394
  d_head,
395
  dropout=0.0,
396
+ context_dim=None,
397
+ add_context_dim=None,
398
+ gated_ff=True,
399
+ checkpoint=True,
400
+ disable_self_attn=False,
401
+ attn_mode="softmax",
402
+ sdp_backend=None,
403
  ):
404
  super().__init__()
405
+ assert attn_mode in self.ATTENTION_MODES
406
+ if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
407
+ print(
408
+ f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
409
+ f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
410
+ )
411
+ attn_mode = "softmax"
412
+ elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
413
+ print(
414
+ "We do not support vanilla attention anymore, as it is too expensive. Sorry."
415
+ )
416
+ if not XFORMERS_IS_AVAILABLE:
417
+ assert (
418
+ False
419
+ ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
420
+ else:
421
+ print("Falling back to xformers efficient attention.")
422
+ attn_mode = "softmax-xformers"
423
+ attn_cls = self.ATTENTION_MODES[attn_mode]
424
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
425
+ assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
426
+ else:
427
+ assert sdp_backend is None
428
+ self.disable_self_attn = disable_self_attn
429
  self.attn1 = MemoryEfficientCrossAttention(
430
  query_dim=dim,
431
  heads=n_heads,
432
  dim_head=d_head,
433
  dropout=dropout,
434
+ context_dim=context_dim if self.disable_self_attn else None,
435
+ backend=sdp_backend,
436
+ ) # is a self-attention if not self.disable_self_attn
437
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
438
+ if context_dim is not None and context_dim > 0:
439
+ self.attn2 = attn_cls(
440
  query_dim=dim,
441
+ context_dim=context_dim,
442
  heads=n_heads,
443
  dim_head=d_head,
444
+ dropout=dropout,
445
+ backend=sdp_backend,
446
+ ) # is self-attn if context is none
447
+ if add_context_dim is not None and add_context_dim > 0:
448
+ self.add_attn = attn_cls(
 
 
449
  query_dim=dim,
450
+ context_dim=add_context_dim,
451
  heads=n_heads,
452
  dim_head=d_head,
453
+ dropout=dropout,
454
+ backend=sdp_backend,
455
+ ) # is self-attn if context is none
456
+ self.add_norm = nn.LayerNorm(dim)
457
  self.norm1 = nn.LayerNorm(dim)
458
+ self.norm2 = nn.LayerNorm(dim)
459
  self.norm3 = nn.LayerNorm(dim)
460
+ self.checkpoint = checkpoint
461
+
462
+ def forward(
463
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
464
+ ):
465
+ kwargs = {"x": x}
466
+
467
+ if context is not None:
468
+ kwargs.update({"context": context})
469
+
470
+ if additional_tokens is not None:
471
+ kwargs.update({"additional_tokens": additional_tokens})
472
 
473
+ if n_times_crossframe_attn_in_self:
474
+ kwargs.update(
475
+ {"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}
476
+ )
477
+
478
+ return checkpoint(
479
+ self._forward, (x, context, add_context), self.parameters(), self.checkpoint
480
+ )
481
+
482
+ def _forward(
483
+ self, x, context=None, add_context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
484
+ ):
485
  x = (
486
  self.attn1(
487
  self.norm1(x),
488
+ context=context if self.disable_self_attn else None,
489
+ additional_tokens=additional_tokens,
490
+ n_times_crossframe_attn_in_self=n_times_crossframe_attn_in_self
491
+ if not self.disable_self_attn
492
+ else 0,
493
  )
494
  + x
495
  )
496
+ if hasattr(self, "attn2"):
497
  x = (
498
+ self.attn2(
499
+ self.norm2(x), context=context, additional_tokens=additional_tokens
 
500
  )
501
  + x
502
  )
503
+ if hasattr(self, "add_attn"):
504
  x = (
505
+ self.add_attn(
506
+ self.add_norm(x), context=add_context, additional_tokens=additional_tokens
 
507
  )
508
  + x
509
  )
 
510
  x = self.ff(self.norm3(x)) + x
511
+ return x
512
+
513
+
514
+ class BasicTransformerSingleLayerBlock(nn.Module):
515
+ ATTENTION_MODES = {
516
+ "softmax": CrossAttention, # vanilla attention
517
+ "softmax-xformers": MemoryEfficientCrossAttention # on the A100s not quite as fast as the above version
518
+ # (todo might depend on head_dim, check, falls back to semi-optimized kernels for dim!=[16,32,64,128])
519
+ }
520
+
521
+ def __init__(
522
+ self,
523
+ dim,
524
+ n_heads,
525
+ d_head,
526
+ dropout=0.0,
527
+ context_dim=None,
528
+ gated_ff=True,
529
+ checkpoint=True,
530
+ attn_mode="softmax",
531
+ ):
532
+ super().__init__()
533
+ assert attn_mode in self.ATTENTION_MODES
534
+ attn_cls = self.ATTENTION_MODES[attn_mode]
535
+ self.attn1 = attn_cls(
536
+ query_dim=dim,
537
+ heads=n_heads,
538
+ dim_head=d_head,
539
+ dropout=dropout,
540
+ context_dim=context_dim,
541
+ )
542
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
543
+ self.norm1 = nn.LayerNorm(dim)
544
+ self.norm2 = nn.LayerNorm(dim)
545
+ self.checkpoint = checkpoint
546
 
547
+ def forward(self, x, context=None):
548
+ return checkpoint(
549
+ self._forward, (x, context), self.parameters(), self.checkpoint
550
+ )
551
+
552
+ def _forward(self, x, context=None):
553
+ x = self.attn1(self.norm1(x), context=context) + x
554
+ x = self.ff(self.norm2(x)) + x
555
  return x
556
 
557
 
558
+ class SpatialTransformer(nn.Module):
559
  """
560
  Transformer block for image-like data.
561
  First, project the input (aka embedding)
 
572
  d_head,
573
  depth=1,
574
  dropout=0.0,
575
+ context_dim=None,
576
+ add_context_dim=None,
577
+ disable_self_attn=False,
578
+ use_linear=False,
579
+ attn_type="softmax",
580
+ use_checkpoint=True,
581
+ # sdp_backend=SDPBackend.FLASH_ATTENTION
582
+ sdp_backend=None,
583
  ):
584
  super().__init__()
585
+ # print(
586
+ # f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
587
+ # )
588
+ from omegaconf import ListConfig
589
+
590
+ if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
591
+ context_dim = [context_dim]
592
+ if exists(context_dim) and isinstance(context_dim, list):
593
+ if depth != len(context_dim):
594
+ # print(
595
+ # f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
596
+ # f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
597
+ # )
598
+ # depth does not match context dims.
599
+ assert all(
600
+ map(lambda x: x == context_dim[0], context_dim)
601
+ ), "need homogenous context_dim to match depth automatically"
602
+ context_dim = depth * [context_dim[0]]
603
+ elif context_dim is None:
604
+ context_dim = [None] * depth
605
  self.in_channels = in_channels
606
  inner_dim = n_heads * d_head
607
  self.norm = Normalize(in_channels)
 
619
  n_heads,
620
  d_head,
621
  dropout=dropout,
622
+ context_dim=context_dim[d],
623
+ add_context_dim=add_context_dim,
624
+ disable_self_attn=disable_self_attn,
625
+ attn_mode=attn_type,
626
+ checkpoint=use_checkpoint,
627
+ sdp_backend=sdp_backend,
628
  )
629
  for d in range(depth)
630
  ]
 
634
  nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
635
  )
636
  else:
637
+ # self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
638
  self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
639
  self.use_linear = use_linear
640
 
641
+ def forward(self, x, context=None, add_context=None):
642
+ # note: if no context is given, cross-attention defaults to self-attention
643
+ if not isinstance(context, list):
644
+ context = [context]
645
  b, c, h, w = x.shape
646
  x_in = x
647
  x = self.norm(x)
 
651
  if self.use_linear:
652
  x = self.proj_in(x)
653
  for i, block in enumerate(self.transformer_blocks):
654
+ if i > 0 and len(context) == 1:
655
+ i = 0 # use same context for each block
656
+ x = block(x, context=context[i], add_context=add_context)
657
  if self.use_linear:
658
  x = self.proj_out(x)
659
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
660
  if not self.use_linear:
661
  x = self.proj_out(x)
662
+ return x + x_in
663
+
664
+
665
+ def benchmark_attn():
666
+ # Lets define a helpful benchmarking function:
667
+ # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
668
+ device = "cuda" if torch.cuda.is_available() else "cpu"
669
+ import torch.nn.functional as F
670
+ import torch.utils.benchmark as benchmark
671
+
672
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
673
+ t0 = benchmark.Timer(
674
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
675
+ )
676
+ return t0.blocked_autorange().mean * 1e6
677
+
678
+ # Lets define the hyper-parameters of our input
679
+ batch_size = 32
680
+ max_sequence_len = 1024
681
+ num_heads = 32
682
+ embed_dimension = 32
683
+
684
+ dtype = torch.float16
685
+
686
+ query = torch.rand(
687
+ batch_size,
688
+ num_heads,
689
+ max_sequence_len,
690
+ embed_dimension,
691
+ device=device,
692
+ dtype=dtype,
693
+ )
694
+ key = torch.rand(
695
+ batch_size,
696
+ num_heads,
697
+ max_sequence_len,
698
+ embed_dimension,
699
+ device=device,
700
+ dtype=dtype,
701
+ )
702
+ value = torch.rand(
703
+ batch_size,
704
+ num_heads,
705
+ max_sequence_len,
706
+ embed_dimension,
707
+ device=device,
708
+ dtype=dtype,
709
+ )
710
+
711
+ print(f"q/k/v shape:", query.shape, key.shape, value.shape)
712
+
713
+ # Lets explore the speed of each of the 3 implementations
714
+ from torch.backends.cuda import SDPBackend, sdp_kernel
715
+
716
+ # Helpful arguments mapper
717
+ backend_map = {
718
+ SDPBackend.MATH: {
719
+ "enable_math": True,
720
+ "enable_flash": False,
721
+ "enable_mem_efficient": False,
722
+ },
723
+ SDPBackend.FLASH_ATTENTION: {
724
+ "enable_math": False,
725
+ "enable_flash": True,
726
+ "enable_mem_efficient": False,
727
+ },
728
+ SDPBackend.EFFICIENT_ATTENTION: {
729
+ "enable_math": False,
730
+ "enable_flash": False,
731
+ "enable_mem_efficient": True,
732
+ },
733
+ }
734
+
735
+ from torch.profiler import ProfilerActivity, profile, record_function
736
+
737
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
738
+
739
+ print(
740
+ f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
741
+ )
742
+ with profile(
743
+ activities=activities, record_shapes=False, profile_memory=True
744
+ ) as prof:
745
+ with record_function("Default detailed stats"):
746
+ for _ in range(25):
747
+ o = F.scaled_dot_product_attention(query, key, value)
748
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
749
+
750
+ print(
751
+ f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
752
+ )
753
+ with sdp_kernel(**backend_map[SDPBackend.MATH]):
754
+ with profile(
755
+ activities=activities, record_shapes=False, profile_memory=True
756
+ ) as prof:
757
+ with record_function("Math implmentation stats"):
758
+ for _ in range(25):
759
+ o = F.scaled_dot_product_attention(query, key, value)
760
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
761
+
762
+ with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
763
+ try:
764
+ print(
765
+ f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
766
+ )
767
+ except RuntimeError:
768
+ print("FlashAttention is not supported. See warnings for reasons.")
769
+ with profile(
770
+ activities=activities, record_shapes=False, profile_memory=True
771
+ ) as prof:
772
+ with record_function("FlashAttention stats"):
773
+ for _ in range(25):
774
+ o = F.scaled_dot_product_attention(query, key, value)
775
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
776
+
777
+ with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
778
+ try:
779
+ print(
780
+ f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
781
+ )
782
+ except RuntimeError:
783
+ print("EfficientAttention is not supported. See warnings for reasons.")
784
+ with profile(
785
+ activities=activities, record_shapes=False, profile_memory=True
786
+ ) as prof:
787
+ with record_function("EfficientAttention stats"):
788
+ for _ in range(25):
789
+ o = F.scaled_dot_product_attention(query, key, value)
790
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
791
+
792
+
793
+ def run_model(model, x, context):
794
+ return model(x, context)
795
+
796
+
797
+ def benchmark_transformer_blocks():
798
+ device = "cuda" if torch.cuda.is_available() else "cpu"
799
+ import torch.utils.benchmark as benchmark
800
+
801
+ def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
802
+ t0 = benchmark.Timer(
803
+ stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
804
+ )
805
+ return t0.blocked_autorange().mean * 1e6
806
+
807
+ checkpoint = True
808
+ compile = False
809
+
810
+ batch_size = 32
811
+ h, w = 64, 64
812
+ context_len = 77
813
+ embed_dimension = 1024
814
+ context_dim = 1024
815
+ d_head = 64
816
+
817
+ transformer_depth = 4
818
+
819
+ n_heads = embed_dimension // d_head
820
+
821
+ dtype = torch.float16
822
+
823
+ model_native = SpatialTransformer(
824
+ embed_dimension,
825
+ n_heads,
826
+ d_head,
827
+ context_dim=context_dim,
828
+ use_linear=True,
829
+ use_checkpoint=checkpoint,
830
+ attn_type="softmax",
831
+ depth=transformer_depth,
832
+ sdp_backend=SDPBackend.FLASH_ATTENTION,
833
+ ).to(device)
834
+ model_efficient_attn = SpatialTransformer(
835
+ embed_dimension,
836
+ n_heads,
837
+ d_head,
838
+ context_dim=context_dim,
839
+ use_linear=True,
840
+ depth=transformer_depth,
841
+ use_checkpoint=checkpoint,
842
+ attn_type="softmax-xformers",
843
+ ).to(device)
844
+ if not checkpoint and compile:
845
+ print("compiling models")
846
+ model_native = torch.compile(model_native)
847
+ model_efficient_attn = torch.compile(model_efficient_attn)
848
+
849
+ x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
850
+ c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
851
+
852
+ from torch.profiler import ProfilerActivity, profile, record_function
853
+
854
+ activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
855
+
856
+ with torch.autocast("cuda"):
857
+ print(
858
+ f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
859
+ )
860
+ print(
861
+ f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
862
+ )
863
+
864
+ print(75 * "+")
865
+ print("NATIVE")
866
+ print(75 * "+")
867
+ torch.cuda.reset_peak_memory_stats()
868
+ with profile(
869
+ activities=activities, record_shapes=False, profile_memory=True
870
+ ) as prof:
871
+ with record_function("NativeAttention stats"):
872
+ for _ in range(25):
873
+ model_native(x, c)
874
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
875
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
876
+
877
+ print(75 * "+")
878
+ print("Xformers")
879
+ print(75 * "+")
880
+ torch.cuda.reset_peak_memory_stats()
881
+ with profile(
882
+ activities=activities, record_shapes=False, profile_memory=True
883
+ ) as prof:
884
+ with record_function("xformers stats"):
885
+ for _ in range(25):
886
+ model_efficient_attn(x, c)
887
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
888
+ print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
889
+
890
+
891
+ def test01():
892
+ # conv1x1 vs linear
893
+ from ..util import count_params
894
+
895
+ conv = nn.Conv2d(3, 32, kernel_size=1).cuda()
896
+ print(count_params(conv))
897
+ linear = torch.nn.Linear(3, 32).cuda()
898
+ print(count_params(linear))
899
+
900
+ print(conv.weight.shape)
901
+
902
+ # use same initialization
903
+ linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
904
+ linear.bias = torch.nn.Parameter(conv.bias)
905
+
906
+ print(linear.weight.shape)
907
+
908
+ x = torch.randn(11, 3, 64, 64).cuda()
909
+
910
+ xr = rearrange(x, "b c h w -> b (h w) c").contiguous()
911
+ print(xr.shape)
912
+ out_linear = linear(xr)
913
+ print(out_linear.mean(), out_linear.shape)
914
+
915
+ out_conv = conv(x)
916
+ print(out_conv.mean(), out_conv.shape)
917
+ print("done with test01.\n")
918
+
919
+
920
+ def test02():
921
+ # try cosine flash attention
922
+ import time
923
+
924
+ torch.backends.cuda.matmul.allow_tf32 = True
925
+ torch.backends.cudnn.allow_tf32 = True
926
+ torch.backends.cudnn.benchmark = True
927
+ print("testing cosine flash attention...")
928
+ DIM = 1024
929
+ SEQLEN = 4096
930
+ BS = 16
931
+
932
+ print(" softmax (vanilla) first...")
933
+ model = BasicTransformerBlock(
934
+ dim=DIM,
935
+ n_heads=16,
936
+ d_head=64,
937
+ dropout=0.0,
938
+ context_dim=None,
939
+ attn_mode="softmax",
940
+ ).cuda()
941
+ try:
942
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
943
+ tic = time.time()
944
+ y = model(x)
945
+ toc = time.time()
946
+ print(y.shape, toc - tic)
947
+ except RuntimeError as e:
948
+ # likely oom
949
+ print(str(e))
950
+
951
+ print("\n now flash-cosine...")
952
+ model = BasicTransformerBlock(
953
+ dim=DIM,
954
+ n_heads=16,
955
+ d_head=64,
956
+ dropout=0.0,
957
+ context_dim=None,
958
+ attn_mode="flash-cosine",
959
+ ).cuda()
960
+ x = torch.randn(BS, SEQLEN, DIM).cuda()
961
+ tic = time.time()
962
+ y = model(x)
963
+ toc = time.time()
964
+ print(y.shape, toc - tic)
965
+ print("done with test02.\n")
966
+
967
+
968
+ if __name__ == "__main__":
969
+ # test01()
970
+ # test02()
971
+ # test03()
972
+
973
+ # benchmark_attn()
974
+ benchmark_transformer_blocks()
975
 
976
+ print("done.")
sgm/modules/diffusionmodules/__init__.py CHANGED
@@ -2,6 +2,6 @@ from .denoiser import Denoiser
2
  from .discretizer import Discretization
3
  from .loss import StandardDiffusionLoss
4
  from .model import Model, Encoder, Decoder
5
- from .openaimodel import UnifiedUNetModel
6
  from .sampling import BaseDiffusionSampler
7
  from .wrappers import OpenAIWrapper
 
2
  from .discretizer import Discretization
3
  from .loss import StandardDiffusionLoss
4
  from .model import Model, Encoder, Decoder
5
+ from .openaimodel import UNetModel
6
  from .sampling import BaseDiffusionSampler
7
  from .wrappers import OpenAIWrapper
sgm/modules/diffusionmodules/guiders.py CHANGED
@@ -11,8 +11,8 @@ class VanillaCFG:
11
  """
12
 
13
  def __init__(self, scale, dyn_thresh_config=None):
14
-
15
- self.scale_value = scale
16
  self.dyn_thresh = instantiate_from_config(
17
  default(
18
  dyn_thresh_config,
@@ -24,14 +24,15 @@ class VanillaCFG:
24
 
25
  def __call__(self, x, sigma):
26
  x_u, x_c = x.chunk(2)
27
- x_pred = self.dyn_thresh(x_u, x_c, self.scale_value)
 
28
  return x_pred
29
 
30
  def prepare_inputs(self, x, s, c, uc):
31
  c_out = dict()
32
 
33
  for k in c:
34
- if k in ["vector", "t_crossattn", "v_crossattn", "concat"]:
35
  c_out[k] = torch.cat((uc[k], c[k]), 0)
36
  else:
37
  assert c[k] == uc[k]
@@ -39,6 +40,34 @@ class VanillaCFG:
39
  return torch.cat([x] * 2), torch.cat([s] * 2), c_out
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class IdentityGuider:
43
  def __call__(self, x, sigma):
44
  return x
 
11
  """
12
 
13
  def __init__(self, scale, dyn_thresh_config=None):
14
+ scale_schedule = lambda scale, sigma: scale # independent of step
15
+ self.scale_schedule = partial(scale_schedule, scale)
16
  self.dyn_thresh = instantiate_from_config(
17
  default(
18
  dyn_thresh_config,
 
24
 
25
  def __call__(self, x, sigma):
26
  x_u, x_c = x.chunk(2)
27
+ scale_value = self.scale_schedule(sigma)
28
+ x_pred = self.dyn_thresh(x_u, x_c, scale_value)
29
  return x_pred
30
 
31
  def prepare_inputs(self, x, s, c, uc):
32
  c_out = dict()
33
 
34
  for k in c:
35
+ if k in ["vector", "crossattn", "add_crossattn", "concat"]:
36
  c_out[k] = torch.cat((uc[k], c[k]), 0)
37
  else:
38
  assert c[k] == uc[k]
 
40
  return torch.cat([x] * 2), torch.cat([s] * 2), c_out
41
 
42
 
43
+ class DualCFG:
44
+
45
+ def __init__(self, scale):
46
+ self.scale = scale
47
+ self.dyn_thresh = instantiate_from_config(
48
+ {
49
+ "target": "sgm.modules.diffusionmodules.sampling_utils.DualThresholding"
50
+ },
51
+ )
52
+
53
+ def __call__(self, x, sigma):
54
+ x_u_1, x_u_2, x_c = x.chunk(3)
55
+ x_pred = self.dyn_thresh(x_u_1, x_u_2, x_c, self.scale)
56
+ return x_pred
57
+
58
+ def prepare_inputs(self, x, s, c, uc_1, uc_2):
59
+ c_out = dict()
60
+
61
+ for k in c:
62
+ if k in ["vector", "crossattn", "concat", "add_crossattn"]:
63
+ c_out[k] = torch.cat((uc_1[k], uc_2[k], c[k]), 0)
64
+ else:
65
+ assert c[k] == uc_1[k]
66
+ c_out[k] = c[k]
67
+ return torch.cat([x] * 3), torch.cat([s] * 3), c_out
68
+
69
+
70
+
71
  class IdentityGuider:
72
  def __call__(self, x, sigma):
73
  return x
sgm/modules/diffusionmodules/loss.py CHANGED
@@ -78,9 +78,7 @@ class FullLoss(StandardDiffusionLoss):
78
  min_attn_size=16,
79
  lambda_local_loss=0.0,
80
  lambda_ocr_loss=0.0,
81
- lambda_style_loss=0.0,
82
  ocr_enabled = False,
83
- style_enabled = False,
84
  predictor_config = None,
85
  *args, **kwarg
86
  ):
@@ -93,9 +91,7 @@ class FullLoss(StandardDiffusionLoss):
93
  self.min_attn_size = min_attn_size
94
  self.lambda_local_loss = lambda_local_loss
95
  self.lambda_ocr_loss = lambda_ocr_loss
96
- self.lambda_style_loss = lambda_style_loss
97
 
98
- self.style_enabled = style_enabled
99
  self.ocr_enabled = ocr_enabled
100
  if ocr_enabled:
101
  self.predictor = instantiate_from_config(predictor_config)
@@ -152,15 +148,9 @@ class FullLoss(StandardDiffusionLoss):
152
  ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
153
  ocr_loss = ocr_loss.mean()
154
 
155
- if self.style_enabled:
156
- style_loss = self.get_style_local_loss(network.diffusion_model.attn_map_cache, batch["mask"])
157
- style_loss = style_loss.mean()
158
-
159
  loss = diff_loss + self.lambda_local_loss * local_loss
160
  if self.ocr_enabled:
161
  loss += self.lambda_ocr_loss * ocr_loss
162
- if self.style_enabled:
163
- loss += self.lambda_style_loss * style_loss
164
 
165
  loss_dict = {
166
  "loss/diff_loss": diff_loss,
@@ -170,8 +160,6 @@ class FullLoss(StandardDiffusionLoss):
170
 
171
  if self.ocr_enabled:
172
  loss_dict["loss/ocr_loss"] = ocr_loss
173
- if self.style_enabled:
174
- loss_dict["loss/style_loss"] = style_loss
175
 
176
  return loss, loss_dict
177
 
@@ -196,9 +184,6 @@ class FullLoss(StandardDiffusionLoss):
196
 
197
  for item in attn_map_cache:
198
 
199
- name = item["name"]
200
- if not name.endswith("t_attn"): continue
201
-
202
  heads = item["heads"]
203
  size = item["size"]
204
  attn_map = item["attn_map"]
@@ -241,9 +226,6 @@ class FullLoss(StandardDiffusionLoss):
241
 
242
  for item in attn_map_cache:
243
 
244
- name = item["name"]
245
- if not name.endswith("t_attn"): continue
246
-
247
  heads = item["heads"]
248
  size = item["size"]
249
  attn_map = item["attn_map"]
@@ -252,7 +234,7 @@ class FullLoss(StandardDiffusionLoss):
252
 
253
  seg_l = seg_mask.shape[1]
254
 
255
- bh, n, l = attn_map.shape # bh: batch size * heads / n: pixel length(h*w) / l: token length
256
  attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
257
 
258
  assert seg_l <= l
@@ -283,43 +265,4 @@ class FullLoss(StandardDiffusionLoss):
283
 
284
  loss = loss / count
285
 
286
- return loss
287
-
288
- def get_style_local_loss(self, attn_map_cache, mask):
289
-
290
- loss = 0
291
- count = 0
292
-
293
- for item in attn_map_cache:
294
-
295
- name = item["name"]
296
- if not name.endswith("v_attn"): continue
297
-
298
- heads = item["heads"]
299
- size = item["size"]
300
- attn_map = item["attn_map"]
301
-
302
- if size < self.min_attn_size: continue
303
-
304
- bh, n, l = attn_map.shape # bh: batch size * heads / n: pixel length(h*w) / l: token length
305
- attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
306
- attn_map = attn_map.permute(0, 1, 3, 2) # b, h, l, n
307
- attn_map = attn_map.mean(dim = 1) # b, l, n
308
-
309
- mask_map = F.interpolate(mask, (size, size))
310
- mask_map = mask_map.reshape((-1, l, n)) # b, l, n
311
- n_mask_map = 1 - mask_map
312
-
313
- p_loss = (mask_map * attn_map).sum(dim = -1) / (mask_map.sum(dim = -1) + 1e-5) # b, l
314
- n_loss = (n_mask_map * attn_map).sum(dim = -1) / (n_mask_map.sum(dim = -1) + 1e-5) # b, l
315
-
316
- p_loss = p_loss.mean(dim = -1)
317
- n_loss = n_loss.mean(dim = -1)
318
-
319
- f_loss = n_loss - p_loss # b,
320
- loss += f_loss
321
- count += 1
322
-
323
- loss = loss / count
324
-
325
  return loss
 
78
  min_attn_size=16,
79
  lambda_local_loss=0.0,
80
  lambda_ocr_loss=0.0,
 
81
  ocr_enabled = False,
 
82
  predictor_config = None,
83
  *args, **kwarg
84
  ):
 
91
  self.min_attn_size = min_attn_size
92
  self.lambda_local_loss = lambda_local_loss
93
  self.lambda_ocr_loss = lambda_ocr_loss
 
94
 
 
95
  self.ocr_enabled = ocr_enabled
96
  if ocr_enabled:
97
  self.predictor = instantiate_from_config(predictor_config)
 
148
  ocr_loss = self.get_ocr_loss(model_output, batch["r_bbox"], batch["label"], first_stage_model, scaler)
149
  ocr_loss = ocr_loss.mean()
150
 
 
 
 
 
151
  loss = diff_loss + self.lambda_local_loss * local_loss
152
  if self.ocr_enabled:
153
  loss += self.lambda_ocr_loss * ocr_loss
 
 
154
 
155
  loss_dict = {
156
  "loss/diff_loss": diff_loss,
 
160
 
161
  if self.ocr_enabled:
162
  loss_dict["loss/ocr_loss"] = ocr_loss
 
 
163
 
164
  return loss, loss_dict
165
 
 
184
 
185
  for item in attn_map_cache:
186
 
 
 
 
187
  heads = item["heads"]
188
  size = item["size"]
189
  attn_map = item["attn_map"]
 
226
 
227
  for item in attn_map_cache:
228
 
 
 
 
229
  heads = item["heads"]
230
  size = item["size"]
231
  attn_map = item["attn_map"]
 
234
 
235
  seg_l = seg_mask.shape[1]
236
 
237
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
238
  attn_map = attn_map.reshape((-1, heads, n, l)) # b, h, n, l
239
 
240
  assert seg_l <= l
 
265
 
266
  loss = loss / count
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  return loss
sgm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -1,4 +1,7 @@
 
 
1
  from abc import abstractmethod
 
2
  from typing import Iterable
3
 
4
  import numpy as np
@@ -10,6 +13,7 @@ from einops import rearrange
10
  from ...modules.attention import SpatialTransformer
11
  from ...modules.diffusionmodules.util import (
12
  avg_pool_nd,
 
13
  conv_nd,
14
  linear,
15
  normalization,
@@ -19,14 +23,47 @@ from ...modules.diffusionmodules.util import (
19
  from ...util import default, exists
20
 
21
 
22
- class Timestep(nn.Module):
23
- def __init__(self, dim):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  super().__init__()
25
- self.dim = dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def forward(self, t):
28
- return timestep_embedding(t, self.dim)
29
-
30
 
31
  class TimestepBlock(nn.Module):
32
  """
@@ -50,14 +87,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
50
  self,
51
  x,
52
  emb,
53
- t_context=None,
54
- v_context=None
 
 
 
 
 
55
  ):
56
  for layer in self:
57
  if isinstance(layer, TimestepBlock):
58
  x = layer(x, emb)
59
  elif isinstance(layer, SpatialTransformer):
60
- x = layer(x, t_context, v_context)
61
  else:
62
  x = layer(x)
63
  return x
@@ -102,6 +144,22 @@ class Upsample(nn.Module):
102
  return x
103
 
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  class Downsample(nn.Module):
106
  """
107
  A downsampling layer with an optional convolution.
@@ -149,6 +207,17 @@ class Downsample(nn.Module):
149
  class ResBlock(TimestepBlock):
150
  """
151
  A residual block that can optionally change the number of channels.
 
 
 
 
 
 
 
 
 
 
 
152
  """
153
 
154
  def __init__(
@@ -160,11 +229,12 @@ class ResBlock(TimestepBlock):
160
  use_conv=False,
161
  use_scale_shift_norm=False,
162
  dims=2,
 
163
  up=False,
164
  down=False,
165
  kernel_size=3,
166
  exchange_temb_dims=False,
167
- skip_t_emb=False
168
  ):
169
  super().__init__()
170
  self.channels = channels
@@ -172,6 +242,7 @@ class ResBlock(TimestepBlock):
172
  self.dropout = dropout
173
  self.out_channels = out_channels or channels
174
  self.use_conv = use_conv
 
175
  self.use_scale_shift_norm = use_scale_shift_norm
176
  self.exchange_temb_dims = exchange_temb_dims
177
 
@@ -240,6 +311,17 @@ class ResBlock(TimestepBlock):
240
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
241
 
242
  def forward(self, x, emb):
 
 
 
 
 
 
 
 
 
 
 
243
  if self.updown:
244
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
245
  h = in_rest(x)
@@ -267,42 +349,233 @@ class ResBlock(TimestepBlock):
267
  h = self.out_layers(h)
268
  return self.skip_connection(x) + h
269
 
270
-
271
- import seaborn as sns
272
- import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
 
275
- class UnifiedUNetModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  def __init__(
278
  self,
279
  in_channels,
280
- ctrl_channels,
281
  model_channels,
282
  out_channels,
283
  num_res_blocks,
284
  attention_resolutions,
285
  dropout=0,
286
  channel_mult=(1, 2, 4, 8),
287
- save_attn_type=None,
288
- save_attn_layers=[],
289
  conv_resample=True,
290
  dims=2,
291
- use_label=None,
 
 
292
  num_heads=-1,
293
  num_head_channels=-1,
294
  num_heads_upsample=-1,
295
  use_scale_shift_norm=False,
296
  resblock_updown=False,
297
- transformer_depth=1,
298
- t_context_dim=None,
299
- v_context_dim=None,
 
 
 
 
300
  num_attention_blocks=None,
 
301
  use_linear_in_transformer=False,
 
302
  adm_in_channels=None,
303
- transformer_depth_middle=None
 
 
304
  ):
305
  super().__init__()
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if num_heads_upsample == -1:
308
  num_heads_upsample = num_heads
@@ -318,39 +591,106 @@ class UnifiedUNetModel(nn.Module):
318
  ), "Either num_heads or num_head_channels has to be set"
319
 
320
  self.in_channels = in_channels
321
- self.ctrl_channels = ctrl_channels
322
  self.model_channels = model_channels
323
  self.out_channels = out_channels
 
 
 
 
 
 
 
324
 
325
- transformer_depth = len(channel_mult) * [transformer_depth]
326
- transformer_depth_middle = default(transformer_depth_middle, transformer_depth[-1])
327
-
328
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  self.attention_resolutions = attention_resolutions
331
  self.dropout = dropout
332
  self.channel_mult = channel_mult
333
  self.conv_resample = conv_resample
334
- self.use_label = use_label
 
 
 
 
335
  self.num_heads = num_heads
336
  self.num_head_channels = num_head_channels
337
  self.num_heads_upsample = num_heads_upsample
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
  time_embed_dim = model_channels * 4
340
- self.time_embed = nn.Sequential(
341
- linear(model_channels, time_embed_dim),
342
- nn.SiLU(),
343
- linear(time_embed_dim, time_embed_dim),
 
 
344
  )
345
-
346
- if self.use_label is not None:
347
- self.label_emb = nn.Sequential(
348
- nn.Sequential(
349
- linear(adm_in_channels, time_embed_dim),
350
- nn.SiLU(),
351
- linear(time_embed_dim, time_embed_dim),
 
 
 
 
 
 
 
 
 
 
352
  )
353
- )
 
 
 
 
 
 
 
 
 
 
354
 
355
  self.input_blocks = nn.ModuleList(
356
  [
@@ -359,26 +699,6 @@ class UnifiedUNetModel(nn.Module):
359
  )
360
  ]
361
  )
362
-
363
- if self.ctrl_channels > 0:
364
- self.ctrl_block = TimestepEmbedSequential(
365
- conv_nd(dims, ctrl_channels, 16, 3, padding=1),
366
- nn.SiLU(),
367
- conv_nd(dims, 16, 16, 3, padding=1),
368
- nn.SiLU(),
369
- conv_nd(dims, 16, 32, 3, padding=1),
370
- nn.SiLU(),
371
- conv_nd(dims, 32, 32, 3, padding=1),
372
- nn.SiLU(),
373
- conv_nd(dims, 32, 96, 3, padding=1),
374
- nn.SiLU(),
375
- conv_nd(dims, 96, 96, 3, padding=1),
376
- nn.SiLU(),
377
- conv_nd(dims, 96, 256, 3, padding=1),
378
- nn.SiLU(),
379
- zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
380
- )
381
-
382
  self._feature_size = model_channels
383
  input_block_chans = [model_channels]
384
  ch = model_channels
@@ -386,13 +706,16 @@ class UnifiedUNetModel(nn.Module):
386
  for level, mult in enumerate(channel_mult):
387
  for nr in range(self.num_res_blocks[level]):
388
  layers = [
389
- ResBlock(
390
- ch,
391
- time_embed_dim,
392
- dropout,
393
- out_channels=mult * model_channels,
394
- dims=dims,
395
- use_scale_shift_norm=use_scale_shift_norm
 
 
 
396
  )
397
  ]
398
  ch = mult * model_channels
@@ -402,19 +725,45 @@ class UnifiedUNetModel(nn.Module):
402
  else:
403
  num_heads = ch // num_head_channels
404
  dim_head = num_head_channels
 
 
 
 
 
 
 
 
 
 
 
 
405
  if (
406
  not exists(num_attention_blocks)
407
  or nr < num_attention_blocks[level]
408
  ):
409
  layers.append(
410
- SpatialTransformer(
411
- ch,
412
- num_heads,
413
- dim_head,
414
- depth=transformer_depth[level],
415
- t_context_dim=t_context_dim,
416
- v_context_dim=v_context_dim,
417
- use_linear=use_linear_in_transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  )
419
  )
420
  self.input_blocks.append(TimestepEmbedSequential(*layers))
@@ -424,14 +773,17 @@ class UnifiedUNetModel(nn.Module):
424
  out_ch = ch
425
  self.input_blocks.append(
426
  TimestepEmbedSequential(
427
- ResBlock(
428
- ch,
429
- time_embed_dim,
430
- dropout,
431
- out_channels=out_ch,
432
- dims=dims,
433
- use_scale_shift_norm=use_scale_shift_norm,
434
- down=True
 
 
 
435
  )
436
  if resblock_updown
437
  else Downsample(
@@ -449,33 +801,54 @@ class UnifiedUNetModel(nn.Module):
449
  else:
450
  num_heads = ch // num_head_channels
451
  dim_head = num_head_channels
452
-
 
 
453
  self.middle_block = TimestepEmbedSequential(
454
- ResBlock(
455
- ch,
456
- time_embed_dim,
457
- dropout,
458
- dims=dims,
459
- use_scale_shift_norm=use_scale_shift_norm
460
- ),
461
- SpatialTransformer( # always uses a self-attn
462
- ch,
463
- num_heads,
464
- dim_head,
465
- depth=transformer_depth_middle,
466
- t_context_dim=t_context_dim,
467
- v_context_dim=v_context_dim,
468
- use_linear=use_linear_in_transformer
469
  ),
470
- ResBlock(
471
- ch,
472
- time_embed_dim,
473
- dropout,
474
- dims=dims,
475
- use_scale_shift_norm=use_scale_shift_norm
 
 
476
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
477
  )
478
-
479
  self._feature_size += ch
480
 
481
  self.output_blocks = nn.ModuleList([])
@@ -483,13 +856,16 @@ class UnifiedUNetModel(nn.Module):
483
  for i in range(self.num_res_blocks[level] + 1):
484
  ich = input_block_chans.pop()
485
  layers = [
486
- ResBlock(
487
- ch + ich,
488
- time_embed_dim,
489
- dropout,
490
- out_channels=model_channels * mult,
491
- dims=dims,
492
- use_scale_shift_norm=use_scale_shift_norm
 
 
 
493
  )
494
  ]
495
  ch = model_channels * mult
@@ -499,32 +875,61 @@ class UnifiedUNetModel(nn.Module):
499
  else:
500
  num_heads = ch // num_head_channels
501
  dim_head = num_head_channels
 
 
 
 
 
 
 
 
 
 
 
 
502
  if (
503
  not exists(num_attention_blocks)
504
  or i < num_attention_blocks[level]
505
  ):
506
  layers.append(
507
- SpatialTransformer(
508
- ch,
509
- num_heads,
510
- dim_head,
511
- depth=transformer_depth[level],
512
- t_context_dim=t_context_dim,
513
- v_context_dim=v_context_dim,
514
- use_linear=use_linear_in_transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  )
516
  )
517
  if level and i == self.num_res_blocks[level]:
518
  out_ch = ch
519
  layers.append(
520
- ResBlock(
521
- ch,
522
- time_embed_dim,
523
- dropout,
524
- out_channels=out_ch,
525
- dims=dims,
526
- use_scale_shift_norm=use_scale_shift_norm,
527
- up=True
 
 
 
528
  )
529
  if resblock_updown
530
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
@@ -533,92 +938,1133 @@ class UnifiedUNetModel(nn.Module):
533
  self.output_blocks.append(TimestepEmbedSequential(*layers))
534
  self._feature_size += ch
535
 
536
- self.out = nn.Sequential(
537
- normalization(ch),
538
- nn.SiLU(),
539
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1))
 
 
540
  )
541
-
542
- # cache attn map
543
- self.attn_type = save_attn_type
544
- self.attn_layers = save_attn_layers
545
- self.attn_map_cache = []
546
- for name, module in self.named_modules():
547
- if any([name.endswith(attn_type) for attn_type in self.attn_type]):
548
- item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
549
- self.attn_map_cache.append(item)
550
- module.attn_map_cache = item
551
-
552
- def clear_attn_map(self):
553
-
554
- for item in self.attn_map_cache:
555
- if item["attn_map"] is not None:
556
- del item["attn_map"]
557
- item["attn_map"] = None
558
-
559
- def save_attn_map(self, attn_type="t_attn", save_name="temp", tokens=""):
560
-
561
- attn_maps = []
562
- for item in self.attn_map_cache:
563
- name = item["name"]
564
- if any([name.startswith(block) for block in self.attn_layers]) and name.endswith(attn_type):
565
- heads = item["heads"]
566
- attn_maps.append(item["attn_map"].detach().cpu())
567
-
568
- attn_map = th.stack(attn_maps, dim=0)
569
- attn_map = th.mean(attn_map, dim=0)
570
-
571
- # attn_map: bh * n * l
572
- bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
573
- attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
574
- b = attn_map.shape[0]
575
-
576
- h = w = int(n**0.5)
577
- attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
578
- attn_map_i = attn_map[-1]
579
 
580
- l = attn_map_i.shape[0]
581
- fig = plt.figure(figsize=(12, 8), dpi=300)
582
- for j in range(12):
583
- if j >= l: break
584
- ax = fig.add_subplot(3, 4, j+1)
585
- sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
586
- if j < len(tokens):
587
- ax.set_title(tokens[j])
588
- fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
589
- plt.close()
590
 
591
- return attn_map_i
592
-
593
- def forward(self, x, timesteps=None, t_context=None, v_context=None, y=None, **kwargs):
 
 
 
 
594
 
 
 
 
 
 
 
 
 
 
595
  assert (y is not None) == (
596
- self.use_label is not None
597
  ), "must specify y if and only if the model is class-conditional"
598
-
599
- self.clear_attn_map()
600
-
601
  hs = []
602
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
603
  emb = self.time_embed(t_emb)
604
 
605
- if self.use_label is not None:
606
  assert y.shape[0] == x.shape[0]
607
  emb = emb + self.label_emb(y)
608
 
 
609
  h = x
610
- if self.ctrl_channels > 0:
611
- in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  for i, module in enumerate(self.input_blocks):
613
  if self.ctrl_channels > 0 and i == 0:
614
- h = module(in_h, emb, t_context, v_context) + self.ctrl_block(add_h, emb, t_context, v_context)
615
  else:
616
- h = module(h, emb, t_context, v_context)
617
  hs.append(h)
618
- h = self.middle_block(h, emb, t_context, v_context)
619
  for i, module in enumerate(self.output_blocks):
620
  h = th.cat([h, hs.pop()], dim=1)
621
- h = module(h, emb, t_context, v_context)
622
  h = h.type(x.dtype)
623
 
624
  return self.out(h)
 
1
+ import os
2
+ import math
3
  from abc import abstractmethod
4
+ from functools import partial
5
  from typing import Iterable
6
 
7
  import numpy as np
 
13
  from ...modules.attention import SpatialTransformer
14
  from ...modules.diffusionmodules.util import (
15
  avg_pool_nd,
16
+ checkpoint,
17
  conv_nd,
18
  linear,
19
  normalization,
 
23
  from ...util import default, exists
24
 
25
 
26
+ # dummy replace
27
+ def convert_module_to_f16(x):
28
+ pass
29
+
30
+
31
+ def convert_module_to_f32(x):
32
+ pass
33
+
34
+
35
+ ## go
36
+ class AttentionPool2d(nn.Module):
37
+ """
38
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ spacial_dim: int,
44
+ embed_dim: int,
45
+ num_heads_channels: int,
46
+ output_dim: int = None,
47
+ ):
48
  super().__init__()
49
+ self.positional_embedding = nn.Parameter(
50
+ th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
51
+ )
52
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
53
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
54
+ self.num_heads = embed_dim // num_heads_channels
55
+ self.attention = QKVAttention(self.num_heads)
56
+
57
+ def forward(self, x):
58
+ b, c, *_spatial = x.shape
59
+ x = x.reshape(b, c, -1) # NC(HW)
60
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
61
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
62
+ x = self.qkv_proj(x)
63
+ x = self.attention(x)
64
+ x = self.c_proj(x)
65
+ return x[:, :, 0]
66
 
 
 
 
67
 
68
  class TimestepBlock(nn.Module):
69
  """
 
87
  self,
88
  x,
89
  emb,
90
+ context=None,
91
+ add_context=None,
92
+ skip_time_mix=False,
93
+ time_context=None,
94
+ num_video_frames=None,
95
+ time_context_cat=None,
96
+ use_crossframe_attention_in_spatial_layers=False,
97
  ):
98
  for layer in self:
99
  if isinstance(layer, TimestepBlock):
100
  x = layer(x, emb)
101
  elif isinstance(layer, SpatialTransformer):
102
+ x = layer(x, context, add_context)
103
  else:
104
  x = layer(x)
105
  return x
 
144
  return x
145
 
146
 
147
+ class TransposedUpsample(nn.Module):
148
+ "Learned 2x upsampling without padding"
149
+
150
+ def __init__(self, channels, out_channels=None, ks=5):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.out_channels = out_channels or channels
154
+
155
+ self.up = nn.ConvTranspose2d(
156
+ self.channels, self.out_channels, kernel_size=ks, stride=2
157
+ )
158
+
159
+ def forward(self, x):
160
+ return self.up(x)
161
+
162
+
163
  class Downsample(nn.Module):
164
  """
165
  A downsampling layer with an optional convolution.
 
207
  class ResBlock(TimestepBlock):
208
  """
209
  A residual block that can optionally change the number of channels.
210
+ :param channels: the number of input channels.
211
+ :param emb_channels: the number of timestep embedding channels.
212
+ :param dropout: the rate of dropout.
213
+ :param out_channels: if specified, the number of out channels.
214
+ :param use_conv: if True and out_channels is specified, use a spatial
215
+ convolution instead of a smaller 1x1 convolution to change the
216
+ channels in the skip connection.
217
+ :param dims: determines if the signal is 1D, 2D, or 3D.
218
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
219
+ :param up: if True, use this block for upsampling.
220
+ :param down: if True, use this block for downsampling.
221
  """
222
 
223
  def __init__(
 
229
  use_conv=False,
230
  use_scale_shift_norm=False,
231
  dims=2,
232
+ use_checkpoint=False,
233
  up=False,
234
  down=False,
235
  kernel_size=3,
236
  exchange_temb_dims=False,
237
+ skip_t_emb=False,
238
  ):
239
  super().__init__()
240
  self.channels = channels
 
242
  self.dropout = dropout
243
  self.out_channels = out_channels or channels
244
  self.use_conv = use_conv
245
+ self.use_checkpoint = use_checkpoint
246
  self.use_scale_shift_norm = use_scale_shift_norm
247
  self.exchange_temb_dims = exchange_temb_dims
248
 
 
311
  self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
312
 
313
  def forward(self, x, emb):
314
+ """
315
+ Apply the block to a Tensor, conditioned on a timestep embedding.
316
+ :param x: an [N x C x ...] Tensor of features.
317
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
318
+ :return: an [N x C x ...] Tensor of outputs.
319
+ """
320
+ return checkpoint(
321
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
322
+ )
323
+
324
+ def _forward(self, x, emb):
325
  if self.updown:
326
  in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
327
  h = in_rest(x)
 
349
  h = self.out_layers(h)
350
  return self.skip_connection(x) + h
351
 
352
+
353
+ class AttentionBlock(nn.Module):
354
+ """
355
+ An attention block that allows spatial positions to attend to each other.
356
+ Originally ported from here, but adapted to the N-d case.
357
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ channels,
363
+ num_heads=1,
364
+ num_head_channels=-1,
365
+ use_checkpoint=False,
366
+ use_new_attention_order=False,
367
+ ):
368
+ super().__init__()
369
+ self.channels = channels
370
+ if num_head_channels == -1:
371
+ self.num_heads = num_heads
372
+ else:
373
+ assert (
374
+ channels % num_head_channels == 0
375
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
376
+ self.num_heads = channels // num_head_channels
377
+ self.use_checkpoint = use_checkpoint
378
+ self.norm = normalization(channels)
379
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
380
+ if use_new_attention_order:
381
+ # split qkv before split heads
382
+ self.attention = QKVAttention(self.num_heads)
383
+ else:
384
+ # split heads before split qkv
385
+ self.attention = QKVAttentionLegacy(self.num_heads)
386
+
387
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
388
+
389
+ def forward(self, x, **kwargs):
390
+ # TODO add crossframe attention and use mixed checkpoint
391
+ return checkpoint(
392
+ self._forward, (x,), self.parameters(), True
393
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
394
+ # return pt_checkpoint(self._forward, x) # pytorch
395
+
396
+ def _forward(self, x):
397
+ b, c, *spatial = x.shape
398
+ x = x.reshape(b, c, -1)
399
+ qkv = self.qkv(self.norm(x))
400
+ h = self.attention(qkv)
401
+ h = self.proj_out(h)
402
+ return (x + h).reshape(b, c, *spatial)
403
+
404
+
405
+ def count_flops_attn(model, _x, y):
406
+ """
407
+ A counter for the `thop` package to count the operations in an
408
+ attention operation.
409
+ Meant to be used like:
410
+ macs, params = thop.profile(
411
+ model,
412
+ inputs=(inputs, timestamps),
413
+ custom_ops={QKVAttention: QKVAttention.count_flops},
414
+ )
415
+ """
416
+ b, c, *spatial = y[0].shape
417
+ num_spatial = int(np.prod(spatial))
418
+ # We perform two matmuls with the same number of ops.
419
+ # The first computes the weight matrix, the second computes
420
+ # the combination of the value vectors.
421
+ matmul_ops = 2 * b * (num_spatial**2) * c
422
+ model.total_ops += th.DoubleTensor([matmul_ops])
423
+
424
+
425
+ class QKVAttentionLegacy(nn.Module):
426
+ """
427
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
428
+ """
429
+
430
+ def __init__(self, n_heads):
431
+ super().__init__()
432
+ self.n_heads = n_heads
433
+
434
+ def forward(self, qkv):
435
+ """
436
+ Apply QKV attention.
437
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
438
+ :return: an [N x (H * C) x T] tensor after attention.
439
+ """
440
+ bs, width, length = qkv.shape
441
+ assert width % (3 * self.n_heads) == 0
442
+ ch = width // (3 * self.n_heads)
443
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
444
+ scale = 1 / math.sqrt(math.sqrt(ch))
445
+ weight = th.einsum(
446
+ "bct,bcs->bts", q * scale, k * scale
447
+ ) # More stable with f16 than dividing afterwards
448
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
449
+ a = th.einsum("bts,bcs->bct", weight, v)
450
+ return a.reshape(bs, -1, length)
451
+
452
+ @staticmethod
453
+ def count_flops(model, _x, y):
454
+ return count_flops_attn(model, _x, y)
455
+
456
+
457
+ class QKVAttention(nn.Module):
458
+ """
459
+ A module which performs QKV attention and splits in a different order.
460
+ """
461
+
462
+ def __init__(self, n_heads):
463
+ super().__init__()
464
+ self.n_heads = n_heads
465
+
466
+ def forward(self, qkv):
467
+ """
468
+ Apply QKV attention.
469
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
470
+ :return: an [N x (H * C) x T] tensor after attention.
471
+ """
472
+ bs, width, length = qkv.shape
473
+ assert width % (3 * self.n_heads) == 0
474
+ ch = width // (3 * self.n_heads)
475
+ q, k, v = qkv.chunk(3, dim=1)
476
+ scale = 1 / math.sqrt(math.sqrt(ch))
477
+ weight = th.einsum(
478
+ "bct,bcs->bts",
479
+ (q * scale).view(bs * self.n_heads, ch, length),
480
+ (k * scale).view(bs * self.n_heads, ch, length),
481
+ ) # More stable with f16 than dividing afterwards
482
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
483
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
484
+ return a.reshape(bs, -1, length)
485
+
486
+ @staticmethod
487
+ def count_flops(model, _x, y):
488
+ return count_flops_attn(model, _x, y)
489
+
490
+
491
+ class Timestep(nn.Module):
492
+ def __init__(self, dim):
493
+ super().__init__()
494
+ self.dim = dim
495
+
496
+ def forward(self, t):
497
+ return timestep_embedding(t, self.dim)
498
 
499
 
500
+ class UNetModel(nn.Module):
501
+ """
502
+ The full UNet model with attention and timestep embedding.
503
+ :param in_channels: channels in the input Tensor.
504
+ :param model_channels: base channel count for the model.
505
+ :param out_channels: channels in the output Tensor.
506
+ :param num_res_blocks: number of residual blocks per downsample.
507
+ :param attention_resolutions: a collection of downsample rates at which
508
+ attention will take place. May be a set, list, or tuple.
509
+ For example, if this contains 4, then at 4x downsampling, attention
510
+ will be used.
511
+ :param dropout: the dropout probability.
512
+ :param channel_mult: channel multiplier for each level of the UNet.
513
+ :param conv_resample: if True, use learned convolutions for upsampling and
514
+ downsampling.
515
+ :param dims: determines if the signal is 1D, 2D, or 3D.
516
+ :param num_classes: if specified (as an int), then this model will be
517
+ class-conditional with `num_classes` classes.
518
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
519
+ :param num_heads: the number of attention heads in each attention layer.
520
+ :param num_heads_channels: if specified, ignore num_heads and instead use
521
+ a fixed channel width per attention head.
522
+ :param num_heads_upsample: works with num_heads to set a different number
523
+ of heads for upsampling. Deprecated.
524
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
525
+ :param resblock_updown: use residual blocks for up/downsampling.
526
+ :param use_new_attention_order: use a different attention pattern for potentially
527
+ increased efficiency.
528
+ """
529
 
530
  def __init__(
531
  self,
532
  in_channels,
 
533
  model_channels,
534
  out_channels,
535
  num_res_blocks,
536
  attention_resolutions,
537
  dropout=0,
538
  channel_mult=(1, 2, 4, 8),
 
 
539
  conv_resample=True,
540
  dims=2,
541
+ num_classes=None,
542
+ use_checkpoint=False,
543
+ use_fp16=False,
544
  num_heads=-1,
545
  num_head_channels=-1,
546
  num_heads_upsample=-1,
547
  use_scale_shift_norm=False,
548
  resblock_updown=False,
549
+ use_new_attention_order=False,
550
+ use_spatial_transformer=False, # custom transformer support
551
+ transformer_depth=1, # custom transformer support
552
+ context_dim=None, # custom transformer support
553
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
554
+ legacy=True,
555
+ disable_self_attentions=None,
556
  num_attention_blocks=None,
557
+ disable_middle_self_attn=False,
558
  use_linear_in_transformer=False,
559
+ spatial_transformer_attn_type="softmax",
560
  adm_in_channels=None,
561
+ use_fairscale_checkpoint=False,
562
+ offload_to_cpu=False,
563
+ transformer_depth_middle=None,
564
  ):
565
  super().__init__()
566
+ from omegaconf.listconfig import ListConfig
567
+
568
+ if use_spatial_transformer:
569
+ assert (
570
+ context_dim is not None
571
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
572
+
573
+ if context_dim is not None:
574
+ assert (
575
+ use_spatial_transformer
576
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
577
+ if type(context_dim) == ListConfig:
578
+ context_dim = list(context_dim)
579
 
580
  if num_heads_upsample == -1:
581
  num_heads_upsample = num_heads
 
591
  ), "Either num_heads or num_head_channels has to be set"
592
 
593
  self.in_channels = in_channels
 
594
  self.model_channels = model_channels
595
  self.out_channels = out_channels
596
+ if isinstance(transformer_depth, int):
597
+ transformer_depth = len(channel_mult) * [transformer_depth]
598
+ elif isinstance(transformer_depth, ListConfig):
599
+ transformer_depth = list(transformer_depth)
600
+ transformer_depth_middle = default(
601
+ transformer_depth_middle, transformer_depth[-1]
602
+ )
603
 
604
+ if isinstance(num_res_blocks, int):
605
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
606
+ else:
607
+ if len(num_res_blocks) != len(channel_mult):
608
+ raise ValueError(
609
+ "provide num_res_blocks either as an int (globally constant) or "
610
+ "as a list/tuple (per-level) with the same length as channel_mult"
611
+ )
612
+ self.num_res_blocks = num_res_blocks
613
+ # self.num_res_blocks = num_res_blocks
614
+ if disable_self_attentions is not None:
615
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
616
+ assert len(disable_self_attentions) == len(channel_mult)
617
+ if num_attention_blocks is not None:
618
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
619
+ assert all(
620
+ map(
621
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
622
+ range(len(num_attention_blocks)),
623
+ )
624
+ )
625
+ print(
626
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
627
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
628
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
629
+ f"attention will still not be set."
630
+ ) # todo: convert to warning
631
 
632
  self.attention_resolutions = attention_resolutions
633
  self.dropout = dropout
634
  self.channel_mult = channel_mult
635
  self.conv_resample = conv_resample
636
+ self.num_classes = num_classes
637
+ self.use_checkpoint = use_checkpoint
638
+ if use_fp16:
639
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
640
+ # self.dtype = th.float16 if use_fp16 else th.float32
641
  self.num_heads = num_heads
642
  self.num_head_channels = num_head_channels
643
  self.num_heads_upsample = num_heads_upsample
644
+ self.predict_codebook_ids = n_embed is not None
645
+
646
+ assert use_fairscale_checkpoint != use_checkpoint or not (
647
+ use_checkpoint or use_fairscale_checkpoint
648
+ )
649
+
650
+ self.use_fairscale_checkpoint = False
651
+ checkpoint_wrapper_fn = (
652
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
653
+ if self.use_fairscale_checkpoint
654
+ else lambda x: x
655
+ )
656
 
657
  time_embed_dim = model_channels * 4
658
+ self.time_embed = checkpoint_wrapper_fn(
659
+ nn.Sequential(
660
+ linear(model_channels, time_embed_dim),
661
+ nn.SiLU(),
662
+ linear(time_embed_dim, time_embed_dim),
663
+ )
664
  )
665
+
666
+ if self.num_classes is not None:
667
+ if isinstance(self.num_classes, int):
668
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
669
+ elif self.num_classes == "continuous":
670
+ print("setting up linear c_adm embedding layer")
671
+ self.label_emb = nn.Linear(1, time_embed_dim)
672
+ elif self.num_classes == "timestep":
673
+ self.label_emb = checkpoint_wrapper_fn(
674
+ nn.Sequential(
675
+ Timestep(model_channels),
676
+ nn.Sequential(
677
+ linear(model_channels, time_embed_dim),
678
+ nn.SiLU(),
679
+ linear(time_embed_dim, time_embed_dim),
680
+ ),
681
+ )
682
  )
683
+ elif self.num_classes == "sequential":
684
+ assert adm_in_channels is not None
685
+ self.label_emb = nn.Sequential(
686
+ nn.Sequential(
687
+ linear(adm_in_channels, time_embed_dim),
688
+ nn.SiLU(),
689
+ linear(time_embed_dim, time_embed_dim),
690
+ )
691
+ )
692
+ else:
693
+ raise ValueError()
694
 
695
  self.input_blocks = nn.ModuleList(
696
  [
 
699
  )
700
  ]
701
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  self._feature_size = model_channels
703
  input_block_chans = [model_channels]
704
  ch = model_channels
 
706
  for level, mult in enumerate(channel_mult):
707
  for nr in range(self.num_res_blocks[level]):
708
  layers = [
709
+ checkpoint_wrapper_fn(
710
+ ResBlock(
711
+ ch,
712
+ time_embed_dim,
713
+ dropout,
714
+ out_channels=mult * model_channels,
715
+ dims=dims,
716
+ use_checkpoint=use_checkpoint,
717
+ use_scale_shift_norm=use_scale_shift_norm,
718
+ )
719
  )
720
  ]
721
  ch = mult * model_channels
 
725
  else:
726
  num_heads = ch // num_head_channels
727
  dim_head = num_head_channels
728
+ if legacy:
729
+ # num_heads = 1
730
+ dim_head = (
731
+ ch // num_heads
732
+ if use_spatial_transformer
733
+ else num_head_channels
734
+ )
735
+ if exists(disable_self_attentions):
736
+ disabled_sa = disable_self_attentions[level]
737
+ else:
738
+ disabled_sa = False
739
+
740
  if (
741
  not exists(num_attention_blocks)
742
  or nr < num_attention_blocks[level]
743
  ):
744
  layers.append(
745
+ checkpoint_wrapper_fn(
746
+ AttentionBlock(
747
+ ch,
748
+ use_checkpoint=use_checkpoint,
749
+ num_heads=num_heads,
750
+ num_head_channels=dim_head,
751
+ use_new_attention_order=use_new_attention_order,
752
+ )
753
+ )
754
+ if not use_spatial_transformer
755
+ else checkpoint_wrapper_fn(
756
+ SpatialTransformer(
757
+ ch,
758
+ num_heads,
759
+ dim_head,
760
+ depth=transformer_depth[level],
761
+ context_dim=context_dim,
762
+ disable_self_attn=disabled_sa,
763
+ use_linear=use_linear_in_transformer,
764
+ attn_type=spatial_transformer_attn_type,
765
+ use_checkpoint=use_checkpoint,
766
+ )
767
  )
768
  )
769
  self.input_blocks.append(TimestepEmbedSequential(*layers))
 
773
  out_ch = ch
774
  self.input_blocks.append(
775
  TimestepEmbedSequential(
776
+ checkpoint_wrapper_fn(
777
+ ResBlock(
778
+ ch,
779
+ time_embed_dim,
780
+ dropout,
781
+ out_channels=out_ch,
782
+ dims=dims,
783
+ use_checkpoint=use_checkpoint,
784
+ use_scale_shift_norm=use_scale_shift_norm,
785
+ down=True,
786
+ )
787
  )
788
  if resblock_updown
789
  else Downsample(
 
801
  else:
802
  num_heads = ch // num_head_channels
803
  dim_head = num_head_channels
804
+ if legacy:
805
+ # num_heads = 1
806
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
807
  self.middle_block = TimestepEmbedSequential(
808
+ checkpoint_wrapper_fn(
809
+ ResBlock(
810
+ ch,
811
+ time_embed_dim,
812
+ dropout,
813
+ dims=dims,
814
+ use_checkpoint=use_checkpoint,
815
+ use_scale_shift_norm=use_scale_shift_norm,
816
+ )
 
 
 
 
 
 
817
  ),
818
+ checkpoint_wrapper_fn(
819
+ AttentionBlock(
820
+ ch,
821
+ use_checkpoint=use_checkpoint,
822
+ num_heads=num_heads,
823
+ num_head_channels=dim_head,
824
+ use_new_attention_order=use_new_attention_order,
825
+ )
826
  )
827
+ if not use_spatial_transformer
828
+ else checkpoint_wrapper_fn(
829
+ SpatialTransformer( # always uses a self-attn
830
+ ch,
831
+ num_heads,
832
+ dim_head,
833
+ depth=transformer_depth_middle,
834
+ context_dim=context_dim,
835
+ disable_self_attn=disable_middle_self_attn,
836
+ use_linear=use_linear_in_transformer,
837
+ attn_type=spatial_transformer_attn_type,
838
+ use_checkpoint=use_checkpoint,
839
+ )
840
+ ),
841
+ checkpoint_wrapper_fn(
842
+ ResBlock(
843
+ ch,
844
+ time_embed_dim,
845
+ dropout,
846
+ dims=dims,
847
+ use_checkpoint=use_checkpoint,
848
+ use_scale_shift_norm=use_scale_shift_norm,
849
+ )
850
+ ),
851
  )
 
852
  self._feature_size += ch
853
 
854
  self.output_blocks = nn.ModuleList([])
 
856
  for i in range(self.num_res_blocks[level] + 1):
857
  ich = input_block_chans.pop()
858
  layers = [
859
+ checkpoint_wrapper_fn(
860
+ ResBlock(
861
+ ch + ich,
862
+ time_embed_dim,
863
+ dropout,
864
+ out_channels=model_channels * mult,
865
+ dims=dims,
866
+ use_checkpoint=use_checkpoint,
867
+ use_scale_shift_norm=use_scale_shift_norm,
868
+ )
869
  )
870
  ]
871
  ch = model_channels * mult
 
875
  else:
876
  num_heads = ch // num_head_channels
877
  dim_head = num_head_channels
878
+ if legacy:
879
+ # num_heads = 1
880
+ dim_head = (
881
+ ch // num_heads
882
+ if use_spatial_transformer
883
+ else num_head_channels
884
+ )
885
+ if exists(disable_self_attentions):
886
+ disabled_sa = disable_self_attentions[level]
887
+ else:
888
+ disabled_sa = False
889
+
890
  if (
891
  not exists(num_attention_blocks)
892
  or i < num_attention_blocks[level]
893
  ):
894
  layers.append(
895
+ checkpoint_wrapper_fn(
896
+ AttentionBlock(
897
+ ch,
898
+ use_checkpoint=use_checkpoint,
899
+ num_heads=num_heads_upsample,
900
+ num_head_channels=dim_head,
901
+ use_new_attention_order=use_new_attention_order,
902
+ )
903
+ )
904
+ if not use_spatial_transformer
905
+ else checkpoint_wrapper_fn(
906
+ SpatialTransformer(
907
+ ch,
908
+ num_heads,
909
+ dim_head,
910
+ depth=transformer_depth[level],
911
+ context_dim=context_dim,
912
+ disable_self_attn=disabled_sa,
913
+ use_linear=use_linear_in_transformer,
914
+ attn_type=spatial_transformer_attn_type,
915
+ use_checkpoint=use_checkpoint,
916
+ )
917
  )
918
  )
919
  if level and i == self.num_res_blocks[level]:
920
  out_ch = ch
921
  layers.append(
922
+ checkpoint_wrapper_fn(
923
+ ResBlock(
924
+ ch,
925
+ time_embed_dim,
926
+ dropout,
927
+ out_channels=out_ch,
928
+ dims=dims,
929
+ use_checkpoint=use_checkpoint,
930
+ use_scale_shift_norm=use_scale_shift_norm,
931
+ up=True,
932
+ )
933
  )
934
  if resblock_updown
935
  else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
 
938
  self.output_blocks.append(TimestepEmbedSequential(*layers))
939
  self._feature_size += ch
940
 
941
+ self.out = checkpoint_wrapper_fn(
942
+ nn.Sequential(
943
+ normalization(ch),
944
+ nn.SiLU(),
945
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
946
+ )
947
  )
948
+ if self.predict_codebook_ids:
949
+ self.id_predictor = checkpoint_wrapper_fn(
950
+ nn.Sequential(
951
+ normalization(ch),
952
+ conv_nd(dims, model_channels, n_embed, 1),
953
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
954
+ )
955
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
956
 
957
+ def convert_to_fp16(self):
958
+ """
959
+ Convert the torso of the model to float16.
960
+ """
961
+ self.input_blocks.apply(convert_module_to_f16)
962
+ self.middle_block.apply(convert_module_to_f16)
963
+ self.output_blocks.apply(convert_module_to_f16)
 
 
 
964
 
965
+ def convert_to_fp32(self):
966
+ """
967
+ Convert the torso of the model to float32.
968
+ """
969
+ self.input_blocks.apply(convert_module_to_f32)
970
+ self.middle_block.apply(convert_module_to_f32)
971
+ self.output_blocks.apply(convert_module_to_f32)
972
 
973
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
974
+ """
975
+ Apply the model to an input batch.
976
+ :param x: an [N x C x ...] Tensor of inputs.
977
+ :param timesteps: a 1-D batch of timesteps.
978
+ :param context: conditioning plugged in via crossattn
979
+ :param y: an [N] Tensor of labels, if class-conditional.
980
+ :return: an [N x C x ...] Tensor of outputs.
981
+ """
982
  assert (y is not None) == (
983
+ self.num_classes is not None
984
  ), "must specify y if and only if the model is class-conditional"
 
 
 
985
  hs = []
986
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
987
  emb = self.time_embed(t_emb)
988
 
989
+ if self.num_classes is not None:
990
  assert y.shape[0] == x.shape[0]
991
  emb = emb + self.label_emb(y)
992
 
993
+ # h = x.type(self.dtype)
994
  h = x
995
+ for i, module in enumerate(self.input_blocks):
996
+ h = module(h, emb, context)
997
+ hs.append(h)
998
+ h = self.middle_block(h, emb, context)
999
+ for i, module in enumerate(self.output_blocks):
1000
+ h = th.cat([h, hs.pop()], dim=1)
1001
+ h = module(h, emb, context)
1002
+ h = h.type(x.dtype)
1003
+ if self.predict_codebook_ids:
1004
+ assert False, "not supported anymore. what the f*** are you doing?"
1005
+ else:
1006
+ return self.out(h)
1007
+
1008
+
1009
+
1010
+ class UNetModel(nn.Module):
1011
+ """
1012
+ The full UNet model with attention and timestep embedding.
1013
+ :param in_channels: channels in the input Tensor.
1014
+ :param model_channels: base channel count for the model.
1015
+ :param out_channels: channels in the output Tensor.
1016
+ :param num_res_blocks: number of residual blocks per downsample.
1017
+ :param attention_resolutions: a collection of downsample rates at which
1018
+ attention will take place. May be a set, list, or tuple.
1019
+ For example, if this contains 4, then at 4x downsampling, attention
1020
+ will be used.
1021
+ :param dropout: the dropout probability.
1022
+ :param channel_mult: channel multiplier for each level of the UNet.
1023
+ :param conv_resample: if True, use learned convolutions for upsampling and
1024
+ downsampling.
1025
+ :param dims: determines if the signal is 1D, 2D, or 3D.
1026
+ :param num_classes: if specified (as an int), then this model will be
1027
+ class-conditional with `num_classes` classes.
1028
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
1029
+ :param num_heads: the number of attention heads in each attention layer.
1030
+ :param num_heads_channels: if specified, ignore num_heads and instead use
1031
+ a fixed channel width per attention head.
1032
+ :param num_heads_upsample: works with num_heads to set a different number
1033
+ of heads for upsampling. Deprecated.
1034
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
1035
+ :param resblock_updown: use residual blocks for up/downsampling.
1036
+ :param use_new_attention_order: use a different attention pattern for potentially
1037
+ increased efficiency.
1038
+ """
1039
+
1040
+ def __init__(
1041
+ self,
1042
+ in_channels,
1043
+ model_channels,
1044
+ out_channels,
1045
+ num_res_blocks,
1046
+ attention_resolutions,
1047
+ dropout=0,
1048
+ channel_mult=(1, 2, 4, 8),
1049
+ conv_resample=True,
1050
+ dims=2,
1051
+ num_classes=None,
1052
+ use_checkpoint=False,
1053
+ use_fp16=False,
1054
+ num_heads=-1,
1055
+ num_head_channels=-1,
1056
+ num_heads_upsample=-1,
1057
+ use_scale_shift_norm=False,
1058
+ resblock_updown=False,
1059
+ use_new_attention_order=False,
1060
+ use_spatial_transformer=False, # custom transformer support
1061
+ transformer_depth=1, # custom transformer support
1062
+ context_dim=None, # custom transformer support
1063
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1064
+ legacy=True,
1065
+ disable_self_attentions=None,
1066
+ num_attention_blocks=None,
1067
+ disable_middle_self_attn=False,
1068
+ use_linear_in_transformer=False,
1069
+ spatial_transformer_attn_type="softmax",
1070
+ adm_in_channels=None,
1071
+ use_fairscale_checkpoint=False,
1072
+ offload_to_cpu=False,
1073
+ transformer_depth_middle=None,
1074
+ ):
1075
+ super().__init__()
1076
+ from omegaconf.listconfig import ListConfig
1077
+
1078
+ if use_spatial_transformer:
1079
+ assert (
1080
+ context_dim is not None
1081
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1082
+
1083
+ if context_dim is not None:
1084
+ assert (
1085
+ use_spatial_transformer
1086
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1087
+ if type(context_dim) == ListConfig:
1088
+ context_dim = list(context_dim)
1089
+
1090
+ if num_heads_upsample == -1:
1091
+ num_heads_upsample = num_heads
1092
+
1093
+ if num_heads == -1:
1094
+ assert (
1095
+ num_head_channels != -1
1096
+ ), "Either num_heads or num_head_channels has to be set"
1097
+
1098
+ if num_head_channels == -1:
1099
+ assert (
1100
+ num_heads != -1
1101
+ ), "Either num_heads or num_head_channels has to be set"
1102
+
1103
+ self.in_channels = in_channels
1104
+ self.model_channels = model_channels
1105
+ self.out_channels = out_channels
1106
+ if isinstance(transformer_depth, int):
1107
+ transformer_depth = len(channel_mult) * [transformer_depth]
1108
+ elif isinstance(transformer_depth, ListConfig):
1109
+ transformer_depth = list(transformer_depth)
1110
+ transformer_depth_middle = default(
1111
+ transformer_depth_middle, transformer_depth[-1]
1112
+ )
1113
+
1114
+ if isinstance(num_res_blocks, int):
1115
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1116
+ else:
1117
+ if len(num_res_blocks) != len(channel_mult):
1118
+ raise ValueError(
1119
+ "provide num_res_blocks either as an int (globally constant) or "
1120
+ "as a list/tuple (per-level) with the same length as channel_mult"
1121
+ )
1122
+ self.num_res_blocks = num_res_blocks
1123
+ # self.num_res_blocks = num_res_blocks
1124
+ if disable_self_attentions is not None:
1125
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1126
+ assert len(disable_self_attentions) == len(channel_mult)
1127
+ if num_attention_blocks is not None:
1128
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
1129
+ assert all(
1130
+ map(
1131
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1132
+ range(len(num_attention_blocks)),
1133
+ )
1134
+ )
1135
+ print(
1136
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1137
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1138
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1139
+ f"attention will still not be set."
1140
+ ) # todo: convert to warning
1141
+
1142
+ self.attention_resolutions = attention_resolutions
1143
+ self.dropout = dropout
1144
+ self.channel_mult = channel_mult
1145
+ self.conv_resample = conv_resample
1146
+ self.num_classes = num_classes
1147
+ self.use_checkpoint = use_checkpoint
1148
+ if use_fp16:
1149
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
1150
+ # self.dtype = th.float16 if use_fp16 else th.float32
1151
+ self.num_heads = num_heads
1152
+ self.num_head_channels = num_head_channels
1153
+ self.num_heads_upsample = num_heads_upsample
1154
+ self.predict_codebook_ids = n_embed is not None
1155
+
1156
+ assert use_fairscale_checkpoint != use_checkpoint or not (
1157
+ use_checkpoint or use_fairscale_checkpoint
1158
+ )
1159
+
1160
+ self.use_fairscale_checkpoint = False
1161
+ checkpoint_wrapper_fn = (
1162
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1163
+ if self.use_fairscale_checkpoint
1164
+ else lambda x: x
1165
+ )
1166
+
1167
+ time_embed_dim = model_channels * 4
1168
+ self.time_embed = checkpoint_wrapper_fn(
1169
+ nn.Sequential(
1170
+ linear(model_channels, time_embed_dim),
1171
+ nn.SiLU(),
1172
+ linear(time_embed_dim, time_embed_dim),
1173
+ )
1174
+ )
1175
+
1176
+ if self.num_classes is not None:
1177
+ if isinstance(self.num_classes, int):
1178
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1179
+ elif self.num_classes == "continuous":
1180
+ print("setting up linear c_adm embedding layer")
1181
+ self.label_emb = nn.Linear(1, time_embed_dim)
1182
+ elif self.num_classes == "timestep":
1183
+ self.label_emb = checkpoint_wrapper_fn(
1184
+ nn.Sequential(
1185
+ Timestep(model_channels),
1186
+ nn.Sequential(
1187
+ linear(model_channels, time_embed_dim),
1188
+ nn.SiLU(),
1189
+ linear(time_embed_dim, time_embed_dim),
1190
+ ),
1191
+ )
1192
+ )
1193
+ elif self.num_classes == "sequential":
1194
+ assert adm_in_channels is not None
1195
+ self.label_emb = nn.Sequential(
1196
+ nn.Sequential(
1197
+ linear(adm_in_channels, time_embed_dim),
1198
+ nn.SiLU(),
1199
+ linear(time_embed_dim, time_embed_dim),
1200
+ )
1201
+ )
1202
+ else:
1203
+ raise ValueError()
1204
+
1205
+ self.input_blocks = nn.ModuleList(
1206
+ [
1207
+ TimestepEmbedSequential(
1208
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1209
+ )
1210
+ ]
1211
+ )
1212
+ self._feature_size = model_channels
1213
+ input_block_chans = [model_channels]
1214
+ ch = model_channels
1215
+ ds = 1
1216
+ for level, mult in enumerate(channel_mult):
1217
+ for nr in range(self.num_res_blocks[level]):
1218
+ layers = [
1219
+ checkpoint_wrapper_fn(
1220
+ ResBlock(
1221
+ ch,
1222
+ time_embed_dim,
1223
+ dropout,
1224
+ out_channels=mult * model_channels,
1225
+ dims=dims,
1226
+ use_checkpoint=use_checkpoint,
1227
+ use_scale_shift_norm=use_scale_shift_norm,
1228
+ )
1229
+ )
1230
+ ]
1231
+ ch = mult * model_channels
1232
+ if ds in attention_resolutions:
1233
+ if num_head_channels == -1:
1234
+ dim_head = ch // num_heads
1235
+ else:
1236
+ num_heads = ch // num_head_channels
1237
+ dim_head = num_head_channels
1238
+ if legacy:
1239
+ # num_heads = 1
1240
+ dim_head = (
1241
+ ch // num_heads
1242
+ if use_spatial_transformer
1243
+ else num_head_channels
1244
+ )
1245
+ if exists(disable_self_attentions):
1246
+ disabled_sa = disable_self_attentions[level]
1247
+ else:
1248
+ disabled_sa = False
1249
+
1250
+ if (
1251
+ not exists(num_attention_blocks)
1252
+ or nr < num_attention_blocks[level]
1253
+ ):
1254
+ layers.append(
1255
+ checkpoint_wrapper_fn(
1256
+ AttentionBlock(
1257
+ ch,
1258
+ use_checkpoint=use_checkpoint,
1259
+ num_heads=num_heads,
1260
+ num_head_channels=dim_head,
1261
+ use_new_attention_order=use_new_attention_order,
1262
+ )
1263
+ )
1264
+ if not use_spatial_transformer
1265
+ else checkpoint_wrapper_fn(
1266
+ SpatialTransformer(
1267
+ ch,
1268
+ num_heads,
1269
+ dim_head,
1270
+ depth=transformer_depth[level],
1271
+ context_dim=context_dim,
1272
+ disable_self_attn=disabled_sa,
1273
+ use_linear=use_linear_in_transformer,
1274
+ attn_type=spatial_transformer_attn_type,
1275
+ use_checkpoint=use_checkpoint,
1276
+ )
1277
+ )
1278
+ )
1279
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1280
+ self._feature_size += ch
1281
+ input_block_chans.append(ch)
1282
+ if level != len(channel_mult) - 1:
1283
+ out_ch = ch
1284
+ self.input_blocks.append(
1285
+ TimestepEmbedSequential(
1286
+ checkpoint_wrapper_fn(
1287
+ ResBlock(
1288
+ ch,
1289
+ time_embed_dim,
1290
+ dropout,
1291
+ out_channels=out_ch,
1292
+ dims=dims,
1293
+ use_checkpoint=use_checkpoint,
1294
+ use_scale_shift_norm=use_scale_shift_norm,
1295
+ down=True,
1296
+ )
1297
+ )
1298
+ if resblock_updown
1299
+ else Downsample(
1300
+ ch, conv_resample, dims=dims, out_channels=out_ch
1301
+ )
1302
+ )
1303
+ )
1304
+ ch = out_ch
1305
+ input_block_chans.append(ch)
1306
+ ds *= 2
1307
+ self._feature_size += ch
1308
+
1309
+ if num_head_channels == -1:
1310
+ dim_head = ch // num_heads
1311
+ else:
1312
+ num_heads = ch // num_head_channels
1313
+ dim_head = num_head_channels
1314
+ if legacy:
1315
+ # num_heads = 1
1316
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1317
+ self.middle_block = TimestepEmbedSequential(
1318
+ checkpoint_wrapper_fn(
1319
+ ResBlock(
1320
+ ch,
1321
+ time_embed_dim,
1322
+ dropout,
1323
+ dims=dims,
1324
+ use_checkpoint=use_checkpoint,
1325
+ use_scale_shift_norm=use_scale_shift_norm,
1326
+ )
1327
+ ),
1328
+ checkpoint_wrapper_fn(
1329
+ AttentionBlock(
1330
+ ch,
1331
+ use_checkpoint=use_checkpoint,
1332
+ num_heads=num_heads,
1333
+ num_head_channels=dim_head,
1334
+ use_new_attention_order=use_new_attention_order,
1335
+ )
1336
+ )
1337
+ if not use_spatial_transformer
1338
+ else checkpoint_wrapper_fn(
1339
+ SpatialTransformer( # always uses a self-attn
1340
+ ch,
1341
+ num_heads,
1342
+ dim_head,
1343
+ depth=transformer_depth_middle,
1344
+ context_dim=context_dim,
1345
+ disable_self_attn=disable_middle_self_attn,
1346
+ use_linear=use_linear_in_transformer,
1347
+ attn_type=spatial_transformer_attn_type,
1348
+ use_checkpoint=use_checkpoint,
1349
+ )
1350
+ ),
1351
+ checkpoint_wrapper_fn(
1352
+ ResBlock(
1353
+ ch,
1354
+ time_embed_dim,
1355
+ dropout,
1356
+ dims=dims,
1357
+ use_checkpoint=use_checkpoint,
1358
+ use_scale_shift_norm=use_scale_shift_norm,
1359
+ )
1360
+ ),
1361
+ )
1362
+ self._feature_size += ch
1363
+
1364
+ self.output_blocks = nn.ModuleList([])
1365
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1366
+ for i in range(self.num_res_blocks[level] + 1):
1367
+ ich = input_block_chans.pop()
1368
+ layers = [
1369
+ checkpoint_wrapper_fn(
1370
+ ResBlock(
1371
+ ch + ich,
1372
+ time_embed_dim,
1373
+ dropout,
1374
+ out_channels=model_channels * mult,
1375
+ dims=dims,
1376
+ use_checkpoint=use_checkpoint,
1377
+ use_scale_shift_norm=use_scale_shift_norm,
1378
+ )
1379
+ )
1380
+ ]
1381
+ ch = model_channels * mult
1382
+ if ds in attention_resolutions:
1383
+ if num_head_channels == -1:
1384
+ dim_head = ch // num_heads
1385
+ else:
1386
+ num_heads = ch // num_head_channels
1387
+ dim_head = num_head_channels
1388
+ if legacy:
1389
+ # num_heads = 1
1390
+ dim_head = (
1391
+ ch // num_heads
1392
+ if use_spatial_transformer
1393
+ else num_head_channels
1394
+ )
1395
+ if exists(disable_self_attentions):
1396
+ disabled_sa = disable_self_attentions[level]
1397
+ else:
1398
+ disabled_sa = False
1399
+
1400
+ if (
1401
+ not exists(num_attention_blocks)
1402
+ or i < num_attention_blocks[level]
1403
+ ):
1404
+ layers.append(
1405
+ checkpoint_wrapper_fn(
1406
+ AttentionBlock(
1407
+ ch,
1408
+ use_checkpoint=use_checkpoint,
1409
+ num_heads=num_heads_upsample,
1410
+ num_head_channels=dim_head,
1411
+ use_new_attention_order=use_new_attention_order,
1412
+ )
1413
+ )
1414
+ if not use_spatial_transformer
1415
+ else checkpoint_wrapper_fn(
1416
+ SpatialTransformer(
1417
+ ch,
1418
+ num_heads,
1419
+ dim_head,
1420
+ depth=transformer_depth[level],
1421
+ context_dim=context_dim,
1422
+ disable_self_attn=disabled_sa,
1423
+ use_linear=use_linear_in_transformer,
1424
+ attn_type=spatial_transformer_attn_type,
1425
+ use_checkpoint=use_checkpoint,
1426
+ )
1427
+ )
1428
+ )
1429
+ if level and i == self.num_res_blocks[level]:
1430
+ out_ch = ch
1431
+ layers.append(
1432
+ checkpoint_wrapper_fn(
1433
+ ResBlock(
1434
+ ch,
1435
+ time_embed_dim,
1436
+ dropout,
1437
+ out_channels=out_ch,
1438
+ dims=dims,
1439
+ use_checkpoint=use_checkpoint,
1440
+ use_scale_shift_norm=use_scale_shift_norm,
1441
+ up=True,
1442
+ )
1443
+ )
1444
+ if resblock_updown
1445
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1446
+ )
1447
+ ds //= 2
1448
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1449
+ self._feature_size += ch
1450
+
1451
+ self.out = checkpoint_wrapper_fn(
1452
+ nn.Sequential(
1453
+ normalization(ch),
1454
+ nn.SiLU(),
1455
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1456
+ )
1457
+ )
1458
+ if self.predict_codebook_ids:
1459
+ self.id_predictor = checkpoint_wrapper_fn(
1460
+ nn.Sequential(
1461
+ normalization(ch),
1462
+ conv_nd(dims, model_channels, n_embed, 1),
1463
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1464
+ )
1465
+ )
1466
+
1467
+ def convert_to_fp16(self):
1468
+ """
1469
+ Convert the torso of the model to float16.
1470
+ """
1471
+ self.input_blocks.apply(convert_module_to_f16)
1472
+ self.middle_block.apply(convert_module_to_f16)
1473
+ self.output_blocks.apply(convert_module_to_f16)
1474
+
1475
+ def convert_to_fp32(self):
1476
+ """
1477
+ Convert the torso of the model to float32.
1478
+ """
1479
+ self.input_blocks.apply(convert_module_to_f32)
1480
+ self.middle_block.apply(convert_module_to_f32)
1481
+ self.output_blocks.apply(convert_module_to_f32)
1482
+
1483
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
1484
+ """
1485
+ Apply the model to an input batch.
1486
+ :param x: an [N x C x ...] Tensor of inputs.
1487
+ :param timesteps: a 1-D batch of timesteps.
1488
+ :param context: conditioning plugged in via crossattn
1489
+ :param y: an [N] Tensor of labels, if class-conditional.
1490
+ :return: an [N x C x ...] Tensor of outputs.
1491
+ """
1492
+ assert (y is not None) == (
1493
+ self.num_classes is not None
1494
+ ), "must specify y if and only if the model is class-conditional"
1495
+ hs = []
1496
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
1497
+ emb = self.time_embed(t_emb)
1498
+
1499
+ if self.num_classes is not None:
1500
+ assert y.shape[0] == x.shape[0]
1501
+ emb = emb + self.label_emb(y)
1502
+
1503
+ # h = x.type(self.dtype)
1504
+ h = x
1505
+ for i, module in enumerate(self.input_blocks):
1506
+ h = module(h, emb, context)
1507
+ hs.append(h)
1508
+ h = self.middle_block(h, emb, context)
1509
+ for i, module in enumerate(self.output_blocks):
1510
+ h = th.cat([h, hs.pop()], dim=1)
1511
+ h = module(h, emb, context)
1512
+ h = h.type(x.dtype)
1513
+ if self.predict_codebook_ids:
1514
+ assert False, "not supported anymore. what the f*** are you doing?"
1515
+ else:
1516
+ return self.out(h)
1517
+
1518
+
1519
+ import seaborn as sns
1520
+ import matplotlib.pyplot as plt
1521
+
1522
+ class UNetAddModel(nn.Module):
1523
+
1524
+ def __init__(
1525
+ self,
1526
+ in_channels,
1527
+ ctrl_channels,
1528
+ model_channels,
1529
+ out_channels,
1530
+ num_res_blocks,
1531
+ attention_resolutions,
1532
+ dropout=0,
1533
+ channel_mult=(1, 2, 4, 8),
1534
+ attn_type="attn2",
1535
+ attn_layers=[],
1536
+ conv_resample=True,
1537
+ dims=2,
1538
+ num_classes=None,
1539
+ use_checkpoint=False,
1540
+ use_fp16=False,
1541
+ num_heads=-1,
1542
+ num_head_channels=-1,
1543
+ num_heads_upsample=-1,
1544
+ use_scale_shift_norm=False,
1545
+ resblock_updown=False,
1546
+ use_new_attention_order=False,
1547
+ use_spatial_transformer=False, # custom transformer support
1548
+ transformer_depth=1, # custom transformer support
1549
+ context_dim=None, # custom transformer support
1550
+ add_context_dim=None,
1551
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
1552
+ legacy=True,
1553
+ disable_self_attentions=None,
1554
+ num_attention_blocks=None,
1555
+ disable_middle_self_attn=False,
1556
+ use_linear_in_transformer=False,
1557
+ spatial_transformer_attn_type="softmax",
1558
+ adm_in_channels=None,
1559
+ use_fairscale_checkpoint=False,
1560
+ offload_to_cpu=False,
1561
+ transformer_depth_middle=None,
1562
+ ):
1563
+ super().__init__()
1564
+ from omegaconf.listconfig import ListConfig
1565
+
1566
+ if use_spatial_transformer:
1567
+ assert (
1568
+ context_dim is not None
1569
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
1570
+
1571
+ if context_dim is not None:
1572
+ assert (
1573
+ use_spatial_transformer
1574
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
1575
+ if type(context_dim) == ListConfig:
1576
+ context_dim = list(context_dim)
1577
+
1578
+ if num_heads_upsample == -1:
1579
+ num_heads_upsample = num_heads
1580
+
1581
+ if num_heads == -1:
1582
+ assert (
1583
+ num_head_channels != -1
1584
+ ), "Either num_heads or num_head_channels has to be set"
1585
+
1586
+ if num_head_channels == -1:
1587
+ assert (
1588
+ num_heads != -1
1589
+ ), "Either num_heads or num_head_channels has to be set"
1590
+
1591
+ self.in_channels = in_channels
1592
+ self.ctrl_channels = ctrl_channels
1593
+ self.model_channels = model_channels
1594
+ self.out_channels = out_channels
1595
+ if isinstance(transformer_depth, int):
1596
+ transformer_depth = len(channel_mult) * [transformer_depth]
1597
+ elif isinstance(transformer_depth, ListConfig):
1598
+ transformer_depth = list(transformer_depth)
1599
+ transformer_depth_middle = default(
1600
+ transformer_depth_middle, transformer_depth[-1]
1601
+ )
1602
+
1603
+ if isinstance(num_res_blocks, int):
1604
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
1605
+ else:
1606
+ if len(num_res_blocks) != len(channel_mult):
1607
+ raise ValueError(
1608
+ "provide num_res_blocks either as an int (globally constant) or "
1609
+ "as a list/tuple (per-level) with the same length as channel_mult"
1610
+ )
1611
+ self.num_res_blocks = num_res_blocks
1612
+ # self.num_res_blocks = num_res_blocks
1613
+ if disable_self_attentions is not None:
1614
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
1615
+ assert len(disable_self_attentions) == len(channel_mult)
1616
+ if num_attention_blocks is not None:
1617
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
1618
+ assert all(
1619
+ map(
1620
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
1621
+ range(len(num_attention_blocks)),
1622
+ )
1623
+ )
1624
+ print(
1625
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
1626
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
1627
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
1628
+ f"attention will still not be set."
1629
+ ) # todo: convert to warning
1630
+
1631
+ self.attention_resolutions = attention_resolutions
1632
+ self.dropout = dropout
1633
+ self.channel_mult = channel_mult
1634
+ self.conv_resample = conv_resample
1635
+ self.num_classes = num_classes
1636
+ self.use_checkpoint = use_checkpoint
1637
+ if use_fp16:
1638
+ print("WARNING: use_fp16 was dropped and has no effect anymore.")
1639
+ # self.dtype = th.float16 if use_fp16 else th.float32
1640
+ self.num_heads = num_heads
1641
+ self.num_head_channels = num_head_channels
1642
+ self.num_heads_upsample = num_heads_upsample
1643
+ self.predict_codebook_ids = n_embed is not None
1644
+
1645
+ assert use_fairscale_checkpoint != use_checkpoint or not (
1646
+ use_checkpoint or use_fairscale_checkpoint
1647
+ )
1648
+
1649
+ self.use_fairscale_checkpoint = False
1650
+ checkpoint_wrapper_fn = (
1651
+ partial(checkpoint_wrapper, offload_to_cpu=offload_to_cpu)
1652
+ if self.use_fairscale_checkpoint
1653
+ else lambda x: x
1654
+ )
1655
+
1656
+ time_embed_dim = model_channels * 4
1657
+ self.time_embed = checkpoint_wrapper_fn(
1658
+ nn.Sequential(
1659
+ linear(model_channels, time_embed_dim),
1660
+ nn.SiLU(),
1661
+ linear(time_embed_dim, time_embed_dim),
1662
+ )
1663
+ )
1664
+
1665
+ if self.num_classes is not None:
1666
+ if isinstance(self.num_classes, int):
1667
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
1668
+ elif self.num_classes == "continuous":
1669
+ print("setting up linear c_adm embedding layer")
1670
+ self.label_emb = nn.Linear(1, time_embed_dim)
1671
+ elif self.num_classes == "timestep":
1672
+ self.label_emb = checkpoint_wrapper_fn(
1673
+ nn.Sequential(
1674
+ Timestep(model_channels),
1675
+ nn.Sequential(
1676
+ linear(model_channels, time_embed_dim),
1677
+ nn.SiLU(),
1678
+ linear(time_embed_dim, time_embed_dim),
1679
+ ),
1680
+ )
1681
+ )
1682
+ elif self.num_classes == "sequential":
1683
+ assert adm_in_channels is not None
1684
+ self.label_emb = nn.Sequential(
1685
+ nn.Sequential(
1686
+ linear(adm_in_channels, time_embed_dim),
1687
+ nn.SiLU(),
1688
+ linear(time_embed_dim, time_embed_dim),
1689
+ )
1690
+ )
1691
+ else:
1692
+ raise ValueError()
1693
+
1694
+ self.input_blocks = nn.ModuleList(
1695
+ [
1696
+ TimestepEmbedSequential(
1697
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
1698
+ )
1699
+ ]
1700
+ )
1701
+ if self.ctrl_channels > 0:
1702
+ self.add_input_block = TimestepEmbedSequential(
1703
+ conv_nd(dims, ctrl_channels, 16, 3, padding=1),
1704
+ nn.SiLU(),
1705
+ conv_nd(dims, 16, 16, 3, padding=1),
1706
+ nn.SiLU(),
1707
+ conv_nd(dims, 16, 32, 3, padding=1),
1708
+ nn.SiLU(),
1709
+ conv_nd(dims, 32, 32, 3, padding=1),
1710
+ nn.SiLU(),
1711
+ conv_nd(dims, 32, 96, 3, padding=1),
1712
+ nn.SiLU(),
1713
+ conv_nd(dims, 96, 96, 3, padding=1),
1714
+ nn.SiLU(),
1715
+ conv_nd(dims, 96, 256, 3, padding=1),
1716
+ nn.SiLU(),
1717
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
1718
+ )
1719
+
1720
+ self._feature_size = model_channels
1721
+ input_block_chans = [model_channels]
1722
+ ch = model_channels
1723
+ ds = 1
1724
+ for level, mult in enumerate(channel_mult):
1725
+ for nr in range(self.num_res_blocks[level]):
1726
+ layers = [
1727
+ checkpoint_wrapper_fn(
1728
+ ResBlock(
1729
+ ch,
1730
+ time_embed_dim,
1731
+ dropout,
1732
+ out_channels=mult * model_channels,
1733
+ dims=dims,
1734
+ use_checkpoint=use_checkpoint,
1735
+ use_scale_shift_norm=use_scale_shift_norm,
1736
+ )
1737
+ )
1738
+ ]
1739
+ ch = mult * model_channels
1740
+ if ds in attention_resolutions:
1741
+ if num_head_channels == -1:
1742
+ dim_head = ch // num_heads
1743
+ else:
1744
+ num_heads = ch // num_head_channels
1745
+ dim_head = num_head_channels
1746
+ if legacy:
1747
+ # num_heads = 1
1748
+ dim_head = (
1749
+ ch // num_heads
1750
+ if use_spatial_transformer
1751
+ else num_head_channels
1752
+ )
1753
+ if exists(disable_self_attentions):
1754
+ disabled_sa = disable_self_attentions[level]
1755
+ else:
1756
+ disabled_sa = False
1757
+
1758
+ if (
1759
+ not exists(num_attention_blocks)
1760
+ or nr < num_attention_blocks[level]
1761
+ ):
1762
+ layers.append(
1763
+ checkpoint_wrapper_fn(
1764
+ AttentionBlock(
1765
+ ch,
1766
+ use_checkpoint=use_checkpoint,
1767
+ num_heads=num_heads,
1768
+ num_head_channels=dim_head,
1769
+ use_new_attention_order=use_new_attention_order,
1770
+ )
1771
+ )
1772
+ if not use_spatial_transformer
1773
+ else checkpoint_wrapper_fn(
1774
+ SpatialTransformer(
1775
+ ch,
1776
+ num_heads,
1777
+ dim_head,
1778
+ depth=transformer_depth[level],
1779
+ context_dim=context_dim,
1780
+ add_context_dim=add_context_dim,
1781
+ disable_self_attn=disabled_sa,
1782
+ use_linear=use_linear_in_transformer,
1783
+ attn_type=spatial_transformer_attn_type,
1784
+ use_checkpoint=use_checkpoint,
1785
+ )
1786
+ )
1787
+ )
1788
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
1789
+ self._feature_size += ch
1790
+ input_block_chans.append(ch)
1791
+ if level != len(channel_mult) - 1:
1792
+ out_ch = ch
1793
+ self.input_blocks.append(
1794
+ TimestepEmbedSequential(
1795
+ checkpoint_wrapper_fn(
1796
+ ResBlock(
1797
+ ch,
1798
+ time_embed_dim,
1799
+ dropout,
1800
+ out_channels=out_ch,
1801
+ dims=dims,
1802
+ use_checkpoint=use_checkpoint,
1803
+ use_scale_shift_norm=use_scale_shift_norm,
1804
+ down=True,
1805
+ )
1806
+ )
1807
+ if resblock_updown
1808
+ else Downsample(
1809
+ ch, conv_resample, dims=dims, out_channels=out_ch
1810
+ )
1811
+ )
1812
+ )
1813
+ ch = out_ch
1814
+ input_block_chans.append(ch)
1815
+ ds *= 2
1816
+ self._feature_size += ch
1817
+
1818
+ if num_head_channels == -1:
1819
+ dim_head = ch // num_heads
1820
+ else:
1821
+ num_heads = ch // num_head_channels
1822
+ dim_head = num_head_channels
1823
+ if legacy:
1824
+ # num_heads = 1
1825
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
1826
+ self.middle_block = TimestepEmbedSequential(
1827
+ checkpoint_wrapper_fn(
1828
+ ResBlock(
1829
+ ch,
1830
+ time_embed_dim,
1831
+ dropout,
1832
+ dims=dims,
1833
+ use_checkpoint=use_checkpoint,
1834
+ use_scale_shift_norm=use_scale_shift_norm,
1835
+ )
1836
+ ),
1837
+ checkpoint_wrapper_fn(
1838
+ AttentionBlock(
1839
+ ch,
1840
+ use_checkpoint=use_checkpoint,
1841
+ num_heads=num_heads,
1842
+ num_head_channels=dim_head,
1843
+ use_new_attention_order=use_new_attention_order,
1844
+ )
1845
+ )
1846
+ if not use_spatial_transformer
1847
+ else checkpoint_wrapper_fn(
1848
+ SpatialTransformer( # always uses a self-attn
1849
+ ch,
1850
+ num_heads,
1851
+ dim_head,
1852
+ depth=transformer_depth_middle,
1853
+ context_dim=context_dim,
1854
+ add_context_dim=add_context_dim,
1855
+ disable_self_attn=disable_middle_self_attn,
1856
+ use_linear=use_linear_in_transformer,
1857
+ attn_type=spatial_transformer_attn_type,
1858
+ use_checkpoint=use_checkpoint,
1859
+ )
1860
+ ),
1861
+ checkpoint_wrapper_fn(
1862
+ ResBlock(
1863
+ ch,
1864
+ time_embed_dim,
1865
+ dropout,
1866
+ dims=dims,
1867
+ use_checkpoint=use_checkpoint,
1868
+ use_scale_shift_norm=use_scale_shift_norm,
1869
+ )
1870
+ ),
1871
+ )
1872
+ self._feature_size += ch
1873
+
1874
+ self.output_blocks = nn.ModuleList([])
1875
+ for level, mult in list(enumerate(channel_mult))[::-1]:
1876
+ for i in range(self.num_res_blocks[level] + 1):
1877
+ ich = input_block_chans.pop()
1878
+ layers = [
1879
+ checkpoint_wrapper_fn(
1880
+ ResBlock(
1881
+ ch + ich,
1882
+ time_embed_dim,
1883
+ dropout,
1884
+ out_channels=model_channels * mult,
1885
+ dims=dims,
1886
+ use_checkpoint=use_checkpoint,
1887
+ use_scale_shift_norm=use_scale_shift_norm,
1888
+ )
1889
+ )
1890
+ ]
1891
+ ch = model_channels * mult
1892
+ if ds in attention_resolutions:
1893
+ if num_head_channels == -1:
1894
+ dim_head = ch // num_heads
1895
+ else:
1896
+ num_heads = ch // num_head_channels
1897
+ dim_head = num_head_channels
1898
+ if legacy:
1899
+ # num_heads = 1
1900
+ dim_head = (
1901
+ ch // num_heads
1902
+ if use_spatial_transformer
1903
+ else num_head_channels
1904
+ )
1905
+ if exists(disable_self_attentions):
1906
+ disabled_sa = disable_self_attentions[level]
1907
+ else:
1908
+ disabled_sa = False
1909
+
1910
+ if (
1911
+ not exists(num_attention_blocks)
1912
+ or i < num_attention_blocks[level]
1913
+ ):
1914
+ layers.append(
1915
+ checkpoint_wrapper_fn(
1916
+ AttentionBlock(
1917
+ ch,
1918
+ use_checkpoint=use_checkpoint,
1919
+ num_heads=num_heads_upsample,
1920
+ num_head_channels=dim_head,
1921
+ use_new_attention_order=use_new_attention_order,
1922
+ )
1923
+ )
1924
+ if not use_spatial_transformer
1925
+ else checkpoint_wrapper_fn(
1926
+ SpatialTransformer(
1927
+ ch,
1928
+ num_heads,
1929
+ dim_head,
1930
+ depth=transformer_depth[level],
1931
+ context_dim=context_dim,
1932
+ add_context_dim=add_context_dim,
1933
+ disable_self_attn=disabled_sa,
1934
+ use_linear=use_linear_in_transformer,
1935
+ attn_type=spatial_transformer_attn_type,
1936
+ use_checkpoint=use_checkpoint,
1937
+ )
1938
+ )
1939
+ )
1940
+ if level and i == self.num_res_blocks[level]:
1941
+ out_ch = ch
1942
+ layers.append(
1943
+ checkpoint_wrapper_fn(
1944
+ ResBlock(
1945
+ ch,
1946
+ time_embed_dim,
1947
+ dropout,
1948
+ out_channels=out_ch,
1949
+ dims=dims,
1950
+ use_checkpoint=use_checkpoint,
1951
+ use_scale_shift_norm=use_scale_shift_norm,
1952
+ up=True,
1953
+ )
1954
+ )
1955
+ if resblock_updown
1956
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1957
+ )
1958
+ ds //= 2
1959
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
1960
+ self._feature_size += ch
1961
+
1962
+ self.out = checkpoint_wrapper_fn(
1963
+ nn.Sequential(
1964
+ normalization(ch),
1965
+ nn.SiLU(),
1966
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1967
+ )
1968
+ )
1969
+ if self.predict_codebook_ids:
1970
+ self.id_predictor = checkpoint_wrapper_fn(
1971
+ nn.Sequential(
1972
+ normalization(ch),
1973
+ conv_nd(dims, model_channels, n_embed, 1),
1974
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1975
+ )
1976
+ )
1977
+
1978
+ # cache attn map
1979
+ self.attn_type = attn_type
1980
+ self.attn_layers = attn_layers
1981
+ self.attn_map_cache = []
1982
+ for name, module in self.named_modules():
1983
+ if name.endswith(self.attn_type):
1984
+ item = {"name": name, "heads": module.heads, "size": None, "attn_map": None}
1985
+ self.attn_map_cache.append(item)
1986
+ module.attn_map_cache = item
1987
+
1988
+ def clear_attn_map(self):
1989
+
1990
+ for item in self.attn_map_cache:
1991
+ if item["attn_map"] is not None:
1992
+ del item["attn_map"]
1993
+ item["attn_map"] = None
1994
+
1995
+ def save_attn_map(self, save_name="temp", tokens=""):
1996
+
1997
+ attn_maps = []
1998
+ for item in self.attn_map_cache:
1999
+ name = item["name"]
2000
+ if any([name.startswith(block) for block in self.attn_layers]):
2001
+ heads = item["heads"]
2002
+ attn_maps.append(item["attn_map"].detach().cpu())
2003
+
2004
+ attn_map = th.stack(attn_maps, dim=0)
2005
+ attn_map = th.mean(attn_map, dim=0)
2006
+
2007
+ # attn_map: bh * n * l
2008
+ bh, n, l = attn_map.shape # bh: batch size * heads / n : pixel length(h*w) / l: token length
2009
+ attn_map = attn_map.reshape((-1,heads,n,l)).mean(dim=1)
2010
+ b = attn_map.shape[0]
2011
+
2012
+ h = w = int(n**0.5)
2013
+ attn_map = attn_map.permute(0,2,1).reshape((b,l,h,w)).numpy()
2014
+
2015
+ attn_map_i = attn_map[-1]
2016
+
2017
+ l = attn_map_i.shape[0]
2018
+ fig = plt.figure(figsize=(12, 8), dpi=300)
2019
+ for j in range(12):
2020
+ if j >= l: break
2021
+ ax = fig.add_subplot(3, 4, j+1)
2022
+ sns.heatmap(attn_map_i[j], square=True, xticklabels=False, yticklabels=False)
2023
+ if j < len(tokens):
2024
+ ax.set_title(tokens[j])
2025
+ fig.savefig(f"temp/attn_map/attn_map_{save_name}.png")
2026
+ plt.close()
2027
+
2028
+ return attn_map_i
2029
+
2030
+ def forward(self, x, timesteps=None, context=None, add_context=None, y=None, **kwargs):
2031
+ """
2032
+ Apply the model to an input batch.
2033
+ :param x: an [N x C x ...] Tensor of inputs.
2034
+ :param timesteps: a 1-D batch of timesteps.
2035
+ :param context: conditioning plugged in via crossattn
2036
+ :param y: an [N] Tensor of labels, if class-conditional.
2037
+ :return: an [N x C x ...] Tensor of outputs.
2038
+ """
2039
+ assert (y is not None) == (
2040
+ self.num_classes is not None
2041
+ ), "must specify y if and only if the model is class-conditional"
2042
+
2043
+ self.clear_attn_map()
2044
+
2045
+ hs = []
2046
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
2047
+ emb = self.time_embed(t_emb)
2048
+
2049
+ if self.num_classes is not None:
2050
+ assert y.shape[0] == x.shape[0]
2051
+ emb = emb + self.label_emb(y)
2052
+
2053
+ # h = x.type(self.dtype)
2054
+ h = x
2055
+ if self.ctrl_channels > 0:
2056
+ in_h, add_h = th.split(h, [self.in_channels, self.ctrl_channels], dim=1)
2057
+
2058
  for i, module in enumerate(self.input_blocks):
2059
  if self.ctrl_channels > 0 and i == 0:
2060
+ h = module(in_h, emb, context, add_context) + self.add_input_block(add_h, emb, context, add_context)
2061
  else:
2062
+ h = module(h, emb, context, add_context)
2063
  hs.append(h)
2064
+ h = self.middle_block(h, emb, context, add_context)
2065
  for i, module in enumerate(self.output_blocks):
2066
  h = th.cat([h, hs.pop()], dim=1)
2067
+ h = module(h, emb, context, add_context)
2068
  h = h.type(x.dtype)
2069
 
2070
  return self.out(h)
sgm/modules/diffusionmodules/sampling.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, Union
7
 
8
  import imageio
9
  import torch
 
10
  import numpy as np
11
  import torch.nn.functional as F
12
  from omegaconf import ListConfig, OmegaConf
@@ -251,15 +252,47 @@ class EulerEDMSampler(EDMSampler):
251
 
252
  return x
253
 
254
- def save_segment_map(self, attn_maps, tokens=None, save_name=None):
 
 
255
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  sections = []
257
  for i in range(len(tokens)):
258
  attn_map = attn_maps[i]
 
 
 
 
 
 
 
 
 
259
  sections.append(attn_map)
260
 
261
  section = np.stack(sections)
262
- np.save(f"./temp/seg_map/seg_{save_name}.npy", section)
 
 
263
 
264
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
265
 
@@ -343,7 +376,8 @@ class EulerEDMSampler(EDMSampler):
343
  local_loss = torch.zeros(1)
344
  if save_attn:
345
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
346
- self.save_segment_map(attn_map, tokens=batch["label"][0], save_name=name)
 
347
 
348
  d = to_d(x, sigma_hat, denoised)
349
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
@@ -376,7 +410,7 @@ class EulerEDMSampler(EDMSampler):
376
 
377
  alpha = 20 * np.sqrt(scales[i])
378
  update = aae_enabled
379
- save_loss = aae_enabled
380
  save_attn = detailed and (i == (num_sigmas-1)//2)
381
  save_inter = aae_enabled
382
 
@@ -412,12 +446,195 @@ class EulerEDMSampler(EDMSampler):
412
  inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
413
  inters.append(inter.astype(np.uint8))
414
 
415
- # print(f"Local losses: {local_losses}")
416
 
417
  if len(inters) > 0:
418
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
419
 
420
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
 
423
  class HeunEDMSampler(EDMSampler):
 
7
 
8
  import imageio
9
  import torch
10
+ import json
11
  import numpy as np
12
  import torch.nn.functional as F
13
  from omegaconf import ListConfig, OmegaConf
 
252
 
253
  return x
254
 
255
+ def create_pascal_label_colormap(self):
256
+ """
257
+ PASCAL VOC 分割数据集的类别标签颜色映射label colormap
258
 
259
+ 返回:
260
+ 可视化分割结果的颜色映射Colormap
261
+ """
262
+ colormap = np.zeros((256, 3), dtype=int)
263
+ ind = np.arange(256, dtype=int)
264
+
265
+ for shift in reversed(range(8)):
266
+ for channel in range(3):
267
+ colormap[:, channel] |= ((ind >> channel) & 1) << shift
268
+ ind >>= 3
269
+
270
+ return colormap
271
+
272
+ def save_segment_map(self, image, attn_maps, tokens=None, save_name=None):
273
+
274
+ colormap = self.create_pascal_label_colormap()
275
+ H, W = image.shape[-2:]
276
+
277
+ image_ = image*0.3
278
  sections = []
279
  for i in range(len(tokens)):
280
  attn_map = attn_maps[i]
281
+ attn_map_t = np.tile(attn_map[None], (1,3,1,1)) # b, 3, h, w
282
+ attn_map_t = torch.from_numpy(attn_map_t)
283
+ attn_map_t = F.interpolate(attn_map_t, (W, H))
284
+
285
+ color = torch.from_numpy(colormap[i+1][None,:,None,None] / 255.0)
286
+ colored_attn_map = attn_map_t * color
287
+ colored_attn_map = colored_attn_map.to(device=image_.device)
288
+
289
+ image_ += colored_attn_map*0.7
290
  sections.append(attn_map)
291
 
292
  section = np.stack(sections)
293
+ np.save(f"temp/seg_map/seg_{save_name}.npy", section)
294
+
295
+ save_image(image_, f"temp/seg_map/seg_{save_name}.png", normalize=True)
296
 
297
  def get_init_noise(self, cfgs, model, cond, batch, uc=None):
298
 
 
376
  local_loss = torch.zeros(1)
377
  if save_attn:
378
  attn_map = model.model.diffusion_model.save_attn_map(save_name=name, tokens=batch["label"][0])
379
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
380
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
381
 
382
  d = to_d(x, sigma_hat, denoised)
383
  dt = append_dims(next_sigma - sigma_hat, x.ndim)
 
410
 
411
  alpha = 20 * np.sqrt(scales[i])
412
  update = aae_enabled
413
+ save_loss = detailed
414
  save_attn = detailed and (i == (num_sigmas-1)//2)
415
  save_inter = aae_enabled
416
 
 
446
  inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
447
  inters.append(inter.astype(np.uint8))
448
 
449
+ print(f"Local losses: {local_losses}")
450
 
451
  if len(inters) > 0:
452
  imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.02)
453
 
454
  return x
455
+
456
+
457
+ class EulerEDMDualSampler(EulerEDMSampler):
458
+
459
+ def prepare_sampling_loop(self, x, cond, uc_1=None, uc_2=None, num_steps=None):
460
+ sigmas = self.discretization(
461
+ self.num_steps if num_steps is None else num_steps, device=self.device
462
+ )
463
+ uc_1 = default(uc_1, cond)
464
+ uc_2 = default(uc_2, cond)
465
+
466
+ x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
467
+ num_sigmas = len(sigmas)
468
+
469
+ s_in = x.new_ones([x.shape[0]])
470
+
471
+ return x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2
472
+
473
+ def denoise(self, x, model, sigma, cond, uc_1, uc_2):
474
+ denoised = model.denoiser(model.model, *self.guider.prepare_inputs(x, sigma, cond, uc_1, uc_2))
475
+ denoised = self.guider(denoised, sigma)
476
+ return denoised
477
+
478
+ def get_init_noise(self, cfgs, model, cond, batch, uc_1=None, uc_2=None):
479
+
480
+ H, W = batch["target_size_as_tuple"][0]
481
+ shape = (cfgs.batch_size, cfgs.channel, int(H) // cfgs.factor, int(W) // cfgs.factor)
482
+
483
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
484
+ x = randn.clone()
485
+
486
+ xs = []
487
+ self.verbose = False
488
+ for _ in range(cfgs.noise_iters):
489
+
490
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
491
+ x, cond, uc_1, uc_2, num_steps=2
492
+ )
493
+
494
+ superv = {
495
+ "mask": batch["mask"] if "mask" in batch else None,
496
+ "seg_mask": batch["seg_mask"] if "seg_mask" in batch else None
497
+ }
498
+
499
+ local_losses = []
500
+
501
+ for i in self.get_sigma_gen(num_sigmas):
502
+
503
+ gamma = (
504
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
505
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
506
+ else 0.0
507
+ )
508
+
509
+ x, inter, local_loss = self.sampler_step(
510
+ s_in * sigmas[i],
511
+ s_in * sigmas[i + 1],
512
+ model,
513
+ x,
514
+ cond,
515
+ superv,
516
+ uc_1,
517
+ uc_2,
518
+ gamma,
519
+ save_loss=True
520
+ )
521
+
522
+ local_losses.append(local_loss.item())
523
+
524
+ xs.append((randn, local_losses[-1]))
525
+
526
+ randn = torch.randn(shape).to(torch.device("cuda", index=cfgs.gpu))
527
+ x = randn.clone()
528
+
529
+ self.verbose = True
530
+
531
+ xs.sort(key = lambda x: x[-1])
532
+
533
+ if len(xs) > 0:
534
+ print(f"Init local loss: Best {xs[0][1]} Worst {xs[-1][1]}")
535
+ x = xs[0][0]
536
+
537
+ return x
538
+
539
+ def sampler_step(self, sigma, next_sigma, model, x, cond, batch=None, uc_1=None, uc_2=None,
540
+ gamma=0.0, alpha=0, iter_enabled=False, thres=None, update=False,
541
+ name=None, save_loss=False, save_attn=False, save_inter=False):
542
+
543
+ sigma_hat = sigma * (gamma + 1.0)
544
+ if gamma > 0:
545
+ eps = torch.randn_like(x) * self.s_noise
546
+ x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
547
+
548
+ if update:
549
+ x = self.attend_and_excite(x, model, sigma_hat, cond, batch, alpha, iter_enabled, thres)
550
+
551
+ denoised = self.denoise(x, model, sigma_hat, cond, uc_1, uc_2)
552
+ denoised_decode = model.decode_first_stage(denoised) if save_inter else None
553
+
554
+ if save_loss:
555
+ local_loss = model.loss_fn.get_min_local_loss(model.model.diffusion_model.attn_map_cache, batch["mask"], batch["seg_mask"])
556
+ local_loss = local_loss[-local_loss.shape[0]//3:]
557
+ else:
558
+ local_loss = torch.zeros(1)
559
+ if save_attn:
560
+ attn_map = model.model.diffusion_model.save_attn_map(save_name=name, save_single=True)
561
+ denoised_decode = model.decode_first_stage(denoised) if denoised_decode is None else denoised_decode
562
+ self.save_segment_map(denoised_decode, attn_map, tokens=batch["label"][0], save_name=name)
563
+
564
+ d = to_d(x, sigma_hat, denoised)
565
+ dt = append_dims(next_sigma - sigma_hat, x.ndim)
566
+
567
+ euler_step = self.euler_step(x, d, dt)
568
+
569
+ return euler_step, denoised_decode, local_loss
570
+
571
+ def __call__(self, model, x, cond, batch=None, uc_1=None, uc_2=None, num_steps=None, init_step=0,
572
+ name=None, aae_enabled=False, detailed=False):
573
+
574
+ x, s_in, sigmas, num_sigmas, cond, uc_1, uc_2 = self.prepare_sampling_loop(
575
+ x, cond, uc_1, uc_2, num_steps
576
+ )
577
+
578
+ name = batch["name"][0]
579
+ inters = []
580
+ local_losses = []
581
+ scales = np.linspace(start=1.0, stop=0, num=num_sigmas)
582
+ iter_lst = np.linspace(start=5, stop=25, num=6, dtype=np.int32)
583
+ thres_lst = np.linspace(start=-0.5, stop=-0.8, num=6)
584
+
585
+ for i in self.get_sigma_gen(num_sigmas, init_step=init_step):
586
+
587
+ gamma = (
588
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
589
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
590
+ else 0.0
591
+ )
592
+
593
+ alpha = 20 * np.sqrt(scales[i])
594
+ update = aae_enabled
595
+ save_loss = aae_enabled
596
+ save_attn = detailed and (i == (num_sigmas-1)//2)
597
+ save_inter = aae_enabled
598
+
599
+ if i in iter_lst:
600
+ iter_enabled = True
601
+ thres = thres_lst[list(iter_lst).index(i)]
602
+ else:
603
+ iter_enabled = False
604
+ thres = 0.0
605
+
606
+ x, inter, local_loss = self.sampler_step(
607
+ s_in * sigmas[i],
608
+ s_in * sigmas[i + 1],
609
+ model,
610
+ x,
611
+ cond,
612
+ batch,
613
+ uc_1,
614
+ uc_2,
615
+ gamma,
616
+ alpha=alpha,
617
+ iter_enabled=iter_enabled,
618
+ thres=thres,
619
+ update=update,
620
+ name=name,
621
+ save_loss=save_loss,
622
+ save_attn=save_attn,
623
+ save_inter=save_inter
624
+ )
625
+
626
+ local_losses.append(local_loss.item())
627
+ if inter is not None:
628
+ inter = torch.clamp((inter + 1.0) / 2.0, min=0.0, max=1.0)[0]
629
+ inter = inter.cpu().numpy().transpose(1, 2, 0) * 255
630
+ inters.append(inter.astype(np.uint8))
631
+
632
+ print(f"Local losses: {local_losses}")
633
+
634
+ if len(inters) > 0:
635
+ imageio.mimsave(f"./temp/inters/{name}.gif", inters, 'GIF', duration=0.1)
636
+
637
+ return x
638
 
639
 
640
  class HeunEDMSampler(EDMSampler):
sgm/modules/diffusionmodules/sampling_utils.py CHANGED
@@ -7,7 +7,10 @@ from ...util import append_dims
7
  class NoDynamicThresholding:
8
  def __call__(self, uncond, cond, scale):
9
  return uncond + scale * (cond - uncond)
10
-
 
 
 
11
 
12
  def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
13
  if order - 1 > i:
 
7
  class NoDynamicThresholding:
8
  def __call__(self, uncond, cond, scale):
9
  return uncond + scale * (cond - uncond)
10
+
11
+ class DualThresholding: # Dual condition CFG (from instructPix2Pix)
12
+ def __call__(self, uncond_1, uncond_2, cond, scale):
13
+ return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2)
14
 
15
  def linear_multistep_coeff(order, t, i, j, epsrel=1e-4):
16
  if order - 1 > i:
sgm/modules/diffusionmodules/wrappers.py CHANGED
@@ -28,8 +28,8 @@ class OpenAIWrapper(IdentityWrapper):
28
  return self.diffusion_model(
29
  x,
30
  timesteps=t,
31
- t_context=c.get("t_crossattn", None),
32
- v_context=c.get("v_crossattn", None),
33
  y=c.get("vector", None),
34
  **kwargs
35
  )
 
28
  return self.diffusion_model(
29
  x,
30
  timesteps=t,
31
+ context=c.get("crossattn", None),
32
+ add_context=c.get("add_crossattn", None),
33
  y=c.get("vector", None),
34
  **kwargs
35
  )
sgm/modules/encoders/modules.py CHANGED
@@ -14,7 +14,6 @@ from transformers import (
14
  ByT5Tokenizer,
15
  CLIPTextModel,
16
  CLIPTokenizer,
17
- CLIPVisionModel,
18
  T5EncoderModel,
19
  T5Tokenizer,
20
  )
@@ -39,19 +38,18 @@ import pytorch_lightning as pl
39
  from torchvision import transforms
40
  from timm.models.vision_transformer import VisionTransformer
41
  from safetensors.torch import load_file as load_safetensors
42
- from torchvision.utils import save_image
43
 
44
  # disable warning
45
  from transformers import logging
46
  logging.set_verbosity_error()
47
 
48
  class AbstractEmbModel(nn.Module):
49
- def __init__(self):
50
  super().__init__()
51
  self._is_trainable = None
52
  self._ucg_rate = None
53
  self._input_key = None
54
- self._emb_key = None
55
 
56
  @property
57
  def is_trainable(self) -> bool:
@@ -65,10 +63,6 @@ class AbstractEmbModel(nn.Module):
65
  def input_key(self) -> str:
66
  return self._input_key
67
 
68
- @property
69
- def emb_key(self) -> str:
70
- return self._emb_key
71
-
72
  @is_trainable.setter
73
  def is_trainable(self, value: bool):
74
  self._is_trainable = value
@@ -81,10 +75,6 @@ class AbstractEmbModel(nn.Module):
81
  def input_key(self, value: str):
82
  self._input_key = value
83
 
84
- @emb_key.setter
85
- def emb_key(self, value: str):
86
- self._emb_key = value
87
-
88
  @is_trainable.deleter
89
  def is_trainable(self):
90
  del self._is_trainable
@@ -97,13 +87,8 @@ class AbstractEmbModel(nn.Module):
97
  def input_key(self):
98
  del self._input_key
99
 
100
- @emb_key.deleter
101
- def emb_key(self):
102
- del self._emb_key
103
-
104
 
105
  class GeneralConditioner(nn.Module):
106
-
107
  OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
108
  KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
109
 
@@ -124,8 +109,7 @@ class GeneralConditioner(nn.Module):
124
  f"Initialized embedder #{n}: {embedder.__class__.__name__} "
125
  f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
126
  )
127
- if "emb_key" in embconfig:
128
- embedder.emb_key = embconfig["emb_key"]
129
  if "input_key" in embconfig:
130
  embedder.input_key = embconfig["input_key"]
131
  elif "input_keys" in embconfig:
@@ -172,10 +156,13 @@ class GeneralConditioner(nn.Module):
172
  if not isinstance(emb_out, (list, tuple)):
173
  emb_out = [emb_out]
174
  for emb in emb_out:
175
- if embedder.emb_key is not None:
176
- out_key = embedder.emb_key
177
  else:
178
  out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
 
 
 
179
  if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
180
  emb = (
181
  expand_dims_like(
@@ -217,6 +204,28 @@ class GeneralConditioner(nn.Module):
217
  return c, uc
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  class InceptionV3(nn.Module):
221
  """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
222
  port with an additional squeeze at the end"""
@@ -400,6 +409,7 @@ class FrozenCLIPEmbedder(AbstractEmbModel):
400
 
401
  def freeze(self):
402
  self.transformer = self.transformer.eval()
 
403
  for param in self.parameters():
404
  param.requires_grad = False
405
 
@@ -684,24 +694,24 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
684
  if self.output_tokens:
685
  z, tokens = z[0], z[1]
686
  z = z.to(image.dtype)
687
- # if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
688
- # z = (
689
- # torch.bernoulli(
690
- # (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
691
- # )[:, None]
692
- # * z
693
- # )
694
- # if tokens is not None:
695
- # tokens = (
696
- # expand_dims_like(
697
- # torch.bernoulli(
698
- # (1.0 - self.ucg_rate)
699
- # * torch.ones(tokens.shape[0], device=tokens.device)
700
- # ),
701
- # tokens,
702
- # )
703
- # * tokens
704
- # )
705
  if self.unsqueeze_dim:
706
  z = z[:, None, :]
707
  if self.output_tokens:
@@ -797,7 +807,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
797
  return [clip_z, t5_z]
798
 
799
 
800
- class SpatialRescaler(AbstractEmbModel):
801
  def __init__(
802
  self,
803
  n_stages=1,
@@ -836,9 +846,6 @@ class SpatialRescaler(AbstractEmbModel):
836
  padding=kernel_size // 2,
837
  )
838
  self.wrap_video = wrap_video
839
-
840
- def freeze(self):
841
- pass
842
 
843
  def forward(self, x):
844
  if self.wrap_video and x.ndim == 5:
 
14
  ByT5Tokenizer,
15
  CLIPTextModel,
16
  CLIPTokenizer,
 
17
  T5EncoderModel,
18
  T5Tokenizer,
19
  )
 
38
  from torchvision import transforms
39
  from timm.models.vision_transformer import VisionTransformer
40
  from safetensors.torch import load_file as load_safetensors
 
41
 
42
  # disable warning
43
  from transformers import logging
44
  logging.set_verbosity_error()
45
 
46
  class AbstractEmbModel(nn.Module):
47
+ def __init__(self, is_add_embedder=False):
48
  super().__init__()
49
  self._is_trainable = None
50
  self._ucg_rate = None
51
  self._input_key = None
52
+ self.is_add_embedder = is_add_embedder
53
 
54
  @property
55
  def is_trainable(self) -> bool:
 
63
  def input_key(self) -> str:
64
  return self._input_key
65
 
 
 
 
 
66
  @is_trainable.setter
67
  def is_trainable(self, value: bool):
68
  self._is_trainable = value
 
75
  def input_key(self, value: str):
76
  self._input_key = value
77
 
 
 
 
 
78
  @is_trainable.deleter
79
  def is_trainable(self):
80
  del self._is_trainable
 
87
  def input_key(self):
88
  del self._input_key
89
 
 
 
 
 
90
 
91
  class GeneralConditioner(nn.Module):
 
92
  OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
93
  KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
94
 
 
109
  f"Initialized embedder #{n}: {embedder.__class__.__name__} "
110
  f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
111
  )
112
+
 
113
  if "input_key" in embconfig:
114
  embedder.input_key = embconfig["input_key"]
115
  elif "input_keys" in embconfig:
 
156
  if not isinstance(emb_out, (list, tuple)):
157
  emb_out = [emb_out]
158
  for emb in emb_out:
159
+ if embedder.is_add_embedder:
160
+ out_key = "add_crossattn"
161
  else:
162
  out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
163
+ if embedder.input_key == "mask":
164
+ H, W = batch["image"].shape[-2:]
165
+ emb = nn.functional.interpolate(emb, (H//8, W//8))
166
  if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
167
  emb = (
168
  expand_dims_like(
 
204
  return c, uc
205
 
206
 
207
+ class DualConditioner(GeneralConditioner):
208
+
209
+ def get_unconditional_conditioning(
210
+ self, batch_c, batch_uc_1=None, batch_uc_2=None, force_uc_zero_embeddings=None
211
+ ):
212
+ if force_uc_zero_embeddings is None:
213
+ force_uc_zero_embeddings = []
214
+ ucg_rates = list()
215
+ for embedder in self.embedders:
216
+ ucg_rates.append(embedder.ucg_rate)
217
+ embedder.ucg_rate = 0.0
218
+
219
+ c = self(batch_c)
220
+ uc_1 = self(batch_uc_1, force_uc_zero_embeddings) if batch_uc_1 is not None else None
221
+ uc_2 = self(batch_uc_2, force_uc_zero_embeddings[:1]) if batch_uc_2 is not None else None
222
+
223
+ for embedder, rate in zip(self.embedders, ucg_rates):
224
+ embedder.ucg_rate = rate
225
+
226
+ return c, uc_1, uc_2
227
+
228
+
229
  class InceptionV3(nn.Module):
230
  """Wrapper around the https://github.com/mseitzer/pytorch-fid inception
231
  port with an additional squeeze at the end"""
 
409
 
410
  def freeze(self):
411
  self.transformer = self.transformer.eval()
412
+
413
  for param in self.parameters():
414
  param.requires_grad = False
415
 
 
694
  if self.output_tokens:
695
  z, tokens = z[0], z[1]
696
  z = z.to(image.dtype)
697
+ if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
698
+ z = (
699
+ torch.bernoulli(
700
+ (1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
701
+ )[:, None]
702
+ * z
703
+ )
704
+ if tokens is not None:
705
+ tokens = (
706
+ expand_dims_like(
707
+ torch.bernoulli(
708
+ (1.0 - self.ucg_rate)
709
+ * torch.ones(tokens.shape[0], device=tokens.device)
710
+ ),
711
+ tokens,
712
+ )
713
+ * tokens
714
+ )
715
  if self.unsqueeze_dim:
716
  z = z[:, None, :]
717
  if self.output_tokens:
 
807
  return [clip_z, t5_z]
808
 
809
 
810
+ class SpatialRescaler(nn.Module):
811
  def __init__(
812
  self,
813
  n_stages=1,
 
846
  padding=kernel_size // 2,
847
  )
848
  self.wrap_video = wrap_video
 
 
 
849
 
850
  def forward(self, x):
851
  if self.wrap_video and x.ndim == 5:
util.py CHANGED
@@ -65,14 +65,6 @@ def prepare_batch(cfgs, batch):
65
  if isinstance(batch[key], torch.Tensor):
66
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
67
 
68
- batch_uc = deep_copy(batch)
69
-
70
- if "ntxt" in batch:
71
- batch_uc["txt"] = batch["ntxt"]
72
- else:
73
- batch_uc["txt"] = ["" for _ in range(len(batch["txt"]))]
74
-
75
- if "label" in batch:
76
- batch_uc["label"] = ["" for _ in range(len(batch["label"]))]
77
 
78
  return batch, batch_uc
 
65
  if isinstance(batch[key], torch.Tensor):
66
  batch[key] = batch[key].to(torch.device("cuda", index=cfgs.gpu))
67
 
68
+ batch_uc = batch
 
 
 
 
 
 
 
 
69
 
70
  return batch, batch_uc