Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import torch | |
import transformers | |
from packaging.version import parse | |
import sys | |
import io | |
import importlib.metadata as importlib_metadata | |
import soundfile as sf | |
import importlib.metadata as importlib_metadata | |
loading_kwargs = {} | |
if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"): | |
loading_kwargs["attn_implementation"] = "eager" | |
def generate(prompt): | |
model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs) | |
sample_length = 8 | |
n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3 | |
sampling_rate = model.config.audio_encoder.sampling_rate | |
processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small") | |
inputs = processor( | |
text=[ | |
prompt, | |
], | |
padding=True, | |
return_tensors="pt", | |
) | |
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens) | |
waveform = audio_values[0].cpu().squeeze() * 2**15 | |
audio_buffer = io.BytesIO() | |
sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV') | |
audio_buffer.seek(0) | |
return audio_buffer | |
st.title("Music Generator") | |
text_prompt = st.text_input("Text Prompt", "") | |
examples = [ | |
"80s pop track with bassy drums and synth", | |
"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves", | |
"90s rock song with loud guitars and heavy drums", | |
"Heartful EDM with beautiful synths and chords", | |
] | |
st.sidebar.title("Examples") | |
selected_example = st.sidebar.radio("Select an example", examples) | |
if st.button("Generate Audio"): | |
if selected_example or text_prompt: | |
with st.spinner("Generating audio..."): | |
audio_output = generate(selected_example) | |
st.audio(audio_output, format='audio/wav') | |
else: | |
st.warning("Please select or enter a text prompt.") | |
if st.checkbox("Show debug info"): | |
st.write("Text Prompt:", text_prompt) |