Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
322cc3a
1
Parent(s):
8544bbf
moving out
Browse files- demo.py +60 -8
- scripts/utils/vamp_folder.py +1 -1
- vampnet/interface.py +14 -3
- vampnet/modules/base.py +25 -1
demo.py
CHANGED
@@ -65,7 +65,10 @@ def vamp(
|
|
65 |
mask_periodic_amt, beat_unmask_dur,
|
66 |
mask_dwn_chk, dwn_factor,
|
67 |
mask_up_chk, up_factor,
|
68 |
-
num_vamps, mode, use_beats, num_steps, snap_to_beats
|
|
|
|
|
|
|
69 |
):
|
70 |
# try:
|
71 |
print(input_audio)
|
@@ -89,7 +92,7 @@ def vamp(
|
|
89 |
mask_upbeats=mask_up_chk,
|
90 |
downbeat_downsample_factor=dwn_factor if dwn_factor > 0 else None,
|
91 |
beat_downsample_factor=up_factor if up_factor > 0 else None,
|
92 |
-
dropout=
|
93 |
invert=True
|
94 |
)
|
95 |
print(beat_mask)
|
@@ -106,6 +109,10 @@ def vamp(
|
|
106 |
suffix_dur_s=suffix_s,
|
107 |
num_vamps=num_vamps,
|
108 |
downsample_factor=mask_periodic_amt,
|
|
|
|
|
|
|
|
|
109 |
intensity=rand_mask_intensity,
|
110 |
ext_mask=beat_mask,
|
111 |
verbose=True,
|
@@ -126,6 +133,7 @@ def vamp(
|
|
126 |
suffix_dur_s=prefix_s, # suffix should be same length as prefix
|
127 |
num_loops=num_vamps,
|
128 |
downsample_factor=mask_periodic_amt,
|
|
|
129 |
intensity=rand_mask_intensity,
|
130 |
ext_mask=beat_mask,
|
131 |
verbose=True,
|
@@ -150,7 +158,9 @@ def save_vamp(
|
|
150 |
mask_periodic_amt, beat_unmask_dur,
|
151 |
mask_dwn_chk, dwn_factor,
|
152 |
mask_up_chk, up_factor,
|
153 |
-
num_vamps, mode, output_audio, notes, use_beats, num_steps, snap_to_beats
|
|
|
|
|
154 |
):
|
155 |
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
156 |
out_dir.mkdir(parents=True, exist_ok=True)
|
@@ -179,6 +189,11 @@ def save_vamp(
|
|
179 |
"snap_to_beats": snap_to_beats,
|
180 |
"mode": mode,
|
181 |
"notes": notes,
|
|
|
|
|
|
|
|
|
|
|
182 |
}
|
183 |
|
184 |
# save with yaml
|
@@ -287,6 +302,12 @@ with gr.Blocks() as demo:
|
|
287 |
# mask settings
|
288 |
with gr.Column():
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
mask_periodic_amt = gr.Slider(
|
291 |
label="periodic hint (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
292 |
minimum=0,
|
@@ -294,11 +315,29 @@ with gr.Blocks() as demo:
|
|
294 |
step=1,
|
295 |
value=9,
|
296 |
)
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
rand_mask_intensity = gr.Slider(
|
300 |
label="random mask intensity. (If this is less than 1, scatters tiny hints throughout the audio, should be between 0.9 and 1.0)",
|
301 |
-
minimum=0.
|
302 |
maximum=1.0,
|
303 |
value=1.0
|
304 |
)
|
@@ -343,7 +382,7 @@ with gr.Blocks() as demo:
|
|
343 |
|
344 |
num_steps = gr.Slider(
|
345 |
label="number of steps (should normally be between 12 and 36)",
|
346 |
-
minimum=
|
347 |
maximum=128,
|
348 |
step=1,
|
349 |
value=36
|
@@ -379,6 +418,13 @@ with gr.Blocks() as demo:
|
|
379 |
maximum=3.0,
|
380 |
value=0.07
|
381 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
with gr.Accordion("downbeat settings", open=False):
|
383 |
mask_dwn_chk = gr.Checkbox(
|
384 |
label="hint downbeats",
|
@@ -427,7 +473,10 @@ with gr.Blocks() as demo:
|
|
427 |
mask_periodic_amt, beat_unmask_dur,
|
428 |
mask_dwn_chk, dwn_factor,
|
429 |
mask_up_chk, up_factor,
|
430 |
-
num_vamps, mode, use_beats, num_steps, snap_to_beats
|
|
|
|
|
|
|
431 |
],
|
432 |
outputs=[output_audio, audio_mask]
|
433 |
)
|
@@ -442,7 +491,10 @@ with gr.Blocks() as demo:
|
|
442 |
mask_up_chk, up_factor,
|
443 |
num_vamps, mode,
|
444 |
output_audio,
|
445 |
-
notes_text, use_beats, num_steps, snap_to_beats
|
|
|
|
|
|
|
446 |
],
|
447 |
outputs=[thank_you, download_file]
|
448 |
)
|
|
|
65 |
mask_periodic_amt, beat_unmask_dur,
|
66 |
mask_dwn_chk, dwn_factor,
|
67 |
mask_up_chk, up_factor,
|
68 |
+
num_vamps, mode, use_beats, num_steps, snap_to_beats,
|
69 |
+
beat_unmask_drop, mask_periodic_width,
|
70 |
+
mask_periodic_dropout, mask_periodic_width_dropout,
|
71 |
+
n_conditioning_codebooks
|
72 |
):
|
73 |
# try:
|
74 |
print(input_audio)
|
|
|
92 |
mask_upbeats=mask_up_chk,
|
93 |
downbeat_downsample_factor=dwn_factor if dwn_factor > 0 else None,
|
94 |
beat_downsample_factor=up_factor if up_factor > 0 else None,
|
95 |
+
dropout=beat_unmask_drop,
|
96 |
invert=True
|
97 |
)
|
98 |
print(beat_mask)
|
|
|
109 |
suffix_dur_s=suffix_s,
|
110 |
num_vamps=num_vamps,
|
111 |
downsample_factor=mask_periodic_amt,
|
112 |
+
periodic_width=mask_periodic_width,
|
113 |
+
periodic_dropout=mask_periodic_dropout,
|
114 |
+
periodic_width_dropout=mask_periodic_width_dropout,
|
115 |
+
n_conditioning_codebooks=n_conditioning_codebooks if n_conditioning_codebooks > 0 else None,
|
116 |
intensity=rand_mask_intensity,
|
117 |
ext_mask=beat_mask,
|
118 |
verbose=True,
|
|
|
133 |
suffix_dur_s=prefix_s, # suffix should be same length as prefix
|
134 |
num_loops=num_vamps,
|
135 |
downsample_factor=mask_periodic_amt,
|
136 |
+
periodic_width=mask_periodic_width,
|
137 |
intensity=rand_mask_intensity,
|
138 |
ext_mask=beat_mask,
|
139 |
verbose=True,
|
|
|
158 |
mask_periodic_amt, beat_unmask_dur,
|
159 |
mask_dwn_chk, dwn_factor,
|
160 |
mask_up_chk, up_factor,
|
161 |
+
num_vamps, mode, output_audio, notes, use_beats, num_steps, snap_to_beats,
|
162 |
+
beat_unmask_drop, mask_periodic_width, mask_periodic_dropout, mask_periodic_width_dropout,
|
163 |
+
n_conditioning_codebooks
|
164 |
):
|
165 |
out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
|
166 |
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
189 |
"snap_to_beats": snap_to_beats,
|
190 |
"mode": mode,
|
191 |
"notes": notes,
|
192 |
+
"beat_unmask_drop": beat_unmask_drop,
|
193 |
+
"mask_periodic_width": mask_periodic_width,
|
194 |
+
"mask_periodic_dropout": mask_periodic_dropout,
|
195 |
+
"mask_periodic_width_dropout": mask_periodic_width_dropout,
|
196 |
+
"n_conditioning_codebooks": n_conditioning_codebooks
|
197 |
}
|
198 |
|
199 |
# save with yaml
|
|
|
302 |
# mask settings
|
303 |
with gr.Column():
|
304 |
|
305 |
+
n_conditioning_codebooks = gr.Number(
|
306 |
+
label="number of conditioning codebooks. probably 0",
|
307 |
+
value=0,
|
308 |
+
precision=0,
|
309 |
+
)
|
310 |
+
|
311 |
mask_periodic_amt = gr.Slider(
|
312 |
label="periodic hint (0.0 means no hint, 2 - lots of hints, 8 - a couple of hints, 16 - occasional hint, 32 - very occasional hint, etc)",
|
313 |
minimum=0,
|
|
|
315 |
step=1,
|
316 |
value=9,
|
317 |
)
|
318 |
+
mask_periodic_width = gr.Slider(
|
319 |
+
label="periodic hint width (steps, 1 step ~= 10milliseconds",
|
320 |
+
minimum=1,
|
321 |
+
maximum=100,
|
322 |
+
step=1,
|
323 |
+
value=1,
|
324 |
+
)
|
325 |
+
mask_periodic_dropout = gr.Slider(
|
326 |
+
label="periodic hint dropout (0.0 means no dropout, 1.0 means all dropout)",
|
327 |
+
minimum=0.0,
|
328 |
+
maximum=1.0,
|
329 |
+
value=0.0,
|
330 |
+
)
|
331 |
+
mask_periodic_width_dropout = gr.Slider(
|
332 |
+
label="periodic hint width dropout (0.0 means no dropout, 1.0 means all dropout)",
|
333 |
+
minimum=0.0,
|
334 |
+
maximum=1.0,
|
335 |
+
value=0.0,
|
336 |
+
)
|
337 |
|
338 |
rand_mask_intensity = gr.Slider(
|
339 |
label="random mask intensity. (If this is less than 1, scatters tiny hints throughout the audio, should be between 0.9 and 1.0)",
|
340 |
+
minimum=0.8,
|
341 |
maximum=1.0,
|
342 |
value=1.0
|
343 |
)
|
|
|
382 |
|
383 |
num_steps = gr.Slider(
|
384 |
label="number of steps (should normally be between 12 and 36)",
|
385 |
+
minimum=1,
|
386 |
maximum=128,
|
387 |
step=1,
|
388 |
value=36
|
|
|
418 |
maximum=3.0,
|
419 |
value=0.07
|
420 |
)
|
421 |
+
beat_unmask_drop = gr.Slider(
|
422 |
+
label="dropout (within beat)",
|
423 |
+
minimum=0.0,
|
424 |
+
maximum=1.0,
|
425 |
+
value=0.0
|
426 |
+
)
|
427 |
+
|
428 |
with gr.Accordion("downbeat settings", open=False):
|
429 |
mask_dwn_chk = gr.Checkbox(
|
430 |
label="hint downbeats",
|
|
|
473 |
mask_periodic_amt, beat_unmask_dur,
|
474 |
mask_dwn_chk, dwn_factor,
|
475 |
mask_up_chk, up_factor,
|
476 |
+
num_vamps, mode, use_beats, num_steps, snap_to_beats,
|
477 |
+
beat_unmask_drop, mask_periodic_width,
|
478 |
+
mask_periodic_dropout, mask_periodic_width_dropout,
|
479 |
+
n_conditioning_codebooks
|
480 |
],
|
481 |
outputs=[output_audio, audio_mask]
|
482 |
)
|
|
|
491 |
mask_up_chk, up_factor,
|
492 |
num_vamps, mode,
|
493 |
output_audio,
|
494 |
+
notes_text, use_beats, num_steps, snap_to_beats,
|
495 |
+
beat_unmask_drop, mask_periodic_width,
|
496 |
+
mask_periodic_dropout, mask_periodic_width_dropout,
|
497 |
+
n_conditioning_codebooks
|
498 |
],
|
499 |
outputs=[thank_you, download_file]
|
500 |
)
|
scripts/utils/vamp_folder.py
CHANGED
@@ -220,7 +220,7 @@ EXP_REGISTRY["mask-ratio"] = {
|
|
220 |
|
221 |
EXP_REGISTRY["sampling-steps"] = {
|
222 |
"codec": reconstructed,
|
223 |
-
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72
|
224 |
}
|
225 |
|
226 |
EXP_REGISTRY["baseline"] = {
|
|
|
220 |
|
221 |
EXP_REGISTRY["sampling-steps"] = {
|
222 |
"codec": reconstructed,
|
223 |
+
**{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72]},
|
224 |
}
|
225 |
|
226 |
EXP_REGISTRY["baseline"] = {
|
vampnet/interface.py
CHANGED
@@ -134,7 +134,7 @@ class Interface:
|
|
134 |
mask_upbeats: bool = True,
|
135 |
downbeat_downsample_factor: int = None,
|
136 |
beat_downsample_factor: int = None,
|
137 |
-
dropout: float = 0.
|
138 |
invert: bool = True,
|
139 |
):
|
140 |
"""make a beat synced mask. that is, make a mask that
|
@@ -182,7 +182,8 @@ class Interface:
|
|
182 |
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
184 |
_m = torch.ones(num_steps, device=self.device)
|
185 |
-
|
|
|
186 |
|
187 |
mask[_slice[0]:_slice[1]] = _m
|
188 |
|
@@ -191,7 +192,8 @@ class Interface:
|
|
191 |
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
192 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
193 |
_m = torch.ones(num_steps, device=self.device)
|
194 |
-
|
|
|
195 |
|
196 |
mask[_slice[0]:_slice[1]] = _m
|
197 |
|
@@ -342,6 +344,9 @@ class Interface:
|
|
342 |
suffix_dur_s: float = 0.0,
|
343 |
num_vamps: int = 1,
|
344 |
downsample_factor: int = None,
|
|
|
|
|
|
|
345 |
intensity: float = 1.0,
|
346 |
debug=False,
|
347 |
swap_prefix_suffix=False,
|
@@ -383,6 +388,9 @@ class Interface:
|
|
383 |
n_prefix=n_prefix,
|
384 |
n_suffix=n_suffix,
|
385 |
downsample_factor=downsample_factor,
|
|
|
|
|
|
|
386 |
mask=cz_mask,
|
387 |
ext_mask=ext_mask,
|
388 |
n_conditioning_codebooks=n_conditioning_codebooks
|
@@ -481,8 +489,11 @@ class Interface:
|
|
481 |
suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
|
482 |
c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
|
483 |
|
|
|
|
|
484 |
if return_mask:
|
485 |
return c_vamp, cz_masked
|
|
|
486 |
return c_vamp
|
487 |
|
488 |
# create a variation of an audio signal
|
|
|
134 |
mask_upbeats: bool = True,
|
135 |
downbeat_downsample_factor: int = None,
|
136 |
beat_downsample_factor: int = None,
|
137 |
+
dropout: float = 0.0,
|
138 |
invert: bool = True,
|
139 |
):
|
140 |
"""make a beat synced mask. that is, make a mask that
|
|
|
182 |
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after)
|
183 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
184 |
_m = torch.ones(num_steps, device=self.device)
|
185 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
186 |
+
_m = _m * _m_mask.long()
|
187 |
|
188 |
mask[_slice[0]:_slice[1]] = _m
|
189 |
|
|
|
192 |
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after)
|
193 |
num_steps = mask[_slice[0]:_slice[1]].shape[0]
|
194 |
_m = torch.ones(num_steps, device=self.device)
|
195 |
+
_m_mask = torch.bernoulli(_m * (1 - dropout))
|
196 |
+
_m = _m * _m_mask.long()
|
197 |
|
198 |
mask[_slice[0]:_slice[1]] = _m
|
199 |
|
|
|
344 |
suffix_dur_s: float = 0.0,
|
345 |
num_vamps: int = 1,
|
346 |
downsample_factor: int = None,
|
347 |
+
periodic_width: int = 1,
|
348 |
+
periodic_dropout=0.0,
|
349 |
+
periodic_width_dropout=0.0,
|
350 |
intensity: float = 1.0,
|
351 |
debug=False,
|
352 |
swap_prefix_suffix=False,
|
|
|
388 |
n_prefix=n_prefix,
|
389 |
n_suffix=n_suffix,
|
390 |
downsample_factor=downsample_factor,
|
391 |
+
periodic_width=periodic_width,
|
392 |
+
periodic_dropout=periodic_dropout,
|
393 |
+
periodic_width_dropout=periodic_width_dropout,
|
394 |
mask=cz_mask,
|
395 |
ext_mask=ext_mask,
|
396 |
n_conditioning_codebooks=n_conditioning_codebooks
|
|
|
489 |
suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
|
490 |
c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
|
491 |
|
492 |
+
# replace the mask token in cz_masked with random tokens
|
493 |
+
# so that we can decode it
|
494 |
if return_mask:
|
495 |
return c_vamp, cz_masked
|
496 |
+
|
497 |
return c_vamp
|
498 |
|
499 |
# create a variation of an audio signal
|
vampnet/modules/base.py
CHANGED
@@ -41,6 +41,9 @@ class VampBase(at.ml.BaseModel):
|
|
41 |
n_prefix: Optional[torch.Tensor] = None,
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
|
|
|
|
|
|
44 |
n_conditioning_codebooks: Optional[int] = None,
|
45 |
noise_mode: str = None,
|
46 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
@@ -75,7 +78,20 @@ class VampBase(at.ml.BaseModel):
|
|
75 |
continue
|
76 |
for j in range(probs.shape[-1]):
|
77 |
if j % factor == 0:
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
mask = torch.bernoulli(probs)
|
81 |
mask = mask.round().long()
|
@@ -175,6 +191,14 @@ class VampBase(at.ml.BaseModel):
|
|
175 |
codec.sample_rate,
|
176 |
)
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
return signal
|
179 |
|
180 |
@torch.no_grad()
|
|
|
41 |
n_prefix: Optional[torch.Tensor] = None,
|
42 |
n_suffix: Optional[torch.Tensor] = None,
|
43 |
downsample_factor: Optional[int] = None,
|
44 |
+
periodic_width: int = 1,
|
45 |
+
periodic_width_dropout: float = 0.0,
|
46 |
+
periodic_dropout: float = 0.0,
|
47 |
n_conditioning_codebooks: Optional[int] = None,
|
48 |
noise_mode: str = None,
|
49 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
78 |
continue
|
79 |
for j in range(probs.shape[-1]):
|
80 |
if j % factor == 0:
|
81 |
+
# if we have periodic dropout
|
82 |
+
if periodic_dropout > 0:
|
83 |
+
# flip a coin
|
84 |
+
if torch.bernoulli(torch.tensor(periodic_dropout)).item() == 1:
|
85 |
+
# if we win, skip
|
86 |
+
continue
|
87 |
+
# figure out how wide the mask should be
|
88 |
+
j_start = max(0, j - periodic_width // 2)
|
89 |
+
j_end = min(probs.shape[-1] - 1, j + periodic_width // 2) + 1
|
90 |
+
# flip a coin for each position in the mask
|
91 |
+
j_mask = torch.bernoulli(torch.ones(j_end - j_start) * periodic_width_dropout)
|
92 |
+
j_fill = torch.ones_like(j_mask) * (1 - j_mask)
|
93 |
+
# fill
|
94 |
+
probs[i, :, j_start:j_end] = 1 - j_fill
|
95 |
|
96 |
mask = torch.bernoulli(probs)
|
97 |
mask = mask.round().long()
|
|
|
191 |
codec.sample_rate,
|
192 |
)
|
193 |
|
194 |
+
# find where the mask token is and replace it with silence in the audio
|
195 |
+
for tstep in range(z.shape[-1]):
|
196 |
+
if torch.any(z[:, :, tstep] == self.mask_token):
|
197 |
+
print("mask token found at step", tstep)
|
198 |
+
sample_idx_0 = tstep * codec.hop_length
|
199 |
+
sample_idx_1 = sample_idx_0 + codec.hop_length
|
200 |
+
signal.samples[:, :, sample_idx_0:sample_idx_1] = 0.0
|
201 |
+
|
202 |
return signal
|
203 |
|
204 |
@torch.no_grad()
|