Hugo Flores Garcia commited on
Commit
f3f4634
·
1 Parent(s): 3815be3

gooood outputs

Browse files
.gitignore CHANGED
@@ -174,3 +174,6 @@ runs-archive
174
  lyrebird-audiotools
175
  lyrebird-audio-codec
176
  samples-*/**
 
 
 
 
174
  lyrebird-audiotools
175
  lyrebird-audio-codec
176
  samples-*/**
177
+
178
+ gradio-outputs/
179
+ models/
conf/interface-jazzpop-exp.yml DELETED
@@ -1,9 +0,0 @@
1
- Interface.coarse_ckpt: /runs/jazzpop-coarse-1m-steps.pth
2
- Interface.coarse2fine_ckpt: /runs/jazzpop-c2f.pth
3
- Interface.codec_ckpt: /runs/codec-ckpt/codec.pth
4
- Interface.coarse_chunk_size_s: 5
5
- Interface.coarse2fine_chunk_size_s: 3
6
-
7
- AudioLoader.sources:
8
- - /data/spotdl/audio/val
9
- - /data/spotdl/audio/test
 
 
 
 
 
 
 
 
 
 
conf/interface-jazzpop.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./models/jazzpop/coarse.pth
2
+ Interface.coarse2fine_ckpt: ./models/jazzpop/c2f.pth
3
+ Interface.codec_ckpt: ./models/jazzpop/codec.pth
4
+ Interface.coarse_chunk_size_s: 5
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+ AudioLoader.sources:
9
+ - /data/spotdl-jazzpop/audio/val
10
+ - /data/spotdl-jazzpop/audio/test
conf/interface-spotdl.yml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Interface.coarse_ckpt: ./models/spotdl/coarse.pth
2
+ Interface.coarse2fine_ckpt: ./models/spotdl/c2f.pth
3
+ Interface.codec_ckpt: ./models/spotdl/codec.pth
4
+ Interface.coarse_chunk_size_s: 10
5
+ Interface.coarse2fine_chunk_size_s: 3
6
+ Interface.wavebeat_ckpt: ./models/wavebeat.pth
7
+
8
+
9
+ AudioLoader.sources:
10
+ - /data/spotdl/audio/val
11
+ - /data/spotdl/audio/test
demo.py CHANGED
@@ -1,6 +1,8 @@
1
  from pathlib import Path
2
  from typing import Tuple
3
  import yaml
 
 
4
 
5
  import numpy as np
6
  import audiotools as at
@@ -9,13 +11,15 @@ import argbind
9
  import gradio as gr
10
  from vampnet.interface import Interface
11
 
12
- conf = yaml.safe_load(Path("conf/interface-jazzpop-exp.yml").read_text())
13
-
14
  Interface = argbind.bind(Interface)
15
  AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
 
 
 
16
  with argbind.scope(conf):
17
  interface = Interface()
18
  loader = AudioLoader()
 
19
 
20
  dataset = at.data.datasets.AudioDataset(
21
  loader,
@@ -26,6 +30,10 @@ dataset = at.data.datasets.AudioDataset(
26
  )
27
 
28
 
 
 
 
 
29
  def load_audio(file):
30
  print(file)
31
  filepath = file.name
@@ -35,87 +43,207 @@ def load_audio(file):
35
  )
36
  sig = interface.preprocess(sig)
37
 
38
- audio = sig.samples.numpy()[0]
39
- sr = sig.sample_rate
40
- return sr, audio.T
 
41
 
42
  def load_random_audio():
43
  index = np.random.randint(0, len(dataset))
44
  sig = dataset[index]["signal"]
45
  sig = interface.preprocess(sig)
46
 
47
- audio = sig.samples.numpy()[0]
48
- sr = sig.sample_rate
49
- return sr, audio.T
 
50
 
51
 
52
  def vamp(
53
- input_audio, prefix_s, suffix_s, rand_mask_intensity,
 
54
  mask_periodic_amt, beat_unmask_dur,
55
  mask_dwn_chk, dwn_factor,
56
  mask_up_chk, up_factor,
57
- num_vamps, mode
58
  ):
59
- try:
60
  print(input_audio)
61
 
