PoseDiffusion_MVP / models /gaussian_diffuser.py
hugoycj
Initial commit
3d3e4e9
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/beb2f2d8dd9b4f2bd5be4719f37082fe061ee450/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
import math
import copy
from pathlib import Path
from random import random
from functools import partial
from collections import namedtuple
from multiprocessing import cpu_count
import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchvision import transforms as T, utils
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from PIL import Image
from tqdm.auto import tqdm
from typing import Any, Dict, List, Optional, Tuple, Union
# constants
ModelPrediction = namedtuple("ModelPrediction", ["pred_noise", "pred_x_start"])
# helpers functions
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = (
torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
)
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module):
def __init__(
self,
timesteps=100,
sampling_timesteps=None,
beta_1=0.0001,
beta_T=0.1,
loss_type="l1",
objective="pred_noise",
beta_schedule="custom",
p2_loss_weight_gamma=0.0,
p2_loss_weight_k=1,
):
super().__init__()
self.objective = objective
assert objective in {
"pred_noise",
"pred_x0",
}, "objective must be either pred_noise (predict noise) \
or pred_x0 (predict image start)"
self.timesteps = timesteps
self.sampling_timesteps = sampling_timesteps
self.beta_1 = beta_1
self.beta_T = beta_T
self.loss_type = loss_type
self.objective = objective
self.beta_schedule = beta_schedule
self.p2_loss_weight_gamma = p2_loss_weight_gamma
self.p2_loss_weight_k = p2_loss_weight_k
self.init_diff_hyper(
self.timesteps,
self.sampling_timesteps,
self.beta_1,
self.beta_T,
self.loss_type,
self.objective,
self.beta_schedule,
self.p2_loss_weight_gamma,
self.p2_loss_weight_k,
)
def init_diff_hyper(
self,
timesteps,
sampling_timesteps,
beta_1,
beta_T,
loss_type,
objective,
beta_schedule,
p2_loss_weight_gamma,
p2_loss_weight_k,
):
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "custom":
betas = torch.linspace(
beta_1, beta_T, timesteps, dtype=torch.float64
)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# sampling related parameters
self.sampling_timesteps = default(
sampling_timesteps, timesteps
) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(
name, val.to(torch.float32)
)
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer(
"sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)
)
register_buffer(
"log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)
)
register_buffer(
"sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)
)
register_buffer(
"sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)
)
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0
# at the beginning of the diffusion chain
register_buffer(
"posterior_log_variance_clipped",
torch.log(posterior_variance.clamp(min=1e-20)),
)
register_buffer(
"posterior_mean_coef1",
betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod),
)
register_buffer(
"posterior_mean_coef2",
(1.0 - alphas_cumprod_prev)
* torch.sqrt(alphas)
/ (1.0 - alphas_cumprod),
)
# calculate p2 reweighting
register_buffer(
"p2_loss_weight",
(p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod))
** -p2_loss_weight_gamma,
)
# helper functions
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0
) / extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(
self.posterior_log_variance_clipped, t, x_t.shape
)
return (
posterior_mean,
posterior_variance,
posterior_log_variance_clipped,
)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
def model_predictions(self, x, t, z, x_self_cond=None):
model_output = self.model(x, t, z)
if self.objective == "pred_noise":
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, model_output)
elif self.objective == "pred_x0":
pred_noise = self.predict_noise_from_start(x, t, model_output)
x_start = model_output
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(
self,
x: torch.Tensor, # B x N_x x dim
t: int,
z: torch.Tensor,
x_self_cond=None,
clip_denoised=False,
):
preds = self.model_predictions(x, t, z)
x_start = preds.pred_x_start
if clip_denoised:
raise NotImplementedError(
"We don't clip the output because \
pose does not have a clear bound."
)
(
model_mean,
posterior_variance,
posterior_log_variance,
) = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(
self,
x: torch.Tensor, # B x N_x x dim
t: int,
z: torch.Tensor,
x_self_cond=None,
clip_denoised=False,
cond_fn=None,
cond_start_step=0,
):
b, *_, device = *x.shape, x.device
batched_times = torch.full(
(x.shape[0],), t, device=x.device, dtype=torch.long
)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(
x=x,
t=batched_times,
z=z,
x_self_cond=x_self_cond,
clip_denoised=clip_denoised,
)
if cond_fn is not None and t < cond_start_step:
model_mean = cond_fn(model_mean, t)
noise = 0.0
else:
noise = torch.randn_like(x) if t > 0 else 0.0 # no noise if t == 0
pred = model_mean + (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop(
self,
shape,
z: torch.Tensor,
cond_fn=None,
cond_start_step=0,
):
batch, device = shape[0], self.betas.device
# Init here
pose = torch.randn(shape, device=device)
x_start = None
pose_process = []
pose_process.append(pose.unsqueeze(0))
for t in reversed(range(0, self.num_timesteps)):
pose, _ = self.p_sample(
x=pose,
t=t,
z=z,
cond_fn=cond_fn,
cond_start_step=cond_start_step,
)
pose_process.append(pose.unsqueeze(0))
return pose, torch.cat(pose_process)
@torch.no_grad()
def sample(self, shape, z, cond_fn=None, cond_start_step=0):
# TODO: add more variants
sample_fn = self.p_sample_loop
return sample_fn(
shape, z=z, cond_fn=cond_fn, cond_start_step=cond_start_step
)
def p_losses(
self,
x_start,
t,
z=None,
noise=None,
):
noise = default(noise, lambda: torch.randn_like(x_start))
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.model(x, t, z)
if self.objective == "pred_noise":
target = noise
x_0_pred = self.predict_start_from_noise(x, t, model_out)
elif self.objective == "pred_x0":
target = x_start
x_0_pred = model_out
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target, reduction="none")
loss = reduce(loss, "b ... -> b (...)", "mean")
loss = loss * extract(self.p2_loss_weight, t, loss.shape)
return {
"loss": loss,
"noise": noise,
"x_0_pred": x_0_pred,
"x_t": x,
"t": t,
}
def forward(self, pose, z=None, *args, **kwargs):
b = len(pose)
t = torch.randint(
0, self.num_timesteps, (b,), device=pose.device
).long()
return self.p_losses(pose, t, z=z, *args, **kwargs)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")