MusicGen / app.py
Tirath5504's picture
Upload 2 files
534596f verified
raw
history blame
2.23 kB
import streamlit as st
import numpy as np
import torch
import transformers
from packaging.version import parse
import sys
import io
import importlib.metadata as importlib_metadata
import soundfile as sf
import importlib.metadata as importlib_metadata
loading_kwargs = {}
if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"):
loading_kwargs["attn_implementation"] = "eager"
def generate(prompt):
model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs)
sample_length = 8
n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3
sampling_rate = model.config.audio_encoder.sampling_rate
processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small")
inputs = processor(
text=[
prompt,
],
padding=True,
return_tensors="pt",
)
audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)
waveform = audio_values[0].cpu().squeeze() * 2**15
audio_buffer = io.BytesIO()
sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV')
audio_buffer.seek(0)
return audio_buffer
st.title("Music Generator")
text_prompt = st.text_input("Text Prompt", "")
examples = [
"80s pop track with bassy drums and synth",
"Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves",
"90s rock song with loud guitars and heavy drums",
"Heartful EDM with beautiful synths and chords",
]
st.sidebar.title("Examples")
for example in examples:
if st.sidebar.button(example):
text_prompt = example
st.experimental_rerun()
if st.button("Generate Audio"):
if text_prompt:
with st.spinner("Generating audio..."):
audio_output = generate(text_prompt)
st.audio(audio_output, format='audio/wav')
else:
st.warning("Please enter a text prompt.")
# Debugging
if st.checkbox("Show debug info"):
st.write("Text Prompt:", text_prompt)