File size: 8,418 Bytes
73465f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
from dataclasses import asdict, dataclass
from typing import Dict, Optional, List
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
@dataclass
class GPTAudioConfig:
"""Configuration for GPT audio processing parameters"""
mel_channels: int = 80
sample_rate: int = 22050
output_sample_rate: int = 24000
@dataclass
class XTTSAudioConfig:
"""Configuration for audio processing parameters"""
sample_rate: int = 22050
output_sample_rate: int = 24000
mel_channels: int = 80
hop_length: int = 256
win_length: int = 1024
n_fft: int = 1024
fmin: int = 0
fmax: int = 8000
power: float = 1.0
mel_norms_file: Optional[str] = None
class XTTSGPTConfig(PretrainedConfig):
"""Configuration class for the GPT component of XTTS."""
model_type = "xtts_gpt"
def __init__(
self,
# Model architecture
hidden_size: int = 1024, # gpt_n_model_channels in original
n_inner: int = 4096,
num_hidden_layers: int = 30, # gpt_layers in original
num_attention_heads: int = 16, # gpt_n_heads in original
# Tokenizer settings
vocab_size: int = 6681, # gpt_number_text_tokens in original
number_text_tokens: int = 6681, # Explicit text token vocabulary size
start_text_token: Optional[int] = None,
stop_text_token: Optional[int] = None,
# Audio token settings
num_audio_tokens: int = 1026, # gpt_num_audio_tokens in original
start_audio_token: int = 1024, # gpt_start_audio_token in original
stop_audio_token: int = 1025, # gpt_stop_audio_token in original
# Sequence length settings
max_audio_tokens: int = 605, # gpt_max_audio_tokens in original
max_text_tokens: int = 402, # gpt_max_text_tokens in original
max_prompt_tokens: int = 70, # gpt_max_prompt_tokens in original
gpt_max_audio_tokens: int = 605, # Used for generation
# Model behavior settings
use_masking_gt_prompt_approach: bool = True, # gpt_use_masking_gt_prompt_approach in original
use_perceiver_resampler: bool = True, # gpt_use_perceiver_resampler in original
kv_cache: bool = True,
enable_redaction: bool = False,
# GPT batch settings
gpt_batch_size: int = 1,
# Audio processing
audio_config: Optional[Dict] = None,
# Architecture specifics
layer_norm_epsilon: float = 1e-5,
initializer_range: float = 0.02,
add_cross_attention: bool = False,
scale_attn_by_inverse_layer_idx: bool = False,
reorder_and_upcast_attn: bool = False,
# Size settings for the decoder
decoder_input_dim: int = 1024,
architectures=["XttsGPT"],
auto_map={
"AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig",
"AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT",
},
activation_function: str = "gelu",
attn_pdrop: float = 0.1,
**kwargs
):
super().__init__(**kwargs)
self.architectures = architectures
self.auto_map = auto_map
self.audio_config = GPTAudioConfig(
**audio_config if audio_config is not None else {}
)
self.activation_function = activation_function
self.attn_pdrop = attn_pdrop
self.hidden_size = hidden_size
self.n_inner = n_inner
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.vocab_size = vocab_size
self.number_text_tokens = number_text_tokens
self.start_text_token = start_text_token
self.stop_text_token = stop_text_token
self.num_audio_tokens = num_audio_tokens
self.start_audio_token = start_audio_token
self.stop_audio_token = stop_audio_token
self.max_audio_tokens = max_audio_tokens
self.max_text_tokens = max_text_tokens
self.max_prompt_tokens = max_prompt_tokens
self.gpt_max_audio_tokens = gpt_max_audio_tokens
self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach
self.use_perceiver_resampler = use_perceiver_resampler
self.kv_cache = kv_cache
self.enable_redaction = enable_redaction
self.gpt_batch_size = gpt_batch_size
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.add_cross_attention = add_cross_attention
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
self.decoder_input_dim = decoder_input_dim
def to_dict(self) -> Dict:
"""Convert the config to a dictionary."""
output = super().to_dict()
output["audio_config"] = asdict(self.audio_config)
return output
@classmethod
def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig":
"""Create a config from a dictionary."""
return cls(**config_dict)
class XTTSConfig(PretrainedConfig):
"""Configuration class for XTTS model components except GPT."""
model_type = "xtts"
def __init__(
self,
# Audio settings
audio_config: Optional[Dict] = None,
input_sample_rate: int = 22050,
output_sample_rate: int = 24000,
output_hop_length: int = 256,
# Model architecture
decoder_input_dim: int = 1024,
d_vector_dim: int = 512,
cond_d_vector_in_each_upsampling_layer: bool = True,
# Training settings
gpt_code_stride_len: int = 1024,
duration_const: int = 102400,
# Tokenizer settings
tokenizer_file: str = "",
num_chars: int = 255,
# Language support
languages: Optional[List[str]] = None,
# GPT configuration
gpt_config: Optional[Dict] = None,
architectures=["Xtts"],
auto_map = {
"AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig",
"AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts",
},
**kwargs
):
super().__init__(**kwargs)
self.architectures = architectures
self.auto_map = auto_map
# Initialize audio config
self.audio_config = XTTSAudioConfig(
**audio_config if audio_config is not None else {}
)
self.input_sample_rate = input_sample_rate
self.output_sample_rate = output_sample_rate
self.output_hop_length = output_hop_length
self.decoder_input_dim = decoder_input_dim
self.d_vector_dim = d_vector_dim
self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer
self.gpt_code_stride_len = gpt_code_stride_len
self.duration_const = duration_const
self.tokenizer_file = tokenizer_file
self.num_chars = num_chars
# Initialize GPT config
self.gpt = XTTSGPTConfig(**gpt_config if gpt_config is not None else {})
if languages is None:
self.languages = [
"en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru",
"nl", "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi"
]
else:
self.languages = languages
def to_dict(self) -> Dict:
"""Convert the config to a dictionary."""
output = super().to_dict()
output["audio_config"] = asdict(self.audio_config)
output["gpt_config"] = self.gpt.to_dict()
return output
@classmethod
def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSConfig":
"""Create a config from a dictionary."""
if "gpt_config" in config_dict:
gpt_config = config_dict["gpt_config"]
config_dict = {k: v for k, v in config_dict.items() if k != "gpt_config"}
return cls(gpt_config=gpt_config, **config_dict)
return cls(**config_dict) |