|
import typing as tp |
|
from einops import rearrange |
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
|
|
|
|
|
|
|
|
|
|
class EncodecModel(nn.Module): |
|
|
|
def __init__(self, |
|
decoder=None, |
|
quantizer=None, |
|
frame_rate=None, |
|
sample_rate=None, |
|
channels=None, |
|
causal=False, |
|
renormalize=False): |
|
|
|
super().__init__() |
|
self.frame_rate=0 |
|
self.sample_rate=0 |
|
self.channels=0 |
|
self.decoder = decoder |
|
self.quantizer = quantizer |
|
self.frame_rate = frame_rate |
|
self.sample_rate = sample_rate |
|
self.channels = channels |
|
self.renormalize = renormalize |
|
self.causal = causal |
|
if self.causal: |
|
|
|
|
|
assert not self.renormalize, 'Causal model does not support renormalize' |
|
|
|
|
|
@property |
|
def total_codebooks(self): |
|
"""Total number of quantizer codebooks available.""" |
|
return self.quantizer.total_codebooks |
|
|
|
@property |
|
def num_codebooks(self): |
|
"""Active number of codebooks used by the quantizer.""" |
|
return self.quantizer.num_codebooks |
|
|
|
def set_num_codebooks(self, n): |
|
"""Set the active number of codebooks used by the quantizer.""" |
|
self.quantizer.set_num_codebooks(n) |
|
|
|
@property |
|
def cardinality(self): |
|
"""Cardinality of each codebook.""" |
|
return self.quantizer.bins |
|
|
|
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
|
scale: tp.Optional[torch.Tensor] |
|
if self.renormalize: |
|
mono = x.mean(dim=1, keepdim=True) |
|
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt() |
|
scale = 1e-8 + volume |
|
x = x / scale |
|
scale = scale.view(-1, 1) |
|
else: |
|
scale = None |
|
return x, scale |
|
|
|
def postprocess(self, |
|
x: torch.Tensor, |
|
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor: |
|
if scale is not None: |
|
assert self.renormalize |
|
x = x * scale.view(-1, 1, 1) |
|
return x |
|
|
|
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): |
|
|
|
emb = self.decode_latent(codes) |
|
|
|
out = self.decoder(emb) |
|
|
|
out = self.postprocess(out, scale) |
|
|
|
return out |
|
|
|
def decode_latent(self, codes: torch.Tensor): |
|
"""Decode from the discrete codes to continuous latent space.""" |
|
return self.quantizer.decode(codes) |