Hugo Flores Garcia commited on
Commit
6f6fd13
·
1 Parent(s): 4687dd9
scripts/utils/process_folder-c2f.py CHANGED
@@ -15,57 +15,47 @@ def coarse2fine_infer(
15
  model,
16
  vqvae,
17
  device,
18
- signal_window=3,
19
- signal_hop=1.5,
20
- max_excerpts=20,
21
  ):
22
- output = defaultdict(list)
23
-
24
- # split into 3 seconds
25
- windows = [s for s in signal.clone().windows(signal_window, signal_hop)]
26
- windows = windows[1:] # skip first window since it's half zero padded
27
- random.shuffle(windows)
28
- for w in windows[:max_excerpts]:
29
- # batch the signal into chunks of 3
30
- with torch.no_grad():
31
- # get codes
32
- w = w.to(device)
33
- z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
34
-
35
- model.to(device)
36
- output["reconstructed"] = model.to_signal(z, vqvae).cpu()
37
-
38
- # make a full mask
39
- mask = torch.ones_like(z)
40
- mask[:, :model.n_conditioning_codebooks, :] = 0
41
-
42
- output["sampled"].append(model.sample(
43
- codec=vqvae,
44
- time_steps=z.shape[-1],
45
- sampling_steps=12,
46
- start_tokens=z,
47
- mask=mask,
48
- temperature=0.85,
49
- top_k=None,
50
- sample="gumbel",
51
- typical_filtering=True,
52
- return_signal=True
53
- ).cpu())
54
-
55
- output["argmax"].append(model.sample(
56
- codec=vqvae,
57
- time_steps=z.shape[-1],
58
- sampling_steps=1,
59
- start_tokens=z,
60
- mask=mask,
61
- temperature=1.0,
62
- top_k=None,
63
- sample="argmax",
64
- typical_filtering=True,
65
- return_signal=True
66
- ).cpu())
67
-
68
- return output
69
 
70
 
71
  @argbind.bind(without_prefix=True)
