File size: 6,060 Bytes
73eb2d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import numpy as np
import torch
import torch.nn as nn
from tqdm import trange
from torchvision.transforms import Compose


class Diffusion(nn.Module):
    def __init__(
        self, nn_backbone, device, n_timesteps=1000, in_channels=3, image_size=128, out_channels=6, motion_transforms=None):
        super(Diffusion, self).__init__()

        self.nn_backbone = nn_backbone
        self.n_timesteps = n_timesteps
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.x_shape = (image_size, image_size)
        self.device = device

        self.motion_transforms = motion_transforms if motion_transforms else Compose([])

        self.timesteps = torch.arange(n_timesteps)
        self.beta = self.get_beta_schedule()
        self.set_params()
        self.device = device

    def sample(self, x_cond, audio_emb, n_audio_motion_embs=2, n_motion_frames=2, motion_channels=3):
        with torch.no_grad():
            n_frames = audio_emb.shape[1]

            xT = torch.randn(x_cond.shape[0], n_frames, self.in_channels, self.x_shape[0], self.x_shape[1]).to(x_cond.device)

            audio_ids = [0] * n_audio_motion_embs
            for i in range(n_audio_motion_embs + 1):
                audio_ids += [i]
            
            motion_frames = [self.motion_transforms(x_cond) for _ in range(n_motion_frames)]
            motion_frames = torch.cat(motion_frames, dim=1)

            samples = []
            for i in trange(n_frames, desc=f'Sampling'):
                sample_frame = self.sample_loop(xT[:, i].to(x_cond.device), x_cond, motion_frames, audio_emb[:, audio_ids])
                samples.append(sample_frame.unsqueeze(1))
                motion_frames = torch.cat([motion_frames[:, motion_channels:, :], self.motion_transforms(sample_frame)], dim=1)
                audio_ids = audio_ids[1:] + [min(i + n_audio_motion_embs + 1, n_frames - 1)]
        return torch.cat(samples, dim=1)

    def sample_loop(self, xT, x_cond, motion_frames, audio_emb):
        xt = xT
        for i, t in reversed(list(enumerate(self.timesteps))):
            timesteps = torch.tensor([t] * xT.shape[0]).to(xT.device)
            timesteps_ids = torch.tensor([i] * xT.shape[0]).to(xT.device)
            nn_out = self.nn_backbone(xt, timesteps, x_cond, motion_frames=motion_frames, audio_emb=audio_emb)
            mean, logvar = self.get_p_params(xt, timesteps_ids, nn_out)
            noise = torch.randn_like(xt) if t > 0 else torch.zeros_like(xt)
            xt = mean + noise * torch.exp(logvar / 2)

        return xt

    def get_p_params(self, xt, timesteps, nn_out):
        if self.in_channels == self.out_channels:
            eps_pred = nn_out
            p_logvar = self.expand(torch.log(self.beta[timesteps]))
        else:
            eps_pred, nu = nn_out.chunk(2, 1)
            nu = (nu + 1) / 2
            p_logvar = nu * self.expand(torch.log(self.beta[timesteps])) + (1 - nu) * self.expand(self.log_beta_tilde_clipped[timesteps])

        p_mean, _ = self.get_q_params(xt, timesteps, eps_pred=eps_pred)
        return p_mean, p_logvar

    def get_q_params(self, xt, timesteps, eps_pred=None, x0=None):
        if x0 is None:
        # predict x0 from xt and eps_pred
            coef1_x0 = self.expand(self.coef1_x0[timesteps])
            coef2_x0 = self.expand(self.coef2_x0[timesteps])
            x0 = coef1_x0 * xt - coef2_x0 * eps_pred
            x0 = x0.clamp(-1, 1)

        # q(x_{t-1} | x_t, x_0)
        coef1_q = self.expand(self.coef1_q[timesteps])
        coef2_q = self.expand(self.coef2_q[timesteps])
        q_mean = coef1_q * x0 + coef2_q * xt

        q_logvar = self.expand(self.log_beta_tilde_clipped[timesteps])

        return q_mean, q_logvar

    def get_beta_schedule(self, max_beta=0.999):
        alpha_bar = lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2
        betas = []
        for i in range(self.n_timesteps):
            t1 = i / self.n_timesteps
            t2 = (i + 1) / self.n_timesteps
            betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
        return torch.tensor(betas).float()

    def set_params(self):
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.alpha_bar_prev = torch.cat([torch.ones(1,), self.alpha_bar[:-1]])

        self.beta_tilde = self.beta * (1.0 - self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        self.log_beta_tilde_clipped = torch.log(torch.cat([self.beta_tilde[1, None], self.beta_tilde[1:]]))

        # to caluclate x0 from eps_pred
        self.coef1_x0 = torch.sqrt(1.0 / self.alpha_bar)
        self.coef2_x0 = torch.sqrt(1.0 / self.alpha_bar - 1)

        # for q(x_{t-1} | x_t, x_0)
        self.coef1_q = self.beta * torch.sqrt(self.alpha_bar_prev) / (1.0 - self.alpha_bar)
        self.coef2_q = (1.0 - self.alpha_bar_prev) * torch.sqrt(self.alpha) / (1.0 - self.alpha_bar)

    def space(self, n_timesteps_new):
        # change parameters for spaced timesteps during sampling
        self.timesteps = self.space_timesteps(self.n_timesteps, n_timesteps_new)
        self.n_timesteps = n_timesteps_new

        self.beta = self.get_spaced_beta()
        self.set_params()

    def space_timesteps(self, n_timesteps, target_timesteps):
        all_steps = []
        frac_stride = (n_timesteps - 1) / (target_timesteps - 1)
        cur_idx = 0.0
        taken_steps = []
        for _ in range(target_timesteps):
            taken_steps.append(round(cur_idx))
            cur_idx += frac_stride
        all_steps += taken_steps
        return all_steps

    def get_spaced_beta(self):
        last_alpha_cumprod = 1.0
        new_beta = []
        for i, alpha_cumprod in enumerate(self.alpha_bar):
            if i in self.timesteps:
                new_beta.append(1 - alpha_cumprod / last_alpha_cumprod)
                last_alpha_cumprod = alpha_cumprod
        return torch.tensor(new_beta)

    def expand(self, arr, dim=4):
        while arr.dim() < dim:
            arr = arr[:, None]
        return arr.to(self.device)