Hugo Flores Garcia commited on
Commit
2f3fb32
·
1 Parent(s): 3346920
Files changed (4) hide show
  1. .gitignore +2 -1
  2. demo.py +27 -5
  3. vampnet/interface.py +23 -12
  4. vampnet/mask.py +4 -4
.gitignore CHANGED
@@ -176,4 +176,5 @@ lyrebird-audio-codec
176
  samples-*/**
177
 
178
  gradio-outputs/
179
- models/
 
 
176
  samples-*/**
177
 
178
  gradio-outputs/
179
+ models/
180
+ samples*/
demo.py CHANGED
@@ -130,8 +130,6 @@ def _vamp(data, return_mask=False):
130
  out_dir = OUT_DIR / str(uuid.uuid4())
131
  out_dir.mkdir()
132
  sig = at.AudioSignal(data[input_audio])
133
- #pitch shift input
134
- sig = sig.shift_pitch(data[input_pitch_shift])
135
 
136
  # TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
137
 
@@ -160,10 +158,20 @@ def _vamp(data, return_mask=False):
160
  mask = pmask.mask_or(
161
  mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
162
  )
 
 
 
 
 
 
 
 
 
163
  # these should be the last two mask ops
164
  mask = pmask.dropout(mask, data[dropout])
165
  mask = pmask.codebook_unmask(mask, ncc)
166
 
 
167
  print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
168
  # save the mask as a txt file
169
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
@@ -322,6 +330,18 @@ with gr.Blocks() as demo:
322
  value=5,
323
  )
324
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  with gr.Accordion("extras ", open=False):
326
  n_conditioning_codebooks = gr.Number(
327
  label="number of conditioning codebooks. probably 0",
@@ -355,14 +375,14 @@ with gr.Blocks() as demo:
355
  temp = gr.Slider(
356
  label="temperature",
357
  minimum=0.0,
358
- maximum=1.5,
359
  value=0.8
360
  )
361
 
362
  with gr.Accordion("sampling settings", open=False):
363
  typical_filtering = gr.Checkbox(
364
  label="typical filtering ",
365
- value=True
366
  )
367
  typical_mass = gr.Slider(
368
  label="typical mass (should probably stay between 0.1 and 0.5)",
@@ -440,7 +460,9 @@ with gr.Blocks() as demo:
440
  typical_filtering,
441
  typical_mass,
442
  typical_min_tokens,
443
- checkpoint_key
 
 
444
  }
445
 
446
  # connect widgets
 
130
  out_dir = OUT_DIR / str(uuid.uuid4())
131
  out_dir.mkdir()
132
  sig = at.AudioSignal(data[input_audio])
 
 
133
 
134
  # TODO: random pitch shift of segments in the signal to prompt! window size should be a parameter, pitch shift width should be a parameter
135
 
 
158
  mask = pmask.mask_or(
159
  mask, pmask.onset_mask(sig, z, interface, width=data[onset_mask_width])
160
  )
161
+ if data[beat_mask_width] > 0:
162
+ beat_mask = interface.make_beat_mask(
163
+ sig,
164
+ before_beat_s=(data[beat_mask_width]/1000)/2,
165
+ after_beat_s=(data[beat_mask_width]/1000)/2,
166
+ mask_upbeats=not data[beat_mask_downbeats],
167
+ )
168
+ mask = pmask.mask_and(mask, beat_mask)
169
+
170
  # these should be the last two mask ops
171
  mask = pmask.dropout(mask, data[dropout])
172
  mask = pmask.codebook_unmask(mask, ncc)
173
 
174
+
175
  print(f"created mask with: linear random {data[rand_mask_intensity]}, inpaint {data[prefix_s]}:{data[suffix_s]}, periodic {data[periodic_p]}:{data[periodic_w]}, dropout {data[dropout]}, codebook unmask {ncc}, onset mask {data[onset_mask_width]}, num steps {data[num_steps]}, init temp {data[temp]}, use coarse2fine {data[use_coarse2fine]}")
176
  # save the mask as a txt file
177
  np.savetxt(out_dir / "mask.txt", mask[:,0,:].long().cpu().numpy())
 
330
  value=5,
331
  )
332
 
333
+ beat_mask_width = gr.Slider(
334
+ label="beat mask width (in milliseconds)",
335
+ minimum=0,
336
+ maximum=200,
337
+ value=0,
338
+ )
339
+ beat_mask_downbeats = gr.Checkbox(
340
+ label="beat mask downbeats only?",
341
+ value=False
342
+ )
343
+
344
+
345
  with gr.Accordion("extras ", open=False):
