asahi417's picture
Update app.py
94aa07d verified
raw
history blame
4.97 kB
import os
from math import floor
from typing import Optional
import spaces
import torch
import gradio as gr
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
# config
model_name = "kotoba-tech/kotoba-whisper-v2.2"
example_file = "sample_diarization_japanese.mp3"
if torch.cuda.is_available():
pipe = pipeline(
model=model_name,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch.bfloat16,
device="cuda",
model_kwargs={'attn_implementation': 'sdpa'},
trust_remote_code=True
)
else:
pipe = pipeline(model=model_name, chunk_length_s=15, batch_size=16, trust_remote_code=True)
def format_time(start: Optional[float], end: Optional[float]):
def _format_time(seconds: Optional[float]):
if seconds is None:
return "[no timestamp available]"
minutes = floor(seconds / 60)
hours = floor(seconds / 3600)
seconds = seconds - hours * 3600 - minutes * 60
m_seconds = floor(round(seconds - floor(seconds), 1) * 10)
seconds = floor(seconds)
return f'{minutes:02}:{seconds:02}.{m_seconds:01}'
return f"[{_format_time(start)} -> {_format_time(end)}]:"
@spaces.GPU
def get_prediction(inputs, **kwargs):
return pipe(inputs, **kwargs)
def transcribe(inputs: str,
add_punctuation: bool,
add_silence_end: bool,
add_silence_start: bool,
num_speakers: float,
min_speakers: float,
max_speakers: float,
chunk_length_s: float):
if inputs is None:
raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
with open(inputs, "rb") as f:
inputs = f.read()
array = ffmpeg_read(inputs, pipe.feature_extractor.sampling_rate)
prediction = get_prediction(
inputs={"array": array, "sampling_rate": pipe.feature_extractor.sampling_rate},
add_punctuation=add_punctuation,
num_speakers=int(num_speakers) if num_speakers != 0 else None,
min_speakers=int(min_speakers) if min_speakers != 0 else None,
max_speakers=int(max_speakers) if max_speakers != 0 else None,
chunk_length_s=int(chunk_length_s) if chunk_length_s != 30 else None,
add_silence_end=0.5 if add_silence_end else None,
add_silence_start=0.5 if add_silence_start else None
)
output = ""
for n, s in enumerate(prediction["speaker_ids"]):
text_timestamped = "\n".join([f"- **{format_time(*c['timestamp'])}** {c['text']}" for c in prediction[f"chunks/{s}"]])
output += f'### Speaker {n+1} \n{prediction[f"text/{s}"]}\n\n{text_timestamped}\n'
return output
description = (f"Transcribe and diarize long-form microphone or audio inputs with the click of a button! Demo uses "
f"Kotoba-Whisper [{model_name}](https://huggingface.co/{model_name}).")
title = f"Audio Transcription and Diarization with {os.path.basename(model_name)}"
shared_config = {"fn": transcribe, "title": title, "description": description, "allow_flagging": "never", "examples": [
[example_file, True, True, True, 0, 0, 0, 30],
[example_file, True, True, True, 4, 0, 0, 30]
]}
o_upload = gr.Markdown()
o_mic = gr.Markdown()
options = [
]
i_upload = gr.Interface(
inputs=[
gr.Audio(sources="upload", type="filepath", label="Audio file"),
gr.Checkbox(label="add punctuation", value=True),
gr.Checkbox(label="add silence at the end", value=True),
gr.Checkbox(label="add silence at the start", value=True),
gr.Slider(0, 10, label="num speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(0, 10, label="min speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(0, 10, label="max speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(5, 30, label="chunk length for ASR", value=30, step=1),
],
outputs=gr.Markdown(),
**shared_config
)
i_mic = gr.Interface(
inputs=[
gr.Audio(sources="microphone", type="filepath", label="Microphone input"),
gr.Checkbox(label="add punctuation", value=True),
gr.Checkbox(label="add silence at the end", value=True),
gr.Checkbox(label="add silence at the start", value=True),
gr.Slider(0, 10, label="num speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(0, 10, label="min speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(0, 10, label="max speakers (set 0 for auto-detect mode)", value=0, step=1),
gr.Slider(5, 30, label="chunk length for ASR", value=30, step=1),
],
outputs=gr.Markdown(),
**shared_config
)
with gr.Blocks() as demo:
gr.TabbedInterface([i_upload, i_mic], ["Audio file", "Microphone"])
demo.queue(api_open=False, default_concurrency_limit=40).launch(show_api=False, show_error=True)