import streamlit as st import numpy as np import torch import transformers from packaging.version import parse import sys import io import importlib.metadata as importlib_metadata import soundfile as sf import importlib.metadata as importlib_metadata loading_kwargs = {} if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"): loading_kwargs["attn_implementation"] = "eager" def generate(prompt): model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs) sample_length = 8 n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3 sampling_rate = model.config.audio_encoder.sampling_rate processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small") inputs = processor( text=[ prompt, ], padding=True, return_tensors="pt", ) audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens) waveform = audio_values[0].cpu().squeeze() * 2**15 audio_buffer = io.BytesIO() sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV') audio_buffer.seek(0) return audio_buffer st.title("Music Generator") text_prompt = st.text_input("Text Prompt", "") examples = [ "80s pop track with bassy drums and synth", "Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves", "90s rock song with loud guitars and heavy drums", "Heartful EDM with beautiful synths and chords", ] st.title("Examples") selected_example = st.radio("Select an example", examples) if st.button("Generate Audio"): if selected_example or text_prompt: prompt = "" if selected_example: prompt = selected_example else: prompt = text_prompt with st.spinner("Generating audio..."): audio_output = generate(prompt) st.audio(audio_output, format='audio/wav') else: st.warning("Please select or enter a text prompt.") if st.checkbox("Show debug info"): if text_prompt: st.write("Text Prompt:", text_prompt) if selected_example: st.write("Text Prompt:", selected_example)