File size: 1,507 Bytes
3b36933 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
"""
Hugging Face compatible implementation of Open-MAGVIt2
Code reference: https://github.com/TencentARC/Open-MAGVIT2
"""
from transformers import PretrainedConfig
class EncoderDecoderConfig(PretrainedConfig):
model_type = "resnet_encoder_decoder"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.ch = kwargs.get("ch", 128)
self.in_channels = kwargs.get("in_channels", 3)
self.out_ch = kwargs.get("out_ch", 3)
self.z_channels = kwargs.get("z_channels", 18)
self.num_res_blocks = kwargs.get("num_res_blocks", 2)
self.ch_mult = kwargs.get("ch_mult", [1, 1, 2, 2, 4])
class QuantizerConfig(PretrainedConfig):
model_type = "lfq_quantizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dim = kwargs.get("dim", 18)
self.codebook_size = kwargs.get("codebook_size", 262144)
self.batch_maximization_weight = kwargs.get("batch_maximization_weight", 1.0)
self.sample_minimization_weight = kwargs.get("sample_minimization_weight", 1.0)
class LFQTokenizerConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transform
"""
model_type = "lfq_tokenizer"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.encoder_decoder_config = kwargs.get("encoder_decoder_config", EncoderDecoderConfig())
self.quantizer_config = kwargs.get("quantizer_config", QuantizerConfig())
|