diffused-heads / diffusion.py
Sof22's picture
Upload 3 files
73eb2d5
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from torchvision.transforms import Compose
class Diffusion(nn.Module):
def __init__(
self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
super(Diffusion, self).__init__()
self.nn_backbone = nn_backbone
self.n_timesteps = n_timesteps
self.in_channels = in_channels
self.out_channels = out_channels
self.x_shape = (image_size, image_size)
self.device = device
self.motion_transforms = motion_transforms if motion_transforms else Compose([])
self.timesteps = torch.arange(n_timesteps)
self.beta = self.get_beta_schedule()
self.set_params()
self.device = device
def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
with torch.no_grad():
n_frames = audio_emb.shape[1]
xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)
audio_ids = [0] * n_audio_motion_embs
for i in range(n_audio_motion_embs + 1):
audio_ids += [i]
motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
motion_frames = torch.cat(motion_frames, dim=1)
samples = []
for i in trange(n_frames, desc=f'Sampling'):
sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
samples.append(sample_frame.unsqueeze(1))
motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
return torch.cat(samples, dim=1)
def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
xt = xT
for i, t in reversed(list(enumerate(self.timesteps))):
timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
xt = mean + noise * torch.exp(logvar / 2)
return xt
def get_p_params(self, xt, timesteps, nn_out):
if self.in_channels == self.out_channels:
eps_pred = nn_out
p_logvar = self.expand(torch.log(self.beta[timesteps]))
else:
eps_pred, nu = nn_out.chunk(2, 1)
nu = (nu + 1) / 2
p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])
p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
return p_mean, p_logvar
def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
if x0 is None:
# predict x0 from xt and eps_pred
coef1_x0 = self.expand(self.coef1_x0[timesteps])
coef2_x0 = self.expand(self.coef2_x0[timesteps])
x0 = coef1_x0 * xt - coef2_x0 * eps_pred
x0 = x0.clamp(-1, 1)
# q(x_{t-1} | x_t, x_0)
coef1_q = self.expand(self.coef1_q[timesteps])
coef2_q = self.expand(self.coef2_q[timesteps])
q_mean = coef1_q * x0 + coef2_q * xt
q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])
return q_mean, q_logvar
def get_beta_schedule(self, max_beta=0.999):
alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
betas = []
for i in range(self.n_timesteps):
t1 = i / self.n_timesteps
t2 = (i + 1) / self.n_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas).float()
def set_params(self):
self.alpha = 1 - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])
self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))
# to caluclate x0 from eps_pred
self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)
# for q(x_{t-1} | x_t, x_0)
self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)
def space(self, n_timesteps_new):
# change parameters for spaced timesteps during sampling
self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
self.n_timesteps = n_timesteps_new
self.beta = self.get_spaced_beta()
self.set_params()
def space_timesteps(self, n_timesteps, target_timesteps):
all_steps = []
frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(target_timesteps):
taken_steps.append(round(cur_idx))
cur_idx += frac_stride
all_steps += taken_steps
return all_steps
def get_spaced_beta(self):
last_alpha_cumprod = 1.0
new_beta = []
for i, alpha_cumprod in enumerate(self.alpha_bar):
if i in self.timesteps:
new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
return torch.tensor(new_beta)
def expand(self, arr, dim=4):
while arr.dim() < dim:
arr = arr[:, None]
return arr.to(self.device)