File size: 2,228 Bytes
534596f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import streamlit as st
import numpy as np
import torch
import transformers
from packaging.version import parse
import sys
import io
import importlib.metadata as importlib_metadata
import soundfile as sf
import importlib.metadata as importlib_metadata

loading_kwargs = {}
if parse(importlib_metadata.version("transformers")) >= parse("4.40.0"):
    loading_kwargs["attn_implementation"] = "eager"

def generate(prompt):
    model = transformers.MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small", torchscript=True, return_dict=False, **loading_kwargs)
    sample_length = 8  
    n_tokens = sample_length * model.config.audio_encoder.frame_rate + 3
    sampling_rate = model.config.audio_encoder.sampling_rate
    processor = transformers.AutoProcessor.from_pretrained("facebook/musicgen-small")
    inputs = processor(
        text=[
            prompt,
        ],
        padding=True,
        return_tensors="pt",
    )
    audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=n_tokens)
    waveform = audio_values[0].cpu().squeeze() * 2**15
    audio_buffer = io.BytesIO()
    sf.write(audio_buffer, waveform.numpy().astype(np.int16), sampling_rate, format='WAV')
    audio_buffer.seek(0)
    return audio_buffer

st.title("Music Generator")

text_prompt = st.text_input("Text Prompt", "")

examples = [
    "80s pop track with bassy drums and synth",
    "Earthy tones, environmentally conscious, ukulele-infused, harmonic, breezy, easygoing, organic instrumentation, gentle grooves",
    "90s rock song with loud guitars and heavy drums",
    "Heartful EDM with beautiful synths and chords",
]

st.sidebar.title("Examples")
for example in examples:
    if st.sidebar.button(example):
        text_prompt = example
        st.experimental_rerun()

if st.button("Generate Audio"):
    if text_prompt:
        with st.spinner("Generating audio..."):
            audio_output = generate(text_prompt)
            st.audio(audio_output, format='audio/wav')
    else:
        st.warning("Please enter a text prompt.")

# Debugging
if st.checkbox("Show debug info"):
    st.write("Text Prompt:", text_prompt)