|
"""Tokenization class for VITS.""" |
|
|
|
import json |
|
import os |
|
import re |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
from transformers.tokenization_utils import PreTrainedTokenizer |
|
from transformers.utils import logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json"} |
|
|
|
|
|
def has_non_roman_characters(input_string): |
|
|
|
non_roman_pattern = re.compile(r"[^\x00-\x7F]") |
|
|
|
|
|
match = non_roman_pattern.search(input_string) |
|
has_non_roman = match is not None |
|
return has_non_roman |
|
|
|
|
|
class IndicVitsTokenizer(PreTrainedTokenizer): |
|
""" |
|
Construct a VITS tokenizer. |
|
|
|
This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to |
|
this superclass for more information regarding those methods. |
|
|
|
Args: |
|
vocab_file (`str`): |
|
Path to the vocabulary file. |
|
language (`str`, *optional*): |
|
Language identifier. |
|
add_blank (`bool`, *optional*, defaults to `True`): |
|
Whether to insert token id 0 in between the other tokens. |
|
normalize (`bool`, *optional*, defaults to `True`): |
|
Whether to normalize the input text by removing all casing and punctuation. |
|
phonemize (`bool`, *optional*, defaults to `True`): |
|
Whether to convert the input text into phonemes. |
|
is_uroman (`bool`, *optional*, defaults to `False`): |
|
Whether the `uroman` Romanizer needs to be applied to the input text prior to tokenizing. |
|
""" |
|
|
|
vocab_files_names = VOCAB_FILES_NAMES |
|
model_input_names = ["input_ids", "attention_mask"] |
|
|
|
def __init__( |
|
self, |
|
vocab_file, |
|
pad_token="<pad>", |
|
unk_token="<unk>", |
|
language=None, |
|
add_blank=True, |
|
normalize=True, |
|
phonemize=True, |
|
is_uroman=False, |
|
**kwargs, |
|
) -> None: |
|
with open(vocab_file, encoding="utf-8") as vocab_handle: |
|
self.encoder = json.load(vocab_handle) |
|
|
|
self.decoder = {v: k for k, v in self.encoder.items()} |
|
self.language = language |
|
self.add_blank = add_blank |
|
self.normalize = normalize |
|
self.phonemize = phonemize |
|
|
|
self.is_uroman = is_uroman |
|
|
|
super().__init__( |
|
pad_token=pad_token, |
|
unk_token=unk_token, |
|
language=language, |
|
add_blank=add_blank, |
|
normalize=normalize, |
|
phonemize=phonemize, |
|
is_uroman=is_uroman, |
|
**kwargs, |
|
) |
|
|
|
@property |
|
def vocab_size(self): |
|
return len(self.encoder) |
|
|
|
def get_vocab(self): |
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} |
|
vocab.update(self.added_tokens_encoder) |
|
return vocab |
|
|
|
def normalize_text(self, input_string): |
|
"""Lowercase the input string, respecting any special token ids that may be part or entirely upper-cased.""" |
|
all_vocabulary = list(self.encoder.keys()) + list(self.added_tokens_encoder.keys()) |
|
filtered_text = "" |
|
|
|
i = 0 |
|
while i < len(input_string): |
|
found_match = False |
|
for word in all_vocabulary: |
|
if input_string[i : i + len(word)] == word: |
|
filtered_text += word |
|
i += len(word) |
|
found_match = True |
|
break |
|
|
|
if not found_match: |
|
filtered_text += input_string[i].lower() |
|
i += 1 |
|
|
|
return filtered_text |
|
|
|
def prepare_for_tokenization( |
|
self, text: str, is_split_into_words: bool = False, normalize: Optional[bool] = None, **kwargs |
|
) -> Tuple[str, Dict[str, Any]]: |
|
""" |
|
Performs any necessary transformations before tokenization. |
|
|
|
This method should pop the arguments from kwargs and return the remaining `kwargs` as well. We test the |
|
`kwargs` at the end of the encoding process to be sure all the arguments have been used. |
|
|
|
Args: |
|
text (`str`): |
|
The text to prepare. |
|
is_split_into_words (`bool`, *optional*, defaults to `False`): |
|
Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the |
|
tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace) |
|
which it will tokenize. |
|
normalize (`bool`, *optional*, defaults to `None`): |
|
Whether or not to apply punctuation and casing normalization to the text inputs. Typically, VITS is |
|
trained on lower-cased and un-punctuated text. Hence, normalization is used to ensure that the input |
|
text consists only of lower-case characters. |
|
kwargs (`Dict[str, Any]`, *optional*): |
|
Keyword arguments to use for the tokenization. |
|
|
|
Returns: |
|
`Tuple[str, Dict[str, Any]]`: The prepared text and the unused kwargs. |
|
""" |
|
normalize = normalize if normalize is not None else self.normalize |
|
|
|
if normalize: |
|
|
|
text = self.normalize_text(text) |
|
|
|
|
|
text = "".join(list(filter(lambda char: char in self.encoder, text))).strip() |
|
return text, kwargs |
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
"""Tokenize a string by inserting the `<pad>` token at the boundary between adjacent characters.""" |
|
tokens = list(text) |
|
|
|
return tokens |
|
|
|
def convert_tokens_to_string(self, tokens: List[str]) -> str: |
|
if self.add_blank and len(tokens) > 1: |
|
tokens = tokens[1::2] |
|
return "".join(tokens) |
|
|
|
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: |
|
ids = [] |
|
for token in tokens: |
|
ids.append(self._convert_token_to_id(token)) |
|
|
|
if self.add_blank: |
|
interspersed = [0] * (len(ids) * 2 + 1) |
|
interspersed[1::2] = ids |
|
ids = interspersed |
|
|
|
return ids |
|
|
|
def _convert_token_to_id(self, token): |
|
"""Converts a token (str) in an id using the vocab.""" |
|
return self.encoder.get(token, self.encoder.get(self.unk_token)) |
|
|
|
def _convert_id_to_token(self, index): |
|
"""Converts an index (integer) in a token (str) using the vocab.""" |
|
return self.decoder.get(index) |
|
|
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Union[Tuple[str], None]: |
|
if not os.path.isdir(save_directory): |
|
logger.error(f"Vocabulary path ({save_directory}) should be a directory") |
|
return |
|
|
|
vocab_file = os.path.join( |
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] |
|
) |
|
|
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") |
|
|
|
return (vocab_file,) |