patrickvonplaten commited on
Commit
b41c4e7
1 Parent(s): 45fbffb

Upload modeling_ddim.py

Browse files
Files changed (1) hide show
  1. modeling_ddim.py +45 -23
modeling_ddim.py CHANGED
@@ -14,13 +14,13 @@
14
  # limitations under the License.
15
 
16
 
17
- from diffusers import DiffusionPipeline
18
- import tqdm
19
  import torch
20
 
 
 
 
21
 
22
  class DDIM(DiffusionPipeline):
23
-
24
  def __init__(self, unet, noise_scheduler):
25
  super().__init__()
26
  self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
@@ -34,39 +34,61 @@ class DDIM(DiffusionPipeline):
34
  inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
35
 
36
  self.unet.to(torch_device)
37
- image = self.noise_scheduler.sample_noise((batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution), device=torch_device, generator=generator)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
40
- # get actual t and t-1
 
 
 
 
41
  train_step = inference_step_times[t]
42
- prev_train_step = inference_step_times[t - 1] if t > 0 else - 1
43
 
44
- # compute alphas
45
  alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
46
  alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
47
- alpha_prod_t_rsqrt = 1 / alpha_prod_t.sqrt()
48
- alpha_prod_t_prev_rsqrt = 1 / alpha_prod_t_prev.sqrt()
49
  beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
50
  beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
51
 
52
- # compute relevant coefficients
53
- coeff_1 = (alpha_prod_t_prev - alpha_prod_t).sqrt() * alpha_prod_t_prev_rsqrt * beta_prod_t_prev_sqrt / beta_prod_t_sqrt * eta
54
- coeff_2 = ((1 - alpha_prod_t_prev) - coeff_1 ** 2).sqrt()
55
-
56
- # model forward
57
- with torch.no_grad():
58
- noise_residual = self.unet(image, train_step)
59
-
60
- # predict mean of prev image
61
- pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual)
62
- pred_mean = torch.clamp(pred_mean, -1, 1)
63
- pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual
 
 
 
64
 
65
  # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
66
  if eta > 0.0:
67
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
68
- image = pred_mean + coeff_1 * noise
69
  else:
70
- image = pred_mean
 
 
 
71
 
72
  return image
 
14
  # limitations under the License.
15
 
16
 
 
 
17
  import torch
18
 
19
+ import tqdm
20
+ from diffusers import DiffusionPipeline
21
+
22
 
23
  class DDIM(DiffusionPipeline):
 
24
  def __init__(self, unet, noise_scheduler):
25
  super().__init__()
26
  self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
 
34
  inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
35
 
36
  self.unet.to(torch_device)
37
+
38
+ # Sample gaussian noise to begin loop
39
+ image = self.noise_scheduler.sample_noise(
40
+ (batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
41
+ device=torch_device,
42
+ generator=generator,
43
+ )
44
+
45
+ # See formulas (9), (10) and (7) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
46
+ # Ideally, read DDIM paper in-detail understanding
47
+
48
+ # Notation (<variable name> -> <name in paper>
49
+ # - pred_noise_t -> e_theta(x_t, t)
50
+ # - pred_original_image -> f_theta(x_t, t) or x_0
51
+ # - std_dev_t -> sigma_t
52
 
53
  for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
54
+ # 1. predict noise residual
55
+ with torch.no_grad():
56
+ pred_noise_t = self.unet(image, inference_step_times[t])
57
+
58
+ # 2. get actual t and t-1
59
  train_step = inference_step_times[t]
60
+ prev_train_step = inference_step_times[t - 1] if t > 0 else -1
61
 
62
+ # 3. compute alphas, betas
63
  alpha_prod_t = self.noise_scheduler.get_alpha_prod(train_step)
64
  alpha_prod_t_prev = self.noise_scheduler.get_alpha_prod(prev_train_step)
 
 
65
  beta_prod_t_sqrt = (1 - alpha_prod_t).sqrt()
66
  beta_prod_t_prev_sqrt = (1 - alpha_prod_t_prev).sqrt()
67
 
68
+ # 4. Compute predicted previous image from predicted noise
69
+ # First: compute predicted original image from predicted noise also called
70
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
71
+ pred_original_image = (image - beta_prod_t_sqrt * pred_noise_t) / alpha_prod_t.sqrt()
72
+ # Second: Clip "predicted x_0"
73
+ pred_original_image = torch.clamp(pred_original_image, -1, 1)
74
+ # Third: Compute variance: "sigma_t" -> see
75
+ # std_dev_t = (1 - alpha_prod_t / alpha_prod_t_prev).sqrt() * beta_prod_t_prev_sqrt / beta_prod_t_sqrt
76
+ std_dev_t = (1 - alpha_prod_t / alpha_prod_t_prev).sqrt()
77
+ std_dev_t = std_dev_t * eta
78
+ # Fourth: Compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
79
+ pred_image_direction = (1 - alpha_prod_t_prev - std_dev_t**2).sqrt() * pred_noise_t
80
+
81
+ # Fourth: Compute outer formula (DDIM formula)
82
+ pred_prev_image = alpha_prod_t_prev.sqrt() * pred_original_image + pred_image_direction
83
 
84
  # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM
85
  if eta > 0.0:
86
  noise = self.noise_scheduler.sample_noise(image.shape, device=image.device, generator=generator)
87
+ prev_image = pred_prev_image + std_dev_t * noise
88
  else:
89
+ prev_image = pred_prev_image
90
+
91
+ # Set current image to prev_image: x_t -> x_t-1
92
+ image = prev_image
93
 
94
  return image