Dionyssos commited on
Commit
3dfbf53
·
1 Parent(s): d912185

del diffusion [unused]

Browse files
audiocraft/audiogen.py DELETED
@@ -1,129 +0,0 @@
1
- import typing as tp
2
- import torch
3
- from audiocraft.loaders import load_compression_model, load_lm_model
4
- import typing as tp
5
- import omegaconf
6
- import torch
7
- import numpy as np
8
- from .lm import LMModel
9
- from .conditioners import ConditioningAttributes
10
- from .utils.autocast import TorchAutocast
11
-
12
-
13
-
14
- def _shift(x):
15
- n = x.shape[2]
16
- i = np.random.randint(.24 * n, max(1, .74 * n)) # high should be above >= 0 TBD do we have very short segments
17
- x = torch.roll(x, i, dims=2)
18
- return x
19
-
20
-
21
- class AudioGen():
22
- """Base generative model with convenient generation API.
23
-
24
- Args:
25
- name (str)
26
- compression_model (CompressionModel): Encodec with Seanet Decoder
27
- lm
28
- max_duration (float, optional): As is using top250 token draw() we can gen xN sequences
29
- """
30
- def __init__(self,
31
- name,
32
- compression_model,
33
- lm,
34
- max_duration=None):
35
- self.name = name
36
- self.compression_model = compression_model
37
- self.lm = lm
38
- self.cfg: tp.Optional[omegaconf.DictConfig] = None
39
- # Just to be safe, let's put everything in eval mode.
40
- self.compression_model.eval()
41
- self.lm.eval()
42
-
43
- if hasattr(lm, 'cfg'):
44
- cfg = lm.cfg
45
- assert isinstance(cfg, omegaconf.DictConfig)
46
- self.cfg = cfg
47
-
48
- if max_duration is None:
49
- if self.cfg is not None:
50
- max_duration = lm.cfg.dataset.segment_duration # type: ignore
51
- else:
52
- raise ValueError("You must provide max_duration when building directly your GenModel")
53
- assert max_duration is not None
54
-
55
- self.max_duration: float = max_duration
56
- self.duration = self.max_duration
57
- self.device = next(iter(lm.parameters())).device
58
- self.generation_params={}
59
-
60
- if self.device.type == 'cpu':
61
- self.autocast = TorchAutocast(enabled=False)
62
- else:
63
- self.autocast = TorchAutocast(
64
- enabled=True,
65
- device_type=self.device.type,
66
- dtype=torch.float16)
67
-
68
- @property
69
- def frame_rate(self) -> float:
70
- """Roughly the number of AR steps per seconds."""
71
- return self.compression_model.frame_rate
72
-
73
- @property
74
- def sample_rate(self) -> int:
75
- """Sample rate of the generated audio."""
76
- return self.compression_model.sample_rate
77
-
78
-
79
-
80
-
81
-
82
- def generate(self, descriptions):
83
- attributes = [
84
- ConditioningAttributes(text={'description': d}) for d in descriptions]
85
- tokens = self._generate_tokens(attributes)
86
- print(f'\n{tokens.shape=}\n{tokens=} FINAL 5 AUD')
87
- return self.generate_audio(tokens)
88
-
89
- def _generate_tokens(self, attributes):
90
-
91
- total_gen_len = int(self.duration * self.frame_rate)
92
-
93
- if self.duration <= self.max_duration:
94
- # generate by sampling from LM, simple case.
95
-
96
- with self.autocast:
97
- gen_tokens = self.lm.generate(conditions=attributes, max_gen_len=total_gen_len)
98
- else:
99
- print('<>Long gen ?<>')
100
- # print(f'{gen_tokens.shape=}') # [5,4,35]
101
- # FLATTEN BATCH AS EXTRA SEQUENCE (BATCH IS VIRTUAL JUST MULTINOMIAL SAMPLING OF N_DRAW TOKENS)
102
- gen_tokens = gen_tokens.transpose(0, 1).reshape(4, -1)[None, :, :]
103
- for _ in range(3):
104
- print(gen_tokens.shape)
105
- gen_tokens = _shift(gen_tokens)
106
- return gen_tokens
107
-
108
- def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
109
- """Generate Audio from tokens."""
110
- assert gen_tokens.dim() == 3
111
- with torch.no_grad():
112
- gen_audio = self.compression_model.decode(gen_tokens, None)
113
- return gen_audio
114
-
115
-
116
- def get_pretrained(name='facebook/audiogen-medium',
117
- device=None):
118
- """Return pretrained model, we provide a single model for now:
119
- - facebook/audiogen-medium (1.5B), text to sound,
120
- # see: https://huggingface.co/facebook/audiogen-medium
121
- """
122
- compression_model = load_compression_model(name, device=device)
123
- lm = load_lm_model(name, device=device)
124
- assert 'self_wav' not in lm.condition_provider.conditioners, \
125
- "AudioGen do not support waveform conditioning for now"
126
- return AudioGen(name, compression_model, lm)
127
-
128
-
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/builders.py CHANGED
@@ -16,10 +16,10 @@ from .conditioners import (
16
  ConditioningProvider,
17
  T5Conditioner,
18
  )