62
- sig = at.AudioSignal(
63
- input_audio[1],
64
- sample_rate=input_audio[0]
65
- )
66
 
67
- if beat_unmask_dur > 0.0:
68
  beat_mask = interface.make_beat_mask(
69
- sig,
70
- before_beat_s=0.01,
71
  after_beat_s=beat_unmask_dur,
72
  mask_downbeats=mask_dwn_chk,
73
  mask_upbeats=mask_up_chk,
74
- downbeat_downsample_factor=dwn_factor,
75
- beat_downsample_factor=up_factor,
76
  dropout=0.7,
77
  invert=True
78
  )
 
79
  else:
80
  beat_mask = None
81
 
82
  if mode == "standard":
83
- zv = interface.coarse_vamp_v2(
 
84
  sig,
 
 
85
  prefix_dur_s=prefix_s,
86
  suffix_dur_s=suffix_s,
87
  num_vamps=num_vamps,
88
  downsample_factor=mask_periodic_amt,
89
  intensity=rand_mask_intensity,
90
- ext_mask=beat_mask
 
 
91
  )
 
 
 
 
 
 
92
  elif mode == "loop":
93
- zv = interface.loop(
94
- zv,
 
 
95
  prefix_dur_s=prefix_s,
96
  suffix_dur_s=suffix_s,
97
  num_loops=num_vamps,
98
  downsample_factor=mask_periodic_amt,
99
  intensity=rand_mask_intensity,
100
- ext_mask=beat_mask
 
 
101
  )
 
 
 
102
 
103
- zv = interface.coarse_to_fine(zv)
104
- sig = interface.to_signal(zv)
105
- return sig.sample_rate, sig.samples[0].T
106
- except Exception as e:
107
- raise gr.Error(f"failed with error: {e}")
 
 
 
108
 
 
 
 
 
 
 
 
 
 
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  with gr.Blocks() as demo:
112
 
113
- gr.Markdown('# Vampnet')
114
-
115
  with gr.Row():
116
  # input audio
117
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  gr.Markdown("## Input Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  manual_audio_upload = gr.File(
121
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
@@ -126,9 +254,13 @@ with gr.Blocks() as demo:
126
  input_audio = gr.Audio(
127
  label="input audio",
128
  interactive=False,
 
129
  )
130
- input_audio_viz = gr.HTML(
131
- label="input audio",
 
 
 
132
  )
133
 
134
  # connect widgets
@@ -147,113 +279,160 @@ with gr.Blocks() as demo:
147
 
148
  # mask settings
149
  with gr.Column():
150
- gr.Markdown("## Mask Settings")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  prefix_s = gr.Slider(
152
- label="prefix length (seconds)",
153
  minimum=0.0,
154
  maximum=10.0,
155
  value=0.0
156
  )
157
  suffix_s = gr.Slider(
158
- label="suffix length (seconds)",
159
  minimum=0.0,
160
  maximum=10.0,
161
  value=0.0
162
  )
163
 
164
- rand_mask_intensity = gr.Slider(
165
- label="random mask intensity (lower means more freedom)",
 
166
  minimum=0.0,
167
- maximum=1.0,
168
- value=1.0
 
 
 
 
 
 
169
  )
170
 
171
- mask_periodic_amt = gr.Slider(
172
- label="periodic unmasking factor (higher means more freedom)",
173
- minimum=0,
174
- maximum=32,
 
 
 
 
 
175
  step=1,
176
- value=2,
177
  )
178
- compute_mask_button = gr.Button("compute mask")
179
- mask_output = gr.Audio(
180
- label="masked audio",
 
 
 
181
  interactive=False,
182
- visible=False
183
- )
184
- mask_output_viz = gr.Video(
185
- label="masked audio",
186
- interactive=False
187
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  with gr.Column():
190
- gr.Markdown("## Beat Unmasking")
191
- with gr.Accordion(label="beat unmask"):
192
  beat_unmask_dur = gr.Slider(
193
  label="duration",
194
  minimum=0.0,
195
  maximum=3.0,
196
  value=0.1
197
  )
198
- with gr.Accordion("downbeat settings"):
199
  mask_dwn_chk = gr.Checkbox(
200
- label="unmask downbeats",
201
  value=True
202
  )
203
  dwn_factor = gr.Slider(
204
- label="downbeat downsample factor (unmask every Nth downbeat)",
205
- value=1,
206
- minimum=1,
207
  maximum=16,
208
  step=1
209
  )
210
- with gr.Accordion("upbeat settings"):
211
  mask_up_chk = gr.Checkbox(
212
- label="unmask upbeats",
213
  value=True
214
  )
215
  up_factor = gr.Slider(
216
- label="upbeat downsample factor (unmask every Nth upbeat)",
217
- value=1,
218
- minimum=1,
219
  maximum=16,
220
  step=1
221
  )
222
-
223
- # process and output
224
- with gr.Row():
225
- with gr.Column():
226
- gr.Markdown("**NOTE**: for loop mode, both prefix and suffix must be greater than 0.")
227
- mode = gr.Radio(
228
- label="mode",
229
- choices=["standard", "loop"],
230
- value="standard"
231
- )
232
- num_vamps = gr.Number(
233
- label="number of vamps",
234
- value=1,
235
- precision=0
236
- )
237
- vamp_button = gr.Button("vamp")
238
 
239
- output_audio = gr.Audio(
240
- label="output audio",
241
- interactive=False,
242
- visible=False
243
- )
 
 
 
 
 
 
 
 
244
 
 
245
  # connect widgets
246
  vamp_button.click(
247
  fn=vamp,
248
- inputs=[input_audio,
249
  prefix_s, suffix_s, rand_mask_intensity,
250
  mask_periodic_amt, beat_unmask_dur,
251
  mask_dwn_chk, dwn_factor,
252
  mask_up_chk, up_factor,
253
- num_vamps, mode
254
  ],
255
- outputs=[output_audio]
256
  )
257
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- demo.launch(share=True, server_name="0.0.0.0")
 
1
  from pathlib import Path
2
  from typing import Tuple
3
  import yaml
4
+ import tempfile
5
+ import uuid
6
 
7
  import numpy as np
8
  import audiotools as at
 
11
  import gradio as gr
12
  from vampnet.interface import Interface
13
 
 
 
14
  Interface = argbind.bind(Interface)
15
  AudioLoader = argbind.bind(at.data.datasets.AudioLoader)
16
+
17
+ conf = argbind.parse_args()
18
+
19
  with argbind.scope(conf):
20
  interface = Interface()
21
  loader = AudioLoader()
22
+ print(f"interface device is {interface.device}")
23
 
24
  dataset = at.data.datasets.AudioDataset(
25
  loader,
 
30
  )
31
 
32
 
33
+ OUT_DIR = Path("gradio-outputs")
34
+ OUT_DIR.mkdir(exist_ok=True, parents=True)
35
+
36
+
37
  def load_audio(file):
38
  print(file)
39
  filepath = file.name
 
43
  )
44
  sig = interface.preprocess(sig)
45
 
46
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
47
+ out_dir.mkdir(parents=True, exist_ok=True)
48
+ sig.write(out_dir / "input.wav")
49
+ return sig.path_to_file
50
 
51
  def load_random_audio():
52
  index = np.random.randint(0, len(dataset))
53
  sig = dataset[index]["signal"]
54
  sig = interface.preprocess(sig)
55
 
56
+ out_dir = OUT_DIR / "tmp" / str(uuid.uuid4())
57
+ out_dir.mkdir(parents=True, exist_ok=True)
58
+ sig.write(out_dir / "input.wav")
59
+ return sig.path_to_file
60
 
61
 
62
  def vamp(
63
+ input_audio, init_temp, final_temp,
64
+ prefix_s, suffix_s, rand_mask_intensity,
65
  mask_periodic_amt, beat_unmask_dur,
66
  mask_dwn_chk, dwn_factor,
67
  mask_up_chk, up_factor,
68
+ num_vamps, mode, use_beats, num_steps
69
  ):
70
+ # try:
71
  print(input_audio)
72
 
73
+ sig = at.AudioSignal(input_audio.name)
 
 
 
74
 
75
+ if beat_unmask_dur > 0.0 and use_beats:
76
  beat_mask = interface.make_beat_mask(
77
+ sig,
78
+ before_beat_s=0.0,
79
  after_beat_s=beat_unmask_dur,
80
  mask_downbeats=mask_dwn_chk,
81
  mask_upbeats=mask_up_chk,
82
+ downbeat_downsample_factor=dwn_factor if dwn_factor > 0 else None,
83
+ beat_downsample_factor=up_factor if up_factor > 0 else None,
84
  dropout=0.7,
85
  invert=True
86
  )
87
+ print(beat_mask)
88
  else:
89
  beat_mask = None
90
 
91
  if mode == "standard":
92
+ print(f"running standard vampnet with {num_vamps} vamps")
93
+ zv, mask_z = interface.coarse_vamp_v2(
94
  sig,
95
+ sampling_steps=num_steps,
96
+ temperature=(init_temp, final_temp),
97
  prefix_dur_s=prefix_s,
98
  suffix_dur_s=suffix_s,
99
  num_vamps=num_vamps,
100
  downsample_factor=mask_periodic_amt,
101
  intensity=rand_mask_intensity,
102
+ ext_mask=beat_mask,
103
+ verbose=True,
104
+ return_mask=True
105
  )
106
+
107
+ zv = interface.coarse_to_fine(zv)
108
+ mask = interface.to_signal(mask_z).cpu()
109
+
110
+ sig = interface.to_signal(zv).cpu()
111
+ print("done")
112
  elif mode == "loop":
113
+ print(f"running loop vampnet with {num_vamps} vamps")
114
+ sig, mask = interface.loop(
115
+ sig,
116
+ temperature=(init_temp, final_temp),
117
  prefix_dur_s=prefix_s,
118
  suffix_dur_s=suffix_s,
119
  num_loops=num_vamps,
120
  downsample_factor=mask_periodic_amt,
121
  intensity=rand_mask_intensity,
122
+ ext_mask=beat_mask,
123
+ verbose=True,
124
+ return_mask=True
125
  )
126
+ sig = sig.cpu()
127
+ mask = mask.cpu()
128
+ print("done")
129
 
130
+
131
+ out_dir = OUT_DIR / str(uuid.uuid4())
132
+ out_dir.mkdir()
133
+ sig.write(out_dir / "output.wav")
134
+ mask.write(out_dir / "mask.wav")
135
+ return sig.path_to_file, mask.path_to_file
136
+ # except Exception as e:
137
+ # raise gr.Error(f"failed with error: {e}")
138
 
139
+ def save_vamp(
140
+ input_audio, init_temp, final_temp,
141
+ prefix_s, suffix_s, rand_mask_intensity,
142
+ mask_periodic_amt, beat_unmask_dur,
143
+ mask_dwn_chk, dwn_factor,
144
+ mask_up_chk, up_factor,
145
+ num_vamps, mode, output_audio, notes, use_beats, num_steps
146
+ ):
147
+ out_dir = OUT_DIR / "saved" / str(uuid.uuid4())
148
+ out_dir.mkdir(parents=True, exist_ok=True)
149
 
150
+ sig_in = at.AudioSignal(input_audio.name)
151
+ sig_out = at.AudioSignal(output_audio.name)
152
+
153
+ sig_in.write(out_dir / "input.wav")
154
+ sig_out.write(out_dir / "output.wav")
155
+
156
+ data = {
157
+ "init_temp": init_temp,
158
+ "final_temp": final_temp,
159
+ "prefix_s": prefix_s,
160
+ "suffix_s": suffix_s,
161
+ "rand_mask_intensity": rand_mask_intensity,
162
+ "mask_periodic_amt": mask_periodic_amt,
163
+ "use_beats": use_beats,
164
+ "beat_unmask_dur": beat_unmask_dur,
165
+ "mask_dwn_chk": mask_dwn_chk,
166
+ "dwn_factor": dwn_factor,
167
+ "mask_up_chk": mask_up_chk,
168
+ "up_factor": up_factor,
169
+ "num_vamps": num_vamps,
170
+ "num_steps": num_steps,
171
+ "mode": mode,
172
+ "notes": notes,
173
+ }
174
+
175
+ # save with yaml
176
+ with open(out_dir / "data.yaml", "w") as f:
177
+ yaml.dump(data, f)
178
+
179
+ import zipfile
180
+ zip_path = out_dir.with_suffix(".zip")
181
+ with zipfile.ZipFile(zip_path, "w") as zf:
182
+ for file in out_dir.iterdir():
183
+ zf.write(file, file.name)
184
+
185
+ return f"saved! your save code is {out_dir.stem}", zip_path
186
 
187
  with gr.Blocks() as demo:
188
 
 
 
189
  with gr.Row():
190
  # input audio
191
  with gr.Column():
192
+ gr.Markdown("""
193
+ # Vampnet
194
+ **Instructions**:
195
+ 1. Upload some audio (or click the load random audio button)
196
+ 2. Adjust the mask hints. The more hints, the more the generated music will follow the input music
197
+ 3. Adjust the vampnet parameters. The more vamps, the longer the generated music will be
198
+ 4. Click the "vamp" button
199
+ 5. Listen to the generated audio
200
+ 6. If you noticed something you liked, write some notes, click the "save vamp" button, and copy the save code
201
+
202
+
203
+
204
+
205
+ """)
206
  gr.Markdown("## Input Audio")
207
+ with gr.Column():
208
+ gr.Markdown("""
209
+ ## Mask Hints
210
+ - most of the original audio will be masked and replaced with audio generated by vampnet
211
+ - mask hints are used to guide vampnet to generate audio that sounds like the original
212
+ - the more hints you give, the more the generated audio will sound like the original
213
+
214
+
215
+
216
+
217
+
218
+
219
+
220
+ """)
221
+ with gr.Column():
222
+ gr.Markdown("""
223
+ ### Tips
224
+ - use the beat sync button so the output audio has the same beat structure as the input audio
225
+ - if you want the generated audio to sound like the original, but with a different beat structure:
226
+ - uncheck the beat sync button
227
+ - decrease the periodic unmasking to anywhere from 2 to 8
228
+ - if you want a more "random" generation:
229
+ - uncheck the beat sync button (or reduce the beat unmask duration)
230
+ - increase the periodic unmasking to 16 or more
231
+
232
+ """)
233
+
234
+
235
+ with gr.Row():
236
+ with gr.Column():
237
+ mode = gr.Radio(
238
+ label="**mode**. note that loop mode requires a prefix and suffix longer than 0",
239
+ choices=["standard", "loop"],
240
+ value="standard"
241
+ )
242
+ num_vamps = gr.Number(
243
+ label="number of vamps (or loops). more vamps = longer generated audio",
244
+ value=1,
245
+ precision=0
246
+ )
247
 
248
  manual_audio_upload = gr.File(
249
  label=f"upload some audio (will be randomly trimmed to max of {interface.coarse.chunk_size_s:.2f}s)",
 
254
  input_audio = gr.Audio(
255
  label="input audio",
256
  interactive=False,
257
+ type="file",
258
  )
259
+
260
+ audio_mask = gr.Audio(
261
+ label="audio mask (listen to this to hear the mask hints)",
262
+ interactive=False,
263
+ type="file",
264
  )
265
 
266
  # connect widgets
 
279
 
280
  # mask settings
281
  with gr.Column():
282
+
283
+ mask_periodic_amt = gr.Slider(
284
+ label="periodic unmasking factor (provides a rhythmic, periodic hint). 0.0 means no hint, 2 means one hint every 2 timesteps, etc, 4 means one hint every 4 timesteps, etc.",
285
+ minimum=0,
286
+ maximum=32,
287
+ step=1,
288
+ value=16,
289
+ )
290
+
291
+
292
+ rand_mask_intensity = gr.Slider(
293
+ label="random mask intensity. (If this is less than 1, scatters tiny hints throughout the audio, should be between 0.9 and 1.0)",
294
+ minimum=0.0,
295
+ maximum=1.0,
296
+ value=1.0
297
+ )
298
+
299
  prefix_s = gr.Slider(
300
+ label="prefix hint length (seconds)",
301
  minimum=0.0,
302
  maximum=10.0,
303
  value=0.0
304
  )
305
  suffix_s = gr.Slider(
306
+ label="suffix hint length (seconds)",
307
  minimum=0.0,
308
  maximum=10.0,
309
  value=0.0
310
  )
311
 
312
+
313
+ init_temp = gr.Slider(
314
+ label="initial temperature (should probably stay between 0.6 and 1)",
315
  minimum=0.0,
316
+ maximum=1.5,
317
+ value=0.8
318
+ )
319
+ final_temp = gr.Slider(
320
+ label="final temperature (should probably stay between 0.7 and 2)",
321
+ minimum=0.0,
322
+ maximum=2.0,
323
+ value=0.9
324
  )
325
 
326
+ use_beats = gr.Checkbox(
327
+ label="use beat hints",
328
+ value=True
329
+ )
330
+
331
+ num_steps = gr.Slider(
332
+ label="number of steps (should normally be between 12 and 36)",
333
+ minimum=4,
334
+ maximum=128,
335
  step=1,
336
+ value=24
337
  )
338
+
339
+
340
+ vamp_button = gr.Button("vamp!!!")
341
+
342
+ output_audio = gr.Audio(
343
+ label="output audio",
344
  interactive=False,
345
+ type="file"
 
 
 
 
346
  )
347
+
348
+
349
+ # gr.Markdown("**NOTE**: for loop mode, both prefix and suffix must be greater than 0.")
350
+ # compute_mask_button = gr.Button("compute mask")
351
+ # mask_output = gr.Audio(
352
+ # label="masked audio",
353
+ # interactive=False,
354
+ # visible=False
355
+ # )
356
+ # mask_output_viz = gr.Video(
357
+ # label="masked audio",
358
+ # interactive=False
359
+ # )
360
 
361
  with gr.Column():
362
+ with gr.Accordion(label="beat unmask (how much time around the beat should be hinted?)"):
363
+
364
  beat_unmask_dur = gr.Slider(
365
  label="duration",
366
  minimum=0.0,
367
  maximum=3.0,
368
  value=0.1
369
  )
370
+ with gr.Accordion("downbeat settings", open=False):
371
  mask_dwn_chk = gr.Checkbox(
372
+ label="hint downbeats",
373
  value=True
374
  )
375
  dwn_factor = gr.Slider(
376
+ label="downbeat downsample factor (hint only every Nth downbeat)",
377
+ value=0,
378
+ minimum=0,
379
  maximum=16,
380
  step=1
381
  )
382
+ with gr.Accordion("upbeat settings", open=False):
383
  mask_up_chk = gr.Checkbox(
384
+ label="hint upbeats",
385
  value=True
386
  )
387
  up_factor = gr.Slider(
388
+ label="upbeat downsample factor (hint only every Nth upbeat)",
389
+ value=0,
390
+ minimum=0,
391
  maximum=16,
392
  step=1
393
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
+ notes_text = gr.Textbox(
396
+ label="type any notes about the generated audio here",
397
+ value="",
398
+ interactive=True
399
+ )
400
+ save_button = gr.Button("download vamp")
401
+ download_file = gr.File(
402
+ label="vamp to download will appear here",
403
+ interactive=False
404
+ )
405
+
406
+
407
+ thank_you = gr.Markdown("")
408
 
409
+
410
  # connect widgets
411
  vamp_button.click(
412
  fn=vamp,
413
+ inputs=[input_audio, init_temp,final_temp,
414
  prefix_s, suffix_s, rand_mask_intensity,
415
  mask_periodic_amt, beat_unmask_dur,
416
  mask_dwn_chk, dwn_factor,
417
  mask_up_chk, up_factor,
418
+ num_vamps, mode, use_beats, num_steps
419
  ],
420
+ outputs=[output_audio, audio_mask]
421
  )
422
 
423
+ save_button.click(
424
+ fn=save_vamp,
425
+ inputs=[
426
+ input_audio, init_temp, final_temp,
427
+ prefix_s, suffix_s, rand_mask_intensity,
428
+ mask_periodic_amt, beat_unmask_dur,
429
+ mask_dwn_chk, dwn_factor,
430
+ mask_up_chk, up_factor,
431
+ num_vamps, mode,
432
+ output_audio,
433
+ notes_text, use_beats, num_steps
434
+ ],
435
+ outputs=[thank_you, download_file]
436
+ )
437
 
438
+ demo.launch(share=True, enable_queue=True)
vampnet/beats.py CHANGED
@@ -215,7 +215,7 @@ class WaveBeat(BeatTracker):
215
  beats, downbeats = self.model.predict_beats_from_array(
216
  audio=signal.audio_data.squeeze(0),
217
  sr=signal.sample_rate,
218
- use_gpu=self.device is not "cpu",
219
  )
220
 
221
  return beats, downbeats
 
215
  beats, downbeats = self.model.predict_beats_from_array(
216
  audio=signal.audio_data.squeeze(0),
217
  sr=signal.sample_rate,
218
+ use_gpu=self.device != "cpu",
219
  )
220
 
221
  return beats, downbeats
vampnet/interface.py CHANGED
@@ -26,6 +26,7 @@ class Interface:
26
  coarse_ckpt: str = None,
27
  coarse2fine_ckpt: str = None,
28
  codec_ckpt: str = None,
 
29
  device: str = "cpu",
30
  coarse_chunk_size_s: int = 5,
31
  coarse2fine_chunk_size_s: int = 3,
@@ -51,6 +52,13 @@ class Interface:
51
  else:
52
  self.c2f = None
53
 
 
 
 
 
 
 
 
54
  self.device = device
55
 
56
  def s2t(self, seconds: float):
@@ -71,8 +79,13 @@ class Interface:
71
  def to(self, device):
72
  self.device = device
73
  self.coarse.to(device)
74
- self.c2f.to(device)
75
  self.codec.to(device)
 
 
 
 
 
 
76
  return self
77
 
78
  def to_signal(self, z: torch.Tensor):
@@ -106,7 +119,7 @@ class Interface:
106
  mask_upbeats: bool = True,
107
  downbeat_downsample_factor: int = None,
108
  beat_downsample_factor: int = None,
109
- dropout: float = 0.7,
110
  invert: bool = True,
111
  ):
112
  """make a beat synced mask. that is, make a mask that
@@ -146,6 +159,8 @@ class Interface:
146
 
147
  beats_z = beats_z[::beat_downsample_factor]
148
  downbeats_z = downbeats_z[::downbeat_downsample_factor]
 
 
149
 
150
  if mask_upbeats:
151
  for beat_idx in beats_z:
@@ -153,8 +168,10 @@ class Interface:
153
  num_steps = mask[_slice[0]:_slice[1]].shape[0]
154
  _m = torch.ones(num_steps, device=self.device)
155
  _m = torch.nn.functional.dropout(_m, p=dropout)
 
156
 
157
  mask[_slice[0]:_slice[1]] = _m
 
158
 
159
  if mask_downbeats:
160
  for downbeat_idx in downbeats_z:
@@ -165,6 +182,7 @@ class Interface:
165
 
166
  mask[_slice[0]:_slice[1]] = _m
167
 
 
168
  if invert:
169
  mask = 1 - mask
170
 
@@ -317,6 +335,7 @@ class Interface:
317
  ext_mask=None,
318
  n_conditioning_codebooks=None,
319
  verbose=False,
 
320
  **kwargs
321
  ):
322
  z = self.encode(signal)
@@ -448,6 +467,9 @@ class Interface:
448
  prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
449
  suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
450
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
 
 
 
451
  return c_vamp
452
 
453
  # create a variation of an audio signal
@@ -527,6 +549,7 @@ class Interface:
527
  num_loops: int = 4,
528
  # overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
529
  verbose: bool = False,
 
530
  **kwargs,
531
  ):
532
  assert prefix_dur_s >= 0.0, "prefix duration must be >= 0"
@@ -549,8 +572,12 @@ class Interface:
549
  prefix_dur_s=prefix_dur_s,
550
  suffix_dur_s=suffix_dur_s,
551
  swap_prefix_suffix=is_flipped,
 
552
  **kwargs
553
  )
 
 
 
554
  # if we're flipped, we trim the prefix off of the end
555
  # otherwise we trim the suffix off of the end
556
  trim_len = prefix_len_tokens if is_flipped else suffix_len_tokens
@@ -568,6 +595,9 @@ class Interface:
568
  loops = [self.coarse_to_fine(l) for l in loops]
569
 
570
  loops = [self.to_signal(l) for l in loops]
571
-
 
 
 
572
  return signal_concat(loops)
573
 
 
26
  coarse_ckpt: str = None,
27
  coarse2fine_ckpt: str = None,
28
  codec_ckpt: str = None,
29
+ wavebeat_ckpt: str = None,
30
  device: str = "cpu",
31
  coarse_chunk_size_s: int = 5,
32
  coarse2fine_chunk_size_s: int = 3,
 
52
  else:
53
  self.c2f = None
54
 
55
+ if wavebeat_ckpt is not None:
56
+ print(f"loading wavebeat from {wavebeat_ckpt}")
57
+ self.beat_tracker = WaveBeat(wavebeat_ckpt)
58
+ self.beat_tracker.model.to(device)
59
+ else:
60
+ self.beat_tracker = None
61
+
62
  self.device = device
63
 
64
  def s2t(self, seconds: float):
 
79
  def to(self, device):
80
  self.device = device
81
  self.coarse.to(device)
 
82
  self.codec.to(device)
83
+
84
+ if self.c2f is not None:
85
+ self.c2f.to(device)
86
+
87
+ if self.beat_tracker is not None:
88
+ self.beat_tracker.model.to(device)
89
  return self
90
 
91
  def to_signal(self, z: torch.Tensor):
 
119
  mask_upbeats: bool = True,
120
  downbeat_downsample_factor: int = None,
121
  beat_downsample_factor: int = None,
122
+ dropout: float = 0.3,
123
  invert: bool = True,
124
  ):
125
  """make a beat synced mask. that is, make a mask that
 
159
 
160
  beats_z = beats_z[::beat_downsample_factor]
161
  downbeats_z = downbeats_z[::downbeat_downsample_factor]
162
+ print(f"beats_z: {len(beats_z)}")
163
+ print(f"downbeats_z: {len(downbeats_z)}")
164
 
165
  if mask_upbeats:
166
  for beat_idx in beats_z:
 
168
  num_steps = mask[_slice[0]:_slice[1]].shape[0]
169
  _m = torch.ones(num_steps, device=self.device)
170
  _m = torch.nn.functional.dropout(_m, p=dropout)
171
+ print(_m)
172
 
173
  mask[_slice[0]:_slice[1]] = _m
174
+ print(mask)
175
 
176
  if mask_downbeats:
177
  for downbeat_idx in downbeats_z:
 
182
 
183
  mask[_slice[0]:_slice[1]] = _m
184
 
185
+ mask = mask.clamp(0, 1)
186
  if invert:
187
  mask = 1 - mask
188
 
 
335
  ext_mask=None,
336
  n_conditioning_codebooks=None,
337
  verbose=False,
338
+ return_mask=False,
339
  **kwargs
340
  ):
341
  z = self.encode(signal)
 
467
  prefix_codes = torch.cat(c_vamp['prefix'], dim=-1)
468
  suffix_codes = torch.cat(c_vamp['suffix'], dim=-1)
469
  c_vamp = torch.cat([prefix_codes, suffix_codes], dim=-1)
470
+
471
+ if return_mask:
472
+ return c_vamp, cz_masked
473
  return c_vamp
474
 
475
  # create a variation of an audio signal
 
549
  num_loops: int = 4,
550
  # overlap_hop_ratio: float = 1.0, # TODO: should this be fixed to 1.0? or should we overlap and replace instead of overlap add
551
  verbose: bool = False,
552
+ return_mask: bool = False,
553
  **kwargs,
554
  ):
555
  assert prefix_dur_s >= 0.0, "prefix duration must be >= 0"
 
572
  prefix_dur_s=prefix_dur_s,
573
  suffix_dur_s=suffix_dur_s,
574
  swap_prefix_suffix=is_flipped,
575
+ return_mask=return_mask,
576
  **kwargs
577
  )
578
+ if return_mask:
579
+ vamped, mask = vamped
580
+
581
  # if we're flipped, we trim the prefix off of the end
582
  # otherwise we trim the suffix off of the end
583
  trim_len = prefix_len_tokens if is_flipped else suffix_len_tokens
 
595
  loops = [self.coarse_to_fine(l) for l in loops]
596
 
597
  loops = [self.to_signal(l) for l in loops]
598
+
599
+ if return_mask:
600
+ return signal_concat(loops), self.to_signal(mask)
601
+
602
  return signal_concat(loops)
603