Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
0321745
1
Parent(s):
2ddbddc
taking prompt c2f tokens into account
Browse files- app.py +3 -2
- vampnet/interface.py +40 -20
app.py
CHANGED
@@ -114,7 +114,7 @@ def _vamp(data, return_mask=False):
|
|
114 |
)
|
115 |
|
116 |
if use_coarse2fine:
|
117 |
-
zv = interface.coarse_to_fine(zv, temperature=data[temp])
|
118 |
|
119 |
sig = interface.to_signal(zv).cpu()
|
120 |
print("done")
|
@@ -410,7 +410,8 @@ with gr.Blocks() as demo:
|
|
410 |
|
411 |
use_coarse2fine = gr.Checkbox(
|
412 |
label="use coarse2fine",
|
413 |
-
value=True
|
|
|
414 |
)
|
415 |
|
416 |
num_steps = gr.Slider(
|
|
|
114 |
)
|
115 |
|
116 |
if use_coarse2fine:
|
117 |
+
zv = interface.coarse_to_fine(zv, temperature=data[temp], mask=mask)
|
118 |
|
119 |
sig = interface.to_signal(zv).cpu()
|
120 |
print("done")
|
|
|
410 |
|
411 |
use_coarse2fine = gr.Checkbox(
|
412 |
label="use coarse2fine",
|
413 |
+
value=True,
|
414 |
+
visible=False
|
415 |
)
|
416 |
|
417 |
num_steps = gr.Slider(
|
vampnet/interface.py
CHANGED
@@ -22,6 +22,7 @@ def signal_concat(
|
|
22 |
|
23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
24 |
|
|
|
25 |
def _load_model(
|
26 |
ckpt: str,
|
27 |
lora_ckpt: str = None,
|
@@ -275,36 +276,47 @@ class Interface(torch.nn.Module):
|
|
275 |
|
276 |
def coarse_to_fine(
|
277 |
self,
|
278 |
-
|
|
|
279 |
**kwargs
|
280 |
):
|
281 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
282 |
-
length =
|
283 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
284 |
-
n_chunks = math.ceil(
|
285 |
|
286 |
# zero pad to chunk_len
|
287 |
if length % chunk_len != 0:
|
288 |
pad_len = chunk_len - (length % chunk_len)
|
289 |
-
|
|
|
290 |
|
291 |
-
n_codebooks_to_append = self.c2f.n_codebooks -
|
292 |
if n_codebooks_to_append > 0:
|
293 |
-
|
294 |
-
|
295 |
-
torch.zeros(
|
296 |
], dim=1)
|
297 |
|
|
|
|
|
|
|
|
|
|
|
298 |
fine_z = []
|
299 |
for i in range(n_chunks):
|
300 |
-
chunk =
|
|
|
|
|
301 |
chunk = self.c2f.generate(
|
302 |
codec=self.codec,
|
303 |
time_steps=chunk_len,
|
304 |
start_tokens=chunk,
|
305 |
return_signal=False,
|
|
|
306 |
**kwargs
|
307 |
)
|
|
|
308 |
fine_z.append(chunk)
|
309 |
|
310 |
fine_z = torch.cat(fine_z, dim=-1)
|
@@ -337,6 +349,12 @@ class Interface(torch.nn.Module):
|
|
337 |
**kwargs
|
338 |
)
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
if return_mask:
|
341 |
return c_vamp, cz_masked
|
342 |
|
@@ -352,17 +370,18 @@ if __name__ == "__main__":
|
|
352 |
at.util.seed(42)
|
353 |
|
354 |
interface = Interface(
|
355 |
-
coarse_ckpt="./models/
|
356 |
-
coarse2fine_ckpt="./models/
|
357 |
-
codec_ckpt="./models/
|
358 |
device="cuda",
|
359 |
wavebeat_ckpt="./models/wavebeat.pth"
|
360 |
)
|
361 |
|
362 |
|
363 |
-
sig = at.AudioSignal.
|
364 |
|
365 |
z = interface.encode(sig)
|
|
|
366 |
|
367 |
# mask = linear_random(z, 1.0)
|
368 |
# mask = mask_and(
|
@@ -374,13 +393,14 @@ if __name__ == "__main__":
|
|
374 |
# )
|
375 |
# )
|
376 |
|
377 |
-
mask = interface.make_beat_mask(
|
378 |
-
|
379 |
-
)
|
380 |
# mask = dropout(mask, 0.0)
|
381 |
# mask = codebook_unmask(mask, 0)
|
|
|
|
|
382 |
|
383 |
-
breakpoint()
|
384 |
zv, mask_z = interface.coarse_vamp(
|
385 |
z,
|
386 |
mask=mask,
|
@@ -389,16 +409,16 @@ if __name__ == "__main__":
|
|
389 |
return_mask=True,
|
390 |
gen_fn=interface.coarse.generate
|
391 |
)
|
|
|
392 |
|
393 |
use_coarse2fine = True
|
394 |
if use_coarse2fine:
|
395 |
-
zv = interface.coarse_to_fine(zv, temperature=0.8)
|
|
|
396 |
|
397 |
mask = interface.to_signal(mask_z).cpu()
|
398 |
|
399 |
sig = interface.to_signal(zv).cpu()
|
400 |
print("done")
|
401 |
|
402 |
-
sig.write("output3.wav")
|
403 |
-
mask.write("mask.wav")
|
404 |
|
|
|
22 |
|
23 |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate)
|
24 |
|
25 |
+
|
26 |
def _load_model(
|
27 |
ckpt: str,
|
28 |
lora_ckpt: str = None,
|
|
|
276 |
|
277 |
def coarse_to_fine(
|
278 |
self,
|
279 |
+
z: torch.Tensor,
|
280 |
+
mask: torch.Tensor = None,
|
281 |
**kwargs
|
282 |
):
|
283 |
assert self.c2f is not None, "No coarse2fine model loaded"
|
284 |
+
length = z.shape[-1]
|
285 |
chunk_len = self.s2t(self.c2f.chunk_size_s)
|
286 |
+
n_chunks = math.ceil(z.shape[-1] / chunk_len)
|
287 |
|
288 |
# zero pad to chunk_len
|
289 |
if length % chunk_len != 0:
|
290 |
pad_len = chunk_len - (length % chunk_len)
|
291 |
+
z = torch.nn.functional.pad(z, (0, pad_len))
|
292 |
+
mask = torch.nn.functional.pad(mask, (0, pad_len)) if mask is not None else None
|
293 |
|
294 |
+
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1]
|
295 |
if n_codebooks_to_append > 0:
|
296 |
+
z = torch.cat([
|
297 |
+
z,
|
298 |
+
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device)
|
299 |
], dim=1)
|
300 |
|
301 |
+
# set the mask to 0 for all conditioning codebooks
|
302 |
+
if mask is not None:
|
303 |
+
mask = mask.clone()
|
304 |
+
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0
|
305 |
+
|
306 |
fine_z = []
|
307 |
for i in range(n_chunks):
|
308 |
+
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len]
|
309 |
+
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None
|
310 |
+
|
311 |
chunk = self.c2f.generate(
|
312 |
codec=self.codec,
|
313 |
time_steps=chunk_len,
|
314 |
start_tokens=chunk,
|
315 |
return_signal=False,
|
316 |
+
mask=mask_chunk,
|
317 |
**kwargs
|
318 |
)
|
319 |
+
breakpoint()
|
320 |
fine_z.append(chunk)
|
321 |
|
322 |
fine_z = torch.cat(fine_z, dim=-1)
|
|
|
349 |
**kwargs
|
350 |
)
|
351 |
|
352 |
+
# add the fine codes back in
|
353 |
+
c_vamp = torch.cat(
|
354 |
+
[c_vamp, z[:, self.coarse.n_codebooks :, :]],
|
355 |
+
dim=1
|
356 |
+
)
|
357 |
+
|
358 |
if return_mask:
|
359 |
return c_vamp, cz_masked
|
360 |
|
|
|
370 |
at.util.seed(42)
|
371 |
|
372 |
interface = Interface(
|
373 |
+
coarse_ckpt="./models/vampnet/coarse.pth",
|
374 |
+
coarse2fine_ckpt="./models/vampnet/c2f.pth",
|
375 |
+
codec_ckpt="./models/vampnet/codec.pth",
|
376 |
device="cuda",
|
377 |
wavebeat_ckpt="./models/wavebeat.pth"
|
378 |
)
|
379 |
|
380 |
|
381 |
+
sig = at.AudioSignal('assets/example.wav')
|
382 |
|
383 |
z = interface.encode(sig)
|
384 |
+
breakpoint()
|
385 |
|
386 |
# mask = linear_random(z, 1.0)
|
387 |
# mask = mask_and(
|
|
|
393 |
# )
|
394 |
# )
|
395 |
|
396 |
+
# mask = interface.make_beat_mask(
|
397 |
+
# sig, 0.0, 0.075
|
398 |
+
# )
|
399 |
# mask = dropout(mask, 0.0)
|
400 |
# mask = codebook_unmask(mask, 0)
|
401 |
+
|
402 |
+
mask = inpaint(z, n_prefix=100, n_suffix=100)
|
403 |
|
|
|
404 |
zv, mask_z = interface.coarse_vamp(
|
405 |
z,
|
406 |
mask=mask,
|
|
|
409 |
return_mask=True,
|
410 |
gen_fn=interface.coarse.generate
|
411 |
)
|
412 |
+
|
413 |
|
414 |
use_coarse2fine = True
|
415 |
if use_coarse2fine:
|
416 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8, mask=mask)
|
417 |
+
breakpoint()
|
418 |
|
419 |
mask = interface.to_signal(mask_z).cpu()
|
420 |
|
421 |
sig = interface.to_signal(zv).cpu()
|
422 |
print("done")
|
423 |
|
|
|
|
|
424 |
|