diffsingerkr / Datasets.py
codejin's picture
initial commit
67d041f
from argparse import Namespace
import torch
import numpy as np
import pickle, os, logging
from typing import Dict, List, Optional
import hgtk
from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration
def Decompose(syllable: str):
onset, nucleus, coda = hgtk.letter.decompose(syllable)
coda += '_'
return onset, nucleus, coda
def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]):
return [
token_dict[letter]
for letter in list(lyric)
]
def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None):
max_token_length = max_length or max([len(token) for token in tokens])
tokens = np.stack(
[np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['<X>']) for token in tokens],
axis= 0
)
return tokens
def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None):
max_note_length = max_length or max([len(note) for note in notes])
notes = np.stack(
[np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes],
axis= 0
)
return notes
def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None):
max_duration_length = max_length or max([len(duration) for duration in durations])
durations = np.stack(
[np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations],
axis= 0
)
return durations
def Feature_Stack(features: List[np.array], max_length: Optional[int]= None):
max_feature_length = max_length or max([feature.shape[0] for feature in features])
features = np.stack(
[np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features],
axis= 0
)
return features
def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None):
max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s])
log_f0s = np.stack(
[np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s],
axis= 0
)
return log_f0s
class Inference_Dataset(torch.utils.data.Dataset):
def __init__(
self,
token_dict: Dict[str, int],
singer_info_dict: Dict[str, int],
genre_info_dict: Dict[str, int],
durations: List[List[float]],
lyrics: List[List[str]],
notes: List[List[int]],
singers: List[str],
genres: List[str],
sample_rate: int,
frame_shift: int,
equality_duration: bool= False,
consonant_duration: int= 3
):
super().__init__()
self.token_dict = token_dict
self.singer_info_dict = singer_info_dict
self.genre_info_dict = genre_info_dict
self.equality_duration = equality_duration
self.consonant_duration = consonant_duration
self.patterns = []
for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)):
if not singer in self.singer_info_dict.keys():
logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index))
continue
if not genre in self.genre_info_dict.keys():
logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index))
continue
music = [x for x in zip(duration, lyric, note)]
singer_label = singer
text = lyric
lyric, note, duration = Convert_Feature_Based_Music(
music= music,
sample_rate= sample_rate,
frame_shift= frame_shift,
consonant_duration= consonant_duration,
equality_duration= equality_duration
)
lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration)
singer = self.singer_info_dict[singer]
genre = self.genre_info_dict[genre]
self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text))
def __getitem__(self, idx):
lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx]
return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text
def __len__(self):
return len(self.patterns)
class Inference_Collater:
def __init__(self,
token_dict: Dict[str, int]
):
self.token_dict = token_dict
def __call__(self, batch):
tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch)
lengths = np.array([len(token) for token in tokens])
max_length = max(lengths)
tokens = Token_Stack(tokens, self.token_dict, max_length)
notes = Note_Stack(notes, max_length)
durations = Duration_Stack(durations, max_length)
tokens = torch.LongTensor(tokens) # [Batch, Time]
notes = torch.LongTensor(notes) # [Batch, Time]
durations = torch.LongTensor(durations) # [Batch, Time]
lengths = torch.LongTensor(lengths) # [Batch]
singers = torch.LongTensor(singers) # [Batch]
genres = torch.LongTensor(genres) # [Batch]
lyrics = [''.join([(x if x != '<X>' else ' ') for x in lyric]) for lyric in lyrics]
return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics