ylacombe's picture
Update app.py
240b319 verified
raw
history blame
5.34 kB
import io
import math
from typing import Optional
import numpy as np
import spaces
import gradio as gr
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from pydub import AudioSegment
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from huggingface_hub import InferenceClient
import nltk
import random
nltk.download('punkt')
device = "cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
torch_dtype = torch.float16 if device != "cpu" else torch.float32
repo_id = "parler-tts/parler_tts_mini_v0.1"
jenny_repo_id = "ylacombe/parler-tts-mini-jenny-30H"
model = ParlerTTSForConditionalGeneration.from_pretrained(
jenny_repo_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
).to(device)
client = InferenceClient()
description_tokenizer = AutoTokenizer.from_pretrained(repo_id)
prompt_tokenizer = AutoTokenizer.from_pretrained(repo_id, padding_side="left")
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42
def numpy_to_mp3(audio_array, sampling_rate):
# Normalize audio_array if it's floating-point
if np.issubdtype(audio_array.dtype, np.floating):
max_val = np.max(np.abs(audio_array))
audio_array = (audio_array / max_val) * 32767 # Normalize to 16-bit range
audio_array = audio_array.astype(np.int16)
# Create an audio segment from the numpy array
audio_segment = AudioSegment(
audio_array.tobytes(),
frame_rate=sampling_rate,
sample_width=audio_array.dtype.itemsize,
channels=1
)
# Export the audio segment to MP3 bytes - use a high bitrate to maximise quality
mp3_io = io.BytesIO()
audio_segment.export(mp3_io, format="mp3", bitrate="320k")
# Get the MP3 bytes
mp3_bytes = mp3_io.getvalue()
mp3_io.close()
return mp3_bytes
sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate
@spaces.GPU
def generate_base(subject, setting):
messages = [{"role": "sytem", "content": ("You are an award-winning children's bedtime story author lauded for your inventive stories."
"You want to write a bed time story for your child. They will give you the subject and setting "
"and you will write the entire story. It should be targetted at children 5 and younger and take about "
"a minute to read")},
{"role": "user", "content": f"Please tell me a story about a {subject} in {setting}"}]
gr.Info("Generating story", duration=3)
response = client.chat_completion(messages, max_tokens=2048, seed=random.randint(1, 5000))
gr.Info("Story Generated", duration=3)
story = response.choices[0].message.content
model_input = story.replace("\n", " ").strip()
model_input_tokens = nltk.sent_tokenize(model_input)
play_steps_in_s = 4.0
play_steps = int(frame_rate * play_steps_in_s)
gr.Info("Generating Audio")
description = "Jenny speaks at an average pace with a calm delivery in a very confined sounding environment with clear audio quality."
story_tokens = prompt_tokenizer(model_input_tokens, return_tensors="pt", padding=True).to(device)
description_tokens = description_tokenizer([description for _ in range(len(model_input_tokens))], return_tensors="pt").to(device)
speech_output = model.generate(input_ids=description_tokens.input_ids, prompt_input_ids=story_tokens.input_ids, attention_mask=description_tokens.attention_mask, prompt_attention_mask=story_tokens.attention_mask)
speech_output = [output.cpu().numpy() for output in speech_output]
gr.Info("Generated Audio")
return None, None, {"audio": speech_output, "text": model_input_tokens}
import time
def stream_audio(state):
speech_output = state["audio"]
sentences = state["text"]
gr.Info("Reading Story")
story = ""
for sentence, new_audio in zip(sentences, speech_output):
print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 2)} seconds")
story += f"{sentence}\n"
yield story, numpy_to_mp3(new_audio, sampling_rate=sampling_rate)
time.sleep(5)
with gr.Blocks() as block:
gr.HTML(
f"""
<h1> Bedtime Story Reader ๐Ÿ˜ด๐Ÿ”Š </h1>
<p> Powered by <a href="https://github.com/huggingface/parler-tts"> Parler-TTS</a>
"""
)
with gr.Group():
with gr.Row():
subject = gr.Dropdown(value="Princess", choices=["Prince", "Princess", "Dog", "Cat"])
setting = gr.Dropdown(value="Forest", choices=["Forest", "Kingdom", "Jungle", "Underwater"])
with gr.Row():
run_button = gr.Button("Generate Story", variant="primary")
with gr.Row():
with gr.Group():
audio_out = gr.Audio(label="Bed time story", streaming=True, autoplay=True)
story = gr.Textbox(label="Story")
inputs = [subject, setting]
outputs = [story, audio_out]
state = gr.State()
run_button.click(fn=generate_base, inputs=inputs, outputs=[story, audio_out, state]).success(stream_audio, inputs=state, outputs=outputs)
block.queue()
block.launch(share=True)