asigalov61 commited on
Commit
dcca7d2
·
verified ·
1 Parent(s): 6927483

Upload 6 files

Browse files
Files changed (6) hide show
  1. config.py +7 -0
  2. inference.py +171 -0
  3. models.py +353 -0
  4. piano_vad.py +130 -0
  5. pytorch_utils.py +66 -0
  6. utilities.py +564 -0
config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ sample_rate = 16000
2
+ classes_num = 88 # Number of notes of piano
3
+ begin_note = 21 # MIDI note of A0, the lowest note of a piano.
4
+ segment_seconds = 10. # Training segment duration
5
+ hop_seconds = 1.
6
+ frames_per_second = 100
7
+ velocity_scale = 128
inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import librosa
5
+ from pathlib import Path
6
+
7
+ import torch
8
+
9
+ from .utilities import (create_folder, get_filename, RegressionPostProcessor,
10
+ write_events_to_midi)
11
+ from .models import Regress_onset_offset_frame_velocity_CRNN, Note_pedal
12
+ from .pytorch_utils import move_data_to_device, forward
13
+ from . import config
14
+
15
+
16
+ class PianoTranscription(object):
17
+ def __init__(self, model_type='Note_pedal', checkpoint_path=None,
18
+ segment_samples=16000*10, device=torch.device('cuda')):
19
+ """Class for transcribing piano solo recording.
20
+
21
+ Args:
22
+ model_type: str
23
+ checkpoint_path: str
24
+ segment_samples: int
25
+ device: 'cuda' | 'cpu'
26
+ """
27
+ if not checkpoint_path:
28
+ checkpoint_path='{}/piano_transcription_inference_data/note_F1=0.9677_pedal_F1=0.9186.pth'.format(str(Path.home()))
29
+ print('Checkpoint path: {}'.format(checkpoint_path))
30
+
31
+ if not os.path.exists(checkpoint_path) or os.path.getsize(checkpoint_path) < 1.6e8:
32
+ create_folder(os.path.dirname(checkpoint_path))
33
+ print('Total size: ~165 MB')
34
+ zenodo_path = 'https://zenodo.org/record/4034264/files/CRNN_note_F1%3D0.9677_pedal_F1%3D0.9186.pth?download=1'
35
+ os.system('wget -O "{}" "{}"'.format(checkpoint_path, zenodo_path))
36
+
37
+ print('Using {} for inference.'.format(device))
38
+
39
+ self.segment_samples = segment_samples
40
+ self.frames_per_second = config.frames_per_second
41
+ self.classes_num = config.classes_num
42
+ self.onset_threshold = 0.3
43
+ self.offset_threshod = 0.3
44
+ self.frame_threshold = 0.1
45
+ self.pedal_offset_threshold = 0.2
46
+
47
+ # Build model
48
+ Model = eval(model_type)
49
+ self.model = Model(frames_per_second=self.frames_per_second,
50
+ classes_num=self.classes_num)
51
+
52
+ # Load model
53
+ checkpoint = torch.load(checkpoint_path, map_location=device)
54
+ self.model.load_state_dict(checkpoint['model'], strict=False)
55
+
56
+ # Parallel
57
+ if 'cuda' in str(device):
58
+ self.model.to(device)
59
+ print('GPU number: {}'.format(torch.cuda.device_count()))
60
+ self.model = torch.nn.DataParallel(self.model)
61
+ else:
62
+ print('Using CPU.')
63
+
64
+ def transcribe(self, audio, midi_path):
65
+ """Transcribe an audio recording.
66
+
67
+ Args:
68
+ audio: (audio_samples,)
69
+ midi_path: str, path to write out the transcribed MIDI.
70
+
71
+ Returns:
72
+ transcribed_dict, dict: {'output_dict':, ..., 'est_note_events': ...}
73
+
74
+ """
75
+ audio = audio[None, :] # (1, audio_samples)
76
+
77
+ # Pad audio to be evenly divided by segment_samples
78
+ audio_len = audio.shape[1]
79
+ pad_len = int(np.ceil(audio_len / self.segment_samples))\
80
+ * self.segment_samples - audio_len
81
+
82
+ audio = np.concatenate((audio, np.zeros((1, pad_len))), axis=1)
83
+
84
+ # Enframe to segments
85
+ segments = self.enframe(audio, self.segment_samples)
86
+ """(N, segment_samples)"""
87
+
88
+ # Forward
89
+ output_dict = forward(self.model, segments, batch_size=1)
90
+ """{'reg_onset_output': (N, segment_frames, classes_num), ...}"""
91
+
92
+ # Deframe to original length
93
+ for key in output_dict.keys():
94
+ output_dict[key] = self.deframe(output_dict[key])[0 : audio_len]
95
+ """output_dict: {
96
+ 'reg_onset_output': (N, segment_frames, classes_num),
97
+ 'reg_offset_output': (N, segment_frames, classes_num),
98
+ 'frame_output': (N, segment_frames, classes_num),
99
+ 'velocity_output': (N, segment_frames, classes_num)}"""
100
+
101
+ # Post processor
102
+ post_processor = RegressionPostProcessor(self.frames_per_second,
103
+ classes_num=self.classes_num, onset_threshold=self.onset_threshold,
104
+ offset_threshold=self.offset_threshod,
105
+ frame_threshold=self.frame_threshold,
106
+ pedal_offset_threshold=self.pedal_offset_threshold)
107
+
108
+ # Post process output_dict to MIDI events
109
+ (est_note_events, est_pedal_events) = \
110
+ post_processor.output_dict_to_midi_events(output_dict)
111
+
112
+ # Write MIDI events to file
113
+ if midi_path:
114
+ write_events_to_midi(start_time=0, note_events=est_note_events,
115
+ pedal_events=est_pedal_events, midi_path=midi_path)
116
+ print('Write out to {}'.format(midi_path))
117
+
118
+ transcribed_dict = {
119
+ 'output_dict': output_dict,
120
+ 'est_note_events': est_note_events,
121
+ 'est_pedal_events': est_pedal_events}
122
+
123
+ return transcribed_dict
124
+
125
+ def enframe(self, x, segment_samples):
126
+ """Enframe long sequence to short segments.
127
+
128
+ Args:
129
+ x: (1, audio_samples)
130
+ segment_samples: int
131
+
132
+ Returns:
133
+ batch: (N, segment_samples)
134
+ """
135
+ assert x.shape[1] % segment_samples == 0
136
+ batch = []
137
+
138
+ pointer = 0
139
+ while pointer + segment_samples <= x.shape[1]:
140
+ batch.append(x[:, pointer : pointer + segment_samples])
141
+ pointer += segment_samples // 2
142
+
143
+ batch = np.concatenate(batch, axis=0)
144
+ return batch
145
+
146
+ def deframe(self, x):
147
+ """Deframe predicted segments to original sequence.
148
+
149
+ Args:
150
+ x: (N, segment_frames, classes_num)
151
+
152
+ Returns:
153
+ y: (audio_frames, classes_num)
154
+ """
155
+ if x.shape[0] == 1:
156
+ return x[0]
157
+
158
+ else:
159
+ x = x[:, 0 : -1, :]
160
+ """Remove an extra frame in the end of each segment caused by the
161
+ 'center=True' argument when calculating spectrogram."""
162
+ (N, segment_samples, classes_num) = x.shape
163
+ assert segment_samples % 4 == 0
164
+
165
+ y = []
166
+ y.append(x[0, 0 : int(segment_samples * 0.75)])
167
+ for i in range(1, N - 1):
168
+ y.append(x[i, int(segment_samples * 0.25) : int(segment_samples * 0.75)])
169
+ y.append(x[-1, int(segment_samples * 0.25) :])
170
+ y = np.concatenate(y, axis=0)
171
+ return y
models.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import math
4
+ import time
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from torchlibrosa.stft import Spectrogram, LogmelFilterBank
13
+ from .pytorch_utils import move_data_to_device
14
+
15
+
16
+ def init_layer(layer):
17
+ """Initialize a Linear or Convolutional layer. """
18
+ nn.init.xavier_uniform_(layer.weight)
19
+
20
+ if hasattr(layer, 'bias'):
21
+ if layer.bias is not None:
22
+ layer.bias.data.fill_(0.)
23
+
24
+
25
+ def init_bn(bn):
26
+ """Initialize a Batchnorm layer. """
27
+ bn.bias.data.fill_(0.)
28
+ bn.weight.data.fill_(1.)
29
+
30
+
31
+ def init_gru(rnn):
32
+ """Initialize a GRU layer. """
33
+
34
+ def _concat_init(tensor, init_funcs):
35
+ (length, fan_out) = tensor.shape
36
+ fan_in = length // len(init_funcs)
37
+
38
+ for (i, init_func) in enumerate(init_funcs):
39
+ init_func(tensor[i * fan_in : (i + 1) * fan_in, :])
40
+
41
+ def _inner_uniform(tensor):
42
+ fan_in = nn.init._calculate_correct_fan(tensor, 'fan_in')
43
+ nn.init.uniform_(tensor, -math.sqrt(3 / fan_in), math.sqrt(3 / fan_in))
44
+
45
+ for i in range(rnn.num_layers):
46
+ _concat_init(
47
+ getattr(rnn, 'weight_ih_l{}'.format(i)),
48
+ [_inner_uniform, _inner_uniform, _inner_uniform]
49
+ )
50
+ torch.nn.init.constant_(getattr(rnn, 'bias_ih_l{}'.format(i)), 0)
51
+
52
+ _concat_init(
53
+ getattr(rnn, 'weight_hh_l{}'.format(i)),
54
+ [_inner_uniform, _inner_uniform, nn.init.orthogonal_]
55
+ )
56
+ torch.nn.init.constant_(getattr(rnn, 'bias_hh_l{}'.format(i)), 0)
57
+
58
+
59
+ class ConvBlock(nn.Module):
60
+ def __init__(self, in_channels, out_channels, momentum):
61
+
62
+ super(ConvBlock, self).__init__()
63
+
64
+ self.conv1 = nn.Conv2d(in_channels=in_channels,
65
+ out_channels=out_channels,
66
+ kernel_size=(3, 3), stride=(1, 1),
67
+ padding=(1, 1), bias=False)
68
+
69
+ self.conv2 = nn.Conv2d(in_channels=out_channels,
70
+ out_channels=out_channels,
71
+ kernel_size=(3, 3), stride=(1, 1),
72
+ padding=(1, 1), bias=False)
73
+
74
+ self.bn1 = nn.BatchNorm2d(out_channels, momentum)
75
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum)
76
+
77
+ self.init_weight()
78
+
79
+ def init_weight(self):
80
+ init_layer(self.conv1)
81
+ init_layer(self.conv2)
82
+ init_bn(self.bn1)
83
+ init_bn(self.bn2)
84
+
85
+
86
+ def forward(self, input, pool_size=(2, 2), pool_type='avg'):
87
+ """
88
+ Args:
89
+ input: (batch_size, in_channels, time_steps, freq_bins)
90
+ Outputs:
91
+ output: (batch_size, out_channels, classes_num)
92
+ """
93
+
94
+ x = F.relu_(self.bn1(self.conv1(input)))
95
+ x = F.relu_(self.bn2(self.conv2(x)))
96
+
97
+ if pool_type == 'avg':
98
+ x = F.avg_pool2d(x, kernel_size=pool_size)
99
+
100
+ return x
101
+
102
+
103
+ class AcousticModelCRnn8Dropout(nn.Module):
104
+ def __init__(self, classes_num, midfeat, momentum):
105
+ super(AcousticModelCRnn8Dropout, self).__init__()
106
+
107
+ self.conv_block1 = ConvBlock(in_channels=1, out_channels=48, momentum=momentum)
108
+ self.conv_block2 = ConvBlock(in_channels=48, out_channels=64, momentum=momentum)
109
+ self.conv_block3 = ConvBlock(in_channels=64, out_channels=96, momentum=momentum)
110
+ self.conv_block4 = ConvBlock(in_channels=96, out_channels=128, momentum=momentum)
111
+
112
+ self.fc5 = nn.Linear(midfeat, 768, bias=False)
113
+ self.bn5 = nn.BatchNorm1d(768, momentum=momentum)
114
+
115
+ self.gru = nn.GRU(input_size=768, hidden_size=256, num_layers=2,
116
+ bias=True, batch_first=True, dropout=0., bidirectional=True)
117
+
118
+ self.fc = nn.Linear(512, classes_num, bias=True)
119
+
120
+ self.init_weight()
121
+
122
+ def init_weight(self):
123
+ init_layer(self.fc5)
124
+ init_bn(self.bn5)
125
+ init_gru(self.gru)
126
+ init_layer(self.fc)
127
+
128
+ def forward(self, input):
129
+ """
130
+ Args:
131
+ input: (batch_size, channels_num, time_steps, freq_bins)
132
+ Outputs:
133
+ output: (batch_size, time_steps, classes_num)
134
+ """
135
+
136
+ x = self.conv_block1(input, pool_size=(1, 2), pool_type='avg')
137
+ x = F.dropout(x, p=0.2, training=self.training)
138
+ x = self.conv_block2(x, pool_size=(1, 2), pool_type='avg')
139
+ x = F.dropout(x, p=0.2, training=self.training)
140
+ x = self.conv_block3(x, pool_size=(1, 2), pool_type='avg')
141
+ x = F.dropout(x, p=0.2, training=self.training)
142
+ x = self.conv_block4(x, pool_size=(1, 2), pool_type='avg')
143
+ x = F.dropout(x, p=0.2, training=self.training)
144
+
145
+ x = x.transpose(1, 2).flatten(2)
146
+ x = F.relu(self.bn5(self.fc5(x).transpose(1, 2)).transpose(1, 2))
147
+ x = F.dropout(x, p=0.5, training=self.training, inplace=True)
148
+
149
+ (x, _) = self.gru(x)
150
+ x = F.dropout(x, p=0.5, training=self.training, inplace=False)
151
+ output = torch.sigmoid(self.fc(x))
152
+ return output
153
+
154
+
155
+ class Regress_onset_offset_frame_velocity_CRNN(nn.Module):
156
+ def __init__(self, frames_per_second, classes_num):
157
+ super(Regress_onset_offset_frame_velocity_CRNN, self).__init__()
158
+
159
+ sample_rate = 16000
160
+ window_size = 2048
161
+ hop_size = sample_rate // frames_per_second
162
+ mel_bins = 229
163
+ fmin = 30
164
+ fmax = sample_rate // 2
165
+
166
+ window = 'hann'
167
+ center = True
168
+ pad_mode = 'reflect'
169
+ ref = 1.0
170
+ amin = 1e-10
171
+ top_db = None
172
+
173
+ midfeat = 1792
174
+ momentum = 0.01
175
+
176
+ # Spectrogram extractor
177
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size,
178
+ hop_length=hop_size, win_length=window_size, window=window,
179
+ center=center, pad_mode=pad_mode, freeze_parameters=True)
180
+
181
+ # Logmel feature extractor
182
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate,
183
+ n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref,
184
+ amin=amin, top_db=top_db, freeze_parameters=True)
185
+
186
+ self.bn0 = nn.BatchNorm2d(mel_bins, momentum)
187
+
188
+ self.frame_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
189
+ self.reg_onset_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
190
+ self.reg_offset_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
191
+ self.velocity_model = AcousticModelCRnn8Dropout(classes_num, midfeat, momentum)
192
+
193
+ self.reg_onset_gru = nn.GRU(input_size=88 * 2, hidden_size=256, num_layers=1,
194
+ bias=True, batch_first=True, dropout=0., bidirectional=True)
195
+ self.reg_onset_fc = nn.Linear(512, classes_num, bias=True)
196
+
197
+ self.frame_gru = nn.GRU(input_size=88 * 3, hidden_size=256, num_layers=1,
198
+ bias=True, batch_first=True, dropout=0., bidirectional=True)
199
+ self.frame_fc = nn.Linear(512, classes_num, bias=True)
200
+
201
+ self.init_weight()
202
+
203
+ def init_weight(self):
204
+ init_bn(self.bn0)
205
+ init_gru(self.reg_onset_gru)
206
+ init_gru(self.frame_gru)
207
+ init_layer(self.reg_onset_fc)
208
+ init_layer(self.frame_fc)
209
+
210
+ def forward(self, input):
211
+ """
212
+ Args:
213
+ input: (batch_size, data_length)
214
+ Outputs:
215
+ output_dict: dict, {
216
+ 'reg_onset_output': (batch_size, time_steps, classes_num),
217
+ 'reg_offset_output': (batch_size, time_steps, classes_num),
218
+ 'frame_output': (batch_size, time_steps, classes_num),
219
+ 'velocity_output': (batch_size, time_steps, classes_num)
220
+ }
221
+ """
222
+
223
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
224
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
225
+
226
+ x = x.transpose(1, 3)
227
+ x = self.bn0(x)
228
+ x = x.transpose(1, 3)
229
+
230
+ frame_output = self.frame_model(x) # (batch_size, time_steps, classes_num)
231
+ reg_onset_output = self.reg_onset_model(x) # (batch_size, time_steps, classes_num)
232
+ reg_offset_output = self.reg_offset_model(x) # (batch_size, time_steps, classes_num)
233
+ velocity_output = self.velocity_model(x) # (batch_size, time_steps, classes_num)
234
+
235
+ # Use velocities to condition onset regression
236
+ x = torch.cat((reg_onset_output, (reg_onset_output ** 0.5) * velocity_output.detach()), dim=2)
237
+ (x, _) = self.reg_onset_gru(x)
238
+ x = F.dropout(x, p=0.5, training=self.training, inplace=False)
239
+ reg_onset_output = torch.sigmoid(self.reg_onset_fc(x))
240
+ """(batch_size, time_steps, classes_num)"""
241
+
242
+ # Use onsets and offsets to condition frame-wise classification
243
+ x = torch.cat((frame_output, reg_onset_output.detach(), reg_offset_output.detach()), dim=2)
244
+ (x, _) = self.frame_gru(x)
245
+ x = F.dropout(x, p=0.5, training=self.training, inplace=False)
246
+ frame_output = torch.sigmoid(self.frame_fc(x)) # (batch_size, time_steps, classes_num)
247
+ """(batch_size, time_steps, classes_num)"""
248
+
249
+ output_dict = {
250
+ 'reg_onset_output': reg_onset_output,
251
+ 'reg_offset_output': reg_offset_output,
252
+ 'frame_output': frame_output,
253
+ 'velocity_output': velocity_output}
254
+
255
+ return output_dict
256
+
257
+
258
+ class Regress_pedal_CRNN(nn.Module):
259
+ def __init__(self, frames_per_second, classes_num):
260
+ super(Regress_pedal_CRNN, self).__init__()
261
+
262
+ sample_rate = 16000
263
+ window_size = 2048
264
+ hop_size = sample_rate // frames_per_second
265
+ mel_bins = 229
266
+ fmin = 30
267
+ fmax = sample_rate // 2
268
+
269
+ window = 'hann'
270
+ center = True
271
+ pad_mode = 'reflect'
272
+ ref = 1.0
273
+ amin = 1e-10
274
+ top_db = None
275
+
276
+ midfeat = 1792
277
+ momentum = 0.01
278
+
279
+ # Spectrogram extractor
280
+ self.spectrogram_extractor = Spectrogram(n_fft=window_size,
281
+ hop_length=hop_size, win_length=window_size, window=window,
282
+ center=center, pad_mode=pad_mode, freeze_parameters=True)
283
+
284
+ # Logmel feature extractor
285
+ self.logmel_extractor = LogmelFilterBank(sr=sample_rate,
286
+ n_fft=window_size, n_mels=mel_bins, fmin=fmin, fmax=fmax, ref=ref,
287
+ amin=amin, top_db=top_db, freeze_parameters=True)
288
+
289
+ self.bn0 = nn.BatchNorm2d(mel_bins, momentum)
290
+
291
+ self.reg_pedal_onset_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
292
+ self.reg_pedal_offset_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
293
+ self.reg_pedal_frame_model = AcousticModelCRnn8Dropout(1, midfeat, momentum)
294
+
295
+ self.init_weight()
296
+
297
+ def init_weight(self):
298
+ init_bn(self.bn0)
299
+
300
+ def forward(self, input):
301
+ """
302
+ Args:
303
+ input: (batch_size, data_length)
304
+ Outputs:
305
+ output_dict: dict, {
306
+ 'reg_onset_output': (batch_size, time_steps, classes_num),
307
+ 'reg_offset_output': (batch_size, time_steps, classes_num),
308
+ 'frame_output': (batch_size, time_steps, classes_num),
309
+ 'velocity_output': (batch_size, time_steps, classes_num)
310
+ }
311
+ """
312
+
313
+ x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
314
+ x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
315
+
316
+ x = x.transpose(1, 3)
317
+ x = self.bn0(x)
318
+ x = x.transpose(1, 3)
319
+
320
+ reg_pedal_onset_output = self.reg_pedal_onset_model(x) # (batch_size, time_steps, classes_num)
321
+ reg_pedal_offset_output = self.reg_pedal_offset_model(x) # (batch_size, time_steps, classes_num)
322
+ pedal_frame_output = self.reg_pedal_frame_model(x) # (batch_size, time_steps, classes_num)
323
+
324
+ output_dict = {
325
+ 'reg_pedal_onset_output': reg_pedal_onset_output,
326
+ 'reg_pedal_offset_output': reg_pedal_offset_output,
327
+ 'pedal_frame_output': pedal_frame_output}
328
+
329
+ return output_dict
330
+
331
+
332
+ # This model is not trained, but is combined from the trained note and pedal models.
333
+ class Note_pedal(nn.Module):
334
+ def __init__(self, frames_per_second, classes_num):
335
+ """The combination of note and pedal model.
336
+ """
337
+ super(Note_pedal, self).__init__()
338
+
339
+ self.note_model = Regress_onset_offset_frame_velocity_CRNN(frames_per_second, classes_num)
340
+ self.pedal_model = Regress_pedal_CRNN(frames_per_second, classes_num)
341
+
342
+ def load_state_dict(self, m, strict=False):
343
+ self.note_model.load_state_dict(m['note_model'], strict=strict)
344
+ self.pedal_model.load_state_dict(m['pedal_model'], strict=strict)
345
+
346
+ def forward(self, input):
347
+ note_output_dict = self.note_model(input)
348
+ pedal_output_dict = self.pedal_model(input)
349
+
350
+ full_output_dict = {}
351
+ full_output_dict.update(note_output_dict)
352
+ full_output_dict.update(pedal_output_dict)
353
+ return full_output_dict
piano_vad.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def note_detection_with_onset_offset_regress(frame_output, onset_output,
5
+ onset_shift_output, offset_output, offset_shift_output, velocity_output,
6
+ frame_threshold):
7
+ """Process prediction matrices to note events information.
8
+ First, detect onsets with onset outputs. Then, detect offsets
9
+ with frame and offset outputs.
10
+
11
+ Args:
12
+ frame_output: (frames_num,)
13
+ onset_output: (frames_num,)
14
+ onset_shift_output: (frames_num,)
15
+ offset_output: (frames_num,)
16
+ offset_shift_output: (frames_num,)
17
+ velocity_output: (frames_num,)
18
+ frame_threshold: float
19
+ Returns:
20
+ output_tuples: list of [bgn, fin, onset_shift, offset_shift, normalized_velocity],
21
+ e.g., [
22
+ [1821, 1909, 0.47498, 0.3048533, 0.72119445],
23
+ [1909, 1947, 0.30730522, -0.45764327, 0.64200014],
24
+ ...]
25
+ """
26
+ output_tuples = []
27
+ bgn = None
28
+ frame_disappear = None
29
+ offset_occur = None
30
+
31
+ for i in range(onset_output.shape[0]):
32
+ if onset_output[i] == 1:
33
+ """Onset detected"""
34
+ if bgn:
35
+ """Consecutive onsets. E.g., pedal is not released, but two
36
+ consecutive notes being played."""
37
+ fin = max(i - 1, 0)
38
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
39
+ 0, velocity_output[bgn]])
40
+ frame_disappear, offset_occur = None, None
41
+ bgn = i
42
+
43
+ if bgn and i > bgn:
44
+ """If onset found, then search offset"""
45
+ if frame_output[i] <= frame_threshold and not frame_disappear:
46
+ """Frame disappear detected"""
47
+ frame_disappear = i
48
+
49
+ if offset_output[i] == 1 and not offset_occur:
50
+ """Offset detected"""
51
+ offset_occur = i
52
+
53
+ if frame_disappear:
54
+ if offset_occur and offset_occur - bgn > frame_disappear - offset_occur:
55
+ """bgn --------- offset_occur --- frame_disappear"""
56
+ fin = offset_occur
57
+ else:
58
+ """bgn --- offset_occur --------- frame_disappear"""
59
+ fin = frame_disappear
60
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
61
+ offset_shift_output[fin], velocity_output[bgn]])
62
+ bgn, frame_disappear, offset_occur = None, None, None
63
+
64
+ if bgn and (i - bgn >= 600 or i == onset_output.shape[0] - 1):
65
+ """Offset not detected"""
66
+ fin = i
67
+ output_tuples.append([bgn, fin, onset_shift_output[bgn],
68
+ offset_shift_output[fin], velocity_output[bgn]])
69
+ bgn, frame_disappear, offset_occur = None, None, None
70
+
71
+ # Sort pairs by onsets
72
+ output_tuples.sort(key=lambda pair: pair[0])
73
+
74
+ return output_tuples
75
+
76
+
77
+ def pedal_detection_with_onset_offset_regress(frame_output, offset_output,
78
+ offset_shift_output, frame_threshold):
79
+ """Process prediction array to pedal events information.
80
+
81
+ Args:
82
+ frame_output: (frames_num,)
83
+ offset_output: (frames_num,)
84
+ offset_shift_output: (frames_num,)
85
+ frame_threshold: float
86
+ Returns:
87
+ output_tuples: list of [bgn, fin, onset_shift, offset_shift],
88
+ e.g., [
89
+ [1821, 1909, 0.4749851, 0.3048533],
90
+ [1909, 1947, 0.30730522, -0.45764327],
91
+ ...]
92
+ """
93
+ output_tuples = []
94
+ bgn = None
95
+ frame_disappear = None
96
+ offset_occur = None
97
+
98
+ for i in range(1, frame_output.shape[0]):
99
+ if frame_output[i] >= frame_threshold and frame_output[i] > frame_output[i - 1]:
100
+ """Pedal onset detected"""
101
+ if bgn:
102
+ pass
103
+ else:
104
+ bgn = i
105
+
106
+ if bgn and i > bgn:
107
+ """If onset found, then search offset"""
108
+ if frame_output[i] <= frame_threshold and not frame_disappear:
109
+ """Frame disappear detected"""
110
+ frame_disappear = i
111
+
112
+ if offset_output[i] == 1 and not offset_occur:
113
+ """Offset detected"""
114
+ offset_occur = i
115
+
116
+ if offset_occur:
117
+ fin = offset_occur
118
+ output_tuples.append([bgn, fin, 0., offset_shift_output[fin]])
119
+ bgn, frame_disappear, offset_occur = None, None, None
120
+
121
+ if frame_disappear and i - frame_disappear >= 10:
122
+ """offset not detected but frame disappear"""
123
+ fin = frame_disappear
124
+ output_tuples.append([bgn, fin, 0., offset_shift_output[fin]])
125
+ bgn, frame_disappear, offset_occur = None, None, None
126
+
127
+ # Sort pairs by onsets
128
+ output_tuples.sort(key=lambda pair: pair[0])
129
+
130
+ return output_tuples
pytorch_utils.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import time
4
+ import torch
5
+
6
+ from .utilities import pad_truncate_sequence
7
+
8
+
9
+ def move_data_to_device(x, device):
10
+ if 'float' in str(x.dtype):
11
+ x = torch.Tensor(x)
12
+ elif 'int' in str(x.dtype):
13
+ x = torch.LongTensor(x)
14
+ else:
15
+ return x
16
+
17
+ return x.to(device)
18
+
19
+
20
+ def append_to_dict(dict, key, value):
21
+ if key in dict.keys():
22
+ dict[key].append(value)
23
+ else:
24
+ dict[key] = [value]
25
+
26
+
27
+ def forward(model, x, batch_size):
28
+ """Forward data to model in mini-batch.
29
+
30
+ Args:
31
+ model: object
32
+ x: (N, segment_samples)
33
+ batch_size: int
34
+
35
+ Returns:
36
+ output_dict: dict, e.g. {
37
+ 'frame_output': (segments_num, frames_num, classes_num),
38
+ 'onset_output': (segments_num, frames_num, classes_num),
39
+ ...}
40
+ """
41
+
42
+ output_dict = {}
43
+ device = next(model.parameters()).device
44
+
45
+ pointer = 0
46
+ total_segments = int(np.ceil(len(x) / batch_size))
47
+
48
+ while True:
49
+ print('Segment {} / {}'.format(pointer, total_segments))
50
+ if pointer >= len(x):
51
+ break
52
+
53
+ batch_waveform = move_data_to_device(x[pointer : pointer + batch_size], device)
54
+ pointer += batch_size
55
+
56
+ with torch.no_grad():
57
+ model.eval()
58
+ batch_output_dict = model(batch_waveform)
59
+
60
+ for key in batch_output_dict.keys():
61
+ append_to_dict(output_dict, key, batch_output_dict[key].data.cpu().numpy())
62
+
63
+ for key in output_dict.keys():
64
+ output_dict[key] = np.concatenate(output_dict[key], axis=0)
65
+
66
+ return output_dict
utilities.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import audioread
4
+ import librosa
5
+ from mido import MidiFile
6
+
7
+ from .piano_vad import (note_detection_with_onset_offset_regress,
8
+ pedal_detection_with_onset_offset_regress)
9
+ from . import config
10
+
11
+
12
+ def create_folder(fd):
13
+ if not os.path.exists(fd):
14
+ os.makedirs(fd)
15
+
16
+
17
+ def get_filename(path):
18
+ path = os.path.realpath(path)
19
+ na_ext = path.split('/')[-1]
20
+ na = os.path.splitext(na_ext)[0]
21
+ return na
22
+
23
+
24
+ def note_to_freq(piano_note):
25
+ return 2 ** ((piano_note - 39) / 12) * 440
26
+
27
+
28
+ def float32_to_int16(x):
29
+ assert np.max(np.abs(x)) <= 1.
30
+ return (x * 32767.).astype(np.int16)
31
+
32
+
33
+ def int16_to_float32(x):
34
+ return (x / 32767.).astype(np.float32)
35
+
36
+
37
+ def pad_truncate_sequence(x, max_len):
38
+ if len(x) < max_len:
39
+ return np.concatenate((x, np.zeros(max_len - len(x))))
40
+ else:
41
+ return x[0 : max_len]
42
+
43
+
44
+ def read_midi(midi_path):
45
+ """Parse MIDI file.
46
+
47
+ Args:
48
+ midi_path: str
49
+
50
+ Returns:
51
+ midi_dict: dict, e.g. {
52
+ 'midi_event': [
53
+ 'program_change channel=0 program=0 time=0',
54
+ 'control_change channel=0 control=64 value=127 time=0',
55
+ 'control_change channel=0 control=64 value=63 time=236',
56
+ ...],
57
+ 'midi_event_time': [0., 0, 0.98307292, ...]}
58
+ """
59
+
60
+ midi_file = MidiFile(midi_path)
61
+ ticks_per_beat = midi_file.ticks_per_beat
62
+
63
+ assert len(midi_file.tracks) == 2
64
+ """The first track contains tempo, time signature. The second track
65
+ contains piano events."""
66
+
67
+ microseconds_per_beat = midi_file.tracks[0][0].tempo
68
+ beats_per_second = 1e6 / microseconds_per_beat
69
+ ticks_per_second = ticks_per_beat * beats_per_second
70
+
71
+ message_list = []
72
+
73
+ ticks = 0
74
+ time_in_second = []
75
+
76
+ for message in midi_file.tracks[1]:
77
+ message_list.append(str(message))
78
+ ticks += message.time
79
+ time_in_second.append(ticks / ticks_per_second)
80
+
81
+ midi_dict = {
82
+ 'midi_event': np.array(message_list),
83
+ 'midi_event_time': np.array(time_in_second)}
84
+
85
+ return midi_dict
86
+
87
+
88
+ def write_events_to_midi(start_time, note_events, pedal_events, midi_path):
89
+ """Write out note events to MIDI file.
90
+
91
+ Args:
92
+ start_time: float
93
+ note_events: list of dict, e.g. [
94
+ {'midi_note': 51, 'onset_time': 696.63544, 'offset_time': 696.9948, 'velocity': 44},
95
+ {'midi_note': 58, 'onset_time': 696.99585, 'offset_time': 697.18646, 'velocity': 50}
96
+ ...]
97
+ midi_path: str
98
+ """
99
+ from mido import Message, MidiFile, MidiTrack, MetaMessage
100
+
101
+ # This configuration is the same as MIDIs in MAESTRO dataset
102
+ ticks_per_beat = 384
103
+ beats_per_second = 2
104
+ ticks_per_second = ticks_per_beat * beats_per_second
105
+ microseconds_per_beat = int(1e6 // beats_per_second)
106
+
107
+ midi_file = MidiFile()
108
+ midi_file.ticks_per_beat = ticks_per_beat
109
+
110
+ # Track 0
111
+ track0 = MidiTrack()
112
+ track0.append(MetaMessage('set_tempo', tempo=microseconds_per_beat, time=0))
113
+ track0.append(MetaMessage('time_signature', numerator=4, denominator=4, time=0))
114
+ track0.append(MetaMessage('end_of_track', time=1))
115
+ midi_file.tracks.append(track0)
116
+
117
+ # Track 1
118
+ track1 = MidiTrack()
119
+
120
+ # Message rolls of MIDI
121
+ message_roll = []
122
+
123
+ for note_event in note_events:
124
+ # Onset
125
+ message_roll.append({
126
+ 'time': note_event['onset_time'],
127
+ 'midi_note': note_event['midi_note'],
128
+ 'velocity': note_event['velocity']})
129
+
130
+ # Offset
131
+ message_roll.append({
132
+ 'time': note_event['offset_time'],
133
+ 'midi_note': note_event['midi_note'],
134
+ 'velocity': 0})
135
+
136
+ if pedal_events:
137
+ for pedal_event in pedal_events:
138
+ message_roll.append({'time': pedal_event['onset_time'], 'control_change': 64, 'value': 127})
139
+ message_roll.append({'time': pedal_event['offset_time'], 'control_change': 64, 'value': 0})
140
+
141
+ # Sort MIDI messages by time
142
+ message_roll.sort(key=lambda note_event: note_event['time'])
143
+
144
+ previous_ticks = 0
145
+ for message in message_roll:
146
+ this_ticks = int((message['time'] - start_time) * ticks_per_second)
147
+ if this_ticks >= 0:
148
+ diff_ticks = this_ticks - previous_ticks
149
+ previous_ticks = this_ticks
150
+ if 'midi_note' in message.keys():
151
+ track1.append(Message('note_on', note=message['midi_note'], velocity=message['velocity'], time=diff_ticks))
152
+ elif 'control_change' in message.keys():
153
+ track1.append(Message('control_change', channel=0, control=message['control_change'], value=message['value'], time=diff_ticks))
154
+ track1.append(MetaMessage('end_of_track', time=1))
155
+ midi_file.tracks.append(track1)
156
+
157
+ midi_file.save(midi_path)
158
+
159
+
160
+ class RegressionPostProcessor(object):
161
+ def __init__(self, frames_per_second, classes_num, onset_threshold,
162
+ offset_threshold, frame_threshold, pedal_offset_threshold):
163
+ """Postprocess the output probabilities of a transription model to MIDI
164
+ events.
165
+
166
+ Args:
167
+ frames_per_second: int
168
+ classes_num: int
169
+ onset_threshold: float
170
+ offset_threshold: float
171
+ frame_threshold: float
172
+ pedal_offset_threshold: float
173
+ """
174
+ self.frames_per_second = frames_per_second
175
+ self.classes_num = classes_num
176
+ self.onset_threshold = onset_threshold
177
+ self.offset_threshold = offset_threshold
178
+ self.frame_threshold = frame_threshold
179
+ self.pedal_offset_threshold = pedal_offset_threshold
180
+ self.begin_note = config.begin_note
181
+ self.velocity_scale = config.velocity_scale
182
+
183
+ def output_dict_to_midi_events(self, output_dict):
184
+ """Main function. Post process model outputs to MIDI events.
185
+
186
+ Args:
187
+ output_dict: {
188
+ 'reg_onset_output': (segment_frames, classes_num),
189
+ 'reg_offset_output': (segment_frames, classes_num),
190
+ 'frame_output': (segment_frames, classes_num),
191
+ 'velocity_output': (segment_frames, classes_num),
192
+ 'reg_pedal_onset_output': (segment_frames, 1),
193
+ 'reg_pedal_offset_output': (segment_frames, 1),
194
+ 'pedal_frame_output': (segment_frames, 1)}
195
+
196
+ Outputs:
197
+ est_note_events: list of dict, e.g. [
198
+ {'onset_time': 39.74, 'offset_time': 39.87, 'midi_note': 27, 'velocity': 83},
199
+ {'onset_time': 11.98, 'offset_time': 12.11, 'midi_note': 33, 'velocity': 88}]
200
+
201
+ est_pedal_events: list of dict, e.g. [
202
+ {'onset_time': 0.17, 'offset_time': 0.96},
203
+ {'osnet_time': 1.17, 'offset_time': 2.65}]
204
+ """
205
+
206
+ # Post process piano note outputs to piano note and pedal events information
207
+ (est_on_off_note_vels, est_pedal_on_offs) = \
208
+ self.output_dict_to_note_pedal_arrays(output_dict)
209
+ """est_on_off_note_vels: (events_num, 4), the four columns are: [onset_time, offset_time, piano_note, velocity],
210
+ est_pedal_on_offs: (pedal_events_num, 2), the two columns are: [onset_time, offset_time]"""
211
+
212
+ # Reformat notes to MIDI events
213
+ est_note_events = self.detected_notes_to_events(est_on_off_note_vels)
214
+
215
+ if est_pedal_on_offs is None:
216
+ est_pedal_events = None
217
+ else:
218
+ est_pedal_events = self.detected_pedals_to_events(est_pedal_on_offs)
219
+
220
+ return est_note_events, est_pedal_events
221
+
222
+ def output_dict_to_note_pedal_arrays(self, output_dict):
223
+ """Postprocess the output probabilities of a transription model to MIDI
224
+ events.
225
+
226
+ Args:
227
+ output_dict: dict, {
228
+ 'reg_onset_output': (frames_num, classes_num),
229
+ 'reg_offset_output': (frames_num, classes_num),
230
+ 'frame_output': (frames_num, classes_num),
231
+ 'velocity_output': (frames_num, classes_num),
232
+ ...}
233
+
234
+ Returns:
235
+ est_on_off_note_vels: (events_num, 4), the 4 columns are onset_time,
236
+ offset_time, piano_note and velocity. E.g. [
237
+ [39.74, 39.87, 27, 0.65],
238
+ [11.98, 12.11, 33, 0.69],
239
+ ...]
240
+
241
+ est_pedal_on_offs: (pedal_events_num, 2), the 2 columns are onset_time
242
+ and offset_time. E.g. [
243
+ [0.17, 0.96],
244
+ [1.17, 2.65],
245
+ ...]
246
+ """
247
+
248
+ # ------ 1. Process regression outputs to binarized outputs ------
249
+ # For example, onset or offset of [0., 0., 0.15, 0.30, 0.40, 0.35, 0.20, 0.05, 0., 0.]
250
+ # will be processed to [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]
251
+
252
+ # Calculate binarized onset output from regression output
253
+ (onset_output, onset_shift_output) = \
254
+ self.get_binarized_output_from_regression(
255
+ reg_output=output_dict['reg_onset_output'],
256
+ threshold=self.onset_threshold, neighbour=2)
257
+
258
+ output_dict['onset_output'] = onset_output # Values are 0 or 1
259
+ output_dict['onset_shift_output'] = onset_shift_output
260
+
261
+ # Calculate binarized offset output from regression output
262
+ (offset_output, offset_shift_output) = \
263
+ self.get_binarized_output_from_regression(
264
+ reg_output=output_dict['reg_offset_output'],
265
+ threshold=self.offset_threshold, neighbour=4)
266
+
267
+ output_dict['offset_output'] = offset_output # Values are 0 or 1
268
+ output_dict['offset_shift_output'] = offset_shift_output
269
+
270
+ if 'reg_pedal_onset_output' in output_dict.keys():
271
+ """Pedal onsets are not used in inference. Instead, frame-wise pedal
272
+ predictions are used to detect onsets. We empirically found this is
273
+ more accurate to detect pedal onsets."""
274
+ pass
275
+
276
+ if 'reg_pedal_offset_output' in output_dict.keys():
277
+ # Calculate binarized pedal offset output from regression output
278
+ (pedal_offset_output, pedal_offset_shift_output) = \
279
+ self.get_binarized_output_from_regression(
280
+ reg_output=output_dict['reg_pedal_offset_output'],
281
+ threshold=self.pedal_offset_threshold, neighbour=4)
282
+
283
+ output_dict['pedal_offset_output'] = pedal_offset_output # Values are 0 or 1
284
+ output_dict['pedal_offset_shift_output'] = pedal_offset_shift_output
285
+
286
+ # ------ 2. Process matrices results to event results ------
287
+ # Detect piano notes from output_dict
288
+ est_on_off_note_vels = self.output_dict_to_detected_notes(output_dict)
289
+
290
+ if 'reg_pedal_onset_output' in output_dict.keys():
291
+ # Detect piano pedals from output_dict
292
+ est_pedal_on_offs = self.output_dict_to_detected_pedals(output_dict)
293
+
294
+ else:
295
+ est_pedal_on_offs = None
296
+
297
+ return est_on_off_note_vels, est_pedal_on_offs
298
+
299
+ def get_binarized_output_from_regression(self, reg_output, threshold, neighbour):
300
+ """Calculate binarized output and shifts of onsets or offsets from the
301
+ regression results.
302
+
303
+ Args:
304
+ reg_output: (frames_num, classes_num)
305
+ threshold: float
306
+ neighbour: int
307
+
308
+ Returns:
309
+ binary_output: (frames_num, classes_num)
310
+ shift_output: (frames_num, classes_num)
311
+ """
312
+ binary_output = np.zeros_like(reg_output)
313
+ shift_output = np.zeros_like(reg_output)
314
+ (frames_num, classes_num) = reg_output.shape
315
+
316
+ for k in range(classes_num):
317
+ x = reg_output[:, k]
318
+ for n in range(neighbour, frames_num - neighbour):
319
+ if x[n] > threshold and self.is_monotonic_neighbour(x, n, neighbour):
320
+ binary_output[n, k] = 1
321
+
322
+ """See Section III-D in [1] for deduction.
323
+ [1] Q. Kong, et al., High-resolution Piano Transcription
324
+ with Pedals by Regressing Onsets and Offsets Times, 2020."""
325
+ if x[n - 1] > x[n + 1]:
326
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n + 1]) / 2
327
+ else:
328
+ shift = (x[n + 1] - x[n - 1]) / (x[n] - x[n - 1]) / 2
329
+ shift_output[n, k] = shift
330
+
331
+ return binary_output, shift_output
332
+
333
+ def is_monotonic_neighbour(self, x, n, neighbour):
334
+ """Detect if values are monotonic in both side of x[n].
335
+
336
+ Args:
337
+ x: (frames_num,)
338
+ n: int
339
+ neighbour: int
340
+
341
+ Returns:
342
+ monotonic: bool
343
+ """
344
+ monotonic = True
345
+ for i in range(neighbour):
346
+ if x[n - i] < x[n - i - 1]:
347
+ monotonic = False
348
+ if x[n + i] < x[n + i + 1]:
349
+ monotonic = False
350
+
351
+ return monotonic
352
+
353
+ def output_dict_to_detected_notes(self, output_dict):
354
+ """Postprocess output_dict to piano notes.
355
+
356
+ Args:
357
+ output_dict: dict, e.g. {
358
+ 'onset_output': (frames_num, classes_num),
359
+ 'onset_shift_output': (frames_num, classes_num),
360
+ 'offset_output': (frames_num, classes_num),
361
+ 'offset_shift_output': (frames_num, classes_num),
362
+ 'frame_output': (frames_num, classes_num),
363
+ 'onset_output': (frames_num, classes_num),
364
+ ...}
365
+
366
+ Returns:
367
+ est_on_off_note_vels: (notes, 4), the four columns are onsets, offsets,
368
+ MIDI notes and velocities. E.g.,
369
+ [[39.7375, 39.7500, 27., 0.6638],
370
+ [11.9824, 12.5000, 33., 0.6892],
371
+ ...]
372
+ """
373
+ est_tuples = []
374
+ est_midi_notes = []
375
+ classes_num = output_dict['frame_output'].shape[-1]
376
+
377
+ for piano_note in range(classes_num):
378
+ """Detect piano notes"""
379
+ est_tuples_per_note = note_detection_with_onset_offset_regress(
380
+ frame_output=output_dict['frame_output'][:, piano_note],
381
+ onset_output=output_dict['onset_output'][:, piano_note],
382
+ onset_shift_output=output_dict['onset_shift_output'][:, piano_note],
383
+ offset_output=output_dict['offset_output'][:, piano_note],
384
+ offset_shift_output=output_dict['offset_shift_output'][:, piano_note],
385
+ velocity_output=output_dict['velocity_output'][:, piano_note],
386
+ frame_threshold=self.frame_threshold)
387
+
388
+ est_tuples += est_tuples_per_note
389
+ est_midi_notes += [piano_note + self.begin_note] * len(est_tuples_per_note)
390
+
391
+ est_tuples = np.array(est_tuples) # (notes, 5)
392
+ """(notes, 5), the five columns are onset, offset, onset_shift,
393
+ offset_shift and normalized_velocity"""
394
+
395
+ est_midi_notes = np.array(est_midi_notes) # (notes,)
396
+
397
+ if len(est_tuples) == 0:
398
+ return np.array([])
399
+
400
+ else:
401
+ onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
402
+ offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
403
+ velocities = est_tuples[:, 4]
404
+
405
+ est_on_off_note_vels = np.stack((onset_times, offset_times, est_midi_notes, velocities), axis=-1)
406
+ """(notes, 3), the three columns are onset_times, offset_times and velocity."""
407
+
408
+ est_on_off_note_vels = est_on_off_note_vels.astype(np.float32)
409
+
410
+ return est_on_off_note_vels
411
+
412
+ def output_dict_to_detected_pedals(self, output_dict):
413
+ """Postprocess output_dict to piano pedals.
414
+
415
+ Args:
416
+ output_dict: dict, e.g. {
417
+ 'pedal_frame_output': (frames_num,),
418
+ 'pedal_offset_output': (frames_num,),
419
+ 'pedal_offset_shift_output': (frames_num,),
420
+ ...}
421
+
422
+ Returns:
423
+ est_on_off: (notes, 2), the two columns are pedal onsets and pedal
424
+ offsets. E.g.,
425
+ [[0.1800, 0.9669],
426
+ [1.1400, 2.6458],
427
+ ...]
428
+ """
429
+ frames_num = output_dict['pedal_frame_output'].shape[0]
430
+
431
+ est_tuples = pedal_detection_with_onset_offset_regress(
432
+ frame_output=output_dict['pedal_frame_output'][:, 0],
433
+ offset_output=output_dict['pedal_offset_output'][:, 0],
434
+ offset_shift_output=output_dict['pedal_offset_shift_output'][:, 0],
435
+ frame_threshold=0.5)
436
+
437
+ est_tuples = np.array(est_tuples)
438
+ """(notes, 2), the two columns are pedal onsets and pedal offsets"""
439
+
440
+ if len(est_tuples) == 0:
441
+ return np.array([])
442
+
443
+ else:
444
+ onset_times = (est_tuples[:, 0] + est_tuples[:, 2]) / self.frames_per_second
445
+ offset_times = (est_tuples[:, 1] + est_tuples[:, 3]) / self.frames_per_second
446
+ est_on_off = np.stack((onset_times, offset_times), axis=-1)
447
+ est_on_off = est_on_off.astype(np.float32)
448
+ return est_on_off
449
+
450
+ def detected_notes_to_events(self, est_on_off_note_vels):
451
+ """Reformat detected notes to midi events.
452
+
453
+ Args:
454
+ est_on_off_vels: (notes, 3), the three columns are onset_times,
455
+ offset_times and velocity. E.g.
456
+ [[32.8376, 35.7700, 0.7932],
457
+ [37.3712, 39.9300, 0.8058],
458
+ ...]
459
+
460
+ Returns:
461
+ midi_events, list, e.g.,
462
+ [{'onset_time': 39.7376, 'offset_time': 39.75, 'midi_note': 27, 'velocity': 84},
463
+ {'onset_time': 11.9824, 'offset_time': 12.50, 'midi_note': 33, 'velocity': 88},
464
+ ...]
465
+ """
466
+ midi_events = []
467
+ for i in range(est_on_off_note_vels.shape[0]):
468
+ midi_events.append({
469
+ 'onset_time': est_on_off_note_vels[i][0],
470
+ 'offset_time': est_on_off_note_vels[i][1],
471
+ 'midi_note': int(est_on_off_note_vels[i][2]),
472
+ 'velocity': int(est_on_off_note_vels[i][3] * self.velocity_scale)})
473
+
474
+ return midi_events
475
+
476
+ def detected_pedals_to_events(self, pedal_on_offs):
477
+ """Reformat detected pedal onset and offsets to events.
478
+
479
+ Args:
480
+ pedal_on_offs: (notes, 2), the two columns are pedal onsets and pedal
481
+ offsets. E.g.,
482
+ [[0.1800, 0.9669],
483
+ [1.1400, 2.6458],
484
+ ...]
485
+
486
+ Returns:
487
+ pedal_events: list of dict, e.g.,
488
+ [{'onset_time': 0.1800, 'offset_time': 0.9669},
489
+ {'onset_time': 1.1400, 'offset_time': 2.6458},
490
+ ...]
491
+ """
492
+ pedal_events = []
493
+ for i in range(len(pedal_on_offs)):
494
+ pedal_events.append({
495
+ 'onset_time': pedal_on_offs[i, 0],
496
+ 'offset_time': pedal_on_offs[i, 1]})
497
+
498
+ return pedal_events
499
+
500
+
501
+ def load_audio(path, sr=22050, mono=True, offset=0.0, duration=None,
502
+ dtype=np.float32, res_type='kaiser_best',
503
+ backends=[audioread.ffdec.FFmpegAudioFile]):
504
+ """Load audio. Copied from librosa.core.load() except that ffmpeg backend is
505
+ always used in this function."""
506
+
507
+ y = []
508
+ with audioread.audio_open(os.path.realpath(path), backends=backends) as input_file:
509
+ sr_native = input_file.samplerate
510
+ n_channels = input_file.channels
511
+
512
+ s_start = int(np.round(sr_native * offset)) * n_channels
513
+
514
+ if duration is None:
515
+ s_end = np.inf
516
+ else:
517
+ s_end = s_start + (int(np.round(sr_native * duration))
518
+ * n_channels)
519
+
520
+ n = 0
521
+
522
+ for frame in input_file:
523
+ frame = frame = librosa.util.buf_to_float(frame, n_bytes=2, dtype=dtype)
524
+ n_prev = n
525
+ n = n + len(frame)
526
+
527
+ if n < s_start:
528
+ # offset is after the current frame
529
+ # keep reading
530
+ continue
531
+
532
+ if s_end < n_prev:
533
+ # we're off the end. stop reading
534
+ break
535
+
536
+ if s_end < n:
537
+ # the end is in this frame. crop.
538
+ frame = frame[:s_end - n_prev]
539
+
540
+ if n_prev <= s_start <= n:
541
+ # beginning is in this frame
542
+ frame = frame[(s_start - n_prev):]
543
+
544
+ # tack on the current frame
545
+ y.append(frame)
546
+
547
+ if y:
548
+ y = np.concatenate(y)
549
+
550
+ if n_channels > 1:
551
+ y = y.reshape((-1, n_channels)).T
552
+ if mono:
553
+ y = librosa.to_mono(y)
554
+
555
+ if sr is not None:
556
+ y = librosa.resample(y, orig_sr=sr_native, target_sr=sr, res_type=res_type)
557
+
558
+ else:
559
+ sr = sr_native
560
+
561
+ # Final cleanup for dtype and contiguity
562
+ y = np.ascontiguousarray(y, dtype=dtype)
563
+
564
+ return (y, sr)