File size: 2,733 Bytes
d72b2c3
 
 
 
 
 
 
 
 
 
d9889a1
d72b2c3
 
0a8807e
 
 
 
 
 
 
 
d9889a1
 
 
 
d72b2c3
 
 
 
 
 
 
 
 
 
 
d9889a1
d72b2c3
 
 
 
 
 
 
 
 
 
 
d9889a1
d72b2c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e366cd5
d72b2c3
e366cd5
d72b2c3
e366cd5
d72b2c3
e366cd5
d72b2c3
 
 
 
5067878
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import typing as tp
from einops import rearrange
import numpy as np
import torch
from torch import nn





class EncodecModel(nn.Module):

    def __init__(self,
                 decoder=None,
                 quantizer=None,
                 frame_rate=None,
                 sample_rate=None,
                 channels=None,
                 causal=False,
                 renormalize=False):
        
        super().__init__()
        self.frame_rate=0
        self.sample_rate=0
        self.channels=0
        self.decoder = decoder
        self.quantizer = quantizer
        self.frame_rate = frame_rate
        self.sample_rate = sample_rate
        self.channels = channels
        self.renormalize = renormalize
        self.causal = causal
        if self.causal:
            # we force disabling here to avoid handling linear overlap of segments
            # as supported in original EnCodec codebase.
            assert not self.renormalize, 'Causal model does not support renormalize'
        

    @property
    def total_codebooks(self):
        """Total number of quantizer codebooks available."""
        return self.quantizer.total_codebooks

    @property
    def num_codebooks(self):
        """Active number of codebooks used by the quantizer."""
        return self.quantizer.num_codebooks

    def set_num_codebooks(self, n):
        """Set the active number of codebooks used by the quantizer."""
        self.quantizer.set_num_codebooks(n)

    @property
    def cardinality(self):
        """Cardinality of each codebook."""
        return self.quantizer.bins

    def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
        scale: tp.Optional[torch.Tensor]
        if self.renormalize:
            mono = x.mean(dim=1, keepdim=True)
            volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
            scale = 1e-8 + volume
            x = x / scale
            scale = scale.view(-1, 1)
        else:
            scale = None
        return x, scale

    def postprocess(self,
                    x: torch.Tensor,
                    scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
        if scale is not None:
            assert self.renormalize
            x = x * scale.view(-1, 1, 1)
        return x

    def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
        # B,K,T -> B,C,T
        emb = self.decode_latent(codes)
        
        out = self.decoder(emb)
        
        out = self.postprocess(out, scale)
        
        return out

    def decode_latent(self, codes: torch.Tensor):
        """Decode from the discrete codes to continuous latent space."""
        return self.quantizer.decode(codes)