File size: 1,507 Bytes
3b36933
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""
Hugging Face compatible implementation of Open-MAGVIt2
Code reference: https://github.com/TencentARC/Open-MAGVIT2
"""


from transformers import PretrainedConfig


class EncoderDecoderConfig(PretrainedConfig):
    model_type = "resnet_encoder_decoder"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.ch = kwargs.get("ch", 128)
        self.in_channels = kwargs.get("in_channels", 3)
        self.out_ch = kwargs.get("out_ch", 3)
        self.z_channels = kwargs.get("z_channels", 18)
        self.num_res_blocks = kwargs.get("num_res_blocks", 2)
        self.ch_mult = kwargs.get("ch_mult", [1, 1, 2, 2, 4])


class QuantizerConfig(PretrainedConfig):
    model_type = "lfq_quantizer"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dim = kwargs.get("dim", 18)
        self.codebook_size = kwargs.get("codebook_size", 262144)
        self.batch_maximization_weight = kwargs.get("batch_maximization_weight", 1.0)
        self.sample_minimization_weight = kwargs.get("sample_minimization_weight", 1.0)


class LFQTokenizerConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a :class:`~transform
    """
    model_type = "lfq_tokenizer"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.encoder_decoder_config = kwargs.get("encoder_decoder_config", EncoderDecoderConfig())
        self.quantizer_config = kwargs.get("quantizer_config", QuantizerConfig())