Hugo Flores Garcia commited on
Commit
3346920
·
1 Parent(s): 3445a71

more sampling fixes

Browse files
sample.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import argbind
3
+
4
+ import audiotools as at
5
+
6
+ from vampnet.interface import Interface
7
+ import logging
8
+
9
+ logger = logging.getLogger()
10
+ logger.setLevel(logging.DEBUG)
11
+
12
+ Interface = argbind.bind(Interface)
13
+
14
+ with open("conf/interface/spotdl.yml") as f:
15
+ conf = yaml.safe_load(f)
16
+
17
+
18
+ with argbind.scope(conf):
19
+ interface = Interface()
20
+ interface.to("cuda")
21
+
22
+ loader = at.data.datasets.AudioLoader(sources=[
23
+ "input.wav",
24
+ ])
25
+
26
+ dataset = at.data.datasets.AudioDataset(
27
+ loader,
28
+ sample_rate=interface.codec.sample_rate,
29
+ duration=interface.coarse.chunk_size_s,
30
+ n_examples=200,
31
+ without_replacement=True,
32
+ )
33
+
34
+ import numpy as np
35
+ def load_random_audio():
36
+ index = np.random.randint(0, len(dataset))
37
+ sig = dataset[index]["signal"]
38
+ sig = interface.preprocess(sig)
39
+
40
+ return sig
41
+
42
+
43
+ sig = load_random_audio()
44
+ z = interface.encode(sig)
45
+
46
+ sig.write('input.wav')
47
+
48
+ from vampnet import mask as pmask
49
+
50
+ # build the mask
51
+ mask = pmask.linear_random(z, 1.0)
52
+
53
+ print("coarse")
54
+ zv, mask_z = interface.coarse_vamp(
55
+ z,
56
+ mask=mask,
57
+ sampling_steps=36,
58
+ temperature=8.0,
59
+ return_mask=True,
60
+ typical_filtering=False,
61
+ # typical_mass=data[typical_mass],
62
+ # typical_min_tokens=data[typical_min_tokens],
63
+ gen_fn=interface.coarse.generate,
64
+ )
65
+
66
+ print("coarse2fine")
67
+ zv = interface.coarse_to_fine(zv, temperature=0.8)
68
+
69
+ sig = interface.to_signal(zv).cpu()
70
+ sig.write('output-t=8.wav')
scripts/{utils/vamp_folder.py → exp/experiment.py} RENAMED
@@ -119,13 +119,15 @@ def beat_mask(ctx_time):
119
  def wrapper(sig, interface):
120
  beat_mask = interface.make_beat_mask(
121
  sig,
122
- before_beat_s=0.0,
123
- after_beat_s=ctx_time,
124
  invert=True
125
  )
 
126
  z = interface.encode(sig)
 
127
  zv = interface.coarse_vamp(
128
- z, beat_mask,
129
  )
130
 
131
  zv = interface.coarse_to_fine(zv)
