codejin commited on
Commit
67d041f
·
1 Parent(s): 749e9f5

initial commit

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.pts filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
Arg_Parser.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+
3
+ def Recursive_Parse(args_dict):
4
+ parsed_dict = {}
5
+ for key, value in args_dict.items():
6
+ if isinstance(value, dict):
7
+ value = Recursive_Parse(value)
8
+ parsed_dict[key]= value
9
+
10
+ args = Namespace()
11
+ args.__dict__ = parsed_dict
12
+ return args
13
+
14
+ def To_Non_Recursive_Dict(
15
+ args: Namespace
16
+ ):
17
+ parsed_dict = {}
18
+ for key, value in args.__dict__.items():
19
+ if isinstance(value, Namespace):
20
+ value_dict = To_Non_Recursive_Dict(value)
21
+ for sub_key, sub_value in value_dict.items():
22
+ parsed_dict[f'{key}.{sub_key}'] = sub_value
23
+ else:
24
+ parsed_dict[key] = value
25
+
26
+ return parsed_dict
27
+
28
+
29
+
30
+
Checkpoint/S_200000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6482992a43b8a98554e7ef9e487a381c2717c5828d564e6dfc6cac16a0e16092
3
+ size 682529563
Datasets.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import torch
3
+ import numpy as np
4
+ import pickle, os, logging
5
+ from typing import Dict, List, Optional
6
+ import hgtk
7
+
8
+ from Pattern_Generator import Convert_Feature_Based_Music, Expand_by_Duration
9
+
10
+ def Decompose(syllable: str):
11
+ onset, nucleus, coda = hgtk.letter.decompose(syllable)
12
+ coda += '_'
13
+
14
+ return onset, nucleus, coda
15
+
16
+ def Lyric_to_Token(lyric: List[str], token_dict: Dict[str, int]):
17
+ return [
18
+ token_dict[letter]
19
+ for letter in list(lyric)
20
+ ]
21
+
22
+ def Token_Stack(tokens: List[List[int]], token_dict: Dict[str, int], max_length: Optional[int]= None):
23
+ max_token_length = max_length or max([len(token) for token in tokens])
24
+ tokens = np.stack(
25
+ [np.pad(token[:max_token_length], [0, max_token_length - len(token[:max_token_length])], constant_values= token_dict['<X>']) for token in tokens],
26
+ axis= 0
27
+ )
28
+ return tokens
29
+
30
+ def Note_Stack(notes: List[List[int]], max_length: Optional[int]= None):
31
+ max_note_length = max_length or max([len(note) for note in notes])
32
+ notes = np.stack(
33
+ [np.pad(note[:max_note_length], [0, max_note_length - len(note[:max_note_length])], constant_values= 0) for note in notes],
34
+ axis= 0
35
+ )
36
+ return notes
37
+
38
+ def Duration_Stack(durations: List[List[int]], max_length: Optional[int]= None):
39
+ max_duration_length = max_length or max([len(duration) for duration in durations])
40
+ durations = np.stack(
41
+ [np.pad(duration[:max_duration_length], [0, max_duration_length - len(duration[:max_duration_length])], constant_values= 0) for duration in durations],
42
+ axis= 0
43
+ )
44
+ return durations
45
+
46
+ def Feature_Stack(features: List[np.array], max_length: Optional[int]= None):
47
+ max_feature_length = max_length or max([feature.shape[0] for feature in features])
48
+ features = np.stack(
49
+ [np.pad(feature, [[0, max_feature_length - feature.shape[0]], [0, 0]], constant_values= -1.0) for feature in features],
50
+ axis= 0
51
+ )
52
+ return features
53
+
54
+ def Log_F0_Stack(log_f0s: List[np.array], max_length: int= None):
55
+ max_log_f0_length = max_length or max([len(log_f0) for log_f0 in log_f0s])
56
+ log_f0s = np.stack(
57
+ [np.pad(log_f0, [0, max_log_f0_length - len(log_f0)], constant_values= 0.0) for log_f0 in log_f0s],
58
+ axis= 0
59
+ )
60
+ return log_f0s
61
+
62
+ class Inference_Dataset(torch.utils.data.Dataset):
63
+ def __init__(
64
+ self,
65
+ token_dict: Dict[str, int],
66
+ singer_info_dict: Dict[str, int],
67
+ genre_info_dict: Dict[str, int],
68
+ durations: List[List[float]],
69
+ lyrics: List[List[str]],
70
+ notes: List[List[int]],
71
+ singers: List[str],
72
+ genres: List[str],
73
+ sample_rate: int,
74
+ frame_shift: int,
75
+ equality_duration: bool= False,
76
+ consonant_duration: int= 3
77
+ ):
78
+ super().__init__()
79
+ self.token_dict = token_dict
80
+ self.singer_info_dict = singer_info_dict
81
+ self.genre_info_dict = genre_info_dict
82
+ self.equality_duration = equality_duration
83
+ self.consonant_duration = consonant_duration
84
+
85
+ self.patterns = []
86
+ for index, (duration, lyric, note, singer, genre) in enumerate(zip(durations, lyrics, notes, singers, genres)):
87
+ if not singer in self.singer_info_dict.keys():
88
+ logging.warn('The singer \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(singer, index))
89
+ continue
90
+ if not genre in self.genre_info_dict.keys():
91
+ logging.warn('The genre \'{}\' is incorrect. The pattern \'{}\' is ignoired.'.format(genre, index))
92
+ continue
93
+
94
+ music = [x for x in zip(duration, lyric, note)]
95
+ singer_label = singer
96
+ text = lyric
97
+
98
+ lyric, note, duration = Convert_Feature_Based_Music(
99
+ music= music,
100
+ sample_rate= sample_rate,
101
+ frame_shift= frame_shift,
102
+ consonant_duration= consonant_duration,
103
+ equality_duration= equality_duration
104
+ )
105
+ lyric_expand, note_expand, duration_expand = Expand_by_Duration(lyric, note, duration)
106
+
107
+ singer = self.singer_info_dict[singer]
108
+ genre = self.genre_info_dict[genre]
109
+
110
+ self.patterns.append((lyric_expand, note_expand, duration_expand, singer, genre, singer_label, text))
111
+
112
+ def __getitem__(self, idx):
113
+ lyric, note, duration, singer, genre, singer_label, text = self.patterns[idx]
114
+
115
+ return Lyric_to_Token(lyric, self.token_dict), note, duration, singer, genre, singer_label, text
116
+
117
+ def __len__(self):
118
+ return len(self.patterns)
119
+
120
+ class Inference_Collater:
121
+ def __init__(self,
122
+ token_dict: Dict[str, int]
123
+ ):
124
+ self.token_dict = token_dict
125
+
126
+ def __call__(self, batch):
127
+ tokens, notes, durations, singers, genres, singer_labels, lyrics = zip(*batch)
128
+
129
+ lengths = np.array([len(token) for token in tokens])
130
+
131
+ max_length = max(lengths)
132
+
133
+ tokens = Token_Stack(tokens, self.token_dict, max_length)
134
+ notes = Note_Stack(notes, max_length)
135
+ durations = Duration_Stack(durations, max_length)
136
+
137
+ tokens = torch.LongTensor(tokens) # [Batch, Time]
138
+ notes = torch.LongTensor(notes) # [Batch, Time]
139
+ durations = torch.LongTensor(durations) # [Batch, Time]
140
+ lengths = torch.LongTensor(lengths) # [Batch]
141
+ singers = torch.LongTensor(singers) # [Batch]
142
+ genres = torch.LongTensor(genres) # [Batch]
143
+
144
+ lyrics = [''.join([(x if x != '<X>' else ' ') for x in lyric]) for lyric in lyrics]
145
+
146
+ return tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics
Hyper_Parameters.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Sound:
2
+ N_FFT: 2048
3
+ Mel_Dim: 80
4
+ Frame_Length: 1024
5
+ Frame_Shift: 256
6
+ Sample_Rate: 22050
7
+ Mel_F_Min: 0
8
+ Mel_F_Max: 8000
9
+
10
+ Feature_Type: 'Mel' #'Spectrogram', 'Mel'
11
+
12
+ Tokens: 77
13
+ Notes: 128
14
+ Durations: 5000
15
+ Genres: 1
16
+ Singers: 1
17
+ Duration:
18
+ Equality: false
19
+ Consonant_Duration: 3 # This is only used when Equality is False.
20
+
21
+ Encoder:
22
+ Size: 384
23
+ ConvFFT:
24
+ Stack: 6
25
+ Head: 2
26
+ Dropout_Rate: 0.1
27
+ Conv:
28
+ Stack: 2
29
+ Kernel_Size: 5
30
+ FFN:
31
+ Kernel_Size: 17
32
+
33
+ Diffusion:
34
+ Max_Step: 100
35
+ Size: 256
36
+ Kernel_Size: 5
37
+ Stack: 20
38
+
39
+ Token_Path: './YAML/Token.yaml'
40
+ Spectrogram_Range_Info_Path: './YAML/Spectrogram_Range_Info.yaml'
41
+ Mel_Range_Info_Path: './YAML/Mel_Range_Info.yaml'
42
+ Log_F0_Info_Path: './YAML/Log_F0_Info.yaml'
43
+ Log_Energy_Info_Path: './YAML/Log_Energy_Info.yaml'
44
+ Singer_Info_Path: './YAML/Singer_Info.yaml'
45
+ Genre_Info_Path: './YAML/Genre_Info.yaml'
Inference.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import logging, yaml, os, sys, argparse, math
4
+ import matplotlib.pyplot as plt
5
+ from tqdm import tqdm
6
+ from librosa import griffinlim
7
+
8
+ from Modules.Modules import DiffSinger
9
+ from Datasets import Inference_Dataset as Dataset, Inference_Collater as Collater
10
+ from meldataset import spectral_de_normalize_torch
11
+ from Arg_Parser import Recursive_Parse
12
+
13
+ import matplotlib as mpl
14
+ # 유니코드 깨짐현상 해결
15
+ mpl.rcParams['axes.unicode_minus'] = False
16
+ # 나눔고딕 폰트 적용
17
+ plt.rcParams["font.family"] = 'NanumGothic'
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO, stream=sys.stdout,
21
+ format= '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
22
+ )
23
+
24
+ class Inferencer:
25
+ def __init__(
26
+ self,
27
+ hp_path: str,
28
+ checkpoint_path: str,
29
+ batch_size= 1
30
+ ):
31
+ self.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
32
+
33
+ self.hp = Recursive_Parse(yaml.load(
34
+ open(hp_path, encoding='utf-8'),
35
+ Loader=yaml.Loader
36
+ ))
37
+
38
+ self.model = DiffSinger(self.hp).to(self.device)
39
+ if self.hp.Feature_Type == 'Mel':
40
+ self.vocoder = torch.jit.load('vocoder.pts', map_location='cpu').to(self.device)
41
+
42
+ if self.hp.Feature_Type == 'Spectrogram':
43
+ self.feature_range_info_dict = yaml.load(open(self.hp.Spectrogram_Range_Info_Path), Loader=yaml.Loader)
44
+ if self.hp.Feature_Type == 'Mel':
45
+ self.feature_range_info_dict = yaml.load(open(self.hp.Mel_Range_Info_Path), Loader=yaml.Loader)
46
+ self.index_singer_dict = {
47
+ value: key
48
+ for key, value in yaml.load(open(self.hp.Singer_Info_Path), Loader=yaml.Loader).items()
49
+ }
50
+
51
+ if self.hp.Feature_Type == 'Spectrogram':
52
+ self.feature_size = self.hp.Sound.N_FFT // 2 + 1
53
+ elif self.hp.Feature_Type == 'Mel':
54
+ self.feature_size = self.hp.Sound.Mel_Dim
55
+ else:
56
+ raise ValueError('Unknown feature type: {}'.format(self.hp.Feature_Type))
57
+
58
+ self.Load_Checkpoint(checkpoint_path)
59
+ self.batch_size = batch_size
60
+
61
+ def Dataset_Generate(self, message_times_list, lyrics, notes, singers, genres):
62
+ token_dict = yaml.load(open(self.hp.Token_Path), Loader=yaml.Loader)
63
+ singer_info_dict = yaml.load(open(self.hp.Singer_Info_Path), Loader=yaml.Loader)
64
+ genre_info_dict = yaml.load(open(self.hp.Genre_Info_Path), Loader=yaml.Loader)
65
+
66
+ return torch.utils.data.DataLoader(
67
+ dataset= Dataset(
68
+ token_dict= token_dict,
69
+ singer_info_dict= singer_info_dict,
70
+ genre_info_dict= genre_info_dict,
71
+ durations= message_times_list,
72
+ lyrics= lyrics,
73
+ notes= notes,
74
+ singers= singers,
75
+ genres= genres,
76
+ sample_rate= self.hp.Sound.Sample_Rate,
77
+ frame_shift= self.hp.Sound.Frame_Shift,
78
+ equality_duration= self.hp.Duration.Equality,
79
+ consonant_duration= self.hp.Duration.Consonant_Duration
80
+ ),
81
+ shuffle= False,
82
+ collate_fn= Collater(
83
+ token_dict= token_dict
84
+ ),
85
+ batch_size= self.batch_size,
86
+ num_workers= 0,
87
+ pin_memory= True
88
+ )
89
+
90
+ def Load_Checkpoint(self, path):
91
+ state_dict = torch.load(path, map_location= 'cpu')
92
+ self.model.load_state_dict(state_dict['Model']['DiffSVS'])
93
+ self.steps = state_dict['Steps']
94
+
95
+ self.model.eval()
96
+
97
+ logging.info('Checkpoint loaded at {} steps.'.format(self.steps))
98
+
99
+ @torch.inference_mode()
100
+ def Inference_Step(self, tokens, notes, durations, lengths, singers, genres, singer_labels, ddim_steps):
101
+ tokens = tokens.to(self.device, non_blocking=True)
102
+ notes = notes.to(self.device, non_blocking=True)
103
+ durations = durations.to(self.device, non_blocking=True)
104
+ lengths = lengths.to(self.device, non_blocking=True)
105
+ singers = singers.to(self.device, non_blocking=True)
106
+ genres = genres.to(self.device, non_blocking=True)
107
+
108
+ linear_predictions, diffusion_predictions, _, _ = self.model(
109
+ tokens= tokens,
110
+ notes= notes,
111
+ durations= durations,
112
+ lengths= lengths,
113
+ genres= genres,
114
+ singers= singers,
115
+ ddim_steps= ddim_steps
116
+ )
117
+ linear_predictions = linear_predictions.clamp(-1.0, 1.0)
118
+ diffusion_predictions = diffusion_predictions.clamp(-1.0, 1.0)
119
+
120
+ linear_prediction_list, diffusion_prediction_list = [], []
121
+ for linear_prediction, diffusion_prediction, singer in zip(linear_predictions, diffusion_predictions, singer_labels):
122
+ feature_max = self.feature_range_info_dict[singer]['Max']
123
+ feature_min = self.feature_range_info_dict[singer]['Min']
124
+ linear_prediction_list.append((linear_prediction + 1.0) / 2.0 * (feature_max - feature_min) + feature_min)
125
+ diffusion_prediction_list.append((diffusion_prediction + 1.0) / 2.0 * (feature_max - feature_min) + feature_min)
126
+ linear_predictions = torch.stack(linear_prediction_list, dim= 0)
127
+ diffusion_predictions = torch.stack(diffusion_prediction_list, dim= 0)
128
+
129
+ if self.hp.Feature_Type == 'Mel':
130
+ audios = self.vocoder(diffusion_predictions)
131
+ if audios.ndim == 1: # This is temporal because of the vocoder problem.
132
+ audios = audios.unsqueeze(0)
133
+ audios = [
134
+ audio[:min(length * self.hp.Sound.Frame_Shift, audio.size(0))].cpu().numpy()
135
+ for audio, length in zip(audios, lengths)
136
+ ]
137
+ elif self.hp.Feature_Type == 'Spectrogram':
138
+ audios = []
139
+ for prediction, length in zip(
140
+ diffusion_predictions,
141
+ lengths
142
+ ):
143
+ prediction = spectral_de_normalize_torch(prediction).cpu().numpy()
144
+ audio = griffinlim(prediction)[:min(prediction.size(1), length) * self.hp.Sound.Frame_Shift]
145
+ audio = (audio / np.abs(audio).max() * 32767.5).astype(np.int16)
146
+ audios.append(audio)
147
+
148
+ return audios
149
+
150
+ def Inference_Epoch(self, message_times_list, lyrics, notes, singers, genres, ddim_steps= None, use_tqdm= True):
151
+ dataloader = self.Dataset_Generate(
152
+ message_times_list= message_times_list,
153
+ lyrics= lyrics,
154
+ notes= notes,
155
+ singers= singers,
156
+ genres= genres
157
+ )
158
+ if use_tqdm:
159
+ dataloader = tqdm(
160
+ dataloader,
161
+ desc='[Inference]',
162
+ total= math.ceil(len(dataloader.dataset) / self.batch_size)
163
+ )
164
+ audios = []
165
+ for tokens, notes, durations, lengths, singers, genres, singer_labels, lyrics in dataloader:
166
+ audios.extend(self.Inference_Step(tokens, notes, durations, lengths, singers, genres, singer_labels, ddim_steps))
167
+
168
+ return audios
Modules/Diffusion.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from argparse import Namespace
4
+ from typing import Optional, List, Dict, Union
5
+ from tqdm import tqdm
6
+
7
+ from .Layer import Conv1d, Lambda
8
+
9
+ class Diffusion(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ hyper_parameters: Namespace
13
+ ):
14
+ super().__init__()
15
+ self.hp = hyper_parameters
16
+
17
+ if self.hp.Feature_Type == 'Mel':
18
+ self.feature_size = self.hp.Sound.Mel_Dim
19
+ elif self.hp.Feature_Type == 'Spectrogram':
20
+ self.feature_size = self.hp.Sound.N_FFT // 2 + 1
21
+
22
+ self.denoiser = Denoiser(
23
+ hyper_parameters= self.hp
24
+ )
25
+
26
+ self.timesteps = self.hp.Diffusion.Max_Step
27
+ betas = torch.linspace(1e-4, 0.06, self.timesteps)
28
+ alphas = 1.0 - betas
29
+ alphas_cumprod = torch.cumprod(alphas, axis= 0)
30
+ alphas_cumprod_prev = torch.cat([torch.tensor([1.0]), alphas_cumprod[:-1]])
31
+
32
+ # calculations for diffusion q(x_t | x_{t-1}) and others
33
+ self.register_buffer('alphas_cumprod', alphas_cumprod) # [Diffusion_t]
34
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) # [Diffusion_t]
35
+ self.register_buffer('sqrt_alphas_cumprod', alphas_cumprod.sqrt())
36
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', (1.0 - alphas_cumprod).sqrt())
37
+ self.register_buffer('sqrt_recip_alphas_cumprod', (1.0 / alphas_cumprod).sqrt())
38
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', (1.0 / alphas_cumprod - 1.0).sqrt())
39
+
40
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
41
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
42
+
43
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
44
+ self.register_buffer('posterior_log_variance', torch.maximum(posterior_variance, torch.tensor([1e-20])).log())
45
+ self.register_buffer('posterior_mean_coef1', betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod))
46
+ self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod))
47
+
48
+ def forward(
49
+ self,
50
+ encodings: torch.Tensor,
51
+ features: torch.Tensor= None
52
+ ):
53
+ '''
54
+ encodings: [Batch, Enc_d, Enc_t]
55
+ features: [Batch, Feature_d, Feature_t]
56
+ feature_lengths: [Batch]
57
+ '''
58
+ if not features is None: # train
59
+ diffusion_steps = torch.randint(
60
+ low= 0,
61
+ high= self.timesteps,
62
+ size= (encodings.size(0),),
63
+ dtype= torch.long,
64
+ device= encodings.device
65
+ ) # random single step
66
+
67
+ noises, epsilons = self.Get_Noise_Epsilon_for_Train(
68
+ features= features,
69
+ encodings= encodings,
70
+ diffusion_steps= diffusion_steps,
71
+ )
72
+ return None, noises, epsilons
73
+ else: # inference
74
+ features = self.Sampling(
75
+ encodings= encodings,
76
+ )
77
+ return features, None, None
78
+
79
+ def Sampling(
80
+ self,
81
+ encodings: torch.Tensor,
82
+ ):
83
+ features = torch.randn(
84
+ size= (encodings.size(0), self.feature_size, encodings.size(2)),
85
+ device= encodings.device
86
+ )
87
+ for diffusion_step in reversed(range(self.timesteps)):
88
+ features = self.P_Sampling(
89
+ features= features,
90
+ encodings= encodings,
91
+ diffusion_steps= torch.full(
92
+ size= (encodings.size(0), ),
93
+ fill_value= diffusion_step,
94
+ dtype= torch.long,
95
+ device= encodings.device
96
+ ),
97
+ )
98
+
99
+ return features
100
+
101
+ def P_Sampling(
102
+ self,
103
+ features: torch.Tensor,
104
+ encodings: torch.Tensor,
105
+ diffusion_steps: torch.Tensor,
106
+ ):
107
+ posterior_means, posterior_log_variances = self.Get_Posterior(
108
+ features= features,
109
+ encodings= encodings,
110
+ diffusion_steps= diffusion_steps,
111
+ )
112
+
113
+ noises = torch.randn_like(features) # [Batch, Feature_d, Feature_d]
114
+ masks = (diffusion_steps > 0).float().unsqueeze(1).unsqueeze(1) #[Batch, 1, 1]
115
+
116
+ return posterior_means + masks * (0.5 * posterior_log_variances).exp() * noises
117
+
118
+ def Get_Posterior(
119
+ self,
120
+ features: torch.Tensor,
121
+ encodings: torch.Tensor,
122
+ diffusion_steps: torch.Tensor
123
+ ):
124
+ noised_predictions = self.denoiser(
125
+ features= features,
126
+ encodings= encodings,
127
+ diffusion_steps= diffusion_steps
128
+ )
129
+
130
+ epsilons = \
131
+ features * self.sqrt_recip_alphas_cumprod[diffusion_steps][:, None, None] - \
132
+ noised_predictions * self.sqrt_recipm1_alphas_cumprod[diffusion_steps][:, None, None]
133
+ epsilons.clamp_(-1.0, 1.0) # clipped
134
+
135
+ posterior_means = \
136
+ epsilons * self.posterior_mean_coef1[diffusion_steps][:, None, None] + \
137
+ features * self.posterior_mean_coef2[diffusion_steps][:, None, None]
138
+ posterior_log_variances = \
139
+ self.posterior_log_variance[diffusion_steps][:, None, None]
140
+
141
+ return posterior_means, posterior_log_variances
142
+
143
+ def Get_Noise_Epsilon_for_Train(
144
+ self,
145
+ features: torch.Tensor,
146
+ encodings: torch.Tensor,
147
+ diffusion_steps: torch.Tensor,
148
+ ):
149
+ noises = torch.randn_like(features)
150
+
151
+ noised_features = \
152
+ features * self.sqrt_alphas_cumprod[diffusion_steps][:, None, None] + \
153
+ noises * self.sqrt_one_minus_alphas_cumprod[diffusion_steps][:, None, None]
154
+
155
+ epsilons = self.denoiser(
156
+ features= noised_features,
157
+ encodings= encodings,
158
+ diffusion_steps= diffusion_steps
159
+ )
160
+
161
+ return noises, epsilons
162
+
163
+ def DDIM(
164
+ self,
165
+ encodings: torch.Tensor,
166
+ ddim_steps: int,
167
+ eta: float= 0.0,
168
+ temperature: float= 1.0,
169
+ use_tqdm: bool= False
170
+ ):
171
+ ddim_timesteps = self.Get_DDIM_Steps(
172
+ ddim_steps= ddim_steps
173
+ )
174
+ sigmas, alphas, alphas_prev = self.Get_DDIM_Sampling_Parameters(
175
+ ddim_timesteps= ddim_timesteps,
176
+ eta= eta
177
+ )
178
+ sqrt_one_minus_alphas = (1. - alphas).sqrt()
179
+
180
+ features = torch.randn(
181
+ size= (encodings.size(0), self.feature_size, encodings.size(2)),
182
+ device= encodings.device
183
+ )
184
+
185
+ setp_range = reversed(range(ddim_steps))
186
+ if use_tqdm:
187
+ tqdm(
188
+ setp_range,
189
+ desc= '[Diffusion]',
190
+ total= ddim_steps
191
+ )
192
+
193
+ for diffusion_steps in setp_range:
194
+ noised_predictions = self.denoiser(
195
+ features= features,
196
+ encodings= encodings,
197
+ diffusion_steps= torch.full(
198
+ size= (encodings.size(0), ),
199
+ fill_value= diffusion_steps,
200
+ dtype= torch.long,
201
+ device= encodings.device
202
+ )
203
+ )
204
+
205
+ feature_starts = (features - sqrt_one_minus_alphas[diffusion_steps] * noised_predictions) / alphas[diffusion_steps].sqrt()
206
+ direction_pointings = (1.0 - alphas_prev[diffusion_steps] - sigmas[diffusion_steps].pow(2.0)) * noised_predictions
207
+ noises = sigmas[diffusion_steps] * torch.randn_like(features) * temperature
208
+
209
+ features = alphas_prev[diffusion_steps].sqrt() * feature_starts + direction_pointings + noises
210
+
211
+ return features
212
+
213
+ # https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py
214
+ def Get_DDIM_Steps(
215
+ self,
216
+ ddim_steps: int,
217
+ ddim_discr_method: str= 'uniform'
218
+ ):
219
+ if ddim_discr_method == 'uniform':
220
+ ddim_timesteps = torch.arange(0, self.timesteps, self.timesteps // ddim_steps).long()
221
+ elif ddim_discr_method == 'quad':
222
+ ddim_timesteps = torch.linspace(0, (torch.tensor(self.timesteps) * 0.8).sqrt(), ddim_steps).pow(2.0).long()
223
+ else:
224
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
225
+
226
+ ddim_timesteps[-1] = self.timesteps - 1
227
+
228
+ return ddim_timesteps
229
+
230
+ def Get_DDIM_Sampling_Parameters(self, ddim_timesteps, eta):
231
+ alphas = self.alphas_cumprod[ddim_timesteps]
232
+ alphas_prev = self.alphas_cumprod_prev[ddim_timesteps]
233
+ sigmas = eta * ((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)).sqrt()
234
+
235
+ return sigmas, alphas, alphas_prev
236
+
237
+ class Denoiser(torch.nn.Module):
238
+ def __init__(
239
+ self,
240
+ hyper_parameters: Namespace
241
+ ):
242
+ super().__init__()
243
+ self.hp = hyper_parameters
244
+
245
+ if self.hp.Feature_Type == 'Mel':
246
+ feature_size = self.hp.Sound.Mel_Dim
247
+ elif self.hp.Feature_Type == 'Spectrogram':
248
+ feature_size = self.hp.Sound.N_FFT // 2 + 1
249
+
250
+ self.prenet = torch.nn.Sequential(
251
+ Conv1d(
252
+ in_channels= feature_size,
253
+ out_channels= self.hp.Diffusion.Size,
254
+ kernel_size= 1,
255
+ w_init_gain= 'relu'
256
+ ),
257
+ torch.nn.Mish()
258
+ )
259
+
260
+ self.step_ffn = torch.nn.Sequential(
261
+ Diffusion_Embedding(
262
+ channels= self.hp.Diffusion.Size
263
+ ),
264
+ Lambda(lambda x: x.unsqueeze(2)),
265
+ Conv1d(
266
+ in_channels= self.hp.Diffusion.Size,
267
+ out_channels= self.hp.Diffusion.Size * 4,
268
+ kernel_size= 1,
269
+ w_init_gain= 'relu'
270
+ ),
271
+ torch.nn.Mish(),
272
+ Conv1d(
273
+ in_channels= self.hp.Diffusion.Size * 4,
274
+ out_channels= self.hp.Diffusion.Size,
275
+ kernel_size= 1,
276
+ w_init_gain= 'linear'
277
+ )
278
+ )
279
+
280
+ self.residual_blocks = torch.nn.ModuleList([
281
+ Residual_Block(
282
+ in_channels= self.hp.Diffusion.Size,
283
+ kernel_size= self.hp.Diffusion.Kernel_Size,
284
+ condition_channels= self.hp.Encoder.Size + feature_size
285
+ )
286
+ for _ in range(self.hp.Diffusion.Stack)
287
+ ])
288
+
289
+ self.projection = torch.nn.Sequential(
290
+ Conv1d(
291
+ in_channels= self.hp.Diffusion.Size,
292
+ out_channels= self.hp.Diffusion.Size,
293
+ kernel_size= 1,
294
+ w_init_gain= 'relu'
295
+ ),
296
+ torch.nn.ReLU(),
297
+ Conv1d(
298
+ in_channels= self.hp.Diffusion.Size,
299
+ out_channels= feature_size,
300
+ kernel_size= 1
301
+ ),
302
+ )
303
+ torch.nn.init.zeros_(self.projection[-1].weight) # This is key factor....
304
+
305
+ def forward(
306
+ self,
307
+ features: torch.Tensor,
308
+ encodings: torch.Tensor,
309
+ diffusion_steps: torch.Tensor
310
+ ):
311
+ '''
312
+ features: [Batch, Feature_d, Feature_t]
313
+ encodings: [Batch, Enc_d, Feature_t]
314
+ diffusion_steps: [Batch]
315
+ '''
316
+ x = self.prenet(features)
317
+
318
+ diffusion_steps = self.step_ffn(diffusion_steps) # [Batch, Res_d, 1]
319
+
320
+ skips_list = []
321
+ for residual_block in self.residual_blocks:
322
+ x, skips = residual_block(
323
+ x= x,
324
+ conditions= encodings,
325
+ diffusion_steps= diffusion_steps
326
+ )
327
+ skips_list.append(skips)
328
+
329
+ x = torch.stack(skips_list, dim= 0).sum(dim= 0) / math.sqrt(self.hp.Diffusion.Stack)
330
+ x = self.projection(x)
331
+
332
+ return x
333
+
334
+ class Diffusion_Embedding(torch.nn.Module):
335
+ def __init__(
336
+ self,
337
+ channels: int
338
+ ):
339
+ super().__init__()
340
+ self.channels = channels
341
+
342
+ def forward(self, x: torch.Tensor):
343
+ half_channels = self.channels // 2 # sine and cosine
344
+ embeddings = math.log(10000.0) / (half_channels - 1)
345
+ embeddings = torch.exp(torch.arange(half_channels, device= x.device) * -embeddings)
346
+ embeddings = x.unsqueeze(1) * embeddings.unsqueeze(0)
347
+ embeddings = torch.cat([embeddings.sin(), embeddings.cos()], dim= -1)
348
+
349
+ return embeddings
350
+
351
+ class Residual_Block(torch.nn.Module):
352
+ def __init__(
353
+ self,
354
+ in_channels: int,
355
+ kernel_size: int,
356
+ condition_channels: int
357
+ ):
358
+ super().__init__()
359
+ self.in_channels = in_channels
360
+
361
+ self.condition = Conv1d(
362
+ in_channels= condition_channels,
363
+ out_channels= in_channels * 2,
364
+ kernel_size= 1
365
+ )
366
+ self.diffusion_step = Conv1d(
367
+ in_channels= in_channels,
368
+ out_channels= in_channels,
369
+ kernel_size= 1
370
+ )
371
+
372
+ self.conv = Conv1d(
373
+ in_channels= in_channels,
374
+ out_channels= in_channels * 2,
375
+ kernel_size= kernel_size,
376
+ padding= kernel_size // 2
377
+ )
378
+
379
+ self.projection = Conv1d(
380
+ in_channels= in_channels,
381
+ out_channels= in_channels * 2,
382
+ kernel_size= 1
383
+ )
384
+
385
+ def forward(
386
+ self,
387
+ x: torch.Tensor,
388
+ conditions: torch.Tensor,
389
+ diffusion_steps: torch.Tensor
390
+ ):
391
+ residuals = x
392
+
393
+ conditions = self.condition(conditions)
394
+ diffusion_steps = self.diffusion_step(diffusion_steps)
395
+
396
+ x = self.conv(x + diffusion_steps) + conditions
397
+ x_a, x_b = x.chunk(chunks= 2, dim= 1)
398
+ x = x_a.sigmoid() * x_b.tanh()
399
+
400
+ x = self.projection(x)
401
+ x, skips = x.chunk(chunks= 2, dim= 1)
402
+
403
+ return (x + residuals) / math.sqrt(2.0), skips
Modules/Layer.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Conv1d(torch.nn.Conv1d):
4
+ def __init__(self, w_init_gain= 'linear', *args, **kwargs):
5
+ self.w_init_gain = w_init_gain
6
+ super().__init__(*args, **kwargs)
7
+
8
+ def reset_parameters(self):
9
+ if self.w_init_gain in ['zero']:
10
+ torch.nn.init.zeros_(self.weight)
11
+ elif self.w_init_gain is None:
12
+ pass
13
+ elif self.w_init_gain in ['relu', 'leaky_relu']:
14
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
15
+ elif self.w_init_gain == 'glu':
16
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
17
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
18
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
19
+ elif self.w_init_gain == 'gate':
20
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
21
+ torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
22
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
23
+ else:
24
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
25
+ if not self.bias is None:
26
+ torch.nn.init.zeros_(self.bias)
27
+
28
+ class ConvTranspose1d(torch.nn.ConvTranspose1d):
29
+ def __init__(self, w_init_gain= 'linear', *args, **kwargs):
30
+ self.w_init_gain = w_init_gain
31
+ super().__init__(*args, **kwargs)
32
+
33
+ def reset_parameters(self):
34
+ if self.w_init_gain in ['zero']:
35
+ torch.nn.init.zeros_(self.weight)
36
+ elif self.w_init_gain in ['relu', 'leaky_relu']:
37
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
38
+ elif self.w_init_gain == 'glu':
39
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
40
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
41
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
42
+ elif self.w_init_gain == 'gate':
43
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
44
+ torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
45
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
46
+ else:
47
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
48
+ if not self.bias is None:
49
+ torch.nn.init.zeros_(self.bias)
50
+
51
+ class Conv2d(torch.nn.Conv2d):
52
+ def __init__(self, w_init_gain= 'linear', *args, **kwargs):
53
+ self.w_init_gain = w_init_gain
54
+ super().__init__(*args, **kwargs)
55
+
56
+ def reset_parameters(self):
57
+ if self.w_init_gain in ['zero']:
58
+ torch.nn.init.zeros_(self.weight)
59
+ elif self.w_init_gain in ['relu', 'leaky_relu']:
60
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
61
+ elif self.w_init_gain == 'glu':
62
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
63
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
64
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
65
+ elif self.w_init_gain == 'gate':
66
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
67
+ torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
68
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
69
+ else:
70
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
71
+ if not self.bias is None:
72
+ torch.nn.init.zeros_(self.bias)
73
+
74
+ class ConvTranspose2d(torch.nn.ConvTranspose2d):
75
+ def __init__(self, w_init_gain= 'linear', *args, **kwargs):
76
+ self.w_init_gain = w_init_gain
77
+ super().__init__(*args, **kwargs)
78
+
79
+ def reset_parameters(self):
80
+ if self.w_init_gain in ['zero']:
81
+ torch.nn.init.zeros_(self.weight)
82
+ elif self.w_init_gain in ['relu', 'leaky_relu']:
83
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
84
+ elif self.w_init_gain == 'glu':
85
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
86
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
87
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
88
+ elif self.w_init_gain == 'gate':
89
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
90
+ torch.nn.init.xavier_uniform_(self.weight[:self.out_channels // 2], gain= torch.nn.init.calculate_gain('tanh'))
91
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
92
+ else:
93
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
94
+ if not self.bias is None:
95
+ torch.nn.init.zeros_(self.bias)
96
+
97
+ class Linear(torch.nn.Linear):
98
+ def __init__(self, w_init_gain= 'linear', *args, **kwargs):
99
+ self.w_init_gain = w_init_gain
100
+ super().__init__(*args, **kwargs)
101
+
102
+ def reset_parameters(self):
103
+ if self.w_init_gain in ['zero']:
104
+ torch.nn.init.zeros_(self.weight)
105
+ elif self.w_init_gain in ['relu', 'leaky_relu']:
106
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
107
+ elif self.w_init_gain == 'glu':
108
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
109
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
110
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
111
+ else:
112
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
113
+ if not self.bias is None:
114
+ torch.nn.init.zeros_(self.bias)
115
+
116
+ class Lambda(torch.nn.Module):
117
+ def __init__(self, lambd):
118
+ super().__init__()
119
+ self.lambd = lambd
120
+
121
+ def forward(self, x):
122
+ return self.lambd(x)
123
+
124
+ class Residual(torch.nn.Module):
125
+ def __init__(self, module):
126
+ super().__init__()
127
+ self.module = module
128
+
129
+ def forward(self, *args, **kwargs):
130
+ return self.module(*args, **kwargs)
131
+
132
+ class LayerNorm(torch.nn.Module):
133
+ def __init__(self, num_features: int, eps: float= 1e-5):
134
+ super().__init__()
135
+
136
+ self.eps = eps
137
+ self.gamma = torch.nn.Parameter(torch.ones(num_features))
138
+ self.beta = torch.nn.Parameter(torch.zeros(num_features))
139
+
140
+
141
+ def forward(self, inputs: torch.Tensor):
142
+ means = inputs.mean(dim= 1, keepdim= True)
143
+ variances = (inputs - means).pow(2.0).mean(dim= 1, keepdim= True)
144
+
145
+ x = (inputs - means) * (variances + self.eps).rsqrt()
146
+
147
+ shape = [1, -1] + [1] * (x.ndim - 2)
148
+
149
+ return x * self.gamma.view(*shape) + self.beta.view(*shape)
150
+
151
+ class LightweightConv1d(torch.nn.Module):
152
+ '''
153
+ Args:
154
+ input_size: # of channels of the input and output
155
+ kernel_size: convolution channels
156
+ padding: padding
157
+ num_heads: number of heads used. The weight is of shape
158
+ `(num_heads, 1, kernel_size)`
159
+ weight_softmax: normalize the weight with softmax before the convolution
160
+
161
+ Shape:
162
+ Input: BxCxT, i.e. (batch_size, input_size, timesteps)
163
+ Output: BxCxT, i.e. (batch_size, input_size, timesteps)
164
+
165
+ Attributes:
166
+ weight: the learnable weights of the module of shape
167
+ `(num_heads, 1, kernel_size)`
168
+ bias: the learnable bias of the module of shape `(input_size)`
169
+ '''
170
+
171
+ def __init__(
172
+ self,
173
+ input_size,
174
+ kernel_size=1,
175
+ padding=0,
176
+ num_heads=1,
177
+ weight_softmax=False,
178
+ bias=False,
179
+ weight_dropout=0.0,
180
+ w_init_gain= 'linear'
181
+ ):
182
+ super().__init__()
183
+ self.input_size = input_size
184
+ self.kernel_size = kernel_size
185
+ self.num_heads = num_heads
186
+ self.padding = padding
187
+ self.weight_softmax = weight_softmax
188
+ self.weight = torch.nn.Parameter(torch.Tensor(num_heads, 1, kernel_size))
189
+ self.w_init_gain = w_init_gain
190
+
191
+ if bias:
192
+ self.bias = torch.nn.Parameter(torch.Tensor(input_size))
193
+ else:
194
+ self.bias = None
195
+ self.weight_dropout_module = FairseqDropout(
196
+ weight_dropout, module_name=self.__class__.__name__
197
+ )
198
+ self.reset_parameters()
199
+
200
+ def reset_parameters(self):
201
+ if self.w_init_gain in ['relu', 'leaky_relu']:
202
+ torch.nn.init.kaiming_uniform_(self.weight, nonlinearity= self.w_init_gain)
203
+ elif self.w_init_gain == 'glu':
204
+ assert self.out_channels % 2 == 0, 'The out_channels of GLU requires even number.'
205
+ torch.nn.init.kaiming_uniform_(self.weight[:self.out_channels // 2], nonlinearity= 'linear')
206
+ torch.nn.init.xavier_uniform_(self.weight[self.out_channels // 2:], gain= torch.nn.init.calculate_gain('sigmoid'))
207
+ else:
208
+ torch.nn.init.xavier_uniform_(self.weight, gain= torch.nn.init.calculate_gain(self.w_init_gain))
209
+ if not self.bias is None:
210
+ torch.nn.init.zeros_(self.bias)
211
+
212
+ def forward(self, input):
213
+ """
214
+ input size: B x C x T
215
+ output size: B x C x T
216
+ """
217
+ B, C, T = input.size()
218
+ H = self.num_heads
219
+
220
+ weight = self.weight
221
+ if self.weight_softmax:
222
+ weight = weight.softmax(dim=-1)
223
+
224
+ weight = self.weight_dropout_module(weight)
225
+ # Merge every C/H entries into the batch dimension (C = self.input_size)
226
+ # B x C x T -> (B * C/H) x H x T
227
+ # One can also expand the weight to C x 1 x K by a factor of C/H
228
+ # and do not reshape the input instead, which is slow though
229
+ input = input.view(-1, H, T)
230
+ output = torch.nn.functional.conv1d(input, weight, padding=self.padding, groups=self.num_heads)
231
+ output = output.view(B, C, T)
232
+ if self.bias is not None:
233
+ output = output + self.bias.view(1, -1, 1)
234
+
235
+ return output
236
+
237
+ class FairseqDropout(torch.nn.Module):
238
+ def __init__(self, p, module_name=None):
239
+ super().__init__()
240
+ self.p = p
241
+ self.module_name = module_name
242
+ self.apply_during_inference = False
243
+
244
+ def forward(self, x, inplace: bool = False):
245
+ if self.training or self.apply_during_inference:
246
+ return torch.nn.functional.dropout(x, p=self.p, training=True, inplace=inplace)
247
+ else:
248
+ return x
249
+
250
+ class LinearAttention(torch.nn.Module):
251
+ def __init__(
252
+ self,
253
+ channels: int,
254
+ calc_channels: int,
255
+ num_heads: int,
256
+ dropout_rate: float= 0.1,
257
+ use_scale: bool= True,
258
+ use_residual: bool= True,
259
+ use_norm: bool= True
260
+ ):
261
+ super().__init__()
262
+ assert calc_channels % num_heads == 0
263
+ self.calc_channels = calc_channels
264
+ self.num_heads = num_heads
265
+ self.use_scale = use_scale
266
+ self.use_residual = use_residual
267
+ self.use_norm = use_norm
268
+
269
+ self.prenet = Conv1d(
270
+ in_channels= channels,
271
+ out_channels= calc_channels * 3,
272
+ kernel_size= 1,
273
+ bias=False,
274
+ w_init_gain= 'linear'
275
+ )
276
+ self.projection = Conv1d(
277
+ in_channels= calc_channels,
278
+ out_channels= channels,
279
+ kernel_size= 1,
280
+ w_init_gain= 'linear'
281
+ )
282
+ self.dropout = torch.nn.Dropout(p= dropout_rate)
283
+
284
+ if use_scale:
285
+ self.scale = torch.nn.Parameter(torch.zeros(1))
286
+
287
+ if use_norm:
288
+ self.norm = LayerNorm(num_features= channels)
289
+
290
+ def forward(self, x: torch.Tensor, *args, **kwargs):
291
+ '''
292
+ x: [Batch, Enc_d, Enc_t]
293
+ '''
294
+ residuals = x
295
+
296
+ x = self.prenet(x) # [Batch, Calc_d * 3, Enc_t]
297
+ x = x.view(x.size(0), self.num_heads, x.size(1) // self.num_heads, x.size(2)) # [Batch, Head, Calc_d // Head * 3, Enc_t]
298
+ queries, keys, values = x.chunk(chunks= 3, dim= 2) # [Batch, Head, Calc_d // Head, Enc_t] * 3
299
+ keys = (keys + 1e-5).softmax(dim= 3)
300
+
301
+ contexts = keys @ values.permute(0, 1, 3, 2) # [Batch, Head, Calc_d // Head, Calc_d // Head]
302
+ contexts = contexts.permute(0, 1, 3, 2) @ queries # [Batch, Head, Calc_d // Head, Enc_t]
303
+ contexts = contexts.view(contexts.size(0), contexts.size(1) * contexts.size(2), contexts.size(3)) # [Batch, Calc_d, Enc_t]
304
+ contexts = self.projection(contexts) # [Batch, Enc_d, Enc_t]
305
+
306
+ if self.use_scale:
307
+ contexts = self.scale * contexts
308
+
309
+ contexts = self.dropout(contexts)
310
+
311
+ if self.use_residual:
312
+ contexts = contexts + residuals
313
+
314
+ if self.use_norm:
315
+ contexts = self.norm(contexts)
316
+
317
+ return contexts
Modules/Modules.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import Namespace
2
+ import torch
3
+ import math
4
+ from typing import Union
5
+
6
+ from .Layer import Conv1d, LayerNorm, LinearAttention
7
+ from .Diffusion import Diffusion
8
+
9
+ class DiffSinger(torch.nn.Module):
10
+ def __init__(self, hyper_parameters: Namespace):
11
+ super().__init__()
12
+ self.hp = hyper_parameters
13
+
14
+ self.encoder = Encoder(self.hp)
15
+ self.diffusion = Diffusion(self.hp)
16
+
17
+ def forward(
18
+ self,
19
+ tokens: torch.LongTensor,
20
+ notes: torch.LongTensor,
21
+ durations: torch.LongTensor,
22
+ lengths: torch.LongTensor,
23
+ genres: torch.LongTensor,
24
+ singers: torch.LongTensor,
25
+ features: Union[torch.FloatTensor, None]= None,
26
+ ddim_steps: Union[int, None]= None
27
+ ):
28
+ encodings, linear_predictions = self.encoder(
29
+ tokens= tokens,
30
+ notes= notes,
31
+ durations= durations,
32
+ lengths= lengths,
33
+ genres= genres,
34
+ singers= singers
35
+ ) # [Batch, Enc_d, Feature_t]
36
+
37
+ encodings = torch.cat([encodings, linear_predictions], dim= 1) # [Batch, Enc_d + Feature_d, Feature_t]
38
+
39
+ if not features is None or ddim_steps is None or ddim_steps == self.hp.Diffusion.Max_Step:
40
+ diffusion_predictions, noises, epsilons = self.diffusion(
41
+ encodings= encodings,
42
+ features= features,
43
+ )
44
+ else:
45
+ noises, epsilons = None, None
46
+ diffusion_predictions = self.diffusion.DDIM(
47
+ encodings= encodings,
48
+ ddim_steps= ddim_steps
49
+ )
50
+
51
+ return linear_predictions, diffusion_predictions, noises, epsilons
52
+
53
+
54
+ class Encoder(torch.nn.Module):
55
+ def __init__(
56
+ self,
57
+ hyper_parameters: Namespace
58
+ ):
59
+ super().__init__()
60
+ self.hp = hyper_parameters
61
+
62
+ if self.hp.Feature_Type == 'Mel':
63
+ self.feature_size = self.hp.Sound.Mel_Dim
64
+ elif self.hp.Feature_Type == 'Spectrogram':
65
+ self.feature_size = self.hp.Sound.N_FFT // 2 + 1
66
+
67
+ self.token_embedding = torch.nn.Embedding(
68
+ num_embeddings= self.hp.Tokens,
69
+ embedding_dim= self.hp.Encoder.Size
70
+ )
71
+ self.note_embedding = torch.nn.Embedding(
72
+ num_embeddings= self.hp.Notes,
73
+ embedding_dim= self.hp.Encoder.Size
74
+ )
75
+ self.duration_embedding = Duration_Positional_Encoding(
76
+ num_embeddings= self.hp.Durations,
77
+ embedding_dim= self.hp.Encoder.Size
78
+ )
79
+ self.genre_embedding = torch.nn.Embedding(
80
+ num_embeddings= self.hp.Genres,
81
+ embedding_dim= self.hp.Encoder.Size,
82
+ )
83
+ self.singer_embedding = torch.nn.Embedding(
84
+ num_embeddings= self.hp.Singers,
85
+ embedding_dim= self.hp.Encoder.Size,
86
+ )
87
+ torch.nn.init.xavier_uniform_(self.token_embedding.weight)
88
+ torch.nn.init.xavier_uniform_(self.note_embedding.weight)
89
+ torch.nn.init.xavier_uniform_(self.genre_embedding.weight)
90
+ torch.nn.init.xavier_uniform_(self.singer_embedding.weight)
91
+
92
+ self.fft_blocks = torch.nn.ModuleList([
93
+ FFT_Block(
94
+ channels= self.hp.Encoder.Size,
95
+ num_head= self.hp.Encoder.ConvFFT.Head,
96
+ ffn_kernel_size= self.hp.Encoder.ConvFFT.FFN.Kernel_Size,
97
+ dropout_rate= self.hp.Encoder.ConvFFT.Dropout_Rate
98
+ )
99
+ for _ in range(self.hp.Encoder.ConvFFT.Stack)
100
+ ])
101
+
102
+ self.linear_projection = Conv1d(
103
+ in_channels= self.hp.Encoder.Size,
104
+ out_channels= self.feature_size,
105
+ kernel_size= 1,
106
+ bias= True,
107
+ w_init_gain= 'linear'
108
+ )
109
+
110
+ def forward(
111
+ self,
112
+ tokens: torch.Tensor,
113
+ notes: torch.Tensor,
114
+ durations: torch.Tensor,
115
+ lengths: torch.Tensor,
116
+ genres: torch.Tensor,
117
+ singers: torch.Tensor
118
+ ):
119
+ x = \
120
+ self.token_embedding(tokens) + \
121
+ self.note_embedding(notes) + \
122
+ self.duration_embedding(durations) + \
123
+ self.genre_embedding(genres).unsqueeze(1) + \
124
+ self.singer_embedding(singers).unsqueeze(1)
125
+ x = x.permute(0, 2, 1) # [Batch, Enc_d, Enc_t]
126
+
127
+ for block in self.fft_blocks:
128
+ x = block(x, lengths) # [Batch, Enc_d, Enc_t]
129
+
130
+ linear_predictions = self.linear_projection(x) # [Batch, Feature_d, Enc_t]
131
+
132
+ return x, linear_predictions
133
+
134
+ class FFT_Block(torch.nn.Module):
135
+ def __init__(
136
+ self,
137
+ channels: int,
138
+ num_head: int,
139
+ ffn_kernel_size: int,
140
+ dropout_rate: float= 0.1,
141
+ ) -> None:
142
+ super().__init__()
143
+
144
+ self.attention = LinearAttention(
145
+ channels= channels,
146
+ calc_channels= channels,
147
+ num_heads= num_head,
148
+ dropout_rate= dropout_rate
149
+ )
150
+
151
+ self.ffn = FFN(
152
+ channels= channels,
153
+ kernel_size= ffn_kernel_size,
154
+ dropout_rate= dropout_rate
155
+ )
156
+
157
+ def forward(
158
+ self,
159
+ x: torch.Tensor,
160
+ lengths: torch.Tensor
161
+ ) -> torch.Tensor:
162
+ '''
163
+ x: [Batch, Dim, Time]
164
+ '''
165
+ masks = (~Mask_Generate(lengths= lengths, max_length= torch.ones_like(x[0, 0]).sum())).unsqueeze(1).float() # float mask
166
+
167
+ # Attention + Dropout + LayerNorm
168
+ x = self.attention(x)
169
+
170
+ # FFN + Dropout + LayerNorm
171
+ x = self.ffn(x, masks)
172
+
173
+ return x * masks
174
+
175
+ class FFN(torch.nn.Module):
176
+ def __init__(
177
+ self,
178
+ channels: int,
179
+ kernel_size: int,
180
+ dropout_rate: float= 0.1,
181
+ ) -> None:
182
+ super().__init__()
183
+ self.conv_0 = Conv1d(
184
+ in_channels= channels,
185
+ out_channels= channels,
186
+ kernel_size= kernel_size,
187
+ padding= (kernel_size - 1) // 2,
188
+ w_init_gain= 'relu'
189
+ )
190
+ self.relu = torch.nn.ReLU()
191
+ self.dropout = torch.nn.Dropout(p= dropout_rate)
192
+ self.conv_1 = Conv1d(
193
+ in_channels= channels,
194
+ out_channels= channels,
195
+ kernel_size= kernel_size,
196
+ padding= (kernel_size - 1) // 2,
197
+ w_init_gain= 'linear'
198
+ )
199
+ self.norm = LayerNorm(
200
+ num_features= channels,
201
+ )
202
+
203
+ def forward(
204
+ self,
205
+ x: torch.Tensor,
206
+ masks: torch.Tensor
207
+ ) -> torch.Tensor:
208
+ '''
209
+ x: [Batch, Dim, Time]
210
+ '''
211
+ residuals = x
212
+
213
+ x = self.conv_0(x * masks)
214
+ x = self.relu(x)
215
+ x = self.dropout(x)
216
+ x = self.conv_1(x * masks)
217
+ x = self.dropout(x)
218
+ x = self.norm(x + residuals)
219
+
220
+ return x * masks
221
+
222
+ # https://pytorch.org/tutorials/beginner/transformer_tutorial.html
223
+ # https://github.com/soobinseo/Transformer-TTS/blob/master/network.py
224
+ class Duration_Positional_Encoding(torch.nn.Embedding):
225
+ def __init__(
226
+ self,
227
+ num_embeddings: int,
228
+ embedding_dim: int,
229
+ ):
230
+ positional_embedding = torch.zeros(num_embeddings, embedding_dim)
231
+ position = torch.arange(0, num_embeddings, dtype=torch.float).unsqueeze(1)
232
+ div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
233
+ positional_embedding[:, 0::2] = torch.sin(position * div_term)
234
+ positional_embedding[:, 1::2] = torch.cos(position * div_term)
235
+ super().__init__(
236
+ num_embeddings= num_embeddings,
237
+ embedding_dim= embedding_dim,
238
+ _weight= positional_embedding
239
+ )
240
+ self.weight.requires_grad = False
241
+
242
+ self.alpha = torch.nn.Parameter(
243
+ data= torch.ones(1) * 0.01,
244
+ requires_grad= True
245
+ )
246
+
247
+ def forward(self, durations):
248
+ '''
249
+ durations: [Batch, Length]
250
+ '''
251
+ return self.alpha * super().forward(durations) # [Batch, Dim, Length]
252
+
253
+ @torch.jit.script
254
+ def get_pe(x: torch.Tensor, pe: torch.Tensor):
255
+ pe = pe.repeat(1, 1, math.ceil(x.size(2) / pe.size(2)))
256
+ return pe[:, :, :x.size(2)]
257
+
258
+ def Mask_Generate(lengths: torch.Tensor, max_length: Union[torch.Tensor, int, None]= None):
259
+ '''
260
+ lengths: [Batch]
261
+ max_lengths: an int value. If None, max_lengths == max(lengths)
262
+ '''
263
+ max_length = max_length or torch.max(lengths)
264
+ sequence = torch.arange(max_length)[None, :].to(lengths.device)
265
+ return sequence >= lengths[:, None] # [Batch, Time]
Pattern_Generator.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mido, os, pickle, yaml, argparse, math, librosa, hgtk, logging
3
+ from tqdm import tqdm
4
+ from pysptk.sptk import rapt
5
+ from typing import List, Tuple
6
+ from argparse import Namespace # for type
7
+ import torch
8
+ from typing import Dict
9
+
10
+ from meldataset import mel_spectrogram, spectrogram, spec_energy
11
+ from Arg_Parser import Recursive_Parse
12
+
13
+ def Convert_Feature_Based_Music(
14
+ music: List[Tuple[float, str, int]],
15
+ sample_rate: int,
16
+ frame_shift: int,
17
+ consonant_duration: int= 3,
18
+ equality_duration: bool= False
19
+ ):
20
+ previous_used = 0
21
+ lyrics = []
22
+ notes = []
23
+ durations = []
24
+ for message_time, lyric, note in music:
25
+ duration = round(message_time * sample_rate) + previous_used
26
+ previous_used = duration % frame_shift
27
+ duration = duration // frame_shift
28
+
29
+ if lyric == '<X>':
30
+ lyrics.append(lyric)
31
+ notes.append(note)
32
+ durations.append(duration)
33
+ else:
34
+ lyrics.extend(Decompose(lyric))
35
+ notes.extend([note] * 3)
36
+ if equality_duration or duration < consonant_duration * 3:
37
+ split_duration = [duration // 3] * 3
38
+ split_duration[1] += duration % 3
39
+ durations.extend(split_duration)
40
+ else:
41
+ durations.extend([
42
+ consonant_duration, # onset
43
+ duration - consonant_duration * 2, # nucleus
44
+ consonant_duration # coda
45
+ ])
46
+
47
+ return lyrics, notes, durations
48
+
49
+ def Expand_by_Duration(
50
+ lyrics: List[str],
51
+ notes: List[int],
52
+ durations: List[int],
53
+ ):
54
+ lyrics = sum([[lyric] * duration for lyric, duration in zip(lyrics, durations)], [])
55
+ notes = sum([*[[note] * duration for note, duration in zip(notes, durations)]], [])
56
+ durations = [index for duration in durations for index in range(duration)]
57
+
58
+ return lyrics, notes, durations
59
+
60
+ def Decompose(syllable: str):
61
+ onset, nucleus, coda = hgtk.letter.decompose(syllable)
62
+ coda += '_'
63
+
64
+ return onset, nucleus, coda
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Diffsingerkr
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: yellow
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
  app_file: app.py
 
1
  ---
2
+ title: Diffsvs
3
+ emoji: 🐢
4
+ colorFrom: blue
5
+ colorTo: blue
6
  sdk: streamlit
7
  sdk_version: 1.17.0
8
  app_file: app.py
YAML/Genre_Info.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ Children: 0
YAML/Log_Energy_Info.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ CSD:
2
+ Mean: 3.540642499923706
3
+ Std: 2.1372854709625244
YAML/Log_F0_Info.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ CSD:
2
+ Mean: 5.851496696472168
3
+ Std: 0.2526451647281647
YAML/Mel_Range_Info.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ CSD:
2
+ Max: 2.6226840019226074
3
+ Min: -11.512925148010254
YAML/Singer_Info.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ CSD: 0
YAML/Spectrogram_Range_Info.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ CSD:
2
+ Max: 5.292316913604736
3
+ Min: -10.36163330078125
YAML/Token.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <E>: 1
2
+ <S>: 0
3
+ <X>: 2
4
+ _: 3
5
+ "\u3131": 4
6
+ "\u3131_": 5
7
+ "\u3132": 6
8
+ "\u3132_": 7
9
+ "\u3133_": 8
10
+ "\u3134": 9
11
+ "\u3134_": 10
12
+ "\u3135_": 11
13
+ "\u3136_": 12
14
+ "\u3137": 13
15
+ "\u3137_": 14
16
+ "\u3138": 15
17
+ "\u3139": 16
18
+ "\u3139_": 17
19
+ "\u313A_": 18
20
+ "\u313B_": 19
21
+ "\u313C_": 20
22
+ "\u313D_": 21
23
+ "\u313E_": 22
24
+ "\u313F_": 23
25
+ "\u3140_": 24
26
+ "\u3141": 25
27
+ "\u3141_": 26
28
+ "\u3142": 27
29
+ "\u3142_": 28
30
+ "\u3143": 29
31
+ "\u3144_": 30
32
+ "\u3145": 31
33
+ "\u3145_": 32
34
+ "\u3146": 33
35
+ "\u3146_": 34
36
+ "\u3147": 35
37
+ "\u3147_": 36
38
+ "\u3148": 37
39
+ "\u3148_": 38
40
+ "\u3149": 39
41
+ "\u314A": 40
42
+ "\u314A_": 41
43
+ "\u314B": 42
44
+ "\u314B_": 43
45
+ "\u314C": 44
46
+ "\u314C_": 45
47
+ "\u314D": 46
48
+ "\u314D_": 47
49
+ "\u314E": 48
50
+ "\u314E_": 49
51
+ "\u314F": 50
52
+ "\u3150": 51
53
+ "\u3151": 52
54
+ "\u3152": 53
55
+ "\u3153": 54
56
+ "\u3154": 55
57
+ "\u3155": 56
58
+ "\u3156": 57
59
+ "\u3157": 58
60
+ "\u3158": 59
61
+ "\u3159": 60
62
+ "\u315A": 61
63
+ "\u315B": 62
64
+ "\u315C": 63
65
+ "\u315D": 64
66
+ "\u315E": 65
67
+ "\u315F": 66
68
+ "\u3160": 67
69
+ "\u3161": 68
70
+ "\u3162": 69
71
+ "\u3163": 70
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from Inference import Inferencer
4
+
5
+ def app_diffsingerkr():
6
+ if not 'diffsingerkr_duration' in st.session_state.keys():
7
+ st.session_state.diffsingerkr_duration = ''
8
+ if not 'diffsingerkr_lyric' in st.session_state.keys():
9
+ st.session_state.diffsingerkr_lyric = ''
10
+ if not 'diffsingerkr_note' in st.session_state.keys():
11
+ st.session_state.diffsingerkr_note = ''
12
+ if not 'inferencer' in st.session_state.keys():
13
+ st.session_state.inferencer = Inferencer(
14
+ hp_path= 'Hyper_Parameters.yaml',
15
+ checkpoint_path= 'Checkpoint/S_200000.pt',
16
+ batch_size= 1
17
+ )
18
+
19
+ st.title('DiffSinger-KR')
20
+ st.markdown('* This code is an implementation of DiffSinger for Korean.')
21
+ st.markdown('* When music score which is note, duration, and lyric information are entered, singing voices are synthesized accordingly.')
22
+ st.markdown('* Due to the range of the trained dataset, the supported notes are between 65 and 89.')
23
+ st.markdown('* Please refer to the [here](https://github.com/CODEJIN/DiffSingerKR) for the source code for training the model.')
24
+
25
+ st.markdown('''---''')
26
+ status_indicator = st.empty()
27
+ status_indicator.header('Insert the music!')
28
+ st.markdown('''---''')
29
+ example1_col, example2_col, example3_col, _ = st.columns(4)
30
+ if example1_col.button('Example 1'):
31
+ st.session_state.diffsingerkr_duration = '0.52,0.17,0.35,0.35,0.35,0.35,0.70,0.35,0.35,0.70,0.35,0.35,0.70,0.52,0.17,0.35,0.35,0.35,0.35,0.70,0.35,0.35,0.35,0.35,1.39'
32
+ st.session_state.diffsingerkr_lyric = '떴,다,떴,다,비,행,기,날,아,라,날,아,라,높,이,높,이,날,아,라,우,리,비,행,기'
33
+ st.session_state.diffsingerkr_note = '76,74,72,74,76,76,76,74,74,74,76,79,79,76,74,72,74,76,76,76,74,74,76,74,72'
34
+ st.experimental_rerun()
35
+ if example2_col.button('Example 2'):
36
+ st.session_state.diffsingerkr_duration = '0.53,0.52,0.50,0.57,0.58,0.46,0.48,0.50,0.37,0.13,0.43,0.21,0.57,0.43,0.49,1.44,0.26,0.49,0.14,0.13,0.57,0.26,0.06,0.15,0.63,0.26,0.51,0.20,0.48,0.72,0.22'
37
+ st.session_state.diffsingerkr_lyric = '만,나,고,<X>,난,외,로,움,을,<X>,알,았,어,내,겐,<X>,관,심,조,<X>,차,<X>,없,<X>,다,는,걸,<X>,알,면,서'
38
+ st.session_state.diffsingerkr_note = '76,78,79,0,71,74,72,71,72,0,71,69,69,71,74,0,79,78,79,0,71,0,74,0,74,72,72,0,71,71,69'
39
+ st.experimental_rerun()
40
+ if example3_col.button('Example 3'):
41
+ st.session_state.diffsingerkr_duration = '0.33,0.16,0.33,0.49,0.33,0.16,0.81,0.33,0.16,0.16,0.33,0.16,0.49,0.16,0.82,0.33,0.16,0.33,0.49,0.33,0.16,0.33,0.49,0.33,0.33,0.16,0.33,1.47,0.33,0.16,0.33,0.49,0.33,0.16,0.81,0.33,0.16,0.16,0.33,0.16,0.49,0.16,0.82,0.33,0.16,0.33,0.16,0.33,0.49,0.16,0.33,0.33,0.33,0.33,0.16,0.33,0.82'
42
+ st.session_state.diffsingerkr_lyric = '마,음,울,적,한,날,에,<X>,거,리,를,걸,어,보,고,향,기,로,운,칵,테,일,에,취,해,도,보,고,한,편,의,시,가,있,는,<X>,전,시,회,장,도,가,고,밤,새,도,<X>,록,그,리,움,에,편,질,쓰,고,파'
43
+ st.session_state.diffsingerkr_note = '80,80,80,87,85,84,82,0,84,84,84,85,84,79,79,77,77,77,80,80,78,77,75,77,80,79,80,82,80,80,80,87,85,84,82,0,84,84,84,85,84,79,79,77,77,77,79,80,80,77,75,75,77,80,79,82,80'
44
+ st.experimental_rerun()
45
+ st.markdown('''---''')
46
+ duration = st.text_input('Duration', value= st.session_state.diffsingerkr_duration)
47
+ lyric = st.text_input('Lyric', value= st.session_state.diffsingerkr_lyric)
48
+ note = st.text_input('Note', value= st.session_state.diffsingerkr_note)
49
+ singer = 'CSD'
50
+ genre = 'Children'
51
+ key_adjustment = st.select_slider(
52
+ label= 'Key adjustment',
53
+ options= [x for x in range(-6, 7)],
54
+ value= 0
55
+ )
56
+
57
+ if st.button("Generate!"):
58
+ if duration != '' and lyric != '' and note != '':
59
+ status_indicator.header('Generating...')
60
+ audio = st.session_state.inferencer.Inference_Epoch(
61
+ message_times_list= [[float(x) for x in duration.strip().split(',')]],
62
+ lyrics= [[x for x in lyric.strip().split(',')]],
63
+ notes= [[
64
+ (int(x) + key_adjustment if int(x) != 0 else int(x))
65
+ for x in note.strip().split(',')
66
+ ]],
67
+ singers= [singer],
68
+ genres= [genre]
69
+ )[0]
70
+
71
+ st.audio(
72
+ audio,
73
+ format="audio/wav",
74
+ start_time=0,
75
+ sample_rate= st.session_state.inferencer.hp.Sound.Sample_Rate
76
+ )
77
+
78
+ status_indicator.header('Done.')
79
+
80
+ if __name__ == '__main__':
81
+ app_diffsingerkr()
meldataset.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###############################################################################
2
+ # MIT License
3
+ #
4
+ # Copyright (c) 2020 Jungil Kong
5
+ #
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+ #
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+ #
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+ ###############################################################################
24
+
25
+ import math
26
+ import os
27
+ import random
28
+ import torch
29
+ import torch.utils.data
30
+ import numpy as np
31
+ from librosa.util import normalize
32
+ from scipy.io.wavfile import read
33
+ from librosa.filters import mel as librosa_mel_fn
34
+
35
+ MAX_WAV_VALUE = 32768.0
36
+
37
+
38
+ def load_wav(full_path):
39
+ sampling_rate, data = read(full_path)
40
+ return data, sampling_rate
41
+
42
+
43
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
44
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
45
+
46
+
47
+ def dynamic_range_decompression(x, C=1):
48
+ return np.exp(x) / C
49
+
50
+
51
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
52
+ return torch.log(torch.clamp(x, min=clip_val) * C)
53
+
54
+
55
+ def dynamic_range_decompression_torch(x, C=1):
56
+ return torch.exp(x) / C
57
+
58
+
59
+ def spectral_normalize_torch(magnitudes):
60
+ output = dynamic_range_compression_torch(magnitudes)
61
+ return output
62
+
63
+
64
+ def spectral_de_normalize_torch(magnitudes):
65
+ output = dynamic_range_decompression_torch(magnitudes)
66
+ return output
67
+
68
+
69
+ mel_basis = {}
70
+ hann_window = {}
71
+
72
+
73
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
74
+ if torch.min(y) < -1.:
75
+ print('min value is ', torch.min(y))
76
+ if torch.max(y) > 1.:
77
+ print('max value is ', torch.max(y))
78
+
79
+ global mel_basis, hann_window
80
+ if fmax not in mel_basis:
81
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
82
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
83
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
84
+
85
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
86
+ y = y.squeeze(1)
87
+
88
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
89
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
90
+
91
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
92
+
93
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
94
+ spec = spectral_normalize_torch(spec)
95
+
96
+ return spec
97
+
98
+ def spectrogram(y, n_fft, hop_size, win_size, center=False):
99
+ if torch.min(y) < -1.:
100
+ print('min value is ', torch.min(y))
101
+ if torch.max(y) > 1.:
102
+ print('max value is ', torch.max(y))
103
+
104
+ global hann_window
105
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
106
+
107
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
108
+ y = y.squeeze(1)
109
+
110
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
111
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
112
+
113
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
114
+ spec = spectral_normalize_torch(spec)
115
+
116
+ return spec
117
+
118
+ def spec_energy(y, n_fft, hop_size, win_size, center=False):
119
+ if torch.min(y) < -1.:
120
+ print('min value is ', torch.min(y))
121
+ if torch.max(y) > 1.:
122
+ print('max value is ', torch.max(y))
123
+
124
+ global hann_window
125
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
126
+
127
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
128
+ y = y.squeeze(1)
129
+
130
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
131
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
132
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
133
+ energy = torch.norm(spec, dim= 1)
134
+
135
+ return energy
136
+
137
+ def get_dataset_filelist(a):
138
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
139
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
140
+ for x in fi.read().split('\n') if len(x) > 0]
141
+
142
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
143
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
144
+ for x in fi.read().split('\n') if len(x) > 0]
145
+ return training_files, validation_files
146
+
147
+
148
+ class MelDataset(torch.utils.data.Dataset):
149
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
150
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
151
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
152
+ self.audio_files = training_files
153
+ random.seed(1234)
154
+ if shuffle:
155
+ random.shuffle(self.audio_files)
156
+ self.segment_size = segment_size
157
+ self.sampling_rate = sampling_rate
158
+ self.split = split
159
+ self.n_fft = n_fft
160
+ self.num_mels = num_mels
161
+ self.hop_size = hop_size
162
+ self.win_size = win_size
163
+ self.fmin = fmin
164
+ self.fmax = fmax
165
+ self.fmax_loss = fmax_loss
166
+ self.cached_wav = None
167
+ self.n_cache_reuse = n_cache_reuse
168
+ self._cache_ref_count = 0
169
+ self.device = device
170
+ self.fine_tuning = fine_tuning
171
+ self.base_mels_path = base_mels_path
172
+
173
+ def __getitem__(self, index):
174
+ filename = self.audio_files[index]
175
+ if self._cache_ref_count == 0:
176
+ audio, sampling_rate = load_wav(filename)
177
+ audio = audio / MAX_WAV_VALUE
178
+ if not self.fine_tuning:
179
+ audio = normalize(audio) * 0.95
180
+ self.cached_wav = audio
181
+ if sampling_rate != self.sampling_rate:
182
+ raise ValueError("{} SR doesn't match target {} SR".format(
183
+ sampling_rate, self.sampling_rate))
184
+ self._cache_ref_count = self.n_cache_reuse
185
+ else:
186
+ audio = self.cached_wav
187
+ self._cache_ref_count -= 1
188
+
189
+ audio = torch.FloatTensor(audio)
190
+ audio = audio.unsqueeze(0)
191
+
192
+ if not self.fine_tuning:
193
+ if self.split:
194
+ if audio.size(1) >= self.segment_size:
195
+ max_audio_start = audio.size(1) - self.segment_size
196
+ audio_start = random.randint(0, max_audio_start)
197
+ audio = audio[:, audio_start:audio_start+self.segment_size]
198
+ else:
199
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
200
+
201
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
202
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
203
+ center=False)
204
+ else:
205
+ mel = np.load(
206
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
207
+ mel = torch.from_numpy(mel)
208
+
209
+ if len(mel.shape) < 3:
210
+ mel = mel.unsqueeze(0)
211
+
212
+ if self.split:
213
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
214
+
215
+ if audio.size(1) >= self.segment_size:
216
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
217
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
218
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
219
+ else:
220
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
221
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
222
+
223
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
224
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
225
+ center=False)
226
+
227
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
228
+
229
+ def __len__(self):
230
+ return len(self.audio_files)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ librosa
4
+ mido
5
+ hgtk
6
+ pysptk
7
+ matplotlib
vocoder.pts ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b47a5d03d744861f94ee973294317f738ccc6dc6d27bafa5d8db5ed18f95566
3
+ size 55884400