|
from dataclasses import asdict, dataclass |
|
from typing import Dict, Optional, List |
|
from transformers.configuration_utils import PretrainedConfig |
|
from transformers.utils import logging |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class GPTAudioConfig: |
|
"""Configuration for GPT audio processing parameters""" |
|
mel_channels: int = 80 |
|
sample_rate: int = 22050 |
|
output_sample_rate: int = 24000 |
|
|
|
@dataclass |
|
class XTTSAudioConfig: |
|
"""Configuration for audio processing parameters""" |
|
sample_rate: int = 22050 |
|
output_sample_rate: int = 24000 |
|
mel_channels: int = 80 |
|
hop_length: int = 256 |
|
win_length: int = 1024 |
|
n_fft: int = 1024 |
|
fmin: int = 0 |
|
fmax: int = 8000 |
|
power: float = 1.0 |
|
mel_norms_file: Optional[str] = None |
|
|
|
|
|
class XTTSGPTConfig(PretrainedConfig): |
|
"""Configuration class for the GPT component of XTTS.""" |
|
model_type = "xtts_gpt" |
|
|
|
def __init__( |
|
self, |
|
|
|
hidden_size: int = 1024, |
|
n_inner: int = 4096, |
|
num_hidden_layers: int = 30, |
|
num_attention_heads: int = 16, |
|
|
|
|
|
vocab_size: int = 6681, |
|
number_text_tokens: int = 6681, |
|
start_text_token: Optional[int] = None, |
|
stop_text_token: Optional[int] = None, |
|
|
|
|
|
num_audio_tokens: int = 1026, |
|
start_audio_token: int = 1024, |
|
stop_audio_token: int = 1025, |
|
|
|
|
|
max_audio_tokens: int = 605, |
|
max_text_tokens: int = 402, |
|
max_prompt_tokens: int = 70, |
|
gpt_max_audio_tokens: int = 605, |
|
|
|
|
|
use_masking_gt_prompt_approach: bool = True, |
|
use_perceiver_resampler: bool = True, |
|
kv_cache: bool = True, |
|
enable_redaction: bool = False, |
|
|
|
|
|
gpt_batch_size: int = 1, |
|
|
|
|
|
audio_config: Optional[Dict] = None, |
|
|
|
|
|
layer_norm_epsilon: float = 1e-5, |
|
initializer_range: float = 0.02, |
|
add_cross_attention: bool = False, |
|
scale_attn_by_inverse_layer_idx: bool = False, |
|
reorder_and_upcast_attn: bool = False, |
|
|
|
|
|
decoder_input_dim: int = 1024, |
|
architectures=["XttsGPT"], |
|
auto_map={ |
|
"AutoConfig": "AstraMindAI/xtts2-gpt--gpt_config.XTTSGPTConfig", |
|
"AutoModelForCausalLM": "AstraMindAI/xtts2-gpt--xtts2_gpt_modeling.XttsGPT", |
|
}, |
|
activation_function: str = "gelu", |
|
attn_pdrop: float = 0.1, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.architectures = architectures |
|
self.auto_map = auto_map |
|
self.audio_config = GPTAudioConfig( |
|
**audio_config if audio_config is not None else {} |
|
) |
|
self.activation_function = activation_function |
|
self.attn_pdrop = attn_pdrop |
|
self.hidden_size = hidden_size |
|
self.n_inner = n_inner |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
|
|
self.vocab_size = vocab_size |
|
self.number_text_tokens = number_text_tokens |
|
self.start_text_token = start_text_token |
|
self.stop_text_token = stop_text_token |
|
|
|
self.num_audio_tokens = num_audio_tokens |
|
self.start_audio_token = start_audio_token |
|
self.stop_audio_token = stop_audio_token |
|
|
|
self.max_audio_tokens = max_audio_tokens |
|
self.max_text_tokens = max_text_tokens |
|
self.max_prompt_tokens = max_prompt_tokens |
|
self.gpt_max_audio_tokens = gpt_max_audio_tokens |
|
|
|
self.use_masking_gt_prompt_approach = use_masking_gt_prompt_approach |
|
self.use_perceiver_resampler = use_perceiver_resampler |
|
self.kv_cache = kv_cache |
|
self.enable_redaction = enable_redaction |
|
|
|
self.gpt_batch_size = gpt_batch_size |
|
|
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
self.initializer_range = initializer_range |
|
self.add_cross_attention = add_cross_attention |
|
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
|
self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
|
self.decoder_input_dim = decoder_input_dim |
|
|
|
def to_dict(self) -> Dict: |
|
"""Convert the config to a dictionary.""" |
|
output = super().to_dict() |
|
output["audio_config"] = asdict(self.audio_config) |
|
return output |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSGPTConfig": |
|
"""Create a config from a dictionary.""" |
|
return cls(**config_dict) |
|
|
|
|
|
class XTTSConfig(PretrainedConfig): |
|
"""Configuration class for XTTS model components except GPT.""" |
|
model_type = "xtts" |
|
|
|
def __init__( |
|
self, |
|
|
|
audio_config: Optional[Dict] = None, |
|
input_sample_rate: int = 22050, |
|
output_sample_rate: int = 24000, |
|
output_hop_length: int = 256, |
|
|
|
|
|
decoder_input_dim: int = 1024, |
|
d_vector_dim: int = 512, |
|
cond_d_vector_in_each_upsampling_layer: bool = True, |
|
|
|
|
|
gpt_code_stride_len: int = 1024, |
|
duration_const: int = 102400, |
|
|
|
|
|
tokenizer_file: str = "", |
|
num_chars: int = 255, |
|
|
|
|
|
languages: Optional[List[str]] = None, |
|
|
|
|
|
gpt_config: Optional[Dict] = None, |
|
architectures=["Xtts"], |
|
auto_map = { |
|
"AutoConfig": "AstraMindAI/xtts2--xtts2_config.XTTSConfig", |
|
"AutoModelForCausalLM": "AstraMindAI/xtts2--xtts2_modeling.Xtts", |
|
}, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.architectures = architectures |
|
self.auto_map = auto_map |
|
|
|
self.audio_config = XTTSAudioConfig( |
|
**audio_config if audio_config is not None else {} |
|
) |
|
|
|
self.input_sample_rate = input_sample_rate |
|
self.output_sample_rate = output_sample_rate |
|
self.output_hop_length = output_hop_length |
|
|
|
self.decoder_input_dim = decoder_input_dim |
|
self.d_vector_dim = d_vector_dim |
|
self.cond_d_vector_in_each_upsampling_layer = cond_d_vector_in_each_upsampling_layer |
|
|
|
self.gpt_code_stride_len = gpt_code_stride_len |
|
self.duration_const = duration_const |
|
|
|
self.tokenizer_file = tokenizer_file |
|
self.num_chars = num_chars |
|
|
|
|
|
self.gpt = XTTSGPTConfig(**gpt_config if gpt_config is not None else {}) |
|
|
|
if languages is None: |
|
self.languages = [ |
|
"en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", |
|
"nl", "cs", "ar", "zh-cn", "hu", "ko", "ja", "hi" |
|
] |
|
else: |
|
self.languages = languages |
|
|
|
def to_dict(self) -> Dict: |
|
"""Convert the config to a dictionary.""" |
|
output = super().to_dict() |
|
output["audio_config"] = asdict(self.audio_config) |
|
output["gpt_config"] = self.gpt.to_dict() |
|
return output |
|
|
|
@classmethod |
|
def from_dict(cls, config_dict: Dict, *args, **kwargs) -> "XTTSConfig": |
|
"""Create a config from a dictionary.""" |
|
if "gpt_config" in config_dict: |
|
gpt_config = config_dict["gpt_config"] |
|
config_dict = {k: v for k, v in config_dict.items() if k != "gpt_config"} |
|
return cls(gpt_config=gpt_config, **config_dict) |
|
return cls(**config_dict) |