346
  n_conditioning_codebooks = gr.Number(
347
  label="number of conditioning codebooks. probably 0",
 
375
  temp = gr.Slider(
376
  label="temperature",
377
  minimum=0.0,
378
+ maximum=3.0,
379
  value=0.8
380
  )
381
 
382
  with gr.Accordion("sampling settings", open=False):
383
  typical_filtering = gr.Checkbox(
384
  label="typical filtering ",
385
+ value=False
386
  )
387
  typical_mass = gr.Slider(
388
  label="typical mass (should probably stay between 0.1 and 0.5)",
 
460
  typical_filtering,
461
  typical_mass,
462
  typical_min_tokens,
463
+ checkpoint_key,
464
+ beat_mask_width,
465
+ beat_mask_downbeats
466
  }
467
 
468
  # connect widgets
vampnet/interface.py CHANGED
@@ -265,7 +265,12 @@ class Interface(torch.nn.Module):
265
  if invert:
266
  mask = 1 - mask
267
 
268
- return mask[None, None, :].bool().long()
 
 
 
 
 
269
 
270
  def coarse_to_fine(
271
  self,
@@ -349,26 +354,32 @@ if __name__ == "__main__":
349
  coarse_ckpt="./models/spotdl/coarse.pth",
350
  coarse2fine_ckpt="./models/spotdl/c2f.pth",
351
  codec_ckpt="./models/spotdl/codec.pth",
352
- device="cuda"
 
353
  )
354
 
355
- sig = at.AudioSignal('introspection ii-1.mp3', duration=10)
 
356
 
357
  z = interface.encode(sig)
358
 
359
- mask = linear_random(z, 1.0)
360
- mask = mask_and(
361
- mask, periodic_mask(
362
- z,
363
- 32,
364
- 1,
365
- random_roll=True
366
- )
 
 
 
 
367
  )
368
  # mask = dropout(mask, 0.0)
369
  # mask = codebook_unmask(mask, 0)
370
 
371
-
372
  zv, mask_z = interface.coarse_vamp(
373
  z,
374
  mask=mask,
 
265
  if invert:
266
  mask = 1 - mask
267
 
268
+ mask = mask[None, None, :].bool().long()
269
+ if self.c2f is not None:
270
+ mask = mask.repeat(1, self.c2f.n_codebooks, 1)
271
+ else:
272
+ mask = mask.repeat(1, self.coarse.n_codebooks, 1)
273
+ return mask
274
 
275
  def coarse_to_fine(
276
  self,
 
354
  coarse_ckpt="./models/spotdl/coarse.pth",
355
  coarse2fine_ckpt="./models/spotdl/c2f.pth",
356
  codec_ckpt="./models/spotdl/codec.pth",
357
+ device="cuda",
358
+ wavebeat_ckpt="./models/wavebeat.pth"
359
  )
360
 
361
+
362
+ sig = at.AudioSignal.zeros(duration=10, sample_rate=44100)
363
 
364
  z = interface.encode(sig)
365
 
366
+ # mask = linear_random(z, 1.0)
367
+ # mask = mask_and(
368
+ # mask, periodic_mask(
369
+ # z,
370
+ # 32,
371
+ # 1,
372
+ # random_roll=True
373
+ # )
374
+ # )
375
+
376
+ mask = interface.make_beat_mask(
377
+ sig, 0.0, 0.075
378
  )
379
  # mask = dropout(mask, 0.0)
380
  # mask = codebook_unmask(mask, 0)
381
 
382
+ breakpoint()
383
  zv, mask_z = interface.coarse_vamp(
384
  z,
385
  mask=mask,
vampnet/mask.py CHANGED
@@ -26,9 +26,9 @@ def apply_mask(
26
  mask: torch.Tensor,
27
  mask_token: int
28
  ):
29
- assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq)"
30
- assert mask.shape == x.shape, "mask must be same shape as x"
31
- assert mask.dtype == torch.long, "mask must be long dtype"
32
  assert ~torch.any(mask > 1), "mask must be binary"
33
  assert ~torch.any(mask < 0), "mask must be binary"
34
 
@@ -163,7 +163,7 @@ def mask_or(
163
  mask1: torch.Tensor,
164
  mask2: torch.Tensor
165
  ):
166
- assert mask1.shape == mask2.shape, "masks must be same shape"
167
  assert mask1.max() <= 1, "mask1 must be binary"
168
  assert mask2.max() <= 1, "mask2 must be binary"
169
  assert mask1.min() >= 0, "mask1 must be binary"
 
26
  mask: torch.Tensor,
27
  mask_token: int
28
  ):
29
+ assert mask.ndim == 3, "mask must be (batch, n_codebooks, seq), but got {mask.ndim}"
30
+ assert mask.shape == x.shape, f"mask must be same shape as x, but got {mask.shape} and {x.shape}"
31
+ assert mask.dtype == torch.long, "mask must be long dtype, but got {mask.dtype}"
32
  assert ~torch.any(mask > 1), "mask must be binary"
33
  assert ~torch.any(mask < 0), "mask must be binary"
34
 
 
163
  mask1: torch.Tensor,
164
  mask2: torch.Tensor
165
  ):
166
+ assert mask1.shape == mask2.shape, f"masks must be same shape, but got {mask1.shape} and {mask2.shape}"
167
  assert mask1.max() <= 1, "mask1 must be binary"
168
  assert mask2.max() <= 1, "mask2 must be binary"
169
  assert mask1.min() >= 0, "mask1 must be binary"