Spaces:
Sleeping
Sleeping
Hugo Flores Garcia
commited on
Commit
·
3346920
1
Parent(s):
3445a71
more sampling fixes
Browse files- sample.py +70 -0
- scripts/{utils/vamp_folder.py → exp/experiment.py} +6 -7
- scripts/utils/parallel-gpu.sh +0 -23
- vampnet/modules/transformer.py +23 -34
sample.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
import argbind
|
3 |
+
|
4 |
+
import audiotools as at
|
5 |
+
|
6 |
+
from vampnet.interface import Interface
|
7 |
+
import logging
|
8 |
+
|
9 |
+
logger = logging.getLogger()
|
10 |
+
logger.setLevel(logging.DEBUG)
|
11 |
+
|
12 |
+
Interface = argbind.bind(Interface)
|
13 |
+
|
14 |
+
with open("conf/interface/spotdl.yml") as f:
|
15 |
+
conf = yaml.safe_load(f)
|
16 |
+
|
17 |
+
|
18 |
+
with argbind.scope(conf):
|
19 |
+
interface = Interface()
|
20 |
+
interface.to("cuda")
|
21 |
+
|
22 |
+
loader = at.data.datasets.AudioLoader(sources=[
|
23 |
+
"input.wav",
|
24 |
+
])
|
25 |
+
|
26 |
+
dataset = at.data.datasets.AudioDataset(
|
27 |
+
loader,
|
28 |
+
sample_rate=interface.codec.sample_rate,
|
29 |
+
duration=interface.coarse.chunk_size_s,
|
30 |
+
n_examples=200,
|
31 |
+
without_replacement=True,
|
32 |
+
)
|
33 |
+
|
34 |
+
import numpy as np
|
35 |
+
def load_random_audio():
|
36 |
+
index = np.random.randint(0, len(dataset))
|
37 |
+
sig = dataset[index]["signal"]
|
38 |
+
sig = interface.preprocess(sig)
|
39 |
+
|
40 |
+
return sig
|
41 |
+
|
42 |
+
|
43 |
+
sig = load_random_audio()
|
44 |
+
z = interface.encode(sig)
|
45 |
+
|
46 |
+
sig.write('input.wav')
|
47 |
+
|
48 |
+
from vampnet import mask as pmask
|
49 |
+
|
50 |
+
# build the mask
|
51 |
+
mask = pmask.linear_random(z, 1.0)
|
52 |
+
|
53 |
+
print("coarse")
|
54 |
+
zv, mask_z = interface.coarse_vamp(
|
55 |
+
z,
|
56 |
+
mask=mask,
|
57 |
+
sampling_steps=36,
|
58 |
+
temperature=8.0,
|
59 |
+
return_mask=True,
|
60 |
+
typical_filtering=False,
|
61 |
+
# typical_mass=data[typical_mass],
|
62 |
+
# typical_min_tokens=data[typical_min_tokens],
|
63 |
+
gen_fn=interface.coarse.generate,
|
64 |
+
)
|
65 |
+
|
66 |
+
print("coarse2fine")
|
67 |
+
zv = interface.coarse_to_fine(zv, temperature=0.8)
|
68 |
+
|
69 |
+
sig = interface.to_signal(zv).cpu()
|
70 |
+
sig.write('output-t=8.wav')
|
scripts/{utils/vamp_folder.py → exp/experiment.py}
RENAMED
@@ -119,13 +119,15 @@ def beat_mask(ctx_time):
|
|
119 |
def wrapper(sig, interface):
|
120 |
beat_mask = interface.make_beat_mask(
|
121 |
sig,
|
122 |
-
before_beat_s=
|
123 |
-
after_beat_s=ctx_time,
|
124 |
invert=True
|
125 |
)
|
|
|
126 |
z = interface.encode(sig)
|
|
|
127 |
zv = interface.coarse_vamp(
|
128 |
-
z, beat_mask
|
129 |
)
|
130 |
|
131 |
zv = interface.coarse_to_fine(zv)
|
@@ -185,9 +187,6 @@ EXP_REGISTRY["sampling-steps"] = {
|
|
185 |
|
186 |
|
187 |
EXP_REGISTRY["musical-sampling"] = {
|
188 |
-
"baseline": baseline,
|
189 |
-
"codec": reconstructed,
|
190 |
-
**{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
|
191 |
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
192 |
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
193 |
}
|
@@ -195,7 +194,7 @@ EXP_REGISTRY["musical-sampling"] = {
|
|
195 |
@argbind.bind(without_prefix=True)
|
196 |
def main(
|
197 |
sources=[
|
198 |
-
"/media/CHONK/hugo/spotdl/
|
199 |
],
|
200 |
output_dir: str = "./samples",
|
201 |
max_excerpts: int = 2000,
|
|
|
119 |
def wrapper(sig, interface):
|
120 |
beat_mask = interface.make_beat_mask(
|
121 |
sig,
|
122 |
+
before_beat_s=ctx_time/2,
|
123 |
+
after_beat_s=ctx_time/2,
|
124 |
invert=True
|
125 |
)
|
126 |
+
|
127 |
z = interface.encode(sig)
|
128 |
+
|
129 |
zv = interface.coarse_vamp(
|
130 |
+
z, beat_mask
|
131 |
)
|
132 |
|
133 |
zv = interface.coarse_to_fine(zv)
|
|
|
187 |
|
188 |
|
189 |
EXP_REGISTRY["musical-sampling"] = {
|
|
|
|
|
|
|
190 |
**{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
|
191 |
**{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
|
192 |
}
|
|
|
194 |
@argbind.bind(without_prefix=True)
|
195 |
def main(
|
196 |
sources=[
|
197 |
+
"/media/CHONK/hugo/spotdl/val",
|
198 |
],
|
199 |
output_dir: str = "./samples",
|
200 |
max_excerpts: int = 2000,
|
scripts/utils/parallel-gpu.sh
DELETED
@@ -1,23 +0,0 @@
|
|
1 |
-
#!/bin/bash
|
2 |
-
|
3 |
-
# Get the command to execute from the user
|
4 |
-
command_to_execute="$1"
|
5 |
-
|
6 |
-
# Get the maximum number of GPUs to use from the user
|
7 |
-
max_gpus="$2"
|
8 |
-
|
9 |
-
# Get the number of instances to start per GPU from the user
|
10 |
-
instances_per_gpu="$3"
|
11 |
-
|
12 |
-
# Set the CUDA_VISIBLE_DEVICES flag for each GPU
|
13 |
-
for gpu_id in $(seq 0 $(($max_gpus - 1))); do
|
14 |
-
export CUDA_VISIBLE_DEVICES="$gpu_id"
|
15 |
-
# Start the specified number of instances for this GPU
|
16 |
-
for i in $(seq 1 "$instances_per_gpu"); do
|
17 |
-
# Run the command in the background
|
18 |
-
$command_to_execute &
|
19 |
-
done
|
20 |
-
done
|
21 |
-
|
22 |
-
# Wait for all instances to finish
|
23 |
-
wait
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vampnet/modules/transformer.py
CHANGED
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
|
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
-
temperature: Union[float, Tuple[float, float]] =
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
@@ -592,15 +592,7 @@ class VampNet(at.ml.BaseModel):
|
|
592 |
#####################
|
593 |
# resolve temperature #
|
594 |
#####################
|
595 |
-
|
596 |
-
temperature = torch.tensor(temperature).repeat(sampling_steps)
|
597 |
-
elif isinstance(temperature, tuple):
|
598 |
-
assert len(temperature) == 2
|
599 |
-
l, h = temperature
|
600 |
-
temperature = torch.linspace(l, h, sampling_steps)
|
601 |
-
else:
|
602 |
-
raise TypeError(f"invalid type for temperature")
|
603 |
-
|
604 |
logging.debug(f"temperature: {temperature}")
|
605 |
|
606 |
|
@@ -642,10 +634,6 @@ class VampNet(at.ml.BaseModel):
|
|
642 |
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
643 |
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
644 |
|
645 |
-
# our r steps
|
646 |
-
r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
|
647 |
-
logging.debug(f"r steps: {r_steps}")
|
648 |
-
|
649 |
# how many codebooks are we inferring vs conditioning on?
|
650 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
651 |
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
@@ -658,11 +646,13 @@ class VampNet(at.ml.BaseModel):
|
|
658 |
logging.debug(f"step {i} of {sampling_steps}")
|
659 |
|
660 |
# our current temperature
|
661 |
-
|
662 |
-
logging.debug(f"temperature: {tmpt}")
|
663 |
|
664 |
# our current schedule step
|
665 |
-
r =
|
|
|
|
|
|
|
666 |
logging.debug(f"r: {r}")
|
667 |
|
668 |
# get latents
|
@@ -699,11 +689,18 @@ class VampNet(at.ml.BaseModel):
|
|
699 |
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
700 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
701 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
702 |
|
703 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
704 |
# we'll unflatten them at the end of the loop for the next forward pass
|
705 |
# remove conditioning codebooks, we'll add them back at the end
|
706 |
-
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
707 |
|
708 |
mask = (z_masked == self.mask_token).int()
|
709 |
|
@@ -715,15 +712,6 @@ class VampNet(at.ml.BaseModel):
|
|
715 |
)
|
716 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
717 |
|
718 |
-
|
719 |
-
# get the confidences: which tokens did we sample?
|
720 |
-
selected_probs = (
|
721 |
-
torch.take_along_dim(
|
722 |
-
probs, sampled_z.long().unsqueeze(-1),
|
723 |
-
dim=-1
|
724 |
-
).squeeze(-1)
|
725 |
-
)
|
726 |
-
|
727 |
# ignore any tokens that weren't masked
|
728 |
selected_probs = torch.where(
|
729 |
mask.bool(), selected_probs, torch.inf
|
@@ -733,18 +721,19 @@ class VampNet(at.ml.BaseModel):
|
|
733 |
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
734 |
logging.debug(f"num to mask: {num_to_mask}")
|
735 |
|
736 |
-
|
737 |
-
torch.
|
738 |
-
|
739 |
-
|
740 |
-
|
|
|
|
|
741 |
)
|
742 |
-
)
|
743 |
|
744 |
|
745 |
# get our new mask
|
746 |
mask = mask_by_random_topk(
|
747 |
-
num_to_mask, selected_probs,
|
748 |
)
|
749 |
|
750 |
# update the mask
|
|
|
581 |
sampling_steps: int = 24,
|
582 |
start_tokens: Optional[torch.Tensor] = None,
|
583 |
mask: Optional[torch.Tensor] = None,
|
584 |
+
temperature: Union[float, Tuple[float, float]] = 2.5,
|
585 |
typical_filtering=False,
|
586 |
typical_mass=0.2,
|
587 |
typical_min_tokens=1,
|
|
|
592 |
#####################
|
593 |
# resolve temperature #
|
594 |
#####################
|
595 |
+
assert isinstance(temperature, float)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
logging.debug(f"temperature: {temperature}")
|
597 |
|
598 |
|
|
|
634 |
num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
|
635 |
logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
|
636 |
|
|
|
|
|
|
|
|
|
637 |
# how many codebooks are we inferring vs conditioning on?
|
638 |
n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
|
639 |
logging.debug(f"n infer codebooks: {n_infer_codebooks}")
|
|
|
646 |
logging.debug(f"step {i} of {sampling_steps}")
|
647 |
|
648 |
# our current temperature
|
649 |
+
logging.debug(f"temperature: {temperature}")
|
|
|
650 |
|
651 |
# our current schedule step
|
652 |
+
r = scalar_to_batch_tensor(
|
653 |
+
(i + 1) / sampling_steps,
|
654 |
+
z.shape[0]
|
655 |
+
).to(z.device)
|
656 |
logging.debug(f"r: {r}")
|
657 |
|
658 |
# get latents
|
|
|
689 |
probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
|
690 |
logging.debug(f"sampled z with shape: {sampled_z.shape}")
|
691 |
|
692 |
+
# get the confidences: which tokens did we sample?
|
693 |
+
selected_probs = (
|
694 |
+
torch.take_along_dim(
|
695 |
+
probs, sampled_z.long().unsqueeze(-1),
|
696 |
+
dim=-1
|
697 |
+
).squeeze(-1)
|
698 |
+
)
|
699 |
|
700 |
# flatten z_masked and mask, so we can deal with the sampling logic
|
701 |
# we'll unflatten them at the end of the loop for the next forward pass
|
702 |
# remove conditioning codebooks, we'll add them back at the end
|
703 |
+
z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
|
704 |
|
705 |
mask = (z_masked == self.mask_token).int()
|
706 |
|
|
|
712 |
)
|
713 |
logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
|
714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
715 |
# ignore any tokens that weren't masked
|
716 |
selected_probs = torch.where(
|
717 |
mask.bool(), selected_probs, torch.inf
|
|
|
721 |
num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
|
722 |
logging.debug(f"num to mask: {num_to_mask}")
|
723 |
|
724 |
+
if i != (sampling_steps - 1):
|
725 |
+
num_to_mask = torch.maximum(
|
726 |
+
torch.tensor(1),
|
727 |
+
torch.minimum(
|
728 |
+
mask.sum(dim=-1, keepdim=True) - 1,
|
729 |
+
num_to_mask
|
730 |
+
)
|
731 |
)
|
|
|
732 |
|
733 |
|
734 |
# get our new mask
|
735 |
mask = mask_by_random_topk(
|
736 |
+
num_to_mask, selected_probs, temperature * (1-r)
|
737 |
)
|
738 |
|
739 |
# update the mask
|