File size: 8,418 Bytes
73465f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from dataclasses import asdict, dataclass
from typing import Dict, Optional, List
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)


@dataclass
class GPTAudioConfig:
    """Configuration for GPT audio processing parameters"""
    mel_channels: int = 80
    sample_rate: int = 22050
    output_sample_rate: int = 24000

@dataclass
class XTTSAudioConfig:
    """Configuration for audio processing parameters"""
    sample_rate: int = 22050
    output_sample_rate: int = 24000
    mel_channels: int = 80
    hop_length: int = 256
    win_length: int = 1024
    n_fft: int = 1024
    fmin: int = 0
    fmax: int = 8000
    power: float = 1.0
    mel_norms_file: Optional[str] = None


class XTTSGPTConfig(PretrainedConfig):
    """Configuration class for the GPT component of XTTS."""
    model_type = "xtts_gpt"

    def __init__(
            self,
            # Model architecture
            hidden_size: int = 1024,  # gpt_n_model_channels in original
            n_inner: int = 4096,
            num_hidden_layers: int = 30,  # gpt_layers in original
            num_attention_heads: int = 16,  # gpt_n_heads in original

            # Tokenizer settings
            vocab_size: int = 6681,  # gpt_number_text_tokens in original
            number_text_tokens: int = 6681,  # Explicit text token vocabulary size
            start_text_token: Optional[int] = None,
            stop_text_token: Optional[int] = None,

            # Audio token settings
            num_audio_tokens: int = 1026,  # gpt_num_audio_tokens in original
            start_audio_token: int = 1024,  # gpt_start_audio_token in original
            stop_audio_token: int = 1025,  # gpt_stop_audio_token in original

            # Sequence length settings
            max_audio_tokens: int = 605,  # gpt_max_audio_tokens in original
            max_text_tokens: int = 402,  # gpt_max_text_tokens in original
            max_prompt_tokens: int = 70,  # gpt_max_prompt_tokens in original
            gpt_max_audio_tokens: int = 605,  # Used for generation

            # Model behavior settings
            use_masking_gt_prompt_approach: bool = True,  # gpt_use_masking_gt_prompt_approach in original
            use_perceiver_resampler: bool = True,  # gpt_use_perceiver_resampler in original
            kv_cache: bool = True,
            enable_redaction: bool = False,

            # GPT batch settings
            gpt_batch_size: int = 1,

            # Audio processing
            audio_config: Optional[Dict] = None,

            # Architecture specifics
            layer_norm_epsilon: float = 1e-5,
            initializer_range: float = 0.02,
            add_cross_attention: bool = False,
            scale_attn_by_inverse_layer_idx: bool = False,
            reorder_and_upcast_attn: bool = False,

            # Size settings for the decoder
            decoder_input_dim: int = 1024,
            architectures=["XttsGPT"],
            auto_map={
                "AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
                "AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
            },
            activation_function: str = "gelu",
            attn_pdrop: float = 0.1,
            **kwargs
    ):
        super().__init__(**kwargs)
        self.architectures = architectures
        self.auto_map = auto_map
        self.audio_config = GPTAudioConfig(
            **audio_config if audio_config is not None else {}
        )
        self.activation_function = activation_function
        self.attn_pdrop = attn_pdrop
        self.hidden_size = hidden_size
        self.n_inner = n_inner
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads

        self.vocab_size = vocab_size
        self.number_text_tokens = number_text_tokens
        self.start_text_token = start_text_token
        self.stop_text_token = stop_text_token

        self.num_audio_tokens = num_audio_tokens
        self.start_audio_token = start_audio_token
        self.stop_audio_token = stop_audio_token

        self.max_audio_tokens = max_audio_tokens
        self.max_text_tokens = max_text_tokens
        self.max_prompt_tokens = max_prompt_tokens
        self.gpt_max_audio_tokens = gpt_max_audio_tokens

        self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
        self.use_perceiver_resampler = use_perceiver_resampler
        self.kv_cache = kv_cache
        self.enable_redaction = enable_redaction

        self.gpt_batch_size = gpt_batch_size

        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.add_cross_attention = add_cross_attention
        self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
        self.reorder_and_upcast_attn = reorder_and_upcast_attn

        self.decoder_input_dim = decoder_input_dim

    def to_dict(self) -> Dict:
        """Convert the config to a dictionary."""
        output = super().to_dict()
        output["audio_config"] = asdict(self.audio_config)
        return output

    @classmethod
    def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
        """Create a config from a dictionary."""
        return cls(**config_dict)


class XTTSConfig(PretrainedConfig):
    """Configuration class for XTTS model components except GPT."""
    model_type = "xtts"

    def __init__(
            self,
            # Audio settings
            audio_config: Optional[Dict] = None,
            input_sample_rate: int = 22050,
            output_sample_rate: int = 24000,
            output_hop_length: int = 256,

            # Model architecture
            decoder_input_dim: int = 1024,
            d_vector_dim: int = 512,
            cond_d_vector_in_each_upsampling_layer: bool = True,

            # Training settings
            gpt_code_stride_len: int = 1024,
            duration_const: int = 102400,

            # Tokenizer settings
            tokenizer_file: str = "",
            num_chars: int = 255,

            # Language support
            languages: Optional[List[str]] = None,

            # GPT configuration
            gpt_config: Optional[Dict] = None,
            architectures=["Xtts"],
            auto_map = {
                       "AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig",
                       "AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts",
                   },
            **kwargs
    ):
        super().__init__(**kwargs)
        self.architectures = architectures
        self.auto_map = auto_map
        # Initialize audio config
        self.audio_config = XTTSAudioConfig(
            **audio_config if audio_config is not None else {}
        )

        self.input_sample_rate = input_sample_rate
        self.output_sample_rate = output_sample_rate
        self.output_hop_length = output_hop_length

        self.decoder_input_dim = decoder_input_dim
        self.d_vector_dim = d_vector_dim
        self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer

        self.gpt_code_stride_len = gpt_code_stride_len
        self.duration_const = duration_const

        self.tokenizer_file = tokenizer_file
        self.num_chars = num_chars

        # Initialize GPT config
        self.gpt = XTTSGPTConfig(**gpt_config if gpt_config is not None else {})

        if languages is None:
            self.languages = [
                "en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru",
                "nl", "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
            ]
        else:
            self.languages = languages

    def to_dict(self) -> Dict:
        """Convert the config to a dictionary."""
        output = super().to_dict()
        output["audio_config"] = asdict(self.audio_config)
        output["gpt_config"] = self.gpt.to_dict()
        return output

    @classmethod
    def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSConfig":
        """Create a config from a dictionary."""
        if "gpt_config" in config_dict:
            gpt_config = config_dict["gpt_config"]
            config_dict = {k: v for k, v in config_dict.items() if k != "gpt_config"}
            return cls(gpt_config=gpt_config, **config_dict)
        return cls(**config_dict)