Spaces:
Build error
Build error
import os | |
from pathlib import Path | |
import math | |
import logging | |
import torch | |
import numpy as np | |
from audiotools import AudioSignal | |
import tqdm | |
from .modules.transformer import VampNet | |
from .beats import WaveBeat | |
from .mask import * | |
# from dac.model.dac import DAC | |
from lac.model.lac import LAC as DAC | |
def signal_concat( | |
audio_signals: list, | |
): | |
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=-1) | |
return AudioSignal(audio_data, sample_rate=audio_signals[0].sample_rate) | |
def _load_model( | |
ckpt: str, | |
lora_ckpt: str = None, | |
device: str = "cpu", | |
chunk_size_s: int = 10, | |
): | |
# we need to set strict to False if the model has lora weights to add later | |
model = VampNet.load(location=Path(ckpt), map_location="cpu", strict=False) | |
# load lora weights if needed | |
if lora_ckpt is not None: | |
if not Path(lora_ckpt).exists(): | |
should_cont = input( | |
f"lora checkpoint {lora_ckpt} does not exist. continue? (y/n) " | |
) | |
if should_cont != "y": | |
raise Exception("aborting") | |
else: | |
model.load_state_dict(torch.load(lora_ckpt, map_location="cpu"), strict=False) | |
model.to(device) | |
model.eval() | |
model.chunk_size_s = chunk_size_s | |
return model | |
class Interface(torch.nn.Module): | |
def __init__( | |
self, | |
coarse_ckpt: str = None, | |
coarse_lora_ckpt: str = None, | |
coarse2fine_ckpt: str = None, | |
coarse2fine_lora_ckpt: str = None, | |
codec_ckpt: str = None, | |
wavebeat_ckpt: str = None, | |
device: str = "cpu", | |
coarse_chunk_size_s: int = 10, | |
coarse2fine_chunk_size_s: int = 3, | |
compile=True, | |
): | |
super().__init__() | |
assert codec_ckpt is not None, "must provide a codec checkpoint" | |
self.codec = DAC.load(Path(codec_ckpt)) | |
self.codec.eval() | |
self.codec.to(device) | |
self.codec_path = Path(codec_ckpt) | |
assert coarse_ckpt is not None, "must provide a coarse checkpoint" | |
self.coarse = _load_model( | |
ckpt=coarse_ckpt, | |
lora_ckpt=coarse_lora_ckpt, | |
device=device, | |
chunk_size_s=coarse_chunk_size_s, | |
) | |
self.coarse_path = Path(coarse_ckpt) | |
# check if we have a coarse2fine ckpt | |
if coarse2fine_ckpt is not None: | |
self.c2f_path = Path(coarse2fine_ckpt) | |
self.c2f = _load_model( | |
ckpt=coarse2fine_ckpt, | |
lora_ckpt=coarse2fine_lora_ckpt, | |
device=device, | |
chunk_size_s=coarse2fine_chunk_size_s, | |
) | |
else: | |
self.c2f_path = None | |
self.c2f = None | |
if wavebeat_ckpt is not None: | |
logging.debug(f"loading wavebeat from {wavebeat_ckpt}") | |
self.beat_tracker = WaveBeat(wavebeat_ckpt) | |
self.beat_tracker.model.to(device) | |
else: | |
self.beat_tracker = None | |
self.device = device | |
self.loudness = -24.0 | |
if compile: | |
logging.debug(f"compiling models") | |
self.coarse = torch.compile(self.coarse) | |
if self.c2f is not None: | |
self.c2f = torch.compile(self.c2f) | |
self.codec = torch.compile(self.codec) | |
def default(cls): | |
from . import download_codec, download_default | |
print(f"loading default vampnet") | |
codec_path = download_codec() | |
coarse_path, c2f_path = download_default() | |
return Interface( | |
coarse_ckpt=coarse_path, | |
coarse2fine_ckpt=c2f_path, | |
codec_ckpt=codec_path, | |
) | |
def available_models(cls): | |
from . import list_finetuned | |
return list_finetuned() + ["default"] | |
def load_finetuned(self, name: str): | |
assert name in self.available_models(), f"{name} is not a valid model name" | |
from . import download_finetuned, download_default | |
if name == "default": | |
coarse_path, c2f_path = download_default() | |
else: | |
coarse_path, c2f_path = download_finetuned(name) | |
self.reload( | |
coarse_ckpt=coarse_path, | |
c2f_ckpt=c2f_path, | |
) | |
def reload( | |
self, | |
coarse_ckpt: str = None, | |
c2f_ckpt: str = None, | |
): | |
if coarse_ckpt is not None: | |
# check if we already loaded, if so, don't reload | |
if self.coarse_path == Path(coarse_ckpt): | |
logging.debug(f"already loaded {coarse_ckpt}") | |
else: | |
self.coarse = _load_model( | |
ckpt=coarse_ckpt, | |
device=self.device, | |
chunk_size_s=self.coarse.chunk_size_s, | |
) | |
self.coarse_path = Path(coarse_ckpt) | |
logging.debug(f"loaded {coarse_ckpt}") | |
if c2f_ckpt is not None: | |
if self.c2f_path == Path(c2f_ckpt): | |
logging.debug(f"already loaded {c2f_ckpt}") | |
else: | |
self.c2f = _load_model( | |
ckpt=c2f_ckpt, | |
device=self.device, | |
chunk_size_s=self.c2f.chunk_size_s, | |
) | |
self.c2f_path = Path(c2f_ckpt) | |
logging.debug(f"loaded {c2f_ckpt}") | |
def s2t(self, seconds: float): | |
"""seconds to tokens""" | |
if isinstance(seconds, np.ndarray): | |
return np.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) | |
else: | |
return math.ceil(seconds * self.codec.sample_rate / self.codec.hop_length) | |
def s2t2s(self, seconds: float): | |
"""seconds to tokens to seconds""" | |
return self.t2s(self.s2t(seconds)) | |
def t2s(self, tokens: int): | |
"""tokens to seconds""" | |
return tokens * self.codec.hop_length / self.codec.sample_rate | |
def to(self, device): | |
self.device = device | |
self.coarse.to(device) | |
self.codec.to(device) | |
if self.c2f is not None: | |
self.c2f.to(device) | |
if self.beat_tracker is not None: | |
self.beat_tracker.model.to(device) | |
return self | |
def decode(self, z: torch.Tensor): | |
return self.coarse.decode(z, self.codec) | |
def _preprocess(self, signal: AudioSignal): | |
signal = ( | |
signal.clone() | |
.resample(self.codec.sample_rate) | |
.to_mono() | |
.normalize(self.loudness) | |
.ensure_max_of_audio(1.0) | |
) | |
logging.debug(f"length before codec preproc: {signal.samples.shape}") | |
signal.samples, length = self.codec.preprocess(signal.samples, signal.sample_rate) | |
logging.debug(f"length after codec preproc: {signal.samples.shape}") | |
return signal | |
def encode(self, signal: AudioSignal): | |
signal = signal.to(self.device) | |
signal = self._preprocess(signal) | |
z = self.codec.encode(signal.samples, signal.sample_rate)["codes"] | |
return z | |
def snap_to_beats( | |
self, | |
signal: AudioSignal | |
): | |
assert hasattr(self, "beat_tracker"), "No beat tracker loaded" | |
beats, downbeats = self.beat_tracker.extract_beats(signal) | |
# trim the signa around the first beat time | |
samples_begin = int(beats[0] * signal.sample_rate ) | |
samples_end = int(beats[-1] * signal.sample_rate) | |
logging.debug(beats[0]) | |
signal = signal.clone().trim(samples_begin, signal.length - samples_end) | |
return signal | |
def make_beat_mask(self, | |
signal: AudioSignal, | |
before_beat_s: float = 0.0, | |
after_beat_s: float = 0.02, | |
mask_downbeats: bool = True, | |
mask_upbeats: bool = True, | |
downbeat_downsample_factor: int = None, | |
beat_downsample_factor: int = None, | |
dropout: float = 0.0, | |
invert: bool = True, | |
): | |
"""make a beat synced mask. that is, make a mask that | |
places 1s at and around the beat, and 0s everywhere else. | |
""" | |
assert self.beat_tracker is not None, "No beat tracker loaded" | |
# get the beat times | |
beats, downbeats = self.beat_tracker.extract_beats(signal) | |
# get the beat indices in z | |
beats_z, downbeats_z = self.s2t(beats), self.s2t(downbeats) | |
# remove downbeats from beats | |
beats_z = torch.tensor(beats_z)[~torch.isin(torch.tensor(beats_z), torch.tensor(downbeats_z))] | |
beats_z = beats_z.tolist() | |
downbeats_z = downbeats_z.tolist() | |
# make the mask | |
seq_len = self.s2t(signal.duration) | |
mask = torch.zeros(seq_len, device=self.device) | |
mask_b4 = self.s2t(before_beat_s) | |
mask_after = self.s2t(after_beat_s) | |
if beat_downsample_factor is not None: | |
if beat_downsample_factor < 1: | |
raise ValueError("mask_beat_downsample_factor must be >= 1 or None") | |
else: | |
beat_downsample_factor = 1 | |
if downbeat_downsample_factor is not None: | |
if downbeat_downsample_factor < 1: | |
raise ValueError("mask_beat_downsample_factor must be >= 1 or None") | |
else: | |
downbeat_downsample_factor = 1 | |
beats_z = beats_z[::beat_downsample_factor] | |
downbeats_z = downbeats_z[::downbeat_downsample_factor] | |
logging.debug(f"beats_z: {len(beats_z)}") | |
logging.debug(f"downbeats_z: {len(downbeats_z)}") | |
if mask_upbeats: | |
for beat_idx in beats_z: | |
_slice = int(beat_idx - mask_b4), int(beat_idx + mask_after) | |
num_steps = mask[_slice[0]:_slice[1]].shape[0] | |
_m = torch.ones(num_steps, device=self.device) | |
_m_mask = torch.bernoulli(_m * (1 - dropout)) | |
_m = _m * _m_mask.long() | |
mask[_slice[0]:_slice[1]] = _m | |
if mask_downbeats: | |
for downbeat_idx in downbeats_z: | |
_slice = int(downbeat_idx - mask_b4), int(downbeat_idx + mask_after) | |
num_steps = mask[_slice[0]:_slice[1]].shape[0] | |
_m = torch.ones(num_steps, device=self.device) | |
_m_mask = torch.bernoulli(_m * (1 - dropout)) | |
_m = _m * _m_mask.long() | |
mask[_slice[0]:_slice[1]] = _m | |
mask = mask.clamp(0, 1) | |
if invert: | |
mask = 1 - mask | |
mask = mask[None, None, :].bool().long() | |
if self.c2f is not None: | |
mask = mask.repeat(1, self.c2f.n_codebooks, 1) | |
else: | |
mask = mask.repeat(1, self.coarse.n_codebooks, 1) | |
return mask | |
def set_chunk_size(self, chunk_size_s: float): | |
self.coarse.chunk_size_s = chunk_size_s | |
def coarse_to_fine( | |
self, | |
z: torch.Tensor, | |
mask: torch.Tensor = None, | |
return_mask: bool = False, | |
**kwargs | |
): | |
assert self.c2f is not None, "No coarse2fine model loaded" | |
length = z.shape[-1] | |
chunk_len = self.s2t(self.c2f.chunk_size_s) | |
n_chunks = math.ceil(z.shape[-1] / chunk_len) | |
# zero pad to chunk_len | |
if length % chunk_len != 0: | |
pad_len = chunk_len - (length % chunk_len) | |
z = torch.nn.functional.pad(z, (0, pad_len)) | |
mask = torch.nn.functional.pad(mask, (0, pad_len), value=1) if mask is not None else None | |
n_codebooks_to_append = self.c2f.n_codebooks - z.shape[1] | |
if n_codebooks_to_append > 0: | |
z = torch.cat([ | |
z, | |
torch.zeros(z.shape[0], n_codebooks_to_append, z.shape[-1]).long().to(self.device) | |
], dim=1) | |
logging.debug(f"appended {n_codebooks_to_append} codebooks to z") | |
# set the mask to 0 for all conditioning codebooks | |
if mask is not None: | |
mask = mask.clone() | |
mask[:, :self.c2f.n_conditioning_codebooks, :] = 0 | |
fine_z = [] | |
for i in range(n_chunks): | |
chunk = z[:, :, i * chunk_len : (i + 1) * chunk_len] | |
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] if mask is not None else None | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
chunk = self.c2f.generate( | |
codec=self.codec, | |
time_steps=chunk_len, | |
start_tokens=chunk, | |
return_signal=False, | |
mask=mask_chunk, | |
cfg_guidance=None, | |
**kwargs | |
) | |
fine_z.append(chunk) | |
fine_z = torch.cat(fine_z, dim=-1) | |
if return_mask: | |
return fine_z[:, :, :length].clone(), apply_mask(fine_z, mask, self.c2f.mask_token)[0][:, :, :length].clone() | |
return fine_z[:, :, :length].clone() | |
def coarse_vamp( | |
self, | |
z, | |
mask, | |
return_mask=False, | |
gen_fn=None, | |
**kwargs | |
): | |
# coarse z | |
cz = z[:, : self.coarse.n_codebooks, :].clone() | |
mask = mask[:, : self.coarse.n_codebooks, :] | |
# assert cz.shape[-1] <= self.s2t(self.coarse.chunk_size_s), f"the sequence of tokens provided must match the one specified in the coarse chunk size, but got {cz.shape[-1]} and {self.s2t(self.coarse.chunk_size_s)}" | |
# cut into chunks, keep the last chunk separate if it's too small | |
chunk_len = self.s2t(self.coarse.chunk_size_s) | |
n_chunks = math.ceil(cz.shape[-1] / chunk_len) | |
last_chunk_len = cz.shape[-1] % chunk_len | |
cz_chunks = [] | |
mask_chunks = [] | |
for i in range(n_chunks): | |
chunk = cz[:, :, i * chunk_len : (i + 1) * chunk_len] | |
mask_chunk = mask[:, :, i * chunk_len : (i + 1) * chunk_len] | |
# make sure that the very first and last timestep of each chunk is 0 so that we don't get a weird | |
# discontinuity when we stitch the chunks back together | |
# only if there's already a 0 somewhere in the chunk | |
if torch.any(mask_chunk == 0): | |
mask_chunk[:, :, 0] = 0 | |
mask_chunk[:, :, -1] = 0 | |
cz_chunks.append(chunk) | |
mask_chunks.append(mask_chunk) | |
# now vamp each chunk | |
cz_masked_chunks = [] | |
cz_vamped_chunks = [] | |
for chunk, mask_chunk in zip(cz_chunks, mask_chunks): | |
cz_masked_chunk, mask_chunk = apply_mask(chunk, mask_chunk, self.coarse.mask_token) | |
cz_masked_chunk = cz_masked_chunk[:, : self.coarse.n_codebooks, :] | |
cz_masked_chunks.append(cz_masked_chunk) | |
gen_fn = gen_fn or self.coarse.generate | |
with torch.autocast("cuda", dtype=torch.bfloat16): | |
c_vamp_chunk = gen_fn( | |
codec=self.codec, | |
time_steps=chunk_len, | |
start_tokens=cz_masked_chunk, | |
return_signal=False, | |
mask=mask_chunk, | |
**kwargs | |
) | |
cz_vamped_chunks.append(c_vamp_chunk) | |
# stitch the chunks back together | |
cz_masked = torch.cat(cz_masked_chunks, dim=-1) | |
c_vamp = torch.cat(cz_vamped_chunks, dim=-1) | |
# add the fine codes back in | |
c_vamp = torch.cat( | |
[c_vamp, z[:, self.coarse.n_codebooks :, :]], | |
dim=1 | |
) | |
if return_mask: | |
return c_vamp, cz_masked | |
return c_vamp | |
def build_mask(self, | |
z: torch.Tensor, | |
sig: AudioSignal = None, | |
rand_mask_intensity: float = 1.0, | |
prefix_s: float = 0.0, | |
suffix_s: float = 0.0, | |
periodic_prompt: int = 7, | |
periodic_prompt_width: int = 1, | |
onset_mask_width: int = 0, | |
_dropout: float = 0.0, | |
upper_codebook_mask: int = 3, | |
ncc: int = 0, | |
): | |
mask = linear_random(z, rand_mask_intensity) | |
mask = mask_and( | |
mask, | |
inpaint(z, self.s2t(prefix_s), self.s2t(suffix_s)), | |
) | |
pmask = periodic_mask(z, periodic_prompt, periodic_prompt_width, random_roll=True) | |
mask = mask_and(mask, pmask) | |
if onset_mask_width > 0: | |
assert sig is not None, f"must provide a signal to use onset mask" | |
mask = mask_and( | |
mask, onset_mask( | |
sig, z, self, | |
width=onset_mask_width | |
) | |
) | |
mask = dropout(mask, _dropout) | |
mask = codebook_unmask(mask, ncc) | |
mask = codebook_mask(mask, int(upper_codebook_mask), None) | |
return mask | |
def vamp( | |
self, | |
codes: torch.Tensor, | |
mask: torch.Tensor, | |
batch_size: int = 1, | |
feedback_steps: int = 1, | |
time_stretch_factor: int = 1, | |
return_mask: bool = False, | |
**kwargs, | |
): | |
z = codes | |
# expand z to batch size | |
z = z.expand(batch_size, -1, -1) | |
mask = mask.expand(batch_size, -1, -1) | |
# stretch mask and z to match the time stretch factor | |
# we'll add (stretch_factor - 1) mask tokens in between each timestep of z | |
# and we'll make the mask 1 in all the new slots we added | |
if time_stretch_factor > 1: | |
z = z.repeat_interleave(time_stretch_factor, dim=-1) | |
mask = mask.repeat_interleave(time_stretch_factor, dim=-1) | |
added_mask = torch.ones_like(mask) | |
added_mask[:, :, ::time_stretch_factor] = 0 | |
mask = mask.bool() | added_mask.bool() | |
mask = mask.long() | |
# the forward pass | |
logging.debug(z.shape) | |
logging.debug("coarse!") | |
zv, mask_z = self.coarse_vamp( | |
z, | |
mask=mask, | |
return_mask=True, | |
**kwargs | |
) | |
# add the top codebooks back in | |
if zv.shape[1] < z.shape[1]: | |
logging.debug(f"adding {z.shape[1] - zv.shape[1]} codebooks back in") | |
zv = torch.cat( | |
[zv, z[:, self.coarse.n_codebooks :, :]], | |
dim=1 | |
) | |
# now, coarse2fine | |
logging.debug(f"coarse2fine!") | |
zv, fine_zv_mask = self.coarse_to_fine( | |
zv, | |
mask=mask, | |
typical_filtering=True, | |
_sampling_steps=2, | |
return_mask=True | |
) | |
mask_z = torch.cat( | |
[mask_z[:, :self.coarse.n_codebooks, :], fine_zv_mask[:, self.coarse.n_codebooks:, :]], | |
dim=1 | |
) | |
z = zv | |
if return_mask: | |
return z, mask_z.cpu(), | |
else: | |
return z | |
def visualize_codes(self, z: torch.Tensor): | |
import matplotlib.pyplot as plt | |
# make sure the figsize is square when imshow is called | |
fig = plt.figure(figsize=(10, 7)) | |
# in subplots, plot z[0] and the mask | |
# set title to "codes" and "mask" | |
fig.add_subplot(2, 1, 1) | |
plt.imshow(z[0].cpu().numpy(), aspect='auto', origin='lower', cmap="tab20") | |
plt.title("codes") | |
plt.ylabel("codebook index") | |
# set the xticks to seconds | |
if __name__ == "__main__": | |
import audiotools as at | |
import logging | |
logger = logging.getLogger() | |
logger.setLevel(logging.INFO) | |
torch.set_logging.debugoptions(threshold=10000) | |
at.util.seed(42) | |
interface = Interface( | |
coarse_ckpt="./models/vampnet/coarse.pth", | |
coarse2fine_ckpt="./models/vampnet/c2f.pth", | |
codec_ckpt="./models/vampnet/codec.pth", | |
device="cuda", | |
wavebeat_ckpt="./models/wavebeat.pth" | |
) | |
sig = at.AudioSignal('assets/example.wav') | |
z = interface.encode(sig) | |
mask = interface.build_mask( | |
z=z, | |
sig=sig, | |
rand_mask_intensity=1.0, | |
prefix_s=0.0, | |
suffix_s=0.0, | |
periodic_prompt=7, | |
periodic_prompt2=7, | |
periodic_prompt_width=1, | |
onset_mask_width=5, | |
_dropout=0.0, | |
upper_codebook_mask=3, | |
upper_codebook_mask_2=None, | |
ncc=0, | |
) | |
zv, mask_z = interface.coarse_vamp( | |
z, | |
mask=mask, | |
return_mask=True, | |
gen_fn=interface.coarse.generate | |
) | |
use_coarse2fine = True | |
if use_coarse2fine: | |
zv = interface.coarse_to_fine(zv, mask=mask) | |
breakpoint() | |
mask = interface.decode(mask_z).cpu() | |
sig = interface.decode(zv).cpu() | |
logging.debug("done") | |