@@ -73,11 +63,10 @@ def main(
73
  sources=[
74
  "/data/spotdl/audio/val", "/data/spotdl/audio/test"
75
  ],
76
- audio_ext="mp3",
77
  exp_name="noise_mode",
78
  model_paths=[
79
- "runs/c2f-exp-03.22.23/ckpt/mask/best/vampnet/weights.pth",
80
- "runs/c2f-exp-03.22.23/ckpt/random/best/vampnet/weights.pth",
81
  ],
82
  model_keys=[
83
  "mask",
@@ -86,10 +75,11 @@ def main(
86
  vqvae_path: str = "runs/codec-ckpt/codec.pth",
87
  device: str = "cuda",
88
  output_dir: str = ".",
 
 
89
  ):
90
  from vampnet.modules.transformer import VampNet
91
  from lac.model.lac import LAC
92
- from audiotools.post import audio_zip
93
 
94
  models = {
95
  k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
@@ -105,26 +95,26 @@ def main(
105
 
106
  output_dir = Path(output_dir) / f"{exp_name}-samples"
107
 
108
- for source in sources:
109
- print(f"Processing {source}...")
110
- source_files = list(Path(source).glob(f"**/*.{audio_ext}"))
111
- random.shuffle(source_files)
112
- for path in tqdm(source_files):
113
- sig = AudioSignal(path)
114
- sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
115
-
116
- out_dir = output_dir / path.stem
 
 
 
 
 
 
 
117
  out_dir.mkdir(parents=True, exist_ok=True)
118
- if out_dir.exists():
119
- print(f"Skipping {path.stem} since {out_dir} already exists.")
120
- continue
121
-
122
- for model_key, model in models.items():
123
- out = coarse2fine_infer(sig, model, vqvae, device)
124
- for k, sig_list in out.items():
125
- for i, s in enumerate(sig_list):
126
- s.write(out_dir / f"{model_key}-{k}-{i}.wav")
127
-
128
 
129
  if __name__ == "__main__":
130
  args = argbind.parse_args()
 
15
  model,
16
  vqvae,
17
  device,
 
 
 
18
  ):
19
+ output = {}
20
+ w = signal
21
+ w = w.to(device)
22
+ z = vqvae.encode(w.audio_data, w.sample_rate)["codes"]
23
+
24
+ model.to(device)
25
+ output["reconstructed"] = model.to_signal(z, vqvae).cpu()
26
+
27
+ # make a full mask
28
+ mask = torch.ones_like(z)
29
+ mask[:, :model.n_conditioning_codebooks, :] = 0
30
+
31
+ output["sampled"] = model.sample(
32
+ codec=vqvae,
33
+ time_steps=z.shape[-1],
34
+ sampling_steps=12,
35
+ start_tokens=z,
36
+ mask=mask,
37
+ temperature=0.85,
38
+ top_k=None,
39
+ sample="gumbel",
40
+ typical_filtering=True,
41
+ return_signal=True
42
+ ).cpu()
43
+
44
+ output["argmax"] = model.sample(
45
+ codec=vqvae,
46
+ time_steps=z.shape[-1],
47
+ sampling_steps=1,
48
+ start_tokens=z,
49
+ mask=mask,
50
+ temperature=1.0,
51
+ top_k=None,
52
+ sample="argmax",
53
+ typical_filtering=True,
54
+ return_signal=True
55
+ ).cpu()
56
+
57
+ return output
58
+
 
 
 
 
 
 
 
59
 
60
 
61
  @argbind.bind(without_prefix=True)
 
63
  sources=[
64
  "/data/spotdl/audio/val", "/data/spotdl/audio/test"
65
  ],
 
66
  exp_name="noise_mode",
67
  model_paths=[
68
+ "runs/c2f-exp-03.22.23/ckpt/mask/epoch=400/vampnet/weights.pth",
69
+ "runs/c2f-exp-03.22.23/ckpt/random/epoch=400/vampnet/weights.pth",
70
  ],
71
  model_keys=[
72
  "mask",
 
75
  vqvae_path: str = "runs/codec-ckpt/codec.pth",
76
  device: str = "cuda",
77
  output_dir: str = ".",
78
+ max_excerpts: int = 5000,
79
+ duration: float = 3.0,
80
  ):
81
  from vampnet.modules.transformer import VampNet
82
  from lac.model.lac import LAC
 
83
 
84
  models = {
85
  k: VampNet.load(p) for k, p in zip(model_keys, model_paths)
 
95
 
96
  output_dir = Path(output_dir) / f"{exp_name}-samples"
97
 
98
+ from audiotools.data.datasets import AudioLoader, AudioDataset
99
+
100
+ loader = AudioLoader(sources=sources)
101
+ dataset = AudioDataset(loader,
102
+ sample_rate=vqvae.sample_rate,
103
+ duration=duration,
104
+ n_examples=max_excerpts,
105
+ without_replacement=True,
106
+ )
107
+ for i in tqdm(range(max_excerpts)):
108
+ sig = dataset[i]["signal"]
109
+ sig.resample(vqvae.sample_rate).normalize(-24).ensure_max_of_audio(1.0)
110
+
111
+ for model_key, model in models.items():
112
+ out = coarse2fine_infer(sig, model, vqvae, device)
113
+ out_dir = output_dir / model_key / Path(sig.path_to_file).stem
114
  out_dir.mkdir(parents=True, exist_ok=True)
115
+ for k, s in out.items():
116
+ s.write(out_dir / f"{k}.wav")
117
+
 
 
 
 
 
 
 
118
 
119
  if __name__ == "__main__":
120
  args = argbind.parse_args()
scripts/utils/vamp_folder.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import argbind
4
+ from tqdm import tqdm
5
+ import torch
6
+
7
+ from vampnet.interface import Interface
8
+
9
+ Interface = argbind.bind(Interface, positional=True)
10
+
11
+ def baseline(sig, interface):
12
+ return sig
13
+
14
+ def reconstructed(sig, interface):
15
+ return interface.to_signal(
16
+ interface.encode(sig)
17
+ )
18
+
19
+ def coarse2fine(sig, interface):
20
+ z = interface.encode(sig)
21
+ z = z[:, :interface.c2f.n_conditioning_codebooks, :]
22
+
23
+ z = interface.coarse_to_fine(z)
24
+ return interface.to_signal(z)
25
+
26
+ def one_codebook(sig, interface):
27
+ z = interface.encode(sig)
28
+
29
+ mask = torch.zeros_like(z)
30
+ mask[:, 1:, :] = 1
31
+
32
+ zv = interface.coarse_vamp_v2(
33
+ sig, ext_mask=mask,
34
+ )
35
+ zv = interface.coarse_to_fine(zv)
36
+
37
+ return interface.to_signal(zv)
38
+
39
+ def four_codebooks_downsampled_4x(sig, interface):
40
+ zv = interface.coarse_vamp_v2(
41
+ sig, downsample_factor=4
42
+ )
43
+ zv = interface.coarse_to_fine(zv)
44
+ return interface.to_signal(zv)
45
+
46
+ def two_codebooks_downsampled_4x(sig, interface):
47
+ z = interface.encode(sig)
48
+
49
+ mask = torch.zeros_like(z)
50
+ mask[:, 2:, :] = 1
51
+
52
+ zv = interface.coarse_vamp_v2(
53
+ sig, ext_mask=mask, downsample_factor=4
54
+ )
55
+ zv = interface.coarse_to_fine(zv)
56
+
57
+ return interface.to_signal(zv)
58
+
59
+ def four_codebooks_downsampled_8x(sig, interface):
60
+ zv = interface.coarse_vamp_v2(
61
+ sig, downsample_factor=8
62
+ )
63
+ zv = interface.coarse_to_fine(zv)
64
+ return interface.to_signal(zv)
65
+
66
+
67
+
68
+
69
+
70
+ SAMPLE_CONDS ={
71
+ "baseline": baseline,
72
+ "reconstructed": reconstructed,
73
+ "coarse2fine": coarse2fine,
74
+ "one_codebook": one_codebook,
75
+ "four_codebooks_downsampled_4x": four_codebooks_downsampled_4x,
76
+ "two_codebooks_downsampled_4x": two_codebooks_downsampled_4x,
77
+ "four_codebooks_downsampled_8x": four_codebooks_downsampled_8x,
78
+ }
79
+
80
+
81
+ @argbind.bind(without_prefix=True)
82
+ def main(
83
+ sources=[
84
+ "/data/spotdl/audio/val", "/data/spotdl/audio/test"
85
+ ],
86
+ output_dir: str = "./samples",
87
+ max_excerpts: int = 5000,
88
+ ):
89
+ interface = Interface()
90
+
91
+ output_dir = Path(output_dir)
92
+ output_dir.mkdir(exist_ok=True, parents=True)
93
+
94
+ from audiotools.data.datasets import AudioLoader, AudioDataset
95
+
96
+ loader = AudioLoader(sources=sources)
97
+ dataset = AudioDataset(loader,
98
+ sample_rate=interface.codec.sample_rate,
99
+ duration=interface.coarse.chunk_size_s,
100
+ n_examples=max_excerpts,
101
+ without_replacement=True,
102
+ )
103
+
104
+ for i in tqdm(range(max_excerpts)):
105
+ sig = dataset[i]["signal"]
106
+
107
+ results = {
108
+ name: cond(sig, interface)
109
+ for name, cond in SAMPLE_CONDS.items()
110
+ }
111
+
112
+ for name, sig in results.items():
113
+ output_dir = Path(output_dir) / name
114
+ output_dir.mkdir(exist_ok=True, parents=True)
115
+
116
+ sig.write(output_dir / f"{i}.wav")
117
+
118
+ if __name__ == "__main__":
119
+ args = argbind.parse_args()
120
+
121
+ with argbind.scope(args):
122
+ main()
vampnet/interface.py CHANGED
@@ -196,6 +196,7 @@ class Interface:
196
  time_steps=chunk_len,
197
  start_tokens=chunk,
198
  return_signal=False,
 
199
  )
200
  fine_z.append(chunk)
201
 
 
196
  time_steps=chunk_len,
197
  start_tokens=chunk,
198
  return_signal=False,
199
+ **kwargs
200
  )
201
  fine_z.append(chunk)
202
 
vampnet/modules/base.py CHANGED
@@ -288,12 +288,12 @@ class VampBase(at.ml.BaseModel):
288
  self,
289
  codec,
290
  time_steps: int = 300,
291
- sampling_steps: int = 24,
292
  start_tokens: Optional[torch.Tensor] = None,
293
  mask: Optional[torch.Tensor] = None,
294
  temperature: Union[float, Tuple[float, float]] = 0.8,
295
  top_k: int = None,
296
- sample: str = "multinomial",
297
  typical_filtering=False,
298
  typical_mass=0.2,
299
  typical_min_tokens=1,
 
288
  self,
289
  codec,
290
  time_steps: int = 300,
291
+ sampling_steps: int = 12,
292
  start_tokens: Optional[torch.Tensor] = None,
293
  mask: Optional[torch.Tensor] = None,
294
  temperature: Union[float, Tuple[float, float]] = 0.8,
295
  top_k: int = None,
296
+ sample: str = "gumbel",
297
  typical_filtering=False,
298
  typical_mass=0.2,
299
  typical_min_tokens=1,