@@ -185,9 +187,6 @@ EXP_REGISTRY["sampling-steps"] = {
185
 
186
 
187
  EXP_REGISTRY["musical-sampling"] = {
188
- "baseline": baseline,
189
- "codec": reconstructed,
190
- **{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
191
  **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
192
  **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
193
  }
@@ -195,7 +194,7 @@ EXP_REGISTRY["musical-sampling"] = {
195
  @argbind.bind(without_prefix=True)
196
  def main(
197
  sources=[
198
- "/media/CHONK/hugo/spotdl/audio-test",
199
  ],
200
  output_dir: str = "./samples",
201
  max_excerpts: int = 2000,
 
119
  def wrapper(sig, interface):
120
  beat_mask = interface.make_beat_mask(
121
  sig,
122
+ before_beat_s=ctx_time/2,
123
+ after_beat_s=ctx_time/2,
124
  invert=True
125
  )
126
+
127
  z = interface.encode(sig)
128
+
129
  zv = interface.coarse_vamp(
130
+ z, beat_mask
131
  )
132
 
133
  zv = interface.coarse_to_fine(zv)
 
187
 
188
 
189
  EXP_REGISTRY["musical-sampling"] = {
 
 
 
190
  **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
191
  **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
192
  }
 
194
  @argbind.bind(without_prefix=True)
195
  def main(
196
  sources=[
197
+ "/media/CHONK/hugo/spotdl/val",
198
  ],
199
  output_dir: str = "./samples",
200
  max_excerpts: int = 2000,
scripts/utils/parallel-gpu.sh DELETED
@@ -1,23 +0,0 @@
1
- #!/bin/bash
2
-
3
- # Get the command to execute from the user
4
- command_to_execute="$1"
5
-
6
- # Get the maximum number of GPUs to use from the user
7
- max_gpus="$2"
8
-
9
- # Get the number of instances to start per GPU from the user
10
- instances_per_gpu="$3"
11
-
12
- # Set the CUDA_VISIBLE_DEVICES flag for each GPU
13
- for gpu_id in $(seq 0 $(($max_gpus - 1))); do
14
- export CUDA_VISIBLE_DEVICES="$gpu_id"
15
- # Start the specified number of instances for this GPU
16
- for i in $(seq 1 "$instances_per_gpu"); do
17
- # Run the command in the background
18
- $command_to_execute &
19
- done
20
- done
21
-
22
- # Wait for all instances to finish
23
- wait
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vampnet/modules/transformer.py CHANGED
@@ -581,7 +581,7 @@ class VampNet(at.ml.BaseModel):
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  mask: Optional[torch.Tensor] = None,
584
- temperature: Union[float, Tuple[float, float]] = 8.0,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
@@ -592,15 +592,7 @@ class VampNet(at.ml.BaseModel):
592
  #####################
593
  # resolve temperature #
594
  #####################
595
- if isinstance(temperature, float):
596
- temperature = torch.tensor(temperature).repeat(sampling_steps)
597
- elif isinstance(temperature, tuple):
598
- assert len(temperature) == 2
599
- l, h = temperature
600
- temperature = torch.linspace(l, h, sampling_steps)
601
- else:
602
- raise TypeError(f"invalid type for temperature")
603
-
604
  logging.debug(f"temperature: {temperature}")
605
 
606
 
@@ -642,10 +634,6 @@ class VampNet(at.ml.BaseModel):
642
  num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
643
  logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
644
 
645
- # our r steps
646
- r_steps = torch.linspace(1e-10, 1, sampling_steps+1)[1:].to(self.device)
647
- logging.debug(f"r steps: {r_steps}")
648
-
649
  # how many codebooks are we inferring vs conditioning on?
650
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
651
  logging.debug(f"n infer codebooks: {n_infer_codebooks}")
@@ -658,11 +646,13 @@ class VampNet(at.ml.BaseModel):
658
  logging.debug(f"step {i} of {sampling_steps}")
659
 
660
  # our current temperature
661
- tmpt = temperature[i]
662
- logging.debug(f"temperature: {tmpt}")
663
 
664
  # our current schedule step
665
- r = r_steps[i : i + 1]
 
 
 
666
  logging.debug(f"r: {r}")
667
 
668
  # get latents
@@ -699,11 +689,18 @@ class VampNet(at.ml.BaseModel):
699
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
700
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
701
 
 
 
 
 
 
 
 
702
 
703
  # flatten z_masked and mask, so we can deal with the sampling logic
704
  # we'll unflatten them at the end of the loop for the next forward pass
705
  # remove conditioning codebooks, we'll add them back at the end
706
- z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
707
 
708
  mask = (z_masked == self.mask_token).int()
709
 
@@ -715,15 +712,6 @@ class VampNet(at.ml.BaseModel):
715
  )
716
  logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
717
 
718
-
719
- # get the confidences: which tokens did we sample?
720
- selected_probs = (
721
- torch.take_along_dim(
722
- probs, sampled_z.long().unsqueeze(-1),
723
- dim=-1
724
- ).squeeze(-1)
725
- )
726
-
727
  # ignore any tokens that weren't masked
728
  selected_probs = torch.where(
729
  mask.bool(), selected_probs, torch.inf
@@ -733,18 +721,19 @@ class VampNet(at.ml.BaseModel):
733
  num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
734
  logging.debug(f"num to mask: {num_to_mask}")
735
 
736
- num_to_mask = torch.maximum(
737
- torch.tensor(1),
738
- torch.minimum(
739
- mask.sum(dim=-1, keepdim=True) - 1,
740
- num_to_mask
 
 
741
  )
742
- )
743
 
744
 
745
  # get our new mask
746
  mask = mask_by_random_topk(
747
- num_to_mask, selected_probs, tmpt * (1-r)
748
  )
749
 
750
  # update the mask
 
581
  sampling_steps: int = 24,
582
  start_tokens: Optional[torch.Tensor] = None,
583
  mask: Optional[torch.Tensor] = None,
584
+ temperature: Union[float, Tuple[float, float]] = 2.5,
585
  typical_filtering=False,
586
  typical_mass=0.2,
587
  typical_min_tokens=1,
 
592
  #####################
593
  # resolve temperature #
594
  #####################
595
+ assert isinstance(temperature, float)
 
 
 
 
 
 
 
 
596
  logging.debug(f"temperature: {temperature}")
597
 
598
 
 
634
  num_mask_tokens_at_start = (z_masked == self.mask_token).sum()
635
  logging.debug(f"num mask tokens at start: {num_mask_tokens_at_start}")
636
 
 
 
 
 
637
  # how many codebooks are we inferring vs conditioning on?
638
  n_infer_codebooks = self.n_codebooks - self.n_conditioning_codebooks
639
  logging.debug(f"n infer codebooks: {n_infer_codebooks}")
 
646
  logging.debug(f"step {i} of {sampling_steps}")
647
 
648
  # our current temperature
649
+ logging.debug(f"temperature: {temperature}")
 
650
 
651
  # our current schedule step
652
+ r = scalar_to_batch_tensor(
653
+ (i + 1) / sampling_steps,
654
+ z.shape[0]
655
+ ).to(z.device)
656
  logging.debug(f"r: {r}")
657
 
658
  # get latents
 
689
  probs = rearrange(probs, "(b seq) prob -> b seq prob", b=b)
690
  logging.debug(f"sampled z with shape: {sampled_z.shape}")
691
 
692
+ # get the confidences: which tokens did we sample?
693
+ selected_probs = (
694
+ torch.take_along_dim(
695
+ probs, sampled_z.long().unsqueeze(-1),
696
+ dim=-1
697
+ ).squeeze(-1)
698
+ )
699
 
700
  # flatten z_masked and mask, so we can deal with the sampling logic
701
  # we'll unflatten them at the end of the loop for the next forward pass
702
  # remove conditioning codebooks, we'll add them back at the end
703
+ z_masked = codebook_flatten(z_masked[:, self.n_conditioning_codebooks:, :])
704
 
705
  mask = (z_masked == self.mask_token).int()
706
 
 
712
  )
713
  logging.debug(f"added z back into sampled z with shape: {sampled_z.shape}")
714
 
 
 
 
 
 
 
 
 
 
715
  # ignore any tokens that weren't masked
716
  selected_probs = torch.where(
717
  mask.bool(), selected_probs, torch.inf
 
721
  num_to_mask = torch.floor(_gamma(r) * num_mask_tokens_at_start).unsqueeze(1).long()
722
  logging.debug(f"num to mask: {num_to_mask}")
723
 
724
+ if i != (sampling_steps - 1):
725
+ num_to_mask = torch.maximum(
726
+ torch.tensor(1),
727
+ torch.minimum(
728
+ mask.sum(dim=-1, keepdim=True) - 1,
729
+ num_to_mask
730
+ )
731
  )
 
732
 
733
 
734
  # get our new mask
735
  mask = mask_by_random_topk(
736
+ num_to_mask, selected_probs, temperature * (1-r)
737
  )
738
 
739
  # update the mask