Hugo Flores Garcia commited on
Commit
93b48cb
·
1 Parent(s): 128981d

more tweaks

Browse files
demo.py CHANGED
@@ -210,25 +210,30 @@ with gr.Blocks() as demo:
210
 
211
  """)
212
  gr.Markdown("## Input Audio")
213
- with gr.Column():
214
- gr.Markdown("""
215
- ## Mask Hints
216
- - most of the original audio will be masked and replaced with audio generated by vampnet
217
- - mask hints are used to guide vampnet to generate audio that sounds like the original
218
- - the more hints you give, the more the generated audio will sound like the original
219
 
220
- """)
221
  with gr.Column():
222
  gr.Markdown("""
223
  ### Tips
224
  - use the beat hint button so the output audio has the same beat structure as the input audio
225
- - if you want the generated audio to sound like the original, but with a different beat structure:
226
- - uncheck the beat hint button
227
- - decrease the periodic unmasking to anywhere from 2 to 8
228
  - if you want a more "random" generation:
229
- - uncheck the beat hint button (or reduce the beat unmask duration)
230
- - increase the periodic unmasking to 16 or more
231
  - increase the temperatures!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  """)
234
 
@@ -243,7 +248,8 @@ with gr.Blocks() as demo:
243
  num_vamps = gr.Number(
244
  label="number of vamps. more vamps = longer generated audio",
245
  value=1,
246
- precision=0
 
247
  )
248
 
249
  manual_audio_upload = gr.File(
@@ -286,7 +292,7 @@ with gr.Blocks() as demo:
286
  minimum=0,
287
  maximum=64,
288
  step=1,
289
- value=19,
290
  )
291
 
292
 
@@ -326,8 +332,8 @@ with gr.Blocks() as demo:
326
  )
327
 
328
  use_beats = gr.Checkbox(
329
- label="use beat hints",
330
- value=True
331
  )
332
 
333
  snap_to_beats = gr.Checkbox(
 
210
 
211
  """)
212
  gr.Markdown("## Input Audio")
 
 
 
 
 
 
213
 
 
214
  with gr.Column():
215
  gr.Markdown("""
216
  ### Tips
217
  - use the beat hint button so the output audio has the same beat structure as the input audio
218
+ - if you want more beat structure:
219
+ - enable beat hints
 
220
  - if you want a more "random" generation:
221
+ - increase the periodic unmasking to 12 or more
 
222
  - increase the temperatures!
223
+ - uncheck the beat hint button (or reduce the beat unmask duration)
224
+ - if you want the generated audio to sound like the original, but with a different beat structure:
225
+ - uncheck the beat hint button
226
+ - decrease the periodic unmasking to anywhere from 2 to 20
227
+ - slightly decrease the random intensity, to like .95
228
+
229
+
230
+ """)
231
+ with gr.Column():
232
+ gr.Markdown("""
233
+ ## Mask Hints
234
+ - most of the original audio will be masked and replaced with audio generated by vampnet
235
+ - mask hints are used to guide vampnet to generate audio that sounds like the original
236
+ - the more hints you give, the more the generated audio will sound like the original
237
 
238
  """)
239
 
 
248
  num_vamps = gr.Number(
249
  label="number of vamps. more vamps = longer generated audio",
250
  value=1,
251
+ precision=0,
252
+ visible=False
253
  )
254
 
255
  manual_audio_upload = gr.File(
 
292
  minimum=0,
293
  maximum=64,
294
  step=1,
295
+ value=9,
296
  )
297
 
298
 
 
332
  )
333
 
334
  use_beats = gr.Checkbox(
335
+ label="use beat hints (helps the output stick to the beat structure of the input)",
336
+ value=False
337
  )
338
 
