|
|
|
|
|
|
|
|
|
import os |
|
import time as reqtime |
|
import datetime |
|
from pytz import timezone |
|
|
|
import gradio as gr |
|
import spaces |
|
|
|
import os |
|
|
|
from tqdm import tqdm |
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
from x_transformer_1_23_2 import * |
|
|
|
import random |
|
|
|
import TMIDIX |
|
|
|
from midi_to_colab_audio import midi_to_colab_audio |
|
|
|
|
|
|
|
|
|
def Harmonize_Melody(input_src_midi, |
|
source_melody_transpose_value, |
|
model_top_k_sampling_value, |
|
texture_harmonized_chords, |
|
melody_MIDI_patch_number, |
|
harmonized_accompaniment_MIDI_patch_number, |
|
base_MIDI_patch_number |
|
): |
|
|
|
print('=' * 70) |
|
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('=' * 70) |
|
|
|
start_time = reqtime.time() |
|
|
|
sfn = os.path.basename(input_src_midi.name) |
|
sfn1 = sfn.split('.')[0] |
|
|
|
print('Input src MIDI name:', sfn) |
|
|
|
print('=' * 70) |
|
print('Requested settings:') |
|
print('Source melody transpose value:', source_melody_transpose_value) |
|
print('Model top_k sampling value:', model_top_k_sampling_value) |
|
print('Texture harmonized chords:', texture_harmonized_chords) |
|
print('Melody MIDI patch number:', melody_MIDI_patch_number) |
|
print('Harmonized accompaniment MIDI patch number:', harmonized_accompaniment_MIDI_patch_number) |
|
print('Base MIDI patch number:', base_MIDI_patch_number) |
|
print('=' * 70) |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Loading seed melody...') |
|
|
|
|
|
|
|
|
|
raw_score = TMIDIX.midi2single_track_ms_score(input_src_midi.name) |
|
|
|
|
|
|
|
|
|
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0] |
|
|
|
|
|
|
|
|
|
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=16) |
|
|
|
cscore = [c[0] for c in TMIDIX.chordify_score([1000, escore_notes])] |
|
|
|
mel_score = TMIDIX.fix_monophonic_score_durations(TMIDIX.recalculate_score_timings(cscore)) |
|
|
|
mel_score = TMIDIX.transpose_escore_notes(mel_score, source_melody_transpose_value) |
|
|
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
mel_pitches = [p[4] % 12 for p in mel_score] |
|
|
|
print('Melody has', len(mel_pitches), 'notes') |
|
print('=' * 70) |
|
|
|
|
|
|
|
print('=' * 70) |
|
print('Melody Harmonizer Transformer') |
|
print('=' * 70) |
|
|
|
print('Loading Melody Harmonizer Transformer Model...') |
|
|
|
SEQ_LEN = 75 |
|
PAD_IDX = 144 |
|
|
|
|
|
|
|
model = TransformerWrapper( |
|
num_tokens = PAD_IDX+1, |
|
max_seq_len = SEQ_LEN, |
|
attn_layers = Decoder(dim = 1024, depth = 12, heads = 16, attn_flash = True) |
|
) |
|
|
|
model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX) |
|
|
|
model_path = 'Melody_Harmonizer_Transformer_Trained_Model_14961_steps_0.4155_loss_0.8664_acc.pth' |
|
|
|
model.load_state_dict(torch.load(model_path, map_location='cpu')) |
|
|
|
model.cpu() |
|
|
|
dtype = torch.bfloat16 |
|
|
|
ctx = torch.amp.autocast(device_type='cpu', dtype=dtype) |
|
|
|
model.eval() |
|
|
|
print('Done!') |
|
|
|
print('=' * 70) |
|
print('Harmonizing...') |
|
print('=' * 70) |
|
|
|
|
|
|
|
mel_remainder_value = (((len(mel_pitches) // 24)+1) * 24) - len(mel_pitches) |
|
|
|
mel_pitches_ext = mel_pitches + mel_pitches[:mel_remainder_value] |
|
|
|
song = [] |
|
|
|
for i in range(0, len(mel_pitches_ext)-12, 12): |
|
|
|
mel_chunk = mel_pitches_ext[i:i+24] |
|
|
|
data = [141] + mel_chunk + [142] |
|
|
|
for j in range(24): |
|
|
|
data.append(mel_chunk[j]) |
|
|
|
x = torch.tensor([data], dtype=torch.long, device='cpu') |
|
|
|
with ctx: |
|
out = model.generate(x, |
|
1, |
|
filter_logits_fn=top_k, |
|
filter_kwargs={'k': model_top_k_sampling_value}, |
|
temperature=1.0, |
|
return_prime=False, |
|
verbose=False) |
|
|
|
outy = out.tolist()[0] |
|
|
|
data.append(outy[0]) |
|
|
|
if i != len(mel_pitches_ext)-24: |
|
|
|
song.extend(data[26:50]) |
|
else: |
|
song.extend(data[26:]) |
|
|
|
song = song[:len(mel_pitches) * 2] |
|
|
|
|
|
|
|
print('Harmonized', len(song) // 2, 'out of', len(mel_pitches), 'notes') |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
def find_best_match(matches_indexes, previous_match_index): |
|
|
|
msigs = [] |
|
|
|
for midx in matches_indexes: |
|
|
|
mat = all_chords_ptcs_chunks[midx] |
|
|
|
msig = [] |
|
|
|
for m in mat: |
|
msig.extend([sum(m) / len(m), len(m)]) |
|
|
|
msigs.append(msig) |
|
|
|
pmat = all_chords_ptcs_chunks[previous_match_index] |
|
|
|
psig = [] |
|
|
|
for p in pmat: |
|
psig.extend([sum(p) / len(p), len(p)]) |
|
|
|
dists = [] |
|
|
|
for m in msigs: |
|
dists.append(TMIDIX.minkowski_distance(psig, m)) |
|
|
|
min_dist = min(dists) |
|
min_dist_idx = dists.index(min_dist) |
|
|
|
return matches_indexes[min_dist_idx] |
|
|
|
|
|
|
|
if texture_harmonized_chords: |
|
|
|
print('=' * 70) |
|
print('Texturing harmonized chords...') |
|
print('=' * 70) |
|
|
|
chunk_length = 2 |
|
|
|
harm_chords = [TMIDIX.ALL_CHORDS_FILTERED[s-12] for s in song if 11 < s < 141] |
|
|
|
harm_toks = [TMIDIX.ALL_CHORDS_FILTERED.index(c) for c in harm_chords] + [TMIDIX.ALL_CHORDS_FILTERED.index(harm_chords[-1])] * (chunk_length - (len(harm_chords) % chunk_length)) |
|
|
|
final_song = [] |
|
|
|
trg_chunk = np.array(harm_toks[:chunk_length]) |
|
sidxs = np.where((src_chunks == trg_chunk).all(axis=1))[0].tolist() |
|
|
|
sidx = random.choice(sidxs) |
|
pidx = sidx |
|
|
|
final_song.extend(all_chords_ptcs_chunks[sidx]) |
|
|
|
for i in tqdm(range(chunk_length, len(harm_toks), chunk_length)): |
|
|
|
trg_chunk = np.array(harm_toks[i:i+chunk_length]) |
|
|
|
sidxs = np.where((src_chunks == trg_chunk).all(axis=1))[0].tolist() |
|
|
|
if len(sidxs) > 0: |
|
|
|
sidx = find_best_match(sidxs, pidx) |
|
pidx = sidx |
|
|
|
final_song.extend(all_chords_ptcs_chunks[sidx]) |
|
|
|
else: |
|
print('Dead end!') |
|
break |
|
|
|
final_song = final_song[:len(harm_chords)] |
|
|
|
print('=' * 70) |
|
print(len(final_song)) |
|
print('=' * 70) |
|
print('Done!') |
|
print('=' * 70) |
|
|
|
print('Rendering textured results...') |
|
print('=' * 70) |
|
|
|
output_score = [] |
|
|
|
time = 0 |
|
|
|
patches = [0] * 16 |
|
patches[0] = harmonized_accompaniment_MIDI_patch_number |
|
|
|
if base_MIDI_patch_number > -1: |
|
patches[2] = base_MIDI_patch_number |
|
|
|
patches[3] = melody_MIDI_patch_number |
|
|
|
i = 0 |
|
|
|
for s in final_song: |
|
|
|
time = mel_score[i][1] * 16 |
|
dur = mel_score[i][2] * 16 |
|
|
|
output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40]) |
|
|
|
for c in s: |
|
|
|
pitch = c |
|
output_score.append(['note', time, dur, 0, pitch, max(40, pitch), harmonized_accompaniment_MIDI_patch_number]) |
|
|
|
if base_MIDI_patch_number > -1: |
|
output_score.append(['note', time, dur, 2, (s[-1] % 12) + 24, 120-(s[-1] % 12), base_MIDI_patch_number]) |
|
|
|
i += 1 |
|
|
|
else: |
|
|
|
print('Rendering results...') |
|
print('=' * 70) |
|
|
|
output_score = [] |
|
|
|
time = 0 |
|
|
|
patches = [0] * 16 |
|
patches[0] = harmonized_accompaniment_MIDI_patch_number |
|
|
|
if base_MIDI_patch_number > -1: |
|
patches[2] = base_MIDI_patch_number |
|
|
|
patches[3] = melody_MIDI_patch_number |
|
|
|
i = 0 |
|
|
|
for s in song: |
|
|
|
if 11 < s < 141: |
|
|
|
time = mel_score[i][1] * 16 |
|
dur = mel_score[i][2] * 16 |
|
|
|
output_score.append(['note', time, dur, 3, mel_score[i][4], 115+(mel_score[i][4] % 12), 40]) |
|
|
|
chord = TMIDIX.ALL_CHORDS_FILTERED[s-12] |
|
|
|
for c in chord: |
|
|
|
pitch = 48+c |
|
output_score.append(['note', time, dur, 0, pitch, max(40, pitch), harmonized_accompaniment_MIDI_patch_number]) |
|
|
|
if base_MIDI_patch_number > -1: |
|
output_score.append(['note', time, dur, 2, chord[-1]+24, 120-chord[-1], base_MIDI_patch_number]) |
|
|
|
i += 1 |
|
|
|
fn1 = "Melody-Harmonizer-Transformer-Composition" |
|
|
|
detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(output_score, |
|
output_signature = 'Melody Harmonizer Transformer', |
|
output_file_name = fn1, |
|
track_name='Project Los Angeles', |
|
list_of_MIDI_patches=patches |
|
) |
|
|
|
new_fn = fn1+'.mid' |
|
|
|
|
|
audio = midi_to_colab_audio(new_fn, |
|
soundfont_path=soundfont, |
|
sample_rate=16000, |
|
volume_scale=10, |
|
output_for_gradio=True |
|
) |
|
|
|
|
|
|
|
output_midi_title = str(fn1) |
|
output_midi = str(new_fn) |
|
output_audio = (16000, audio) |
|
|
|
output_plot = TMIDIX.plot_ms_SONG(output_score, plot_title=output_midi, return_plt=True) |
|
|
|
print('Done!') |
|
|
|
|
|
|
|
harmonization_summary_string = '=' * 70 |
|
harmonization_summary_string += '\n' |
|
|
|
harmonization_summary_string += 'Source melody has ' + str(len(mel_pitches)) + ' monophonic pitches' + '\n' |
|
harmonization_summary_string += '=' * 70 |
|
harmonization_summary_string += '\n' |
|
|
|
harmonization_summary_string += 'Harmonized ' + str(len(song) // 2) + ' out of ' + str(len(mel_pitches)) + ' source melody pitches' + '\n' |
|
harmonization_summary_string += '=' * 70 |
|
harmonization_summary_string += '\n' |
|
|
|
|
|
|
|
print('-' * 70) |
|
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('-' * 70) |
|
print('Req execution time:', (reqtime.time() - start_time), 'sec') |
|
|
|
return output_audio, output_plot, output_midi, harmonization_summary_string |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
PDT = timezone('US/Pacific') |
|
|
|
print('=' * 70) |
|
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) |
|
print('=' * 70) |
|
|
|
|
|
|
|
soundfont = "SGM-v2.01-YamahaGrand-Guit-Bass-v2.7.sf2" |
|
|
|
print('Loading Melody Harmonizer Transformer Pitches Chords Pairs Data...') |
|
print('=' * 70) |
|
all_chords_toks_chunks, all_chords_ptcs_chunks = TMIDIX.Tegridy_Any_Pickle_File_Reader('Melody_Harmonizer_Transformer_Pitches_Chords_Pairs_Data') |
|
|
|
print('=' * 70) |
|
print('Total number of pitches chords pairs:', len(all_chords_toks_chunks)) |
|
print('=' * 70) |
|
print('Loading pitches chords pairs...') |
|
|
|
src_chunks = np.array(all_chords_toks_chunks) |
|
|
|
print('Done!') |
|
print('=' * 70) |
|
|
|
|
|
|
|
app = gr.Blocks() |
|
|
|
with app: |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Melody Harmonizer Transformer</h1>") |
|
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Harmonize any MIDI melody with transformers</h1>") |
|
gr.Markdown( |
|
"![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Melody-Harmonizer-Transformer&style=flat)\n\n" |
|
"This is a demo for Monster MIDI Dataset\n\n" |
|
"Check out [Monster MIDI Dataset](https://github.com/asigalov61/Monster-MIDI-Dataset) on GitHub!\n\n" |
|
) |
|
|
|
gr.Markdown("## Upload your MIDI or select a sample example below") |
|
gr.Markdown("### For best results upload only monophonic melody MIDIs") |
|
|
|
input_src_midi = gr.File(label="Source MIDI", file_types=[".midi", ".mid", ".kar"]) |
|
|
|
gr.Markdown("## Select harmonization options") |
|
|
|
source_melody_transpose_value = gr.Slider(-6, 6, value=0, step=1, label="Source melody transpose value", info="You can transpose source melody by specified number of semitones if the original melody key does not harmonize well") |
|
model_top_k_sampling_value = gr.Slider(1, 50, value=25, step=1, label="Model sampling top_k value", info="Decreasing this value may produce better harmonization results in some cases") |
|
texture_harmonized_chords = gr.Checkbox(label="Texture harmonized chords", value=True, info="Texture harmonized chords for more pleasant listening") |
|
melody_MIDI_patch_number = gr.Slider(0, 127, value=40, step=1, label="Source melody MIDI patch number") |
|
harmonized_accompaniment_MIDI_patch_number = gr.Slider(0, 127, value=0, step=1, label="Harmonized accompaniment MIDI patch number") |
|
base_MIDI_patch_number = gr.Slider(-1, 127, value=35, step=1, label="Base MIDI patch number") |
|
|
|
run_btn = gr.Button("Harmonize Melody", variant="primary") |
|
|
|
gr.Markdown("## Harmonization results") |
|
|
|
output_summary = gr.Textbox(label="Melody harmonization summary") |
|
|
|
output_audio = gr.Audio(label="Output MIDI audio", format="mp3", elem_id="midi_audio") |
|
output_plot = gr.Plot(label="Output MIDI score plot") |
|
output_midi = gr.File(label="Output MIDI file", file_types=[".mid"]) |
|
|
|
run_event = run_btn.click(Harmonize_Melody, |
|
[input_src_midi, |
|
source_melody_transpose_value, |
|
model_top_k_sampling_value, |
|
texture_harmonized_chords, |
|
melody_MIDI_patch_number, |
|
harmonized_accompaniment_MIDI_patch_number, |
|
base_MIDI_patch_number], |
|
[output_audio, output_plot, output_midi, output_summary] |
|
) |
|
|
|
gr.Examples( |
|
[ |
|
["USSR Anthem Seed Melody.mid", 0, 25, True, 40, 0, 35], |
|
], |
|
[input_src_midi, |
|
source_melody_transpose_value, |
|
model_top_k_sampling_value, |
|
texture_harmonized_chords, |
|
melody_MIDI_patch_number, |
|
harmonized_accompaniment_MIDI_patch_number, |
|
base_MIDI_patch_number], |
|
[output_audio, output_plot, output_midi, output_summary], |
|
Harmonize_Melody, |
|
cache_examples=False, |
|
) |
|
|
|
app.queue().launch() |