MusicGen / app.py
Tirath5504's picture
Update app.py
e03464f verified
raw
history blame
2.13 kB
import streamlit as st
import numpy as np
import torch
import transformers
from packaging.version import parse
import sys
import io
import importlib.metadata as importlib_metadata
import soundfile as sf
import importlib.metadata as importlib_metadata
loading_kwargs = {}
if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"):
loading_kwargs["attn_implementation"] = "eager"
def generate(prompt):
model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs)
sample_length = 8
n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3
sampling_rate = model.config.audio_encoder.sampling_rate
processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=[
prompt,
],
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)
waveform = audio_values[0].cpu().squeeze() * 2**15
audio_buffer = io.BytesIO()
sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV')
audio_buffer.seek(0)
return audio_buffer
st.title("Music Generator")
text_prompt = st.text_input("Text Prompt", "")
examples = [
"80s pop track with bassy drums and synth",
"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves",
"90s rock song with loud guitars and heavy drums",
"Heartful EDM with beautiful synths and chords",
]
st.sidebar.title("Examples")
selected_example = st.sidebar.radio("Select an example", examples)
if st.button("Generate Audio"):
if selected_example or text_prompt:
with st.spinner("Generating audio..."):
audio_output = generate(selected_example)
st.audio(audio_output, format='audio/wav')
else:
st.warning("Please select or enter a text prompt.")
if st.checkbox("Show debug info"):
st.write("Text Prompt:", text_prompt)