339
  snap_to_beats = gr.Checkbox(
scripts/exp/eval.py CHANGED
@@ -5,6 +5,7 @@ from functools import partial
5
  from frechet_audio_distance import FrechetAudioDistance
6
  import pandas
7
  import argbind
 
8
  from tqdm import tqdm
9
 
10
  import audiotools
@@ -21,15 +22,16 @@ def eval(
21
  assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
22
 
23
  # set up our metrics
24
- sisdr_loss = audiotools.metrics.distance.SISDRLoss()
25
- stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
26
  mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
27
  frechet = FrechetAudioDistance(
28
  use_pca=False,
29
  use_activation=False,
30
- verbose=True
 
31
  )
32
- visqol = partial(audiotools.metrics.quality.visqol, mode="audio")
33
 
34
  # figure out what conditions we have
35
  conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
@@ -44,7 +46,7 @@ def eval(
44
  baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
45
 
46
  metrics = []
47
- for condition in conditions:
48
  cond_dir = exp_dir / condition
49
  cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
50
 
@@ -68,14 +70,17 @@ def eval(
68
  cond_sig.resample(baseline_sig.sample_rate)
69
  cond_sig.truncate_samples(baseline_sig.length)
70
 
71
- # compute the metrics
72
- # try:
73
- # vsq = visqol(baseline_sig, cond_sig)
74
- # except:
75
- # vsq = 0.0
 
 
 
76
  return {
77
- "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
78
- "stft": stft_loss(baseline_sig, cond_sig).item(),
79
  "mel": mel_loss(baseline_sig, cond_sig).item(),
80
  "frechet": frechet_score,
81
  # "visqol": vsq,
 
5
  from frechet_audio_distance import FrechetAudioDistance
6
  import pandas
7
  import argbind
8
+ import torch
9
  from tqdm import tqdm
10
 
11
  import audiotools
 
22
  assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist"
23
 
24
  # set up our metrics
25
+ # sisdr_loss = audiotools.metrics.distance.SISDRLoss()
26
+ # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss()
27
  mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss()
28
  frechet = FrechetAudioDistance(
29
  use_pca=False,
30
  use_activation=False,
31
+ verbose=True,
32
+ audio_load_worker=4,
33
  )
34
+ frechet.model.to("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
  # figure out what conditions we have
37
  conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()]
 
46
  baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
47
 
48
  metrics = []
49
+ for condition in tqdm(conditions):
50
  cond_dir = exp_dir / condition
51
  cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem))
52
 
 
70
  cond_sig.resample(baseline_sig.sample_rate)
71
  cond_sig.truncate_samples(baseline_sig.length)
72
 
73
+ # if our condition is inpainting, we need to trim the conditioning off
74
+ if "inpaint" in condition:
75
+ ctx_amt = float(condition.split("_")[-1])
76
+ ctx_samples = int(ctx_amt * baseline_sig.sample_rate)
77
+ print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}")
78
+ cond_sig.trim(ctx_samples, ctx_samples)
79
+ baseline_sig.trim(ctx_samples, ctx_samples)
80
+
81
  return {
82
+ # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(),
83
+ # "stft": stft_loss(baseline_sig, cond_sig).item(),
84
  "mel": mel_loss(baseline_sig, cond_sig).item(),
85
  "frechet": frechet_score,
86
  # "visqol": vsq,
scripts/utils/vamp_folder.py CHANGED
@@ -6,7 +6,7 @@ import subprocess
6
 
7
  import argbind
8
  from tqdm import tqdm
9
- import argbind
10
 
11
  from vampnet.interface import Interface
12
  import audiotools as at
@@ -48,7 +48,6 @@ def coarse2fine_argmax(sig, interface):
48
  )
49
  return interface.to_signal(z)
50
 
51
-
52
  class CoarseCond:
53
 
54
  def __init__(self, num_codebooks, downsample_factor):
@@ -59,13 +58,12 @@ class CoarseCond:
59
  n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
60
  zv = interface.coarse_vamp_v2(sig,
61
  n_conditioning_codebooks=n_conditioning_codebooks,
62
- downsample_factor=self.downsample_factor
63
  )
64
 
65
  zv = interface.coarse_to_fine(zv)
66
  return interface.to_signal(zv)
67
 
68
-
69
  def opus(sig, interface, bitrate=128):
70
  sig = interface.preprocess(sig)
71
 
@@ -97,8 +95,78 @@ def opus(sig, interface, bitrate=128):
97
  )
98
  return sig
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- COARSE_SAMPLE_CONDS ={
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  "baseline": baseline,
103
  "reconstructed": reconstructed,
104
  "coarse2fine": coarse2fine,
@@ -119,23 +187,55 @@ COARSE_SAMPLE_CONDS ={
119
 
120
  }
121
 
122
- OPUS_JAZZPOP_SAMPLE_CONDS = {
123
  f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
124
  for bitrate in [5620, 1875, 1250, 625]
125
  }
126
 
127
- OPUS_SPOTDL_SAMPLE_CONDS = {
128
  f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
129
  for bitrate in [8036, 2296, 1148, 574]
130
  }
131
 
132
- C2F_SAMPLE_CONDS = {
 
 
 
 
 
133
  "baseline": baseline,
134
  "reconstructed": reconstructed,
135
  "coarse2fine": coarse2fine,
136
  "coarse2fine_argmax": coarse2fine_argmax,
137
  }
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @argbind.bind(without_prefix=True)
140
  def main(
141
  sources=[
@@ -162,14 +262,8 @@ def main(
162
  without_replacement=True,
163
  )
164
 
165
- if exp_type == "opus-jazzpop":
166
- SAMPLE_CONDS = OPUS_JAZZPOP_SAMPLE_CONDS
167
- elif exp_type == "opus-spotdl":
168
- SAMPLE_CONDS = OPUS_SPOTDL_SAMPLE_CONDS
169
- elif exp_type == "coarse":
170
- SAMPLE_CONDS = COARSE_SAMPLE_CONDS
171
- elif exp_type == "c2f":
172
- SAMPLE_CONDS = C2F_SAMPLE_CONDS
173
  else:
174
  raise ValueError(f"Unknown exp_type {exp_type}")
175
 
@@ -178,12 +272,12 @@ def main(
178
  random.shuffle(indices)
179
  for i in tqdm(indices):
180
  # if all our files are already there, skip
181
- # done = []
182
- # for name in SAMPLE_CONDS:
183
- # o_dir = Path(output_dir) / name
184
- # done.append((o_dir / f"{i}.wav").exists())
185
- # if all(done):
186
- # continue
187
 
188
  sig = dataset[i]["signal"]
189
  results = {
 
6
 
7
  import argbind
8
  from tqdm import tqdm
9
+ import torch
10
 
11
  from vampnet.interface import Interface
12
  import audiotools as at
 
48
  )
49
  return interface.to_signal(z)
50
 
 
51
  class CoarseCond:
52
 
53
  def __init__(self, num_codebooks, downsample_factor):
 
58
  n_conditioning_codebooks = interface.coarse.n_codebooks - self.num_codebooks
59
  zv = interface.coarse_vamp_v2(sig,
60
  n_conditioning_codebooks=n_conditioning_codebooks,
61
+ downsample_factor=self.downsample_factor,
62
  )
63
 
64
  zv = interface.coarse_to_fine(zv)
65
  return interface.to_signal(zv)
66
 
 
67
  def opus(sig, interface, bitrate=128):
68
  sig = interface.preprocess(sig)
69
 
 
95
  )
96
  return sig
97
 
98
+ def token_noise(ratio=1.0):
99
+ def wrapper(sig, interface):
100
+ z = interface.encode(sig)
101
+ r = interface.coarse.invgamma(ratio).to(interface.device)
102
+ print(f'adding noise with ratio {ratio}')
103
+ z, mask = interface.coarse.add_noise(
104
+ z,
105
+ r,
106
+ noise_mode="random"
107
+ )
108
+ return interface.to_signal(z)
109
+ return wrapper
110
+
111
+ def mask_ratio_1_step(ratio=1.0):
112
+ def wrapper(sig, interface):
113
+ r = interface.coarse.invgamma(ratio).to(interface.device)
114
+ intensity = 1-r
115
+
116
+ zv = interface.coarse_vamp_v2(
117
+ sig,
118
+ sample='argmax',
119
+ sampling_steps=1,
120
+ intensity=intensity
121
+ )
122
+
123
+ return interface.to_signal(zv)
124
+ return wrapper
125
+
126
+ def num_sampling_steps(num_steps=1):
127
+ def wrapper(sig, interface):
128
+ zv = interface.coarse_vamp_v2(
129
+ sig,
130
+ downsample_factor=16,
131
+ sampling_steps=num_steps,
132
+ )
133
 
134
+ zv = interface.coarse_to_fine(zv)
135
+ return interface.to_signal(zv)
136
+ return wrapper
137
+
138
+ def beat_mask(ctx_time):
139
+ def wrapper(sig, interface):
140
+ beat_mask = interface.make_beat_mask(
141
+ sig,
142
+ before_beat_s=0.0,
143
+ after_beat_s=ctx_time,
144
+ invert=True
145
+ )
146
+ zv = interface.coarse_vamp_v2(
147
+ sig,
148
+ ext_mask=beat_mask,
149
+ )
150
+
151
+ zv = interface.coarse_to_fine(zv)
152
+ return interface.to_signal(zv)
153
+ return wrapper
154
+
155
+ def inpaint(ctx_time):
156
+ def wrapper(sig, interface):
157
+ zv = interface.coarse_vamp_v2(
158
+ sig,
159
+ prefix_dur_s=ctx_time,
160
+ suffix_dur_s=ctx_time,
161
+ )
162
+
163
+ zv = interface.coarse_to_fine(zv)
164
+ return interface.to_signal(zv)
165
+ return wrapper
166
+
167
+ EXP_REGISTRY = {}
168
+
169
+ EXP_REGISTRY["gen-compression"] = {
170
  "baseline": baseline,
171
  "reconstructed": reconstructed,
172
  "coarse2fine": coarse2fine,
 
187
 
188
  }
189
 
190
+ EXP_REGISTRY["opus-jazzpop"] = {
191
  f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
192
  for bitrate in [5620, 1875, 1250, 625]
193
  }
194
 
195
+ EXP_REGISTRY["opus-spotdl"] = {
196
  f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
197
  for bitrate in [8036, 2296, 1148, 574]
198
  }
199
 
200
+ EXP_REGISTRY["opus-baseline"] = {
201
+ f"opus_{bitrate}": lambda sig, interface: opus(sig, interface, bitrate=bitrate)
202
+ for bitrate in [8000, 12000, 16000]
203
+ }
204
+
205
+ EXP_REGISTRY["c2f"] = {
206
  "baseline": baseline,
207
  "reconstructed": reconstructed,
208
  "coarse2fine": coarse2fine,
209
  "coarse2fine_argmax": coarse2fine_argmax,
210
  }
211
 
212
+ EXP_REGISTRY["token-noise"] = {
213
+ f"token_noise_{r}": token_noise(r) for r in [0.25, 0.5, 0.75, 1.0]
214
+ }
215
+
216
+ EXP_REGISTRY["mask-ratio"] = {
217
+ "codec": reconstructed,
218
+ **{f"mask_ratio_{r}": mask_ratio_1_step(r) for r in [0.25, 0.5, 0.75, 0.9]}
219
+ }
220
+
221
+ EXP_REGISTRY["sampling-steps"] = {
222
+ "codec": reconstructed,
223
+ **{f"steps_{n}": num_sampling_steps(n) for n in [1, 4, 12, 24, 36, 64, 72, 128]},
224
+ }
225
+
226
+ EXP_REGISTRY["baseline"] = {
227
+ "baseline": baseline,
228
+ "codec": reconstructed,
229
+ }
230
+
231
+ EXP_REGISTRY["musical-sampling"] = {
232
+ "baseline": baseline,
233
+ "codec": reconstructed,
234
+ **{f"downsample_{x}x": CoarseCond(4, downsample_factor=x) for x in [16, 32]},
235
+ **{f"beat_mask_{t}": beat_mask(t) for t in [0.075]},
236
+ **{f"inpaint_{t}": inpaint(t) for t in [0.5, 1.0,]}, # multiply these by 2 (they go left and right)
237
+ }
238
+
239
  @argbind.bind(without_prefix=True)
240
  def main(
241
  sources=[
 
262
  without_replacement=True,
263
  )
264
 
265
+ if exp_type in EXP_REGISTRY:
266
+ SAMPLE_CONDS = EXP_REGISTRY[exp_type]
 
 
 
 
 
 
267
  else:
268
  raise ValueError(f"Unknown exp_type {exp_type}")
269
 
 
272
  random.shuffle(indices)
273
  for i in tqdm(indices):
274
  # if all our files are already there, skip
275
+ done = []
276
+ for name in SAMPLE_CONDS:
277
+ o_dir = Path(output_dir) / name
278
+ done.append((o_dir / f"{i}.wav").exists())
279
+ if all(done):
280
+ continue
281
 
282
  sig = dataset[i]["signal"]
283
  results = {
vampnet/interface.py CHANGED
@@ -183,10 +183,8 @@ class Interface:
183
  num_steps = mask[_slice[0]:_slice[1]].shape[0]
184
  _m = torch.ones(num_steps, device=self.device)
185
  _m = torch.nn.functional.dropout(_m, p=dropout)
186
- print(_m)
187
 
188
  mask[_slice[0]:_slice[1]] = _m
189
- print(mask)
190
 
191
  if mask_downbeats:
192
  for downbeat_idx in downbeats_z:
 
183
  num_steps = mask[_slice[0]:_slice[1]].shape[0]
184
  _m = torch.ones(num_steps, device=self.device)
185
  _m = torch.nn.functional.dropout(_m, p=dropout)
 
186
 
187
  mask[_slice[0]:_slice[1]] = _m
 
188
 
189
  if mask_downbeats:
190
  for downbeat_idx in downbeats_z:
vampnet/modules/base.py CHANGED
@@ -42,6 +42,7 @@ class VampBase(at.ml.BaseModel):
42
  n_suffix: Optional[torch.Tensor] = None,
43
  downsample_factor: Optional[int] = None,
44
  n_conditioning_codebooks: Optional[int] = None,
 
45
  ) -> Tuple[torch.Tensor, torch.Tensor]:
46
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
47
 
@@ -89,13 +90,14 @@ class VampBase(at.ml.BaseModel):
89
  if random_x is None:
90
  random_x = torch.randint_like(x, 0, self.vocab_size)
91
 
92
- if self.noise_mode == "mask":
 
93
  random_x = torch.full_like(x, self.mask_token)
94
- elif self.noise_mode == "random":
95
  if random_x is None:
96
  random_x = torch.randint_like(x, 0, self.vocab_size)
97
  else:
98
- raise ValueError(f"invalid noise mode {self.noise_mode}")
99
 
100
  # add the external mask if we were given one
101
  if ext_mask is not None:
@@ -132,6 +134,11 @@ class VampBase(at.ml.BaseModel):
132
  def gamma(self, r):
133
  return (r * torch.pi / 2).cos()
134
 
 
 
 
 
 
135
  def r_embed(self, r, max_positions=10000):
136
  """ """
137
  assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"
 
42
  n_suffix: Optional[torch.Tensor] = None,
43
  downsample_factor: Optional[int] = None,
44
  n_conditioning_codebooks: Optional[int] = None,
45
+ noise_mode: str = None,
46
  ) -> Tuple[torch.Tensor, torch.Tensor]:
47
  assert x.ndim == 3, "x must be (batch, n_codebooks, seq)"
48
 
 
90
  if random_x is None:
91
  random_x = torch.randint_like(x, 0, self.vocab_size)
92
 
93
+ noise_mode = noise_mode if noise_mode is not None else self.noise_mode
94
+ if noise_mode == "mask":
95
  random_x = torch.full_like(x, self.mask_token)
96
+ elif noise_mode == "random":
97
  if random_x is None:
98
  random_x = torch.randint_like(x, 0, self.vocab_size)
99
  else:
100
+ raise ValueError(f"invalid noise mode {noise_mode}")
101
 
102
  # add the external mask if we were given one
103
  if ext_mask is not None:
 
134
  def gamma(self, r):
135
  return (r * torch.pi / 2).cos()
136
 
137
+ def invgamma(self, y):
138
+ if not torch.is_tensor(y):
139
+ y = torch.tensor(y)[None]
140
+ return 2 * y.acos() / torch.pi
141
+
142
  def r_embed(self, r, max_positions=10000):
143
  """ """
144
  assert hasattr(self, "r_cond_dim"), "must set r_cond_dim before calling r_embed"