Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
3f6f517
1
Parent(s):
75a7169
critical sampling fix, two demoes for comparing old and new sampling
Browse files- conf/generated/bulgarian-tv-choir/c2f.yml +15 -0
- conf/generated/bulgarian-tv-choir/coarse.yml +8 -0
- conf/generated/bulgarian-tv-choir/interface.yml +7 -0
- conf/generated/panchos/c2f.yml +15 -0
- conf/generated/panchos/coarse.yml +8 -0
- conf/generated/panchos/interface.yml +7 -0
- conf/generated/titi-monkey/c2f.yml +15 -0
- conf/generated/titi-monkey/coarse.yml +8 -0
- conf/generated/titi-monkey/interface.yml +7 -0
- conf/interface/spotdl.yml +1 -1
- demo-new.py +518 -0
- demo.py +65 -5
- scripts/exp/train.py +6 -12
- scripts/utils/augment.py +53 -0
- vampnet/interface.py +46 -32
- vampnet/mask.py +1 -1
- vampnet/modules/transformer.py +288 -32
conf/generated/bulgarian-tv-choir/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/bulgarian-tv-choir/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/bulgarian-tv-choir/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/bulgarian-tv-choir/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/bulgarian-tv-choir/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/bulgarian-female-tv-choir/
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/bulgarian-tv-choir/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/bulgarian-tv-choir/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/panchos/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/panchos/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/panchos/
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/panchos/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/panchos/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/panchos/
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/panchos/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/panchos/
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/panchos/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/panchos/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/generated/titi-monkey/c2f.yml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
AudioDataset.duration: 3.0
|
4 |
+
AudioDataset.loudness_cutoff: -40.0
|
5 |
+
VampNet.embedding_dim: 1280
|
6 |
+
VampNet.n_codebooks: 14
|
7 |
+
VampNet.n_conditioning_codebooks: 4
|
8 |
+
VampNet.n_heads: 20
|
9 |
+
VampNet.n_layers: 16
|
10 |
+
fine_tune: true
|
11 |
+
fine_tune_checkpoint: ./models/spotdl/c2f.pth
|
12 |
+
save_path: ./runs/titi-monkey/c2f
|
13 |
+
train/AudioLoader.sources: &id001
|
14 |
+
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
15 |
+
val/AudioLoader.sources: *id001
|
conf/generated/titi-monkey/coarse.yml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
$include:
|
2 |
+
- conf/lora/lora.yml
|
3 |
+
fine_tune: true
|
4 |
+
fine_tune_checkpoint: ./models/spotdl/coarse.pth
|
5 |
+
save_path: ./runs/titi-monkey/coarse
|
6 |
+
train/AudioLoader.sources: &id001
|
7 |
+
- /media/CHONK/hugo/loras/titi-monkey.mp3
|
8 |
+
val/AudioLoader.sources: *id001
|
conf/generated/titi-monkey/interface.yml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
AudioLoader.sources:
|
2 |
+
- - /media/CHONK/hugo/loras/titi-monkey.mp3
|
3 |
+
Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
|
4 |
+
Interface.coarse2fine_lora_ckpt: ./runs/titi-monkey/c2f/latest/lora.pth
|
5 |
+
Interface.coarse_ckpt: ./models/spotdl/coarse.pth
|
6 |
+
Interface.coarse_lora_ckpt: ./runs/titi-monkey/coarse/latest/lora.pth
|
7 |
+
Interface.codec_ckpt: ./models/spotdl/codec.pth
|
conf/interface/spotdl.yml
CHANGED
@@ -7,6 +7,6 @@ Interface.coarse2fine_chunk_size_s: 3
|
|
7 |
|
8 |
|
9 |
AudioLoader.sources:
|
10 |
-
# - /media/CHONK/hugo/spotdl/subsets/jazz-blues
|
11 |
- /media/CHONK/null
|
12 |
|
|
|
7 |
|
8 |
|
9 |
AudioLoader.sources:
|
10 |
+
# - /media/CHONK/hugo/spotdl/subsets/jazz-blues/
|
11 |
- /media/CHONK/null
|
12 |
|
demo-new.py
ADDED
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Tuple
|
3 |
+
import yaml
|
4 |
+
import tempfile
|
5 |
+
import uuid
|
6 |
+
from dataclasses import dataclass, asdict
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import audiotools as at
|
10 |
+
import argbind
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
from vampnet.interface import Interface
|
14 |
+
from vampnet import mask as pmask
|
15 |
+
|
16 |
+
import logging
|
17 |
+
logger = logging.getLogger()
|
18 |
+
logger.setLevel(logging.CRITICAL)
|
19 |
+
|
20 |
+
Interface = argbind.bind(Interface)
|
21 |
+
AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
|
22 |
+
|
23 |
+
conf = argbind.parse_args()
|
24 |
+
|
25 |
+
with argbind.scope(conf):
|
26 |
+
interface = Interface()
|
27 |
+
loader = AudioLoader()
|
28 |
+
print(f"interface device is {interface.device}")
|
29 |
+
|
30 |
+
dataset = at.data.datasets.AudioDataset(
|
31 |
+
loader,
|
32 |
+
sample_rate=interface.codec.sample_rate,
|
33 |
+
duration=interface.coarse.chunk_size_s,
|
34 |
+
n_examples=5000,
|
35 |
+
without_replacement=True,
|
36 |
+
)
|
37 |
+
|
38 |
+
|
39 |
+
checkpoints = {
|
40 |
+
"spotdl": {
|
41 |
+
"coarse": "./models/spotdl/coarse.pth",
|
42 |
+
"c2f": "./models/spotdl/c2f.pth",
|
43 |
+
"codec": "./models/spotdl/codec.pth",
|
44 |
+
"full_ckpt": True
|
45 |
+
},
|
46 |
+
"berta": {
|
47 |
+
"coarse": "./models/finetuned/berta-goldman-speech/coarse.pth",
|
48 |
+
"c2f": "./models/finetuned/berta-goldman-speech/c2f.pth",
|
49 |
+
"codec": "./model/spotdl/codec.pth",
|
50 |
+
"full_ckpt": True
|
51 |
+
},
|
52 |
+
"xeno-canto-2": {
|
53 |
+
"coarse": "./models/finetuned/xeno-canto-2/coarse.pth",
|
54 |
+
"c2f": "./models/finetuned/xeno-canto-2/c2f.pth",
|
55 |
+
"codec": "./models/spotdl/codec.pth",
|
56 |
+
"full_ckpt": True
|
57 |
+
},
|
58 |
+
"panchos": {
|
59 |
+
"coarse": "./models/finetuned/panchos/coarse.pth",
|
60 |
+
"c2f": "./models/finetuned/panchos/c2f.pth",
|
61 |
+
"codec": "./models/spotdl/codec.pth",
|
62 |
+
"full_ckpt": False
|
63 |
+
},
|
64 |
+
"tv-choir": {
|
65 |
+
"coarse": "./models/finetuned/tv-choir/coarse.pth",
|
66 |
+
"c2f": "./models/finetuned/tv-choir/c2f.pth",
|
67 |
+
"codec": "./models/spotdl/codec.pth",
|
68 |
+
"full_ckpt": False
|
69 |
+
},
|
70 |
+
"titi": {
|
71 |
+
"coarse": "./models/finetuned/titi/coarse.pth",
|
72 |
+
"c2f": "./models/finetuned/titi/c2f.pth",
|
73 |
+
"codec": "./models/spotdl/codec.pth",
|
74 |
+
"full_ckpt": False
|
75 |
+
},
|
76 |
+
"titi-clean": {
|
77 |
+
"coarse": "./models/finetuned/titi-clean/coarse.pth",
|
78 |
+
"c2f": "./models/finetuned/titi-clean/c2f.pth",
|
79 |
+
"codec": "./models/spotdl/codec.pth",
|
80 |
+
"full_ckpt": False
|
81 |
+
}
|
82 |
+
}
|
83 |
+
interface.checkpoint_key = "spotdl"
|
84 |
+
|
85 |
+
|
86 |
+
OUT_DIR = Path("gradio-outputs")
|
87 |
+
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
88 |
+
|
89 |
+
|
90 |
+
def load_audio(file):
|
91 |
+
print(file)
|
92 |
+
filepath = file.name
|
93 |
+
sig = at.AudioSignal.salient_excerpt(
|
94 |
+
filepath,
|
95 |
+
duration=interface.coarse.chunk_size_s
|
96 |
+
)
|
97 |
+
sig = interface.preprocess(sig)
|
98 |
+
|
99 |
+
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
|
100 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
101 |
+
sig.write(out_dir / "input.wav")
|
102 |
+
return sig.path_to_file
|
103 |
+
|
104 |
+
|
105 |
+
def load_random_audio():
|
106 |
+
index = np.random.randint(0, len(dataset))
|
107 |
+
sig = dataset[index]["signal"]
|
108 |
+
sig = interface.preprocess(sig)
|
109 |
+
|
110 |
+
out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
|
111 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
112 |
+
sig.write(out_dir / "input.wav")
|
113 |
+
return sig.path_to_file
|
114 |
+
|
115 |
+
|
116 |
+
def _vamp(data, return_mask=False):
|
117 |
+
|
118 |
+
# if our checkpoint key is different, we need to load a new checkpoint
|
119 |
+
if data[checkpoint_key] != interface.checkpoint_key:
|
120 |
+
print(f"loading checkpoint {data[checkpoint_key]}")
|
121 |
+
interface.lora_load(
|
122 |
+
checkpoints[data[checkpoint_key]]["coarse"],
|
123 |
+
checkpoints[data[checkpoint_key]]["c2f"],
|
124 |
+
checkpoints[data[checkpoint_key]]["full_ckpt"],
|
125 |
+
)
|
126 |
+
interface.checkpoint_key = data[checkpoint_key]
|
127 |
+
|
128 |
+
out_dir = OUT_DIR / str(uuid.uuid4())
|
129 |
+
out_dir.mkdir()
|
130 |
+
sig = at.AudioSignal(data[input_audio])
|
131 |
+
#pitch shift input
|
132 |
+
sig = sig.shift_pitch(data[input_pitch_shift])
|
133 |
+
|
134 |
+
# TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
|
135 |
+
|
136 |
+
z = interface.encode(sig)
|
137 |
+
|
138 |
+
ncc = data[n_conditioning_codebooks]
|
139 |
+
|
140 |
+
# build the mask
|
141 |
+
mask = pmask.linear_random(z, data[rand_mask_intensity])
|
142 |
+
mask = pmask.mask_and(
|
143 |
+
mask, pmask.inpaint(
|
144 |
+
z,
|
145 |
+
interface.s2t(data[prefix_s]),
|
146 |
+
interface.s2t(data[suffix_s])
|
147 |
+
)
|
148 |
+
)
|
149 |
+
mask = pmask.mask_and(
|
150 |
+
mask, pmask.periodic_mask(
|
151 |
+
z,
|
152 |
+
data[periodic_p],
|
153 |
+
data[periodic_w],
|
154 |
+
random_roll=True
|
155 |
+
)
|
156 |
+
)
|
157 |
+
if data[onset_mask_width] > 0:
|
158 |
+
mask = pmask.mask_or(
|
159 |
+
mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
|
160 |
+
)
|
161 |
+
# these should be the last two mask ops
|
162 |
+
mask = pmask.dropout(mask, data[dropout])
|
163 |
+
mask = pmask.codebook_unmask(mask, ncc)
|
164 |
+
|
165 |
+
print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[init_temp]}, final temp {data[final_temp]}, use coarse2fine {data[use_coarse2fine]}")
|
166 |
+
# save the mask as a txt file
|
167 |
+
np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
|
168 |
+
|
169 |
+
# if data[topk] is not None:
|
170 |
+
# top_k = data[topk] if data[topk] > 0 else None
|
171 |
+
# else:
|
172 |
+
# top_k = None
|
173 |
+
|
174 |
+
zv, mask_z = interface.coarse_vamp(
|
175 |
+
z,
|
176 |
+
mask=mask,
|
177 |
+
sampling_steps=data[num_steps],
|
178 |
+
temperature=(data[init_temp]*10, data[final_temp]*10),
|
179 |
+
return_mask=True,
|
180 |
+
# sample=data[sampling_strategy],
|
181 |
+
typical_filtering=data[typical_filtering],
|
182 |
+
typical_mass=data[typical_mass],
|
183 |
+
typical_min_tokens=data[typical_min_tokens],
|
184 |
+
# top_k=top_k,
|
185 |
+
gen_fn=interface.coarse.generate,
|
186 |
+
)
|
187 |
+
|
188 |
+
if use_coarse2fine:
|
189 |
+
zv = interface.coarse_to_fine(zv)
|
190 |
+
|
191 |
+
sig = interface.to_signal(zv).cpu()
|
192 |
+
print("done")
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
sig.write(out_dir / "output.wav")
|
197 |
+
|
198 |
+
if return_mask:
|
199 |
+
mask = interface.to_signal(mask_z).cpu()
|
200 |
+
mask.write(out_dir / "mask.wav")
|
201 |
+
return sig.path_to_file, mask.path_to_file
|
202 |
+
else:
|
203 |
+
return sig.path_to_file
|
204 |
+
|
205 |
+
def vamp(data):
|
206 |
+
return _vamp(data, return_mask=True)
|
207 |
+
|
208 |
+
def api_vamp(data):
|
209 |
+
return _vamp(data, return_mask=False)
|
210 |
+
|
211 |
+
def save_vamp(data):
|
212 |
+
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
213 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
214 |
+
|
215 |
+
sig_in = at.AudioSignal(data[input_audio])
|
216 |
+
sig_out = at.AudioSignal(data[output_audio])
|
217 |
+
|
218 |
+
sig_in.write(out_dir / "input.wav")
|
219 |
+
sig_out.write(out_dir / "output.wav")
|
220 |
+
|
221 |
+
_data = {
|
222 |
+
"init_temp": data[init_temp],
|
223 |
+
"final_temp": data[final_temp],
|
224 |
+
"prefix_s": data[prefix_s],
|
225 |
+
"suffix_s": data[suffix_s],
|
226 |
+
"rand_mask_intensity": data[rand_mask_intensity],
|
227 |
+
"num_steps": data[num_steps],
|
228 |
+
"notes": data[notes_text],
|
229 |
+
"periodic_period": data[periodic_p],
|
230 |
+
"periodic_width": data[periodic_w],
|
231 |
+
"n_conditioning_codebooks": data[n_conditioning_codebooks],
|
232 |
+
"use_coarse2fine": data[use_coarse2fine],
|
233 |
+
"stretch_factor": data[stretch_factor],
|
234 |
+
}
|
235 |
+
|
236 |
+
# save with yaml
|
237 |
+
with open(out_dir / "data.yaml", "w") as f:
|
238 |
+
yaml.dump(_data, f)
|
239 |
+
|
240 |
+
import zipfile
|
241 |
+
zip_path = out_dir.with_suffix(".zip")
|
242 |
+
with zipfile.ZipFile(zip_path, "w") as zf:
|
243 |
+
for file in out_dir.iterdir():
|
244 |
+
zf.write(file, file.name)
|
245 |
+
|
246 |
+
return f"saved! your save code is {out_dir.stem}", zip_path
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
with gr.Blocks() as demo:
|
251 |
+
|
252 |
+
with gr.Row():
|
253 |
+
with gr.Column():
|
254 |
+
use_coarse2fine = gr.Checkbox(
|
255 |
+
label="use coarse2fine",
|
256 |
+
value=True
|
257 |
+
)
|
258 |
+
|
259 |
+
manual_audio_upload = gr.File(
|
260 |
+
label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
|
261 |
+
file_types=["audio"]
|
262 |
+
)
|
263 |
+
load_random_audio_button = gr.Button("or load random audio")
|
264 |
+
|
265 |
+
input_audio = gr.Audio(
|
266 |
+
label="input audio",
|
267 |
+
interactive=False,
|
268 |
+
type="filepath",
|
269 |
+
)
|
270 |
+
|
271 |
+
audio_mask = gr.Audio(
|
272 |
+
label="audio mask (listen to this to hear the mask hints)",
|
273 |
+
interactive=False,
|
274 |
+
type="filepath",
|
275 |
+
)
|
276 |
+
|
277 |
+
# connect widgets
|
278 |
+
load_random_audio_button.click(
|
279 |
+
fn=load_random_audio,
|
280 |
+
inputs=[],
|
281 |
+
outputs=[ input_audio]
|
282 |
+
)
|
283 |
+
|
284 |
+
manual_audio_upload.change(
|
285 |
+
fn=load_audio,
|
286 |
+
inputs=[manual_audio_upload],
|
287 |
+
outputs=[ input_audio]
|
288 |
+
)
|
289 |
+
|
290 |
+
# mask settings
|
291 |
+
with gr.Column():
|
292 |
+
|
293 |
+
input_pitch_shift = gr.Slider(
|
294 |
+
label="input pitch shift (semitones)",
|
295 |
+
minimum=-36,
|
296 |
+
maximum=36,
|
297 |
+
step=1,
|
298 |
+
value=0,
|
299 |
+
)
|
300 |
+
|
301 |
+
rand_mask_intensity = gr.Slider(
|
302 |
+
label="random mask intensity. (If this is less than 1, scatters prompts throughout the audio, should be between 0.9 and 1.0)",
|
303 |
+
minimum=0.0,
|
304 |
+
maximum=1.0,
|
305 |
+
value=1.0
|
306 |
+
)
|
307 |
+
|
308 |
+
periodic_p = gr.Slider(
|
309 |
+
label="periodic prompt (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
310 |
+
minimum=0,
|
311 |
+
maximum=128,
|
312 |
+
step=1,
|
313 |
+
value=3,
|
314 |
+
)
|
315 |
+
periodic_w = gr.Slider(
|
316 |
+
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
317 |
+
minimum=1,
|
318 |
+
maximum=20,
|
319 |
+
step=1,
|
320 |
+
value=1,
|
321 |
+
)
|
322 |
+
|
323 |
+
onset_mask_width = gr.Slider(
|
324 |
+
label="onset mask width (steps, 1 step ~= 10milliseconds)",
|
325 |
+
minimum=0,
|
326 |
+
maximum=20,
|
327 |
+
step=1,
|
328 |
+
value=5,
|
329 |
+
)
|
330 |
+
|
331 |
+
with gr.Accordion("extras ", open=False):
|
332 |
+
n_conditioning_codebooks = gr.Number(
|
333 |
+
label="number of conditioning codebooks. probably 0",
|
334 |
+
value=0,
|
335 |
+
precision=0,
|
336 |
+
)
|
337 |
+
|
338 |
+
stretch_factor = gr.Slider(
|
339 |
+
label="time stretch factor",
|
340 |
+
minimum=0,
|
341 |
+
maximum=64,
|
342 |
+
step=1,
|
343 |
+
value=1,
|
344 |
+
)
|
345 |
+
|
346 |
+
|
347 |
+
with gr.Accordion("prefix/suffix hints", open=False):
|
348 |
+
prefix_s = gr.Slider(
|
349 |
+
label="prefix hint length (seconds)",
|
350 |
+
minimum=0.0,
|
351 |
+
maximum=10.0,
|
352 |
+
value=0.0
|
353 |
+
)
|
354 |
+
suffix_s = gr.Slider(
|
355 |
+
label="suffix hint length (seconds)",
|
356 |
+
minimum=0.0,
|
357 |
+
maximum=10.0,
|
358 |
+
value=0.0
|
359 |
+
)
|
360 |
+
|
361 |
+
with gr.Accordion("temperature settings", open=False):
|
362 |
+
init_temp = gr.Slider(
|
363 |
+
label="initial temperature (should probably stay between 0.6 and 1)",
|
364 |
+
minimum=0.0,
|
365 |
+
maximum=1.5,
|
366 |
+
value=0.8
|
367 |
+
)
|
368 |
+
final_temp = gr.Slider(
|
369 |
+
label="final temperature (should probably stay between 0.7 and 2)",
|
370 |
+
minimum=0.0,
|
371 |
+
maximum=2.0,
|
372 |
+
value=0.8
|
373 |
+
)
|
374 |
+
|
375 |
+
with gr.Accordion("sampling settings", open=False):
|
376 |
+
sampling_strategy = gr.Radio(
|
377 |
+
label="sampling strategy",
|
378 |
+
choices=["gumbel", "multinomial"],
|
379 |
+
value="gumbel"
|
380 |
+
)
|
381 |
+
typical_filtering = gr.Checkbox(
|
382 |
+
label="typical filtering (cannot be used with topk)",
|
383 |
+
value=False
|
384 |
+
)
|
385 |
+
typical_mass = gr.Slider(
|
386 |
+
label="typical mass (should probably stay between 0.1 and 0.5)",
|
387 |
+
minimum=0.01,
|
388 |
+
maximum=0.99,
|
389 |
+
value=0.2
|
390 |
+
)
|
391 |
+
typical_min_tokens = gr.Slider(
|
392 |
+
label="typical min tokens (should probably stay between 1 and 256)",
|
393 |
+
minimum=1,
|
394 |
+
maximum=256,
|
395 |
+
step=1,
|
396 |
+
value=1
|
397 |
+
)
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
num_steps = gr.Slider(
|
403 |
+
label="number of steps (should normally be between 12 and 36)",
|
404 |
+
minimum=1,
|
405 |
+
maximum=128,
|
406 |
+
step=1,
|
407 |
+
value=36
|
408 |
+
)
|
409 |
+
|
410 |
+
dropout = gr.Slider(
|
411 |
+
label="mask dropout",
|
412 |
+
minimum=0.0,
|
413 |
+
maximum=1.0,
|
414 |
+
step=0.01,
|
415 |
+
value=0.0
|
416 |
+
)
|
417 |
+
|
418 |
+
|
419 |
+
# mask settings
|
420 |
+
with gr.Column():
|
421 |
+
checkpoint_key = gr.Radio(
|
422 |
+
label="checkpoint",
|
423 |
+
choices=list(checkpoints.keys()),
|
424 |
+
value="spotdl"
|
425 |
+
)
|
426 |
+
vamp_button = gr.Button("vamp!!!")
|
427 |
+
output_audio = gr.Audio(
|
428 |
+
label="output audio",
|
429 |
+
interactive=False,
|
430 |
+
type="filepath"
|
431 |
+
)
|
432 |
+
|
433 |
+
|
434 |
+
|
435 |
+
# with gr.Column():
|
436 |
+
# with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
|
437 |
+
# use_beats = gr.Checkbox(
|
438 |
+
# label="use beat hints (helps the output stick to the beat structure of the input)",
|
439 |
+
# value=False
|
440 |
+
# )
|
441 |
+
|
442 |
+
# snap_to_beats = gr.Checkbox(
|
443 |
+
# label="trim to beat markers (uncheck if the output audio is too short.)",
|
444 |
+
# value=True
|
445 |
+
# )
|
446 |
+
|
447 |
+
# beat_unmask_dur = gr.Slider(
|
448 |
+
# label="duration",
|
449 |
+
# minimum=0.0,
|
450 |
+
# maximum=3.0,
|
451 |
+
# value=0.07
|
452 |
+
# )
|
453 |
+
|
454 |
+
|
455 |
+
notes_text = gr.Textbox(
|
456 |
+
label="type any notes about the generated audio here",
|
457 |
+
value="",
|
458 |
+
interactive=True
|
459 |
+
)
|
460 |
+
save_button = gr.Button("save vamp")
|
461 |
+
download_file = gr.File(
|
462 |
+
label="vamp to download will appear here",
|
463 |
+
interactive=False
|
464 |
+
)
|
465 |
+
use_as_input_button = gr.Button("use output as input")
|
466 |
+
|
467 |
+
thank_you = gr.Markdown("")
|
468 |
+
|
469 |
+
|
470 |
+
_inputs = {
|
471 |
+
input_audio,
|
472 |
+
num_steps,
|
473 |
+
init_temp, final_temp,
|
474 |
+
prefix_s, suffix_s,
|
475 |
+
rand_mask_intensity,
|
476 |
+
periodic_p, periodic_w,
|
477 |
+
n_conditioning_codebooks,
|
478 |
+
dropout,
|
479 |
+
use_coarse2fine,
|
480 |
+
stretch_factor,
|
481 |
+
onset_mask_width,
|
482 |
+
input_pitch_shift,
|
483 |
+
sampling_strategy,
|
484 |
+
typical_filtering,
|
485 |
+
typical_mass,
|
486 |
+
typical_min_tokens,
|
487 |
+
# topk,
|
488 |
+
checkpoint_key
|
489 |
+
}
|
490 |
+
|
491 |
+
# connect widgets
|
492 |
+
vamp_button.click(
|
493 |
+
fn=vamp,
|
494 |
+
inputs=_inputs,
|
495 |
+
outputs=[output_audio, audio_mask],
|
496 |
+
)
|
497 |
+
|
498 |
+
api_vamp_button = gr.Button("api vamp")
|
499 |
+
api_vamp_button.click(
|
500 |
+
fn=api_vamp,
|
501 |
+
inputs=_inputs,
|
502 |
+
outputs=[output_audio],
|
503 |
+
api_name="vamp"
|
504 |
+
)
|
505 |
+
|
506 |
+
use_as_input_button.click(
|
507 |
+
fn=lambda x: x,
|
508 |
+
inputs=[output_audio],
|
509 |
+
outputs=[input_audio]
|
510 |
+
)
|
511 |
+
|
512 |
+
save_button.click(
|
513 |
+
fn=save_vamp,
|
514 |
+
inputs=_inputs | {notes_text, output_audio},
|
515 |
+
outputs=[thank_you, download_file]
|
516 |
+
)
|
517 |
+
|
518 |
+
demo.launch(share=True, enable_queue=False, debug=True, server_name="0.0.0.0")
|
demo.py
CHANGED
@@ -32,6 +32,47 @@ dataset = at.data.datasets.AudioDataset(
|
|
32 |
)
|
33 |
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
OUT_DIR = Path("gradio-outputs")
|
36 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
37 |
|
@@ -63,6 +104,19 @@ def load_random_audio():
|
|
63 |
|
64 |
|
65 |
def _vamp(data, return_mask=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
67 |
out_dir.mkdir()
|
68 |
sig = at.AudioSignal(data[input_audio])
|
@@ -229,8 +283,8 @@ with gr.Blocks() as demo:
|
|
229 |
|
230 |
input_pitch_shift = gr.Slider(
|
231 |
label="input pitch shift (semitones)",
|
232 |
-
minimum=-
|
233 |
-
maximum=
|
234 |
step=1,
|
235 |
value=0,
|
236 |
)
|
@@ -247,7 +301,7 @@ with gr.Blocks() as demo:
|
|
247 |
minimum=0,
|
248 |
maximum=128,
|
249 |
step=1,
|
250 |
-
value=
|
251 |
)
|
252 |
periodic_w = gr.Slider(
|
253 |
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
@@ -262,7 +316,7 @@ with gr.Blocks() as demo:
|
|
262 |
minimum=0,
|
263 |
maximum=20,
|
264 |
step=1,
|
265 |
-
value=
|
266 |
)
|
267 |
|
268 |
with gr.Accordion("extras ", open=False):
|
@@ -361,6 +415,11 @@ with gr.Blocks() as demo:
|
|
361 |
|
362 |
# mask settings
|
363 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
364 |
vamp_button = gr.Button("vamp!!!")
|
365 |
output_audio = gr.Audio(
|
366 |
label="output audio",
|
@@ -423,6 +482,7 @@ with gr.Blocks() as demo:
|
|
423 |
typical_mass,
|
424 |
typical_min_tokens,
|
425 |
topk,
|
|
|
426 |
}
|
427 |
|
428 |
# connect widgets
|
@@ -452,4 +512,4 @@ with gr.Blocks() as demo:
|
|
452 |
outputs=[thank_you, download_file]
|
453 |
)
|
454 |
|
455 |
-
demo.launch(share=True, enable_queue=False, debug=True)
|
|
|
32 |
)
|
33 |
|
34 |
|
35 |
+
checkpoints = {
|
36 |
+
"spotdl": {
|
37 |
+
"coarse": "./models/spotdl/coarse.pth",
|
38 |
+
"c2f": "./models/spotdl/c2f.pth",
|
39 |
+
"codec": "./models/spotdl/codec.pth",
|
40 |
+
"full_ckpt": True
|
41 |
+
},
|
42 |
+
"berta": {
|
43 |
+
"coarse": "./models/finetuned/berta-goldman-speech/coarse.pth",
|
44 |
+
"c2f": "./models/finetuned/berta-goldman-speech/c2f.pth",
|
45 |
+
"codec": "./model/spotdl/codec.pth",
|
46 |
+
"full_ckpt": True
|
47 |
+
},
|
48 |
+
"xeno-canto-2": {
|
49 |
+
"coarse": "./models/finetuned/xeno-canto-2/coarse.pth",
|
50 |
+
"c2f": "./models/finetuned/xeno-canto-2/c2f.pth",
|
51 |
+
"codec": "./models/spotdl/codec.pth",
|
52 |
+
"full_ckpt": True
|
53 |
+
},
|
54 |
+
"panchos": {
|
55 |
+
"coarse": "./models/finetuned/panchos/coarse.pth",
|
56 |
+
"c2f": "./models/finetuned/panchos/c2f.pth",
|
57 |
+
"codec": "./models/spotdl/codec.pth",
|
58 |
+
"full_ckpt": False
|
59 |
+
},
|
60 |
+
"tv-choir": {
|
61 |
+
"coarse": "./models/finetuned/tv-choir/coarse.pth",
|
62 |
+
"c2f": "./models/finetuned/tv-choir/c2f.pth",
|
63 |
+
"codec": "./models/spotdl/codec.pth",
|
64 |
+
"full_ckpt": False
|
65 |
+
},
|
66 |
+
"titi": {
|
67 |
+
"coarse": "./models/finetuned/titi/coarse.pth",
|
68 |
+
"c2f": "./models/finetuned/titi/c2f.pth",
|
69 |
+
"codec": "./models/spotdl/codec.pth",
|
70 |
+
"full_ckpt": False
|
71 |
+
}
|
72 |
+
}
|
73 |
+
interface.checkpoint_key = "spotdl"
|
74 |
+
|
75 |
+
|
76 |
OUT_DIR = Path("gradio-outputs")
|
77 |
OUT_DIR.mkdir(exist_ok=True, parents=True)
|
78 |
|
|
|
104 |
|
105 |
|
106 |
def _vamp(data, return_mask=False):
|
107 |
+
|
108 |
+
# if our checkpoint key is different, we need to load a new checkpoint
|
109 |
+
if data[checkpoint_key] != interface.checkpoint_key:
|
110 |
+
print(f"loading checkpoint {data[checkpoint_key]}")
|
111 |
+
interface.lora_load(
|
112 |
+
checkpoints[data[checkpoint_key]]["coarse"],
|
113 |
+
checkpoints[data[checkpoint_key]]["c2f"],
|
114 |
+
checkpoints[data[checkpoint_key]]["full_ckpt"],
|
115 |
+
reset=(data[checkpoint_key] == "spotdl")
|
116 |
+
)
|
117 |
+
interface.checkpoint_key = data[checkpoint_key]
|
118 |
+
|
119 |
+
|
120 |
out_dir = OUT_DIR / str(uuid.uuid4())
|
121 |
out_dir.mkdir()
|
122 |
sig = at.AudioSignal(data[input_audio])
|
|
|
283 |
|
284 |
input_pitch_shift = gr.Slider(
|
285 |
label="input pitch shift (semitones)",
|
286 |
+
minimum=-36,
|
287 |
+
maximum=36,
|
288 |
step=1,
|
289 |
value=0,
|
290 |
)
|
|
|
301 |
minimum=0,
|
302 |
maximum=128,
|
303 |
step=1,
|
304 |
+
value=3,
|
305 |
)
|
306 |
periodic_w = gr.Slider(
|
307 |
label="periodic prompt width (steps, 1 step ~= 10milliseconds)",
|
|
|
316 |
minimum=0,
|
317 |
maximum=20,
|
318 |
step=1,
|
319 |
+
value=5,
|
320 |
)
|
321 |
|
322 |
with gr.Accordion("extras ", open=False):
|
|
|
415 |
|
416 |
# mask settings
|
417 |
with gr.Column():
|
418 |
+
checkpoint_key = gr.Radio(
|
419 |
+
label="checkpoint",
|
420 |
+
choices=list(checkpoints.keys()),
|
421 |
+
value="spotdl"
|
422 |
+
)
|
423 |
vamp_button = gr.Button("vamp!!!")
|
424 |
output_audio = gr.Audio(
|
425 |
label="output audio",
|
|
|
482 |
typical_mass,
|
483 |
typical_min_tokens,
|
484 |
topk,
|
485 |
+
checkpoint_key
|
486 |
}
|
487 |
|
488 |
# connect widgets
|
|
|
512 |
outputs=[thank_you, download_file]
|
513 |
)
|
514 |
|
515 |
+
demo.launch(share=True, enable_queue=False, debug=True, server_name="0.0.0.0")
|
scripts/exp/train.py
CHANGED
@@ -353,12 +353,9 @@ def train(
|
|
353 |
mask[:, vn.n_conditioning_codebooks :, :],
|
354 |
)
|
355 |
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
output["loss"] = criterion(z_hat, t_masked)
|
360 |
-
else:
|
361 |
-
output["loss"] = criterion(z_hat, target)
|
362 |
|
363 |
self._metrics(
|
364 |
vn=vn,
|
@@ -429,12 +426,9 @@ def train(
|
|
429 |
)
|
430 |
|
431 |
output = {}
|
432 |
-
|
433 |
-
|
434 |
-
|
435 |
-
output["loss"] = criterion(z_hat, t_masked)
|
436 |
-
else:
|
437 |
-
output["loss"] = criterion(z_hat, target)
|
438 |
|
439 |
self._metrics(
|
440 |
vn=vn,
|
|
|
353 |
mask[:, vn.n_conditioning_codebooks :, :],
|
354 |
)
|
355 |
|
356 |
+
# replace target with ignore index for masked tokens
|
357 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
358 |
+
output["loss"] = criterion(z_hat, t_masked)
|
|
|
|
|
|
|
359 |
|
360 |
self._metrics(
|
361 |
vn=vn,
|
|
|
426 |
)
|
427 |
|
428 |
output = {}
|
429 |
+
# replace target with ignore index for masked tokens
|
430 |
+
t_masked = target.masked_fill(~flat_mask.bool(), IGNORE_INDEX)
|
431 |
+
output["loss"] = criterion(z_hat, t_masked)
|
|
|
|
|
|
|
432 |
|
433 |
self._metrics(
|
434 |
vn=vn,
|
scripts/utils/augment.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import audiotools as at
|
4 |
+
from audiotools import AudioSignal
|
5 |
+
|
6 |
+
import argbind
|
7 |
+
import tqdm
|
8 |
+
|
9 |
+
|
10 |
+
from pedalboard import (
|
11 |
+
Compressor, Gain, Chorus, LadderFilter, Phaser, Convolution, Reverb, Pedalboard
|
12 |
+
)
|
13 |
+
from pedalboard.io import AudioFile
|
14 |
+
|
15 |
+
# Read in a whole file, resampling to our desired sample rate:
|
16 |
+
samplerate = 44100.0
|
17 |
+
with AudioFile('guitar-input.wav').resampled_to(samplerate) as f:
|
18 |
+
audio = f.read(f.frames)
|
19 |
+
|
20 |
+
# Make a pretty interesting sounding guitar pedalboard:
|
21 |
+
board = Pedalboard([
|
22 |
+
Compressor(threshold_db=-50, ratio=25),
|
23 |
+
Gain(gain_db=30),
|
24 |
+
Chorus(),
|
25 |
+
LadderFilter(mode=LadderFilter.Mode.HPF12, cutoff_hz=900),
|
26 |
+
Phaser(),
|
27 |
+
Convolution("./guitar_amp.wav", 1.0),
|
28 |
+
Reverb(room_size=0.25),
|
29 |
+
])
|
30 |
+
|
31 |
+
|
32 |
+
@argbind.bind(without_prefix=True)
|
33 |
+
def augment(
|
34 |
+
audio_folder: Path,
|
35 |
+
dest_folder: Path,
|
36 |
+
n_augmentations: int = 10,
|
37 |
+
):
|
38 |
+
"""
|
39 |
+
Augment a folder of audio files by applying audiotools and pedalboard transforms.
|
40 |
+
|
41 |
+
The dest foler will contain a folder for each of the clean dataset's files.
|
42 |
+
Under each of these folders, there will be a clean file and many augmented files.
|
43 |
+
"""
|
44 |
+
|
45 |
+
audio_files = at.util.find_audio(audio_folder)
|
46 |
+
|
47 |
+
for audio_file in tqdm.tqdm(audio_files):
|
48 |
+
subtree = dest_folder / audio_file.relative_to(audio_folder).parent
|
49 |
+
subdir = subtree / audio_file.stem
|
50 |
+
subdir.mkdir(parents=True, exist_ok=True)
|
51 |
+
|
52 |
+
# apply pedalboard transforms
|
53 |
+
for i in range(n_augmentations):
|
vampnet/interface.py
CHANGED
@@ -97,17 +97,36 @@ class Interface(torch.nn.Module):
|
|
97 |
|
98 |
def lora_load(
|
99 |
self,
|
100 |
-
|
101 |
-
|
|
|
102 |
):
|
103 |
-
if
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
|
113 |
def s2t(self, seconds: float):
|
@@ -290,6 +309,7 @@ class Interface(torch.nn.Module):
|
|
290 |
z,
|
291 |
mask,
|
292 |
return_mask=False,
|
|
|
293 |
**kwargs
|
294 |
):
|
295 |
# coarse z
|
@@ -301,7 +321,8 @@ class Interface(torch.nn.Module):
|
|
301 |
cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
|
302 |
cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
|
303 |
|
304 |
-
|
|
|
305 |
codec=self.codec,
|
306 |
time_steps=cz.shape[-1],
|
307 |
start_tokens=cz,
|
@@ -310,8 +331,6 @@ class Interface(torch.nn.Module):
|
|
310 |
**kwargs
|
311 |
)
|
312 |
|
313 |
-
# replace the mask token in cz_masked with random tokens
|
314 |
-
# so that we can decode it
|
315 |
if return_mask:
|
316 |
return c_vamp, cz_masked
|
317 |
|
@@ -320,53 +339,48 @@ class Interface(torch.nn.Module):
|
|
320 |
|
321 |
if __name__ == "__main__":
|
322 |
import audiotools as at
|
|
|
|
|
|
|
|
|
323 |
|
324 |
interface = Interface(
|
325 |
coarse_ckpt="./models/spotdl/coarse.pth",
|
326 |
coarse2fine_ckpt="./models/spotdl/c2f.pth",
|
327 |
codec_ckpt="./models/spotdl/codec.pth",
|
328 |
-
device="
|
329 |
)
|
330 |
|
331 |
-
sig = at.AudioSignal('
|
332 |
|
333 |
z = interface.encode(sig)
|
334 |
|
335 |
-
mask = linear_random(z, 0
|
336 |
-
print(mask)
|
337 |
-
mask = mask_and(
|
338 |
-
mask, inpaint(
|
339 |
-
z,
|
340 |
-
interface.s2t(3),
|
341 |
-
interface.s2t(3)
|
342 |
-
)
|
343 |
-
)
|
344 |
-
print(mask)
|
345 |
mask = mask_and(
|
346 |
mask, periodic_mask(
|
347 |
z,
|
348 |
-
|
349 |
1,
|
350 |
random_roll=True
|
351 |
)
|
352 |
)
|
353 |
-
mask = dropout(mask, 0.0)
|
354 |
-
mask = codebook_unmask(mask, 0)
|
355 |
|
356 |
|
357 |
zv, mask_z = interface.coarse_vamp(
|
358 |
z,
|
359 |
mask=mask,
|
360 |
-
sampling_steps=
|
361 |
-
temperature=
|
362 |
-
return_mask=True
|
|
|
363 |
)
|
364 |
|
365 |
use_coarse2fine = False
|
366 |
if use_coarse2fine:
|
367 |
zv = interface.coarse_to_fine(zv)
|
368 |
|
369 |
-
print(mask_z)
|
370 |
mask = interface.to_signal(mask_z).cpu()
|
371 |
|
372 |
sig = interface.to_signal(zv).cpu()
|
|
|
97 |
|
98 |
def lora_load(
|
99 |
self,
|
100 |
+
coarse_ckpt: str = None,
|
101 |
+
c2f_ckpt: str = None,
|
102 |
+
full_ckpts: bool = False,
|
103 |
):
|
104 |
+
if full_ckpts:
|
105 |
+
if coarse_ckpt is not None:
|
106 |
+
self.coarse = _load_model(
|
107 |
+
ckpt=coarse_ckpt,
|
108 |
+
device=self.device,
|
109 |
+
chunk_size_s=self.coarse.chunk_size_s,
|
110 |
+
)
|
111 |
+
if c2f_ckpt is not None:
|
112 |
+
self.c2f = _load_model(
|
113 |
+
ckpt=c2f_ckpt,
|
114 |
+
device=self.device,
|
115 |
+
chunk_size_s=self.c2f.chunk_size_s,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
if coarse_ckpt is not None:
|
119 |
+
self.coarse.to("cpu")
|
120 |
+
state_dict = torch.load(coarse_ckpt, map_location="cpu")
|
121 |
+
|
122 |
+
self.coarse.load_state_dict(state_dict, strict=False)
|
123 |
+
self.coarse.to(self.device)
|
124 |
+
if c2f_ckpt is not None:
|
125 |
+
self.c2f.to("cpu")
|
126 |
+
state_dict = torch.load(c2f_ckpt, map_location="cpu")
|
127 |
+
|
128 |
+
self.c2f.load_state_dict(state_dict, strict=False)
|
129 |
+
self.c2f.to(self.device)
|
130 |
|
131 |
|
132 |
def s2t(self, seconds: float):
|
|
|
309 |
z,
|
310 |
mask,
|
311 |
return_mask=False,
|
312 |
+
gen_fn=None,
|
313 |
**kwargs
|
314 |
):
|
315 |
# coarse z
|
|
|
321 |
cz_masked, mask = apply_mask(cz, mask, self.coarse.mask_token)
|
322 |
cz_masked = cz_masked[:, : self.coarse.n_codebooks, :]
|
323 |
|
324 |
+
gen_fn = gen_fn or self.coarse.sample
|
325 |
+
c_vamp = gen_fn(
|
326 |
codec=self.codec,
|
327 |
time_steps=cz.shape[-1],
|
328 |
start_tokens=cz,
|
|
|
331 |
**kwargs
|
332 |
)
|
333 |
|
|
|
|
|
334 |
if return_mask:
|
335 |
return c_vamp, cz_masked
|
336 |
|
|
|
339 |
|
340 |
if __name__ == "__main__":
|
341 |
import audiotools as at
|
342 |
+
import logging
|
343 |
+
logger = logging.getLogger()
|
344 |
+
logger.setLevel(logging.INFO)
|
345 |
+
torch.set_printoptions(threshold=10000)
|
346 |
|
347 |
interface = Interface(
|
348 |
coarse_ckpt="./models/spotdl/coarse.pth",
|
349 |
coarse2fine_ckpt="./models/spotdl/c2f.pth",
|
350 |
codec_ckpt="./models/spotdl/codec.pth",
|
351 |
+
device="cuda"
|
352 |
)
|
353 |
|
354 |
+
sig = at.AudioSignal('introspection ii-1.mp3', duration=10)
|
355 |
|
356 |
z = interface.encode(sig)
|
357 |
|
358 |
+
mask = linear_random(z, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
mask = mask_and(
|
360 |
mask, periodic_mask(
|
361 |
z,
|
362 |
+
32,
|
363 |
1,
|
364 |
random_roll=True
|
365 |
)
|
366 |
)
|
367 |
+
# mask = dropout(mask, 0.0)
|
368 |
+
# mask = codebook_unmask(mask, 0)
|
369 |
|
370 |
|
371 |
zv, mask_z = interface.coarse_vamp(
|
372 |
z,
|
373 |
mask=mask,
|
374 |
+
sampling_steps=36,
|
375 |
+
temperature=6.0,
|
376 |
+
return_mask=True,
|
377 |
+
# gen_fn=interface.coarse.generate
|
378 |
)
|
379 |
|
380 |
use_coarse2fine = False
|
381 |
if use_coarse2fine:
|
382 |
zv = interface.coarse_to_fine(zv)
|
383 |
|
|
|
384 |
mask = interface.to_signal(mask_z).cpu()
|
385 |
|
386 |
sig = interface.to_signal(zv).cpu()
|
vampnet/mask.py
CHANGED
@@ -6,7 +6,7 @@ from audiotools import AudioSignal
|
|
6 |
from .util import scalar_to_batch_tensor
|
7 |
|
8 |
def _gamma(r):
|
9 |
-
return (r * torch.pi / 2).cos()
|
10 |
|
11 |
def _invgamma(y):
|
12 |
if not torch.is_tensor(y):
|
|
|
6 |
from .util import scalar_to_batch_tensor
|
7 |
|
8 |
def _gamma(r):
|
9 |
+
return (r * torch.pi / 2).cos().clamp(1e-10, 1.0)
|
10 |
|
11 |
def _invgamma(y):
|
12 |
if not torch.is_tensor(y):
|
vampnet/modules/transformer.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import math
|
|
|
2 |
from typing import Optional, Tuple, Union
|
3 |
|
4 |
import numpy as np
|
@@ -19,17 +20,17 @@ from ..mask import _gamma
|
|
19 |
|
20 |
LORA_R = 8
|
21 |
|
22 |
-
def log(t, eps=1e-20):
|
23 |
-
|
24 |
|
25 |
|
26 |
-
def
|
27 |
-
noise = torch.zeros_like(t).uniform_(
|
28 |
-
return -log(-log(noise))
|
29 |
|
30 |
|
31 |
def gumbel_sample(t, temperature=1.0, dim=-1):
|
32 |
-
return ((t / max(temperature, 1e-10)) +
|
33 |
|
34 |
|
35 |
class RMSNorm(nn.Module):
|
@@ -477,23 +478,16 @@ class VampNet(at.ml.BaseModel):
|
|
477 |
self.flash_attn = flash_attn
|
478 |
self.noise_mode = noise_mode
|
479 |
|
480 |
-
|
481 |
-
special_tokens = ["MASK"]
|
482 |
-
elif noise_mode == "random":
|
483 |
-
special_tokens = None
|
484 |
-
else:
|
485 |
-
raise ValueError(f"Unknown noise mode: {noise_mode}")
|
486 |
|
487 |
self.embedding = CodebookEmbedding(
|
488 |
latent_dim=latent_dim,
|
489 |
n_codebooks=n_codebooks,
|
490 |
vocab_size=vocab_size,
|
491 |
emb_dim=embedding_dim,
|
492 |
-
special_tokens=
|
493 |
)
|
494 |
-
|
495 |
-
if noise_mode == "mask":
|
496 |
-
self.mask_token = self.embedding.special_idxs["MASK"]
|
497 |
|
498 |
self.transformer = TransformerStack(
|
499 |
d_model=embedding_dim,
|
@@ -584,23 +578,20 @@ class VampNet(at.ml.BaseModel):
|
|
584 |
z_hat,
|
585 |
mask,
|
586 |
):
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
z_hat
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
)
|
598 |
|
599 |
-
|
600 |
|
601 |
-
|
602 |
-
else:
|
603 |
-
raise ValueError(f"invalid noise mode for adding truth to logits {self.noise_mode}")
|
604 |
|
605 |
return z_hat
|
606 |
|
@@ -742,6 +733,272 @@ class VampNet(at.ml.BaseModel):
|
|
742 |
else:
|
743 |
return z
|
744 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
745 |
def sample_from_logits(
|
746 |
logits,
|
747 |
top_k: int = None,
|
@@ -798,7 +1055,6 @@ def sample_from_logits(
|
|
798 |
return inferred
|
799 |
|
800 |
|
801 |
-
|
802 |
if __name__ == "__main__":
|
803 |
# import argbind
|
804 |
from .layers import num_params
|
|
|
1 |
import math
|
2 |
+
import logging
|
3 |
from typing import Optional, Tuple, Union
|
4 |
|
5 |
import numpy as np
|
|
|
20 |
|
21 |
LORA_R = 8
|
22 |
|
23 |
+
# def log(t, eps=1e-20):
|
24 |
+
# return torch.log(t + eps)
|
25 |
|
26 |
|
27 |
+
def gumbel_noise_like(t):
|
28 |
+
noise = torch.zeros_like(t).uniform_(1e-20, 1)
|
29 |
+
return -torch.log(-torch.log(noise))
|
30 |
|
31 |
|
32 |
def gumbel_sample(t, temperature=1.0, dim=-1):
|
33 |
+
return ((t / max(temperature, 1e-10)) + gumbel_noise_like(t)).argmax(dim=dim)
|
34 |
|
35 |
|
36 |
class RMSNorm(nn.Module):
|
|
|
478 |
self.flash_attn = flash_attn
|
479 |
self.noise_mode = noise_mode
|
480 |
|
481 |
+
assert self.noise_mode == "mask", "deprecated"
|
|
|
|
|
|
|
|
|
|
|
482 |
|
483 |
self.embedding = CodebookEmbedding(
|
484 |
latent_dim=latent_dim,
|
485 |
n_codebooks=n_codebooks,
|
486 |
vocab_size=vocab_size,
|
487 |
emb_dim=embedding_dim,
|
488 |
+
special_tokens=["MASK"],
|
489 |
)
|
490 |
+
self.mask_token = self.embedding.special_idxs["MASK"]
|
|
|
|
|
491 |
|
492 |
self.transformer = TransformerStack(
|
493 |
d_model=embedding_dim,
|
|
|
578 |
z_hat,
|
579 |
mask,
|
580 |
):
|
581 |
+
z_true = z_true[:, self.n_conditioning_codebooks :, :]
|
582 |
+
mask = mask[:, self.n_conditioning_codebooks :, :]
|
583 |
+
|
584 |
+
truth = F.one_hot(z_true, self.vocab_size)
|
585 |
+
mask = mask[:, :, :, None].expand(-1, -1, -1, self.vocab_size)
|
586 |
+
z_hat = rearrange(
|
587 |
+
z_hat,
|
588 |
+
"b p (t c) -> b c t p",
|
589 |
+
c=self.n_codebooks - self.n_conditioning_codebooks,
|
590 |
+
)
|
|
|
591 |
|
592 |
+
z_hat = z_hat * mask + truth * (1 - mask)
|
593 |
|
594 |
+
z_hat = rearrange(z_hat, "b c t p -> b p (t c)")
|
|
|
|
|
595 |
|
596 |
return z_hat
|
597 |
|
|
|
733 |
else:
|
734 |
return z
|
735 |
|
736 |
+
@torch.no_grad()
|
737 |
+
def generate(
|
738 |
+
self,
|
739 |
+
codec,
|
740 |
+
time_steps: int = 300,
|
741 |
+
sampling_steps: int = 36,
|
742 |
+
start_tokens: Optional[torch.Tensor] = None,
|
743 |
+
mask: Optional[torch.Tensor] = None,
|
744 |
+
temperature: Union[float, Tuple[float, float]] = 0.8,
|
745 |
+
typical_filtering=False,
|
746 |
+
typical_mass=0.2,
|
747 |
+
typical_min_tokens=1,
|
748 |
+
return_signal=True,
|
749 |
+
):
|
750 |
+
logging.info(f"beginning generation with {sampling_steps} steps")
|
751 |
+
|
752 |
+
#####################
|
753 |
+
# resolve temperature #
|
754 |
+
#####################
|
755 |
+
if isinstance(temperature, float):
|
756 |
+
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
757 |
+
elif isinstance(temperature, tuple):
|
758 |
+
assert len(temperature) == 2
|
759 |
+
l, h = temperature
|
760 |
+
temperature = torch.linspace(l, h, sampling_steps)
|
761 |
+
else:
|
762 |
+
raise TypeError(f"invalid type for temperature")
|
763 |
+
|
764 |
+
logging.info(f"temperature: {temperature}")
|
765 |
+
|
766 |
+
|
767 |
+
#####################
|
768 |
+
# resolve initial z #
|
769 |
+
#####################
|
770 |
+
z = start_tokens
|
771 |
+
|
772 |
+
if z is None:
|
773 |
+
z = torch.full((1, self.n_codebooks, time_steps), self.mask_token).to(
|
774 |
+
self.device
|
775 |
+
)
|
776 |
+
|
777 |
+
logging.info(f"created z with shape {z.shape}")
|
778 |
+
|
779 |
+
|
780 |
+
#################
|
781 |
+
# resolve mask #
|
782 |
+
#################
|
783 |
+
|
784 |
+
if mask is None:
|
785 |
+
mask = torch.ones_like(z).to(self.device).int()
|
786 |
+
mask[:, : self.n_conditioning_codebooks, :] = 0.0
|
787 |
+
if mask.ndim == 2:
|
788 |
+
mask = mask[:, None, :].repeat(1, z.shape[1], 1)
|
789 |
+
# init_mask = mask.clone()
|
790 |
+
|
791 |
+
logging.info(f"created mask with shape {mask.shape}")
|
792 |
+
|
793 |
+
|
794 |
+
###########
|
795 |
+
# set up #
|
796 |
+
##########
|
797 |
+
# apply the mask to z
|
798 |
+
z_masked = z.masked_fill(mask.bool(), self.mask_token)
|
799 |
+
# logging.info(f"z_masked: {z_masked}")
|
800 |
+
|
801 |
+
# how many mask tokens to begin with?
|
802 |
+
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
803 |
+
logging.info(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
804 |
+
|
805 |
+
# our r steps
|
806 |
+
r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
|
807 |
+
logging.info(f"r steps: {r_steps}")
|
808 |
+
|
809 |
+
# how many codebooks are we inferring vs conditioning on?
|
810 |
+
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
811 |
+
logging.info(f"n infer codebooks: {n_infer_codebooks}")
|
812 |
+
|
813 |
+
#################
|
814 |
+
# begin sampling #
|
815 |
+
#################
|
816 |
+
|
817 |
+
for i in range(sampling_steps):
|
818 |
+
logging.info(f"step {i} of {sampling_steps}")
|
819 |
+
|
820 |
+
# our current temperature
|
821 |
+
tmpt = temperature[i]
|
822 |
+
logging.info(f"temperature: {tmpt}")
|
823 |
+
|
824 |
+
# our current schedule step
|
825 |
+
r = r_steps[i : i + 1]
|
826 |
+
logging.info(f"r: {r}")
|
827 |
+
|
828 |
+
# get latents
|
829 |
+
latents = self.embedding.from_codes(z_masked, codec)
|
830 |
+
logging.info(f"computed latents with shape: {latents.shape}")
|
831 |
+
|
832 |
+
|
833 |
+
# infer from latents
|
834 |
+
# NOTE: this collapses the codebook dimension into the sequence dimension
|
835 |
+
logits = self.forward(latents, r) # b, prob, seq
|
836 |
+
logits = logits.permute(0, 2, 1) # b, seq, prob
|
837 |
+
if typical_filtering:
|
838 |
+
typical_filter(logits,
|
839 |
+
typical_mass=typical_mass,
|
840 |
+
typical_min_tokens=typical_min_tokens
|
841 |
+
)
|
842 |
+
|
843 |
+
|
844 |
+
logging.info(f"permuted logits with shape: {logits.shape}")
|
845 |
+
|
846 |
+
|
847 |
+
# logits2probs
|
848 |
+
probs = torch.softmax(logits, dim=-1)
|
849 |
+
logging.info(f"computed probs with shape: {probs.shape}")
|
850 |
+
|
851 |
+
# flatten z_masked and mask, so we can deal with the sampling logic
|
852 |
+
# we'll unflatten them at the end of the loop for the next forward pass
|
853 |
+
z_masked = codebook_flatten(z_masked)
|
854 |
+
|
855 |
+
# sample from logits with multinomial sampling
|
856 |
+
b = probs.shape[0]
|
857 |
+
probs = rearrange(probs, "b seq prob -> (b seq) prob")
|
858 |
+
|
859 |
+
|
860 |
+
|
861 |
+
sampled_z = torch.multinomial(probs, 1).squeeze(-1)
|
862 |
+
|
863 |
+
sampled_z = rearrange(sampled_z, "(b seq)-> b seq", b=b)
|
864 |
+
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
865 |
+
logging.info(f"sampled z with shape: {sampled_z.shape}")
|
866 |
+
|
867 |
+
# update the mask
|
868 |
+
mask = (z_masked == self.mask_token).int()
|
869 |
+
logging.info(f"updated mask with shape: {mask.shape}")
|
870 |
+
|
871 |
+
# add z back into sampled z where the mask was false
|
872 |
+
sampled_z = torch.where(
|
873 |
+
mask.bool(), sampled_z, z_masked
|
874 |
+
)
|
875 |
+
logging.info(f"added z back into sampled z with shape: {sampled_z.shape}")
|
876 |
+
|
877 |
+
|
878 |
+
# get the confidences: which tokens did we sample?
|
879 |
+
selected_probs = (
|
880 |
+
torch.take_along_dim(
|
881 |
+
probs, sampled_z.long().unsqueeze(-1),
|
882 |
+
dim=-1
|
883 |
+
).squeeze(-1)
|
884 |
+
)
|
885 |
+
|
886 |
+
# ignore any tokens that weren't masked
|
887 |
+
selected_probs = torch.where(
|
888 |
+
mask.bool(), selected_probs, torch.inf
|
889 |
+
)
|
890 |
+
|
891 |
+
# get the num tokens to mask, according to the schedule
|
892 |
+
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
893 |
+
logging.info(f"num to mask: {num_to_mask}")
|
894 |
+
|
895 |
+
num_to_mask = torch.maximum(
|
896 |
+
torch.tensor(1),
|
897 |
+
torch.minimum(
|
898 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
899 |
+
num_to_mask
|
900 |
+
)
|
901 |
+
)
|
902 |
+
|
903 |
+
|
904 |
+
# get our new mask
|
905 |
+
# print(tmpt * (1-_gamma(r)))
|
906 |
+
mask = mask_by_random_topk(
|
907 |
+
num_to_mask, selected_probs, tmpt * (1-r)
|
908 |
+
)
|
909 |
+
|
910 |
+
# print(f"most confident tokens: ")
|
911 |
+
# print(torch.take_along_dim(
|
912 |
+
# sampled_z, selected_probs.argsort(descending=False), dim=-1)
|
913 |
+
# )
|
914 |
+
# print(sampled_z[~mask.bool()])
|
915 |
+
|
916 |
+
|
917 |
+
# update the mask
|
918 |
+
z_masked = torch.where(
|
919 |
+
mask.bool(), self.mask_token, sampled_z
|
920 |
+
)
|
921 |
+
logging.info(f"updated z_masked with shape: {z_masked.shape}")
|
922 |
+
|
923 |
+
|
924 |
+
z_masked = codebook_unflatten(z_masked, self.n_codebooks)
|
925 |
+
mask = codebook_unflatten(mask, self.n_codebooks)
|
926 |
+
logging.info(f"unflattened z_masked with shape: {z_masked.shape}")
|
927 |
+
|
928 |
+
|
929 |
+
logging.info(f"updated z_masked with shape: {z_masked.shape}")
|
930 |
+
|
931 |
+
|
932 |
+
logging.info(f"finished sampling")
|
933 |
+
z = codebook_unflatten(sampled_z, self.n_codebooks)
|
934 |
+
|
935 |
+
if return_signal:
|
936 |
+
return self.to_signal(z, codec)
|
937 |
+
else:
|
938 |
+
return z
|
939 |
+
|
940 |
+
|
941 |
+
def mask_by_random_topk(num_to_mask: int, probs: torch.Tensor, temperature: float = 1.0):
|
942 |
+
"""
|
943 |
+
Args:
|
944 |
+
num_to_mask (int): number of tokens to mask
|
945 |
+
probs (torch.Tensor): probabilities for each sampled event, shape (batch, seq)
|
946 |
+
temperature (float, optional): temperature. Defaults to 1.0.
|
947 |
+
"""
|
948 |
+
logging.info(f"masking by random topk")
|
949 |
+
logging.info(f"num to mask: {num_to_mask}")
|
950 |
+
logging.info(f"probs shape: {probs.shape}")
|
951 |
+
logging.info(f"temperature: {temperature}")
|
952 |
+
logging.info("")
|
953 |
+
|
954 |
+
confidence = torch.log(probs) + temperature * gumbel_noise_like(probs)
|
955 |
+
logging.info(f"confidence shape: {confidence.shape}")
|
956 |
+
|
957 |
+
sorted_confidence, sorted_idx = confidence.sort(dim=-1)
|
958 |
+
logging.info(f"sorted confidence shape: {sorted_confidence.shape}")
|
959 |
+
logging.info(f"sorted idx shape: {sorted_idx.shape}")
|
960 |
+
|
961 |
+
# get the cut off threshold, given the mask length
|
962 |
+
cut_off = torch.take_along_dim(
|
963 |
+
sorted_confidence, num_to_mask, axis=-1
|
964 |
+
)
|
965 |
+
logging.info(f"cut off shape: {cut_off.shape}")
|
966 |
+
|
967 |
+
# mask out the tokens
|
968 |
+
mask = confidence < cut_off
|
969 |
+
logging.info(f"mask shape: {mask.shape}")
|
970 |
+
|
971 |
+
return mask
|
972 |
+
|
973 |
+
def typical_filter(
|
974 |
+
logits,
|
975 |
+
typical_mass: float = 0.95,
|
976 |
+
typical_min_tokens: int = 1,):
|
977 |
+
nb, nt, _ = logits.shape
|
978 |
+
x_flat = rearrange(logits, "b t l -> (b t ) l")
|
979 |
+
x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
|
980 |
+
x_flat_norm_p = torch.exp(x_flat_norm)
|
981 |
+
entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)
|
982 |
+
|
983 |
+
c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
|
984 |
+
c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
|
985 |
+
x_flat_cumsum = (
|
986 |
+
x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)
|
987 |
+
)
|
988 |
+
|
989 |
+
last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
|
990 |
+
sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(
|
991 |
+
1, last_ind.view(-1, 1)
|
992 |
+
)
|
993 |
+
if typical_min_tokens > 1:
|
994 |
+
sorted_indices_to_remove[..., :typical_min_tokens] = 0
|
995 |
+
indices_to_remove = sorted_indices_to_remove.scatter(
|
996 |
+
1, x_flat_indices, sorted_indices_to_remove
|
997 |
+
)
|
998 |
+
x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
|
999 |
+
logits = rearrange(x_flat, "(b t) l -> b t l", t=nt)
|
1000 |
+
return logits
|
1001 |
+
|
1002 |
def sample_from_logits(
|
1003 |
logits,
|
1004 |
top_k: int = None,
|
|
|
1055 |
return inferred
|
1056 |
|
1057 |
|
|
|
1058 |
if __name__ == "__main__":
|
1059 |
# import argbind
|
1060 |
from .layers import num_params
|