19
- from .unet import DiffusionUnet
20
  from .vq import ResidualVectorQuantizer
21
 
22
- from .diffusion_schedule import MultiBandProcessor, SampleProcessor
23
 
24
  def dict_from_config(cfg):
25
  dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
@@ -155,25 +155,3 @@ def get_codebooks_pattern_provider(n_q, cfg):
155
 
156
  klass = pattern_providers[name]
157
  return klass(n_q, **kwargs)
158
-
159
-
160
-
161
-
162
-
163
- def get_diffusion_model(cfg: omegaconf.DictConfig):
164
- # TODO Find a way to infer the channels from dset
165
- channels = cfg.channels
166
- num_steps = cfg.schedule.num_steps
167
- return DiffusionUnet(
168
- chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
169
-
170
-
171
- def get_processor(cfg, sample_rate: int = 24000):
172
- sample_processor = SampleProcessor()
173
- if cfg.use:
174
- kw = dict(cfg)
175
- kw.pop('use')
176
- kw.pop('name')
177
- if cfg.name == "multi_band_processor":
178
- sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
179
- return sample_processor
 
16
  ConditioningProvider,
17
  T5Conditioner,
18
  )
19
+
20
  from .vq import ResidualVectorQuantizer
21
 
22
+
23
 
24
  def dict_from_config(cfg):
25
  dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
 
155
 
156
  klass = pattern_providers[name]
157
  return klass(n_q, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/diffusion_schedule.py DELETED
@@ -1,272 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Functions for Noise Schedule, defines diffusion process, reverse process and data processor.
9
- """
10
-
11
- from collections import namedtuple
12
- import random
13
- import typing as tp
14
- import julius
15
- import torch
16
-
17
- TrainingItem = namedtuple("TrainingItem", "noisy noise step")
18
-
19
-
20
- def betas_from_alpha_bar(alpha_bar):
21
- alphas = torch.cat([torch.Tensor([alpha_bar[0]]), alpha_bar[1:]/alpha_bar[:-1]])
22
- return 1 - alphas
23
-
24
-
25
- class SampleProcessor(torch.nn.Module):
26
- def project_sample(self, x: torch.Tensor):
27
- """Project the original sample to the 'space' where the diffusion will happen."""
28
- return x
29
-
30
- def return_sample(self, z: torch.Tensor):
31
- """Project back from diffusion space to the actual sample space."""
32
- return z
33
-
34
-
35
- class MultiBandProcessor(SampleProcessor):
36
- """
37
- MultiBand sample processor. The input audio is splitted across
38
- frequency bands evenly distributed in mel-scale.
39
-
40
- Each band will be rescaled to match the power distribution
41
- of Gaussian noise in that band, using online metrics
42
- computed on the first few samples.
43
-
44
- Args:
45
- n_bands (int): Number of mel-bands to split the signal over.
46
- sample_rate (int): Sample rate of the audio.
47
- num_samples (int): Number of samples to use to fit the rescaling
48
- for each band. The processor won't be stable
49
- until it has seen that many samples.
50
- power_std (float or list/tensor): The rescaling factor computed to match the
51
- power of Gaussian noise in each band is taken to
52
- that power, i.e. `1.` means full correction of the energy
53
- in each band, and values less than `1` means only partial
54
- correction. Can be used to balance the relative importance
55
- of low vs. high freq in typical audio signals.
56
- """
57
- def __init__(self, n_bands: int = 8, sample_rate: float = 24_000,
58
- num_samples: int = 10_000, power_std: tp.Union[float, tp.List[float], torch.Tensor] = 1.):
59
- super().__init__()
60
- self.n_bands = n_bands
61
- self.split_bands = julius.SplitBands(sample_rate, n_bands=n_bands)
62
- self.num_samples = num_samples
63
- self.power_std = power_std
64
- if isinstance(power_std, list):
65
- assert len(power_std) == n_bands
66
- power_std = torch.tensor(power_std)
67
- self.register_buffer('counts', torch.zeros(1))
68
- self.register_buffer('sum_x', torch.zeros(n_bands))
69
- self.register_buffer('sum_x2', torch.zeros(n_bands))
70
- self.register_buffer('sum_target_x2', torch.zeros(n_bands))
71
- self.counts: torch.Tensor
72
- self.sum_x: torch.Tensor
73
- self.sum_x2: torch.Tensor
74
- self.sum_target_x2: torch.Tensor
75
-
76
- @property
77
- def mean(self):
78
- mean = self.sum_x / self.counts
79
- return mean
80
-
81
- @property
82
- def std(self):
83
- std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt()
84
- return std
85
-
86
- @property
87
- def target_std(self):
88
- target_std = self.sum_target_x2 / self.counts
89
- return target_std
90
-
91
- def project_sample(self, x: torch.Tensor):
92
- assert x.dim() == 3
93
- bands = self.split_bands(x)
94
- if self.counts.item() < self.num_samples:
95
- ref_bands = self.split_bands(torch.randn_like(x))
96
- self.counts += len(x)
97
- self.sum_x += bands.mean(dim=(2, 3)).sum(dim=1)
98
- self.sum_x2 += bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
99
- self.sum_target_x2 += ref_bands.pow(2).mean(dim=(2, 3)).sum(dim=1)
100
- rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size
101
- bands = (bands - self.mean.view(-1, 1, 1, 1)) * rescale.view(-1, 1, 1, 1)
102
- return bands.sum(dim=0)
103
-
104
- def return_sample(self, x: torch.Tensor):
105
- assert x.dim() == 3
106
- bands = self.split_bands(x)
107
- rescale = (self.std / self.target_std) ** self.power_std
108
- bands = bands * rescale.view(-1, 1, 1, 1) + self.mean.view(-1, 1, 1, 1)
109
- return bands.sum(dim=0)
110
-
111
-
112
- class NoiseSchedule:
113
- """Noise schedule for diffusion.
114
-
115
- Args:
116
- beta_t0 (float): Variance of the first diffusion step.
117
- beta_t1 (float): Variance of the last diffusion step.
118
- beta_exp (float): Power schedule exponent
119
- num_steps (int): Number of diffusion step.
120
- variance (str): choice of the sigma value for the denoising eq. Choices: "beta" or "beta_tilde"
121
- clip (float): clipping value for the denoising steps
122
- rescale (float): rescaling value to avoid vanishing signals unused by default (i.e 1)
123
- repartition (str): shape of the schedule only power schedule is supported
124
- sample_processor (SampleProcessor): Module that normalize data to match better the gaussian distribution
125
- noise_scale (float): Scaling factor for the noise
126
- """
127
- def __init__(self, beta_t0: float = 1e-4, beta_t1: float = 0.02, num_steps: int = 1000, variance: str = 'beta',
128
- clip: float = 5., rescale: float = 1., device='cuda', beta_exp: float = 1,
129
- repartition: str = "power", alpha_sigmoid: dict = {}, n_bands: tp.Optional[int] = None,
130
- sample_processor: SampleProcessor = SampleProcessor(), noise_scale: float = 1.0, **kwargs):
131
-
132
- self.beta_t0 = beta_t0
133
- self.beta_t1 = beta_t1
134
- self.variance = variance
135
- self.num_steps = num_steps
136
- self.clip = clip
137
- self.sample_processor = sample_processor
138
- self.rescale = rescale
139
- self.n_bands = n_bands
140
- self.noise_scale = noise_scale
141
- assert n_bands is None
142
- if repartition == "power":
143
- self.betas = torch.linspace(beta_t0 ** (1 / beta_exp), beta_t1 ** (1 / beta_exp), num_steps,
144
- device=device, dtype=torch.float) ** beta_exp
145
- else:
146
- raise RuntimeError('Not implemented')
147
- self.rng = random.Random(1234)
148
-
149
- def get_beta(self, step: tp.Union[int, torch.Tensor]):
150
- if self.n_bands is None:
151
- return self.betas[step]
152
- else:
153
- return self.betas[:, step] # [n_bands, len(step)]
154
-
155
- def get_initial_noise(self, x: torch.Tensor):
156
- if self.n_bands is None:
157
- return torch.randn_like(x)
158
- return torch.randn((x.size(0), self.n_bands, x.size(2)))
159
-
160
- def get_alpha_bar(self, step: tp.Optional[tp.Union[int, torch.Tensor]] = None) -> torch.Tensor:
161
- """Return 'alpha_bar', either for a given step, or as a tensor with its value for each step."""
162
- if step is None:
163
- return (1 - self.betas).cumprod(dim=-1) # works for simgle and multi bands
164
- if type(step) is int:
165
- return (1 - self.betas[:step + 1]).prod()
166
- else:
167
- return (1 - self.betas).cumprod(dim=0)[step].view(-1, 1, 1)
168
-
169
- def get_training_item(self, x: torch.Tensor, tensor_step: bool = False) -> TrainingItem:
170
- """Create a noisy data item for diffusion model training:
171
-
172
- Args:
173
- x (torch.Tensor): clean audio data torch.tensor(bs, 1, T)
174
- tensor_step (bool): If tensor_step = false, only one step t is sample,
175
- the whole batch is diffused to the same step and t is int.
176
- If tensor_step = true, t is a tensor of size (x.size(0),)
177
- every element of the batch is diffused to a independently sampled.
178
- """
179
- step: tp.Union[int, torch.Tensor]
180
- if tensor_step:
181
- bs = x.size(0)
182
- step = torch.randint(0, self.num_steps, size=(bs,), device=x.device)
183
- else:
184
- step = self.rng.randrange(self.num_steps)
185
- alpha_bar = self.get_alpha_bar(step) # [batch_size, n_bands, 1]
186
-
187
- x = self.sample_processor.project_sample(x)
188
- noise = torch.randn_like(x)
189
- noisy = (alpha_bar.sqrt() / self.rescale) * x + (1 - alpha_bar).sqrt() * noise * self.noise_scale
190
- return TrainingItem(noisy, noise, step)
191
-
192
- def generate(self, model: torch.nn.Module, initial: tp.Optional[torch.Tensor] = None,
193
- condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
194
- """Full ddpm reverse process.
195
-
196
- Args:
197
- model (nn.Module): Diffusion model.
198
- initial (tensor): Initial Noise.
199
- condition (tensor): Input conditionning Tensor (e.g. encodec compressed representation).
200
- return_list (bool): Whether to return the whole process or only the sampled point.
201
- """
202
- alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
203
- current = initial
204
- iterates = [initial]
205
- for step in range(self.num_steps)[::-1]:
206
- with torch.no_grad():
207
- estimate = model(current, step, condition=condition).sample
208
- alpha = 1 - self.betas[step]
209
- previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
210
- previous_alpha_bar = self.get_alpha_bar(step=step - 1)
211
- if step == 0:
212
- sigma2 = 0
213
- elif self.variance == 'beta':
214
- sigma2 = 1 - alpha
215
- elif self.variance == 'beta_tilde':
216
- sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
217
- elif self.variance == 'none':
218
- sigma2 = 0
219
- else:
220
- raise ValueError(f'Invalid variance type {self.variance}')
221
-
222
- if sigma2 > 0:
223
- previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
224
- if self.clip:
225
- previous = previous.clamp(-self.clip, self.clip)
226
- current = previous
227
- alpha_bar = previous_alpha_bar
228
- if step == 0:
229
- previous *= self.rescale
230
- if return_list:
231
- iterates.append(previous.cpu())
232
-
233
- if return_list:
234
- return iterates
235
- else:
236
- return self.sample_processor.return_sample(previous)
237
-
238
- def generate_subsampled(self, model: torch.nn.Module, initial: torch.Tensor, step_list: tp.Optional[list] = None,
239
- condition: tp.Optional[torch.Tensor] = None, return_list: bool = False):
240
- """Reverse process that only goes through Markov chain states in step_list."""
241
- if step_list is None:
242
- step_list = list(range(1000))[::-50] + [0]
243
- alpha_bar = self.get_alpha_bar(step=self.num_steps - 1)
244
- alpha_bars_subsampled = (1 - self.betas).cumprod(dim=0)[list(reversed(step_list))].cpu()
245
- betas_subsampled = betas_from_alpha_bar(alpha_bars_subsampled)
246
- current = initial * self.noise_scale
247
- iterates = [current]
248
- for idx, step in enumerate(step_list[:-1]):
249
- with torch.no_grad():
250
- estimate = model(current, step, condition=condition).sample * self.noise_scale
251
- alpha = 1 - betas_subsampled[-1 - idx]
252
- previous = (current - (1 - alpha) / (1 - alpha_bar).sqrt() * estimate) / alpha.sqrt()
253
- previous_alpha_bar = self.get_alpha_bar(step_list[idx + 1])
254
- if step == step_list[-2]:
255
- sigma2 = 0
256
- previous_alpha_bar = torch.tensor(1.0)
257
- else:
258
- sigma2 = (1 - previous_alpha_bar) / (1 - alpha_bar) * (1 - alpha)
259
- if sigma2 > 0:
260
- previous += sigma2**0.5 * torch.randn_like(previous) * self.noise_scale
261
- if self.clip:
262
- previous = previous.clamp(-self.clip, self.clip)
263
- current = previous
264
- alpha_bar = previous_alpha_bar
265
- if step == 0:
266
- previous *= self.rescale
267
- if return_list:
268
- iterates.append(previous.cpu())
269
- if return_list:
270
- return iterates
271
- else:
272
- return self.sample_processor.return_sample(previous)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/loaders.py CHANGED
@@ -1,33 +1,9 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Utility functions to load from the checkpoints.
9
- Each checkpoint is a torch.saved dict with the following keys:
10
- - 'xp.cfg': the hydra config as dumped during training. This should be used
11
- to rebuild the object using the audiocraft.models.builders functions,
12
- - 'model_best_state': a readily loadable best state for the model, including
13
- the conditioner. The model obtained from `xp.cfg` should be compatible
14
- with this state dict. In the case of a LM, the encodec model would not be
15
- bundled along but instead provided separately.
16
-
17
- Those functions also support loading from a remote location with the Torch Hub API.
18
- They also support overriding some parameters, in particular the device and dtype
19
- of the returned model.
20
- """
21
-
22
  from pathlib import Path
23
  from huggingface_hub import hf_hub_download
24
  import typing as tp
25
  import os
26
-
27
  from omegaconf import OmegaConf, DictConfig
28
  import torch
29
-
30
- import audiocraft
31
  from . import builders
32
  from .encodec import EncodecModel
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
  from huggingface_hub import hf_hub_download
3
  import typing as tp
4
  import os
 
5
  from omegaconf import OmegaConf, DictConfig
6
  import torch
 
 
7
  from . import builders
8
  from .encodec import EncodecModel
9
 
audiocraft/rope.py DELETED
@@ -1,125 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import typing as tp
8
-
9
- from torch import nn
10
- import torch
11
-
12
-
13
- class XPos(nn.Module):
14
- """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
15
- This applies an exponential decay to the RoPE rotation matrix.
16
-
17
- Args:
18
- dim (int): Embedding dimension.
19
- smoothing (float): Smoothing factor applied to the decay rates.
20
- base_scale (int): Base decay rate, given in terms of scaling time.
21
- device (torch.device, optional): Device on which to initialize the module.
22
- dtype (torch.dtype): dtype to use to generate the embedding.
23
- """
24
- def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
25
- device=None, dtype: torch.dtype = torch.float32):
26
- super().__init__()
27
- assert dim % 2 == 0
28
- assert dtype in [torch.float64, torch.float32]
29
- self.dtype = dtype
30
- self.base_scale = base_scale
31
-
32
- half_dim = dim // 2
33
- adim = torch.arange(half_dim, device=device, dtype=dtype)
34
- decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
35
- self.register_buffer("decay_rates", decay_rates)
36
- self.decay: tp.Optional[torch.Tensor] = None
37
-
38
- def get_decay(self, start: int, end: int):
39
- """Create complex decay tensor, cache values for fast computation."""
40
- if self.decay is None or end > self.decay.shape[0]:
41
- assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
42
- idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
43
- power = idx / self.base_scale
44
- scale = self.decay_rates ** power.unsqueeze(-1)
45
- self.decay = torch.polar(scale, torch.zeros_like(scale))
46
- return self.decay[start:end] # [T, C/2]
47
-
48
-
49
- class RotaryEmbedding(nn.Module):
50
- """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
51
-
52
- Args:
53
- dim (int): Embedding dimension (twice the number of frequencies).
54
- max_period (float): Maximum period of the rotation frequencies.
55
- xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
56
- scale (float): Scale of positional embedding, set to 0 to deactivate.
57
- device (torch.device, optional): Device on which to initialize the module.
58
- dtype (torch.dtype): dtype to use to generate the embedding.
59
- """
60
- def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
61
- scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
62
- super().__init__()
63
- assert dim % 2 == 0
64
- self.scale = scale
65
- assert dtype in [torch.float64, torch.float32]
66
- self.dtype = dtype
67
-
68
- adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
69
- frequencies = 1.0 / (max_period ** (adim / dim))
70
- self.register_buffer("frequencies", frequencies)
71
- self.rotation: tp.Optional[torch.Tensor] = None
72
-
73
- self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
74
-
75
- def get_rotation(self, start: int, end: int):
76
- """Create complex rotation tensor, cache values for fast computation."""
77
- if self.rotation is None or end > self.rotation.shape[0]:
78
- assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
79
- idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
80
- angles = torch.outer(idx, self.frequencies)
81
- self.rotation = torch.polar(torch.ones_like(angles), angles)
82
- return self.rotation[start:end]
83
-
84
- def rotate(self, x: torch.Tensor, start: int = 0, time_dim: int = 1, invert_decay: bool = False):
85
- """Apply rope rotation to query or key tensor."""
86
- T = x.shape[time_dim]
87
- target_shape = [1] * x.dim()
88
- target_shape[time_dim] = T
89
- target_shape[-1] = -1
90
- rotation = self.get_rotation(start, start + T).view(target_shape)
91
-
92
- if self.xpos:
93
- decay = self.xpos.get_decay(start, start + T).view(target_shape)
94
- else:
95
- decay = 1.0
96
-
97
- if invert_decay:
98
- decay = decay ** -1
99
-
100
- x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
101
- scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
102
- x_out = torch.view_as_real(x_complex * scaled_rotation).view_as(x)
103
-
104
- return x_out.type_as(x)
105
-
106
- def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0, time_dim: int = 1):
107
- """ Apply rope rotation to both query and key tensors.
108
- Supports streaming mode, in which query and key are not expected to have the same shape.
109
- In streaming mode, key will be of length [P + C] with P the cached past timesteps, but
110
- query will be [C] (typically C == 1).
111
-
112
- Args:
113
- query (torch.Tensor): Query to rotate.
114
- key (torch.Tensor): Key to rotate.
115
- start (int): Start index of the sequence for time offset.
116
- time_dim (int): which dimension represent the time steps.
117
- """
118
- query_timesteps = query.shape[time_dim]
119
- key_timesteps = key.shape[time_dim]
120
- streaming_offset = key_timesteps - query_timesteps
121
-
122
- query_out = self.rotate(query, start + streaming_offset, time_dim)
123
- key_out = self.rotate(key, start, time_dim, invert_decay=True)
124
-
125
- return query_out, key_out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audiocraft/unet.py DELETED
@@ -1,214 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- """
8
- Pytorch Unet Module used for diffusion.
9
- """
10
-
11
- from dataclasses import dataclass
12
- import typing as tp
13
-
14
- import torch
15
- from torch import nn
16
- from torch.nn import functional as F
17
- from .transformer import StreamingTransformer, create_sin_embedding
18
-
19
-
20
- @dataclass
21
- class Output:
22
- sample: torch.Tensor
23
-
24
-
25
- def get_model(cfg, channels: int, side: int, num_steps: int):
26
- if cfg.model == 'unet':
27
- return DiffusionUnet(
28
- chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
29
- else:
30
- raise RuntimeError('Not Implemented')
31
-
32
-
33
- class ResBlock(nn.Module):
34
- def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
35
- dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
36
- dropout: float = 0.):
37
- super().__init__()
38
- stride = 1
39
- padding = dilation * (kernel - stride) // 2
40
- Conv = nn.Conv1d
41
- Drop = nn.Dropout1d
42
- self.norm1 = nn.GroupNorm(norm_groups, channels)
43
- self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
44
- self.activation1 = activation()
45
- self.dropout1 = Drop(dropout)
46
-
47
- self.norm2 = nn.GroupNorm(norm_groups, channels)
48
- self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
49
- self.activation2 = activation()
50
- self.dropout2 = Drop(dropout)
51
-
52
- def forward(self, x):
53
- h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
54
- h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
55
- return x + h
56
-
57
-
58
- class DecoderLayer(nn.Module):
59
- def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
60
- norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
61
- dropout: float = 0.):
62
- super().__init__()
63
- padding = (kernel - stride) // 2
64
- self.res_blocks = nn.Sequential(
65
- *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
66
- for idx in range(res_blocks)])
67
- self.norm = nn.GroupNorm(norm_groups, chin)
68
- ConvTr = nn.ConvTranspose1d
69
- self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
70
- self.activation = activation()
71
-
72
- def forward(self, x: torch.Tensor) -> torch.Tensor:
73
- x = self.res_blocks(x)
74
- x = self.norm(x)
75
- x = self.activation(x)
76
- x = self.convtr(x)
77
- return x
78
-
79
-
80
- class EncoderLayer(nn.Module):
81
- def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
82
- norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
83
- dropout: float = 0.):
84
- super().__init__()
85
- padding = (kernel - stride) // 2
86
- Conv = nn.Conv1d
87
- self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
88
- self.norm = nn.GroupNorm(norm_groups, chout)
89
- self.activation = activation()
90
- self.res_blocks = nn.Sequential(
91
- *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
92
- for idx in range(res_blocks)])
93
-
94
- def forward(self, x: torch.Tensor) -> torch.Tensor:
95
- B, C, T = x.shape
96
- stride, = self.conv.stride
97
- pad = (stride - (T % stride)) % stride
98
- x = F.pad(x, (0, pad))
99
-
100
- x = self.conv(x)
101
- x = self.norm(x)
102
- x = self.activation(x)
103
- x = self.res_blocks(x)
104
- return x
105
-
106
-
107
- class BLSTM(nn.Module):
108
- """BiLSTM with same hidden units as input dim.
109
- """
110
- def __init__(self, dim, layers=2):
111
- super().__init__()
112
- self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
113
- self.linear = nn.Linear(2 * dim, dim)
114
-
115
- def forward(self, x):
116
- x = x.permute(2, 0, 1)
117
- x = self.lstm(x)[0]
118
- x = self.linear(x)
119
- x = x.permute(1, 2, 0)
120
- return x
121
-
122
-
123
- class DiffusionUnet(nn.Module):
124
- def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
125
- max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
126
- bilstm: bool = False, transformer: bool = False,
127
- codec_dim: tp.Optional[int] = None, **kwargs):
128
- super().__init__()
129
- self.encoders = nn.ModuleList()
130
- self.decoders = nn.ModuleList()
131
- self.embeddings: tp.Optional[nn.ModuleList] = None
132
- self.embedding = nn.Embedding(num_steps, hidden)
133
- if emb_all_layers:
134
- self.embeddings = nn.ModuleList()
135
- self.condition_embedding: tp.Optional[nn.Module] = None
136
- for d in range(depth):
137
- encoder = EncoderLayer(chin, hidden, **kwargs)
138
- decoder = DecoderLayer(hidden, chin, **kwargs)
139
- self.encoders.append(encoder)
140
- self.decoders.insert(0, decoder)
141
- if emb_all_layers and d > 0:
142
- assert self.embeddings is not None
143
- self.embeddings.append(nn.Embedding(num_steps, hidden))
144
- chin = hidden
145
- hidden = min(int(chin * growth), max_channels)
146
- self.bilstm: tp.Optional[nn.Module]
147
- if bilstm:
148
- self.bilstm = BLSTM(chin)
149
- else:
150
- self.bilstm = None
151
- self.use_transformer = transformer
152
- self.cross_attention = False
153
- if transformer:
154
- self.cross_attention = cross_attention
155
- self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
156
- cross_attention=cross_attention)
157
-
158
- self.use_codec = False
159
- if codec_dim is not None:
160
- self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
161
- self.use_codec = True
162
-
163
- def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
164
- skips = []
165
- bs = x.size(0)
166
- z = x
167
- view_args = [1]
168
- if type(step) is torch.Tensor:
169
- step_tensor = step
170
- else:
171
- step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
172
-
173
- for idx, encoder in enumerate(self.encoders):
174
- z = encoder(z)
175
- if idx == 0:
176
- z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
177
- elif self.embeddings is not None:
178
- z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
179
-
180
- skips.append(z)
181
-
182
- if self.use_codec: # insert condition in the bottleneck
183
- assert condition is not None, "Model defined for conditionnal generation"
184
- condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
185
- assert condition_emb.size(-1) <= 2 * z.size(-1), \
186
- f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
187
- if not self.cross_attention:
188
-
189
- condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
190
- assert z.size() == condition_emb.size()
191
- z += condition_emb
192
- cross_attention_src = None
193
- else:
194
- cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
195
- B, T, C = cross_attention_src.shape
196
- positions = torch.arange(T, device=x.device).view(1, -1, 1)
197
- pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
198
- cross_attention_src = cross_attention_src + pos_emb
199
- if self.use_transformer:
200
- z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
201
- else:
202
- if self.bilstm is None:
203
- z = torch.zeros_like(z)
204
- else:
205
- z = self.bilstm(z)
206
-
207
- for decoder in self.decoders:
208
- s = skips.pop(-1)
209
- z = z[:, :, :s.shape[2]]
210
- z = z + s
211
- z = decoder(z)
212
-
213
- z = z[:, :, :x.shape[2]]
214
- return Output(z)