import os import re import json import uuid import torch import shutil import requests import gradio as gr from piano_transcription_inference import PianoTranscription, load_audio, sample_rate from modelscope import snapshot_download from urllib.parse import urlparse from convert import midi2xml, xml2abc, xml2mxl, xml2jpg WEIGHTS_PATH = ( snapshot_download("Genius-Society/piano_trans", cache_dir="./__pycache__") + "/CRNN_note_F1=0.9677_pedal_F1=0.9186.pth" ) def clean_cache(cache_dir): if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.mkdir(cache_dir) def download_audio(url: str, save_path: str): response = requests.get(url, stream=True) response.raise_for_status() with open(save_path, "wb") as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) def is_url(s: str): try: result = urlparse(s) return all([result.scheme, result.netloc]) except: return False def audio2midi(audio_path: str, cache_dir: str): audio, _ = load_audio(audio_path, sr=sample_rate) transcriptor = PianoTranscription( device="cuda" if torch.cuda.is_available() else "cpu", checkpoint_path=WEIGHTS_PATH, ) midi_path = f"{cache_dir}/output.mid" transcriptor.transcribe(audio, midi_path) return midi_path, os.path.basename(audio_path).split(".")[-2].capitalize() def upl_infer(audio_path: str, cache_dir="./__pycache__/mode1"): clean_cache(cache_dir) try: print(audio_path) midi, title = audio2midi(audio_path, cache_dir) xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) return midi, pdf, xml, mxl, abc, jpg except Exception as e: return None, None, None, None, f"{e}", None def get_1st_int(input_string: str): match = re.search(r"\d+", input_string) if match: return str(int(match.group())) else: return "" def music163_song_info(id: str): detail_api = "https://music.163.com/api/v3/song/detail" parm_dict = {"id": id, "c": str([{"id": id}]), "csrf_token": ""} free = False song_name = "Failed to get the song" response = requests.get(detail_api, params=parm_dict) if response.status_code == 200: data = json.loads(response.text) if data and "songs" in data and data["songs"]: fee = int(data["songs"][0]["fee"]) free = fee == 0 or fee == 8 song_name = str(data["songs"][0]["name"]) else: song_name = "The song does not exist" else: raise ConnectionError(f"Error: {response.status_code}, {response.text}") return song_name, free def url_infer(song: str, cache_dir="./__pycache__/mode2"): song_name = "" clean_cache(cache_dir) audio_path = f"/tmp/gradio/{uuid.uuid4().hex}/" os.makedirs(audio_path, exist_ok=True) try: if (is_url(song) and "163" in song and "?id=" in song) or song.isdigit(): song_id = get_1st_int(song.split("?id=")[-1]) song_url = f"https://music.163.com/song/media/outer/url?id={song_id}.mp3" song_name, free = music163_song_info(song_id) if not free: raise AttributeError("Unable to parse VIP songs") audio_path += f"{song_id}.mp3" download_audio(song_url, audio_path) midi, title = audio2midi(audio_path, cache_dir) if song_name: title = song_name xml = midi2xml(midi, title) abc = xml2abc(xml) mxl = xml2mxl(xml) pdf, jpg = xml2jpg(xml) return audio_path, midi, pdf, xml, mxl, abc, jpg except Exception as e: return None, None, None, None, None, f"{e}", None if __name__ == "__main__": with gr.Blocks() as iface: gr.Markdown("# Piano Transcription Tool") with gr.Tab("Uploading Mode"): gr.Interface( fn=upl_infer, inputs=gr.Audio( label="Upload an audio", type="filepath", ), outputs=[ gr.File(label="Download MIDI"), gr.File(label="Download PDF score"), gr.File(label="Download MusicXML"), gr.File(label="Download MXL"), gr.Textbox(label="ABC notation", show_copy_button=True), gr.Image(label="Staff", type="filepath"), ], description="Please make sure the audio is completely uploaded before clicking Submit", flagging_mode="never", ) with gr.Tab("Direct Link Mode"): gr.Interface( fn=url_infer, inputs=gr.Textbox( label="Input audio direct link", placeholder="https://music.163.com/#/song?id=", ), outputs=[ gr.Audio(label="Download audio", type="filepath"), gr.File(label="Download MIDI"), gr.File(label="Download PDF score"), gr.File(label="Download MusicXML"), gr.File(label="Download MXL"), gr.Textbox(label="ABC notation", show_copy_button=True), gr.Image(label="Staff", type="filepath"), ], description="For Netease Cloud music, you can directly input the non-VIP song page link", examples=["1945798894", "1945798973", "1946098771"], flagging_mode="never", cache_examples=False, ) iface.launch()