Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
99122c4
1
Parent(s):
5a343f4
basic readme stuff
Browse files- README.md +31 -53
- conf/lora/gas-station.yml +10 -0
- demo.py +3 -2
- scripts/exp/train.py +9 -87
- scripts/utils/vamp_folder.py +5 -5
- vampnet/interface.py +13 -5
- vampnet/modules/base.py +8 -119
- vampnet/modules/layers.py +14 -0
- vampnet/signal.py +5 -0
- vampnet/util.py +3 -34
README.md
CHANGED
@@ -1,80 +1,58 @@
|
|
1 |
-
#
|
2 |
|
3 |
-
This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
|
4 |
|
|
|
5 |
|
6 |
-
##
|
7 |
-
### Setting everything up
|
8 |
|
9 |
-
|
10 |
|
11 |
```bash
|
12 |
-
|
|
|
13 |
```
|
14 |
|
15 |
-
|
16 |
-
Once run, follow the instructions it prints out to create your
|
17 |
-
environment file, which will be at `env/env.sh`.
|
18 |
-
|
19 |
-
Note that if this is a new machine, and
|
20 |
-
the data is not downloaded somewhere on it already, it will ask you
|
21 |
-
for a directory to download the data to.
|
22 |
-
|
23 |
-
For Github setup, if you don't have a .netrc token, create one by going to your Github profile -> Developer settings -> Personal access tokens -> Generate new token. Copy the token and [keep it secret, keep it safe](https://www.youtube.com/watch?v=iThtELZvfPs).
|
24 |
-
|
25 |
-
When complete, run:
|
26 |
|
27 |
```bash
|
28 |
-
|
|
|
29 |
```
|
30 |
|
31 |
-
|
32 |
|
33 |
```bash
|
34 |
-
|
|
|
35 |
```
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
|
41 |
-
|
42 |
|
43 |
-
|
44 |
-
docker compose run dev
|
45 |
-
```
|
46 |
-
|
47 |
-
To tear down your development environment, just do
|
48 |
-
|
49 |
-
```bash
|
50 |
-
docker compose down
|
51 |
-
```
|
52 |
|
|
|
53 |
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
`stage` creates a directory with a copy of all of the Git-tracked files in the root repository.`stage` launches a shell into said directory, so all commands are run on the
|
59 |
-
copy of the original repository code. This is useful for rewinding to an old experiment
|
60 |
-
and resuming it, for example. Even if the repository code changes, the snapshot in the experiment directory is unchanged from the original run, so it can be re-used.
|
61 |
-
|
62 |
-
Then, the experiment can be run via:
|
63 |
|
64 |
```bash
|
65 |
-
|
66 |
-
scripts/exp/train.py \
|
67 |
-
--args.load=conf/args.yml \
|
68 |
```
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
#### Cleaning up after a run
|
75 |
-
|
76 |
-
Sometimes DDP runs fail to clear themselves out of the machine. To fix this, run
|
77 |
-
|
78 |
```bash
|
79 |
-
|
80 |
```
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VampNet
|
2 |
|
3 |
+
This repository contains recipes for training generative music models on top of the Lyrebird Audio Codec.
|
4 |
|
5 |
+
# Setting up
|
6 |
|
7 |
+
## Install LAC
|
|
|
8 |
|
9 |
+
install AudioTools
|
10 |
|
11 |
```bash
|
12 |
+
git clone https://github.com/hugofloresgarcia/audiotools.git
|
13 |
+
pip install -e ./audiotools
|
14 |
```
|
15 |
|
16 |
+
install the LAC library.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
```bash
|
19 |
+
git clone https://github.com/hugofloresgarcia/lac.git
|
20 |
+
pip install -e ./lac
|
21 |
```
|
22 |
|
23 |
+
install VampNet
|
24 |
|
25 |
```bash
|
26 |
+
git clone https://github.com/hugofloresgarcia/vampnet2.git
|
27 |
+
pip install -e ./vampnet2
|
28 |
```
|
29 |
|
30 |
+
## A note on Argbind
|
31 |
+
This repository relies on [argbind](https://github.com/pseeth/argbind) to manage CLIs and config files.
|
32 |
+
Config files are stored in the `conf/` folder.
|
33 |
|
34 |
+
# Usage
|
35 |
|
36 |
+
## Staging a Run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
+
Staging a run makes a copy of all the git-tracked files in the codebase and saves them to a folder for reproducibility. You can then run the training script from the staged folder.
|
39 |
|
40 |
+
coming soon
|
41 |
|
42 |
+
## Training a model
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
```bash
|
45 |
+
python scripts/exp/train.py --args.load conf/vampnet.yml --save_path /path/to/checkpoints
|
|
|
|
|
46 |
```
|
47 |
|
48 |
+
## Fine-tuning
|
49 |
+
To fine-tune a model, see the configuration files under `conf/lora/`.
|
50 |
+
You just need to provide a list of audio files // folders to fine-tune on, then launch the training job as usual.
|
|
|
|
|
|
|
|
|
|
|
51 |
```bash
|
52 |
+
python scripts/exp/train.py --args.load conf/lora/birds.yml --save_path /path/to/checkpoints
|
53 |
```
|
54 |
+
|
55 |
+
## Launching the Gradio Interface
|
56 |
+
```bash
|
57 |
+
python demo.py --args.load conf/interface/spotdl.yml --Interface.device cuda
|
58 |
+
```
|
conf/lora/gas-station.yml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
|
4 |
+
fine_tune: True
|
5 |
+
|
6 |
+
train/AudioLoader.sources:
|
7 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
8 |
+
|
9 |
+
val/AudioLoader.sources:
|
10 |
+
- /media/CHONK/hugo/spotdl/subsets/gas-station-sushi.mp3
|
demo.py
CHANGED
@@ -48,6 +48,7 @@ def load_audio(file):
|
|
48 |
sig.write(out_dir / "input.wav")
|
49 |
return sig.path_to_file
|
50 |
|
|
|
51 |
def load_random_audio():
|
52 |
index = np.random.randint(0, len(dataset))
|
53 |
sig = dataset[index]["signal"]
|
@@ -68,7 +69,7 @@ def ez_vamp(
|
|
68 |
sig = at.AudioSignal(input_audio)
|
69 |
|
70 |
print(f"running standard vampnet with {num_vamps} vamps")
|
71 |
-
zv = interface.
|
72 |
sig,
|
73 |
sampling_steps=num_steps,
|
74 |
temperature=(init_temp, final_temp),
|
@@ -140,7 +141,7 @@ def vamp(
|
|
140 |
|
141 |
if mode == "standard":
|
142 |
print(f"running standard vampnet with {num_vamps} vamps")
|
143 |
-
zv, mask_z = interface.
|
144 |
sig,
|
145 |
sampling_steps=num_steps,
|
146 |
temperature=(init_temp, final_temp),
|
|
|
48 |
sig.write(out_dir / "input.wav")
|
49 |
return sig.path_to_file
|
50 |
|
51 |
+
|
52 |
def load_random_audio():
|
53 |
index = np.random.randint(0, len(dataset))
|
54 |
sig = dataset[index]["signal"]
|
|
|
69 |
sig = at.AudioSignal(input_audio)
|
70 |
|
71 |
print(f"running standard vampnet with {num_vamps} vamps")
|
72 |
+
zv = interface.coarse_vamp(
|
73 |
sig,
|
74 |
sampling_steps=num_steps,
|
75 |
temperature=(init_temp, final_temp),
|
|
|
141 |
|
142 |
if mode == "standard":
|
143 |
print(f"running standard vampnet with {num_vamps} vamps")
|
144 |
+
zv, mask_z = interface.coarse_vamp(
|
145 |
sig,
|
146 |
sampling_steps=num_steps,
|
147 |
temperature=(init_temp, final_temp),
|
scripts/exp/train.py
CHANGED
@@ -115,6 +115,10 @@ def load(
|
|
115 |
}
|
116 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
117 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
|
|
|
|
|
|
|
|
118 |
|
119 |
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
120 |
codec.eval()
|
@@ -149,25 +153,6 @@ def load(
|
|
149 |
}
|
150 |
|
151 |
|
152 |
-
def get_gpu_memory_map():
|
153 |
-
"""Get the current gpu usage.
|
154 |
-
|
155 |
-
Returns
|
156 |
-
-------
|
157 |
-
usage: dict
|
158 |
-
Keys are device ids as integers.
|
159 |
-
Values are memory usage as integers in MB.
|
160 |
-
"""
|
161 |
-
result = subprocess.check_output(
|
162 |
-
["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"],
|
163 |
-
encoding="utf-8",
|
164 |
-
)
|
165 |
-
# Convert lines into a dictionary
|
166 |
-
gpu_memory = [int(x) for x in result.strip().split("\n")]
|
167 |
-
gpu_memory_map = dict(zip(range(len(gpu_memory)), gpu_memory))
|
168 |
-
gpu_memory_map = {f"gpu/{k}": v / 1024 for k, v in gpu_memory_map.items()}
|
169 |
-
return gpu_memory_map
|
170 |
-
|
171 |
|
172 |
def num_params_hook(o, p):
|
173 |
return o + f" {p/1e6:<.3f}M params."
|
@@ -189,7 +174,6 @@ def accuracy(
|
|
189 |
target: torch.Tensor,
|
190 |
top_k: int = 1,
|
191 |
ignore_index: Optional[int] = None,
|
192 |
-
**kwargs,
|
193 |
) -> torch.Tensor:
|
194 |
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
195 |
preds = rearrange(preds, "b p s -> (b s) p")
|
@@ -214,30 +198,6 @@ def accuracy(
|
|
214 |
|
215 |
return accuracy
|
216 |
|
217 |
-
def sample_prefix_suffix_amt(
|
218 |
-
z,
|
219 |
-
n_batch,
|
220 |
-
prefix_amt,
|
221 |
-
suffix_amt,
|
222 |
-
prefix_dropout,
|
223 |
-
suffix_dropout,
|
224 |
-
rng
|
225 |
-
):
|
226 |
-
"""
|
227 |
-
Sample the number of prefix and suffix tokens to drop.
|
228 |
-
"""
|
229 |
-
if prefix_amt > 0.0:
|
230 |
-
prefix_mask = flip_coin(n_batch, 1 - prefix_dropout, rng)
|
231 |
-
n_prefix = int(prefix_amt * z.shape[-1]) * prefix_mask
|
232 |
-
else:
|
233 |
-
n_prefix = None
|
234 |
-
if suffix_amt > 0.0:
|
235 |
-
suffix_mask = flip_coin(n_batch, 1 - suffix_dropout, rng)
|
236 |
-
n_suffix = int(suffix_amt * z.shape[-1]) * suffix_mask
|
237 |
-
else:
|
238 |
-
n_suffix = None
|
239 |
-
return n_prefix, n_suffix
|
240 |
-
|
241 |
|
242 |
@argbind.bind(without_prefix=True)
|
243 |
def train(
|
@@ -256,10 +216,6 @@ def train(
|
|
256 |
num_workers: int = 10,
|
257 |
detect_anomaly: bool = False,
|
258 |
grad_clip_val: float = 5.0,
|
259 |
-
prefix_amt: float = 0.0,
|
260 |
-
suffix_amt: float = 0.0,
|
261 |
-
prefix_dropout: float = 0.1,
|
262 |
-
suffix_dropout: float = 0.1,
|
263 |
fine_tune: bool = False,
|
264 |
quiet: bool = False,
|
265 |
):
|
@@ -342,16 +298,12 @@ def train(
|
|
342 |
target=r_unmasked_target,
|
343 |
ignore_index=IGNORE_INDEX,
|
344 |
top_k=topk,
|
345 |
-
task="multiclass",
|
346 |
-
num_classes=vn.vocab_size,
|
347 |
)
|
348 |
output[f"{tag}/masked"] = accuracy(
|
349 |
preds=r_z_hat,
|
350 |
target=r_masked_target,
|
351 |
ignore_index=IGNORE_INDEX,
|
352 |
top_k=topk,
|
353 |
-
task="multiclass",
|
354 |
-
num_classes=vn.vocab_size,
|
355 |
)
|
356 |
|
357 |
def train_loop(self, engine, batch):
|
@@ -370,15 +322,7 @@ def train(
|
|
370 |
n_batch = z.shape[0]
|
371 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
372 |
|
373 |
-
|
374 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
375 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
376 |
-
rng=rng
|
377 |
-
)
|
378 |
-
|
379 |
-
z_mask, mask = vn.add_noise(
|
380 |
-
z, r, n_prefix=n_prefix, n_suffix=n_suffix
|
381 |
-
)
|
382 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
383 |
|
384 |
dtype = torch.bfloat16 if accel.amp else None
|
@@ -454,13 +398,7 @@ def train(
|
|
454 |
n_batch = z.shape[0]
|
455 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
456 |
|
457 |
-
|
458 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
459 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
460 |
-
rng=rng
|
461 |
-
)
|
462 |
-
|
463 |
-
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
464 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
465 |
|
466 |
z_hat = model(z_mask_latent, r)
|
@@ -574,17 +512,8 @@ def train(
|
|
574 |
)
|
575 |
|
576 |
def save_imputation(self, z: torch.Tensor):
|
577 |
-
|
578 |
-
|
579 |
-
_suffix_amt = suffix_amt
|
580 |
-
|
581 |
-
if _prefix_amt == 0:
|
582 |
-
_prefix_amt = 0.25
|
583 |
-
if _suffix_amt == 0:
|
584 |
-
_suffix_amt = 0.25
|
585 |
-
|
586 |
-
n_prefix = int(z.shape[-1] * _prefix_amt)
|
587 |
-
n_suffix = int(z.shape[-1] * _suffix_amt)
|
588 |
downsample_factor = None
|
589 |
|
590 |
vn = accel.unwrap(model)
|
@@ -647,13 +576,7 @@ def train(
|
|
647 |
|
648 |
n_batch = z.shape[0]
|
649 |
|
650 |
-
|
651 |
-
n_batch=n_batch, prefix_amt=prefix_amt, suffix_amt=suffix_amt,
|
652 |
-
prefix_dropout=prefix_dropout, suffix_dropout=suffix_dropout,
|
653 |
-
rng=rng
|
654 |
-
)
|
655 |
-
|
656 |
-
z_mask, mask = vn.add_noise(z, r, n_prefix=n_prefix, n_suffix=n_suffix)
|
657 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
658 |
|
659 |
z_hat = model(z_mask_latent, r)
|
@@ -664,7 +587,6 @@ def train(
|
|
664 |
z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
|
665 |
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
666 |
|
667 |
-
print("z_mask", z_mask.shape)
|
668 |
generated = vn.to_signal(z_pred, codec)
|
669 |
reconstructed = vn.to_signal(z, codec)
|
670 |
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
|
|
115 |
}
|
116 |
if (Path(kwargs["folder"]) / "vampnet").exists():
|
117 |
model, v_extra = VampNet.load_from_folder(**kwargs)
|
118 |
+
else:
|
119 |
+
raise ValueError(
|
120 |
+
f"Could not find a VampNet checkpoint in {kwargs['folder']}"
|
121 |
+
)
|
122 |
|
123 |
codec = LAC.load(args["codec_ckpt"], map_location="cpu")
|
124 |
codec.eval()
|
|
|
153 |
}
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
def num_params_hook(o, p):
|
158 |
return o + f" {p/1e6:<.3f}M params."
|
|
|
174 |
target: torch.Tensor,
|
175 |
top_k: int = 1,
|
176 |
ignore_index: Optional[int] = None,
|
|
|
177 |
) -> torch.Tensor:
|
178 |
# Flatten the predictions and targets to be of shape (batch_size * sequence_length, n_class)
|
179 |
preds = rearrange(preds, "b p s -> (b s) p")
|
|
|
198 |
|
199 |
return accuracy
|
200 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
@argbind.bind(without_prefix=True)
|
203 |
def train(
|
|
|
216 |
num_workers: int = 10,
|
217 |
detect_anomaly: bool = False,
|
218 |
grad_clip_val: float = 5.0,
|
|
|
|
|
|
|
|
|
219 |
fine_tune: bool = False,
|
220 |
quiet: bool = False,
|
221 |
):
|
|
|
298 |
target=r_unmasked_target,
|
299 |
ignore_index=IGNORE_INDEX,
|
300 |
top_k=topk,
|
|
|
|
|
301 |
)
|
302 |
output[f"{tag}/masked"] = accuracy(
|
303 |
preds=r_z_hat,
|
304 |
target=r_masked_target,
|
305 |
ignore_index=IGNORE_INDEX,
|
306 |
top_k=topk,
|
|
|
|
|
307 |
)
|
308 |
|
309 |
def train_loop(self, engine, batch):
|
|
|
322 |
n_batch = z.shape[0]
|
323 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
324 |
|
325 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
327 |
|
328 |
dtype = torch.bfloat16 if accel.amp else None
|
|
|
398 |
n_batch = z.shape[0]
|
399 |
r = rng.draw(n_batch)[:, 0].to(accel.device)
|
400 |
|
401 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
403 |
|
404 |
z_hat = model(z_mask_latent, r)
|
|
|
512 |
)
|
513 |
|
514 |
def save_imputation(self, z: torch.Tensor):
|
515 |
+
n_prefix = int(z.shape[-1] * 0.25)
|
516 |
+
n_suffix = int(z.shape[-1] * 0.25)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
downsample_factor = None
|
518 |
|
519 |
vn = accel.unwrap(model)
|
|
|
576 |
|
577 |
n_batch = z.shape[0]
|
578 |
|
579 |
+
z_mask, mask = vn.add_noise(z, r)
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
z_mask_latent = vn.embedding.from_codes(z_mask, codec)
|
581 |
|
582 |
z_hat = model(z_mask_latent, r)
|
|
|
587 |
z_pred = vn.embedding.unflatten(z_pred, n_codebooks=vn.n_predict_codebooks)
|
588 |
z_pred = torch.cat([z[:, : vn.n_conditioning_codebooks, :], z_pred], dim=1)
|
589 |
|
|
|
590 |
generated = vn.to_signal(z_pred, codec)
|
591 |
reconstructed = vn.to_signal(z, codec)
|
592 |
masked = vn.to_signal(z_mask.squeeze(1), codec)
|
scripts/utils/vamp_folder.py
CHANGED
@@ -56,7 +56,7 @@ class CoarseCond:
|
|
56 |
|
57 |
def __call__(self, sig, interface):
|
58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
59 |
-
zv = interface.
|
60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
61 |
downsample_factor=self.downsample_factor,
|
62 |
)
|
@@ -113,7 +113,7 @@ def mask_ratio_1_step(ratio=1.0):
|
|
113 |
r = interface.coarse.invgamma(ratio).to(interface.device)
|
114 |
intensity = 1-r
|
115 |
|
116 |
-
zv = interface.
|
117 |
sig,
|
118 |
sample='argmax',
|
119 |
sampling_steps=1,
|
@@ -125,7 +125,7 @@ def mask_ratio_1_step(ratio=1.0):
|
|
125 |
|
126 |
def num_sampling_steps(num_steps=1):
|
127 |
def wrapper(sig, interface):
|
128 |
-
zv = interface.
|
129 |
sig,
|
130 |
downsample_factor=16,
|
131 |
sampling_steps=num_steps,
|
@@ -143,7 +143,7 @@ def beat_mask(ctx_time):
|
|
143 |
after_beat_s=ctx_time,
|
144 |
invert=True
|
145 |
)
|
146 |
-
zv = interface.
|
147 |
sig,
|
148 |
ext_mask=beat_mask,
|
149 |
)
|
@@ -154,7 +154,7 @@ def beat_mask(ctx_time):
|
|
154 |
|
155 |
def inpaint(ctx_time):
|
156 |
def wrapper(sig, interface):
|
157 |
-
zv = interface.
|
158 |
sig,
|
159 |
prefix_dur_s=ctx_time,
|
160 |
suffix_dur_s=ctx_time,
|
|
|
56 |
|
57 |
def __call__(self, sig, interface):
|
58 |
n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
|
59 |
+
zv = interface.coarse_vamp(sig,
|
60 |
n_conditioning_codebooks=n_conditioning_codebooks,
|
61 |
downsample_factor=self.downsample_factor,
|
62 |
)
|
|
|
113 |
r = interface.coarse.invgamma(ratio).to(interface.device)
|
114 |
intensity = 1-r
|
115 |
|
116 |
+
zv = interface.coarse_vamp(
|
117 |
sig,
|
118 |
sample='argmax',
|
119 |
sampling_steps=1,
|
|
|
125 |
|
126 |
def num_sampling_steps(num_steps=1):
|
127 |
def wrapper(sig, interface):
|
128 |
+
zv = interface.coarse_vamp(
|
129 |
sig,
|
130 |
downsample_factor=16,
|
131 |
sampling_steps=num_steps,
|
|
|
143 |
after_beat_s=ctx_time,
|
144 |
invert=True
|
145 |
)
|
146 |
+
zv = interface.coarse_vamp(
|
147 |
sig,
|
148 |
ext_mask=beat_mask,
|
149 |
)
|
|
|
154 |
|
155 |
def inpaint(ctx_time):
|
156 |
def wrapper(sig, interface):
|
157 |
+
zv = interface.coarse_vamp(
|
158 |
sig,
|
159 |
prefix_dur_s=ctx_time,
|
160 |
suffix_dur_s=ctx_time,
|
vampnet/interface.py
CHANGED
@@ -20,6 +20,14 @@ def signal_concat(
|
|
20 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
21 |
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
class Interface(torch.nn.Module):
|
24 |
def __init__(
|
25 |
self,
|
@@ -28,7 +36,7 @@ class Interface(torch.nn.Module):
|
|
28 |
codec_ckpt: str = None,
|
29 |
wavebeat_ckpt: str = None,
|
30 |
device: str = "cpu",
|
31 |
-
coarse_chunk_size_s: int =
|
32 |
coarse2fine_chunk_size_s: int = 3,
|
33 |
):
|
34 |
super().__init__()
|
@@ -141,7 +149,7 @@ class Interface(torch.nn.Module):
|
|
141 |
"""make a beat synced mask. that is, make a mask that
|
142 |
places 1s at and around the beat, and 0s everywhere else.
|
143 |
"""
|
144 |
-
assert
|
145 |
|
146 |
# get the beat times
|
147 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
@@ -242,7 +250,7 @@ class Interface(torch.nn.Module):
|
|
242 |
return fine_z[:, :, :length].clone()
|
243 |
|
244 |
|
245 |
-
def
|
246 |
self,
|
247 |
signal,
|
248 |
prefix_dur_s: float = 0.0,
|
@@ -471,7 +479,7 @@ class Interface(torch.nn.Module):
|
|
471 |
else:
|
472 |
ext_mask = None
|
473 |
|
474 |
-
out_z = self.
|
475 |
sig,
|
476 |
num_vamps=1,
|
477 |
swap_prefix_suffix=False,
|
@@ -520,7 +528,7 @@ class Interface(torch.nn.Module):
|
|
520 |
range_fn = range if not verbose else tqdm.trange
|
521 |
for i in range_fn(num_loops):
|
522 |
is_flipped = i % 2 == 0
|
523 |
-
vamped = self.
|
524 |
signal,
|
525 |
prefix_dur_s=prefix_dur_s,
|
526 |
suffix_dur_s=suffix_dur_s,
|
|
|
20 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
21 |
|
22 |
|
23 |
+
class SignalPrompt:
|
24 |
+
|
25 |
+
def __init__(self, signal: AudioSignal):
|
26 |
+
self.sig = signal
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
class Interface(torch.nn.Module):
|
32 |
def __init__(
|
33 |
self,
|
|
|
36 |
codec_ckpt: str = None,
|
37 |
wavebeat_ckpt: str = None,
|
38 |
device: str = "cpu",
|
39 |
+
coarse_chunk_size_s: int = 10,
|
40 |
coarse2fine_chunk_size_s: int = 3,
|
41 |
):
|
42 |
super().__init__()
|
|
|
149 |
"""make a beat synced mask. that is, make a mask that
|
150 |
places 1s at and around the beat, and 0s everywhere else.
|
151 |
"""
|
152 |
+
assert self.beat_tracker is not None, "No beat tracker loaded"
|
153 |
|
154 |
# get the beat times
|
155 |
beats, downbeats = self.beat_tracker.extract_beats(signal)
|
|
|
250 |
return fine_z[:, :, :length].clone()
|
251 |
|
252 |
|
253 |
+
def coarse_vamp(
|
254 |
self,
|
255 |
signal,
|
256 |
prefix_dur_s: float = 0.0,
|
|
|
479 |
else:
|
480 |
ext_mask = None
|
481 |
|
482 |
+
out_z = self.coarse_vamp(
|
483 |
sig,
|
484 |
num_vamps=1,
|
485 |
swap_prefix_suffix=False,
|
|
|
528 |
range_fn = range if not verbose else tqdm.trange
|
529 |
for i in range_fn(num_loops):
|
530 |
is_flipped = i % 2 == 0
|
531 |
+
vamped = self.coarse_vamp(
|
532 |
signal,
|
533 |
prefix_dur_s=prefix_dur_s,
|
534 |
suffix_dur_s=suffix_dur_s,
|
vampnet/modules/base.py
CHANGED
@@ -10,6 +10,8 @@ import torch.nn.functional as F
|
|
10 |
from einops import rearrange
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
|
|
13 |
|
14 |
def log(t, eps=1e-20):
|
15 |
return torch.log(t + eps)
|
@@ -24,9 +26,6 @@ def gumbel_sample(t, temperature=1.0, dim=-1):
|
|
24 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
25 |
|
26 |
|
27 |
-
def scalar_to_batch_tensor(x, batch_size):
|
28 |
-
return torch.tensor(x).repeat(batch_size)
|
29 |
-
|
30 |
class VampBase(at.ml.BaseModel):
|
31 |
def forward(self, x: torch.Tensor, r: torch.Tensor):
|
32 |
raise NotImplementedError
|
@@ -150,6 +149,8 @@ class VampBase(at.ml.BaseModel):
|
|
150 |
z_hat = z_hat * mask + truth * (1 - mask)
|
151 |
|
152 |
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
|
|
|
|
153 |
|
154 |
return z_hat
|
155 |
|
@@ -186,6 +187,9 @@ class VampBase(at.ml.BaseModel):
|
|
186 |
|
187 |
@torch.no_grad()
|
188 |
def to_signal(self, z, codec):
|
|
|
|
|
|
|
189 |
if z.ndim == 2:
|
190 |
z = self.embedding.unflatten(z)
|
191 |
assert z.ndim == 3
|
@@ -207,122 +211,7 @@ class VampBase(at.ml.BaseModel):
|
|
207 |
return signal
|
208 |
|
209 |
@torch.no_grad()
|
210 |
-
def sample(
|
211 |
-
if self.noise_mode == "mask":
|
212 |
-
return self.maskgit_sample(**kwargs)
|
213 |
-
else:
|
214 |
-
return self.paella_sample(**kwargs)
|
215 |
-
|
216 |
-
def paella_sample(
|
217 |
-
self,
|
218 |
-
codec,
|
219 |
-
time_steps: int = 400,
|
220 |
-
sampling_steps: int = 36,
|
221 |
-
start_tokens: Optional[torch.Tensor] = None,
|
222 |
-
mask: Optional[torch.Tensor] = None,
|
223 |
-
temperature: Union[float, Tuple[float, float]] = 0.8,
|
224 |
-
top_k: int = None,
|
225 |
-
sample: str = "gumbel",
|
226 |
-
renoise_mode: str = "start",
|
227 |
-
renoise_steps=None,
|
228 |
-
typical_filtering=True,
|
229 |
-
typical_mass=0.2,
|
230 |
-
typical_min_tokens=1,
|
231 |
-
return_signal=True,
|
232 |
-
):
|
233 |
-
|
234 |
-
r = torch.linspace(0, 1, sampling_steps + 1)[:-1][:, None].to(self.device)
|
235 |
-
if renoise_steps == None:
|
236 |
-
renoise_steps = sampling_steps - 1
|
237 |
-
|
238 |
-
if isinstance(temperature, float):
|
239 |
-
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
240 |
-
elif isinstance(temperature, tuple):
|
241 |
-
assert len(temperature) == 2
|
242 |
-
l, h = temperature
|
243 |
-
temperature = torch.linspace(l, h, sampling_steps)
|
244 |
-
else:
|
245 |
-
raise TypeError(f"invalid type for temperature")
|
246 |
-
|
247 |
-
if self.n_conditioning_codebooks > 0:
|
248 |
-
assert (
|
249 |
-
start_tokens is not None
|
250 |
-
), "must provide start_tokens if n_conditioning_codebooks > 0"
|
251 |
-
|
252 |
-
if start_tokens is None:
|
253 |
-
if self.noise_mode == "noise":
|
254 |
-
z = torch.randint(
|
255 |
-
0, self.vocab_size, size=(1, self.n_codebooks, time_steps)
|
256 |
-
).to(self.device)
|
257 |
-
elif self.noise_mode == "mask":
|
258 |
-
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token)
|
259 |
-
else:
|
260 |
-
z = start_tokens
|
261 |
-
assert (
|
262 |
-
z.ndim == 3
|
263 |
-
), f"start_tokens must be shape (batch, n_codebooks, seq_len), got {z.shape}"
|
264 |
-
assert z.shape[0] == 1, f"batch size must be 1"
|
265 |
-
|
266 |
-
if mask is None:
|
267 |
-
mask = torch.ones(z.shape[0], z.shape[-1]).to(self.device).int()
|
268 |
-
mask = mask[:, None, :]
|
269 |
-
mask = mask.repeat(1, z.shape[1], 1)
|
270 |
-
|
271 |
-
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
272 |
-
|
273 |
-
|
274 |
-
z_true = z.clone()
|
275 |
-
|
276 |
-
z, mask = self.add_noise(z, r=r[0], random_x=None, mask=mask)
|
277 |
-
z_init = z.clone()
|
278 |
-
for i, tmpt in enumerate(temperature):
|
279 |
-
if renoise_mode == "prev":
|
280 |
-
z_prev = z.clone()
|
281 |
-
|
282 |
-
latents = self.embedding.from_codes(z, codec)
|
283 |
-
logits = self.forward(latents, r[i])
|
284 |
-
|
285 |
-
# for mask mode
|
286 |
-
logits = self.add_truth_to_logits(z_true, logits, mask)
|
287 |
-
|
288 |
-
# Apply topk sampling
|
289 |
-
logits = logits.permute(0, 2, 1)
|
290 |
-
|
291 |
-
z = self.sample_from_logits(
|
292 |
-
logits,
|
293 |
-
top_k=top_k,
|
294 |
-
temperature=tmpt,
|
295 |
-
sample=sample,
|
296 |
-
typical_filtering=typical_filtering,
|
297 |
-
typical_mass=typical_mass,
|
298 |
-
typical_min_tokens=typical_min_tokens,
|
299 |
-
)
|
300 |
-
|
301 |
-
# add back in conditioning codebooks
|
302 |
-
z = self.embedding.unflatten(z, n_codebooks=self.n_predict_codebooks)
|
303 |
-
z = torch.cat(
|
304 |
-
[z_init[:, : self.n_conditioning_codebooks, :], z], dim=1
|
305 |
-
).int()
|
306 |
-
|
307 |
-
if i < renoise_steps:
|
308 |
-
if renoise_mode == "prev":
|
309 |
-
z, _ = self.add_noise(z, r[i + 1], random_x=z_prev)
|
310 |
-
elif renoise_mode == "start":
|
311 |
-
z, _ = self.add_noise(z, r[i + 1], random_x=z_init)
|
312 |
-
elif renoise_mode == "rand":
|
313 |
-
z, _ = self.add_noise(z, r[i + 1])
|
314 |
-
else:
|
315 |
-
raise ValueError(f"Invalid renoise_mode: {renoise_mode}")
|
316 |
-
|
317 |
-
if mask is not None:
|
318 |
-
z = start_tokens * (1 - mask) + z * mask
|
319 |
-
|
320 |
-
if return_signal:
|
321 |
-
return self.to_signal(z, codec)
|
322 |
-
else:
|
323 |
-
return z
|
324 |
-
|
325 |
-
def maskgit_sample(
|
326 |
self,
|
327 |
codec,
|
328 |
time_steps: int = 300,
|
|
|
10 |
from einops import rearrange
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
+
from ..util import scalar_to_batch_tensor
|
14 |
+
|
15 |
|
16 |
def log(t, eps=1e-20):
|
17 |
return torch.log(t + eps)
|
|
|
26 |
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)
|
27 |
|
28 |
|
|
|
|
|
|
|
29 |
class VampBase(at.ml.BaseModel):
|
30 |
def forward(self, x: torch.Tensor, r: torch.Tensor):
|
31 |
raise NotImplementedError
|
|
|
149 |
z_hat = z_hat * mask + truth * (1 - mask)
|
150 |
|
151 |
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
152 |
+
else:
|
153 |
+
raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
|
154 |
|
155 |
return z_hat
|
156 |
|
|
|
187 |
|
188 |
@torch.no_grad()
|
189 |
def to_signal(self, z, codec):
|
190 |
+
"""
|
191 |
+
convert a sequence of latents to a signal.
|
192 |
+
"""
|
193 |
if z.ndim == 2:
|
194 |
z = self.embedding.unflatten(z)
|
195 |
assert z.ndim == 3
|
|
|
211 |
return signal
|
212 |
|
213 |
@torch.no_grad()
|
214 |
+
def sample(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
self,
|
216 |
codec,
|
217 |
time_steps: int = 300,
|
vampnet/modules/layers.py
CHANGED
@@ -132,6 +132,11 @@ class CodebookEmbedding(nn.Module):
|
|
132 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
133 |
|
134 |
def from_codes(self, codes: torch.Tensor, codec):
|
|
|
|
|
|
|
|
|
|
|
135 |
n_codebooks = codes.shape[1]
|
136 |
latent = []
|
137 |
for i in range(n_codebooks):
|
@@ -151,14 +156,23 @@ class CodebookEmbedding(nn.Module):
|
|
151 |
return latent
|
152 |
|
153 |
def forward(self, latents: torch.Tensor):
|
|
|
|
|
|
|
154 |
x = self.out_proj(latents)
|
155 |
return x
|
156 |
|
157 |
def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
|
|
|
|
|
|
|
158 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
159 |
return rearrange(tokens, "b c t -> b (t c)", c=n_c)
|
160 |
|
161 |
def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
|
|
|
|
|
|
|
162 |
nb, nt = flat_tokens.shape
|
163 |
|
164 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
|
|
132 |
self.out_proj = nn.Conv1d(n_codebooks * self.latent_dim, self.emb_dim, 1)
|
133 |
|
134 |
def from_codes(self, codes: torch.Tensor, codec):
|
135 |
+
"""
|
136 |
+
get a sequence of continuous embeddings from a sequence of discrete codes.
|
137 |
+
unlike it's counterpart in the original VQ-VAE, this function adds for any special tokens
|
138 |
+
necessary for the language model, like <MASK>.
|
139 |
+
"""
|
140 |
n_codebooks = codes.shape[1]
|
141 |
latent = []
|
142 |
for i in range(n_codebooks):
|
|
|
156 |
return latent
|
157 |
|
158 |
def forward(self, latents: torch.Tensor):
|
159 |
+
"""
|
160 |
+
project a sequence of latents to a sequence of embeddings
|
161 |
+
"""
|
162 |
x = self.out_proj(latents)
|
163 |
return x
|
164 |
|
165 |
def flatten(self, tokens: torch.Tensor, n_codebooks: int = None):
|
166 |
+
"""
|
167 |
+
flatten a sequence of tokens from (batch, codebook, time) to (batch, codebook * time)
|
168 |
+
"""
|
169 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
170 |
return rearrange(tokens, "b c t -> b (t c)", c=n_c)
|
171 |
|
172 |
def unflatten(self, flat_tokens: torch.Tensor, n_codebooks: int = None):
|
173 |
+
"""
|
174 |
+
unflatten a sequence of tokens from (batch, codebook * time) to (batch, codebook, time)
|
175 |
+
"""
|
176 |
nb, nt = flat_tokens.shape
|
177 |
|
178 |
n_c = n_codebooks if n_codebooks is not None else self.n_codebooks
|
vampnet/signal.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
|
4 |
+
from .util import scalar_to_batch_tensor
|
5 |
+
|
vampnet/util.py
CHANGED
@@ -1,40 +1,9 @@
|
|
1 |
import tqdm
|
2 |
-
# import pathos
|
3 |
|
4 |
-
|
5 |
-
"""
|
6 |
-
Equivalent of `list(map(fn, *iterables))`
|
7 |
-
driven by `concurrent.futures.ProcessPoolExecutor`.
|
8 |
|
9 |
-
|
10 |
-
|
11 |
-
tqdm_class : optional
|
12 |
-
`tqdm` class to use for bars [default: tqdm.auto.tqdm].
|
13 |
-
max_workers : int, optional
|
14 |
-
Maximum number of workers to spawn; passed to
|
15 |
-
`concurrent.futures.ProcessPoolExecutor.__init__`.
|
16 |
-
[default: min(32, cpu_count() + 4)].
|
17 |
-
chunksize : int, optional
|
18 |
-
Size of chunks sent to worker processes; passed to
|
19 |
-
`concurrent.futures.ProcessPoolExecutor.map`. [default: 1].
|
20 |
-
lock_name : str, optional
|
21 |
-
Member of `tqdm_class.get_lock()` to use [default: mp_lock].
|
22 |
-
"""
|
23 |
-
from concurrent.futures import ProcessPoolExecutor
|
24 |
-
if iterables and "chunksize" not in tqdm_kwargs:
|
25 |
-
# default `chunksize=1` has poor performance for large iterables
|
26 |
-
# (most time spent dispatching items to workers).
|
27 |
-
longest_iterable_len = max(map(length_hint, iterables))
|
28 |
-
if longest_iterable_len > 1000:
|
29 |
-
from warnings import warn
|
30 |
-
warn("Iterable length %d > 1000 but `chunksize` is not set."
|
31 |
-
" This may seriously degrade multiprocess performance."
|
32 |
-
" Set `chunksize=1` or more." % longest_iterable_len,
|
33 |
-
TqdmWarning, stacklevel=2)
|
34 |
-
if "lock_name" not in tqdm_kwargs:
|
35 |
-
tqdm_kwargs = tqdm_kwargs.copy()
|
36 |
-
tqdm_kwargs["lock_name"] = "mp_lock"
|
37 |
-
return _executor_map(ProcessPoolExecutor, fn, *iterables, **tqdm_kwargs)
|
38 |
|
39 |
|
40 |
def parallelize(
|
|
|
1 |
import tqdm
|
|
|
2 |
|
3 |
+
import torch
|
|
|
|
|
|
|
4 |
|
5 |
+
def scalar_to_batch_tensor(x, batch_size):
|
6 |
+
return torch.tensor(x).repeat(batch_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def parallelize(
|