Spaces:
Sleeping
Sleeping
import torch | |
import pickle | |
import whisper | |
import streamlit as st | |
import torchaudio as ta | |
import numpy as np | |
from io import BytesIO | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
# Set up device and dtype | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if device == "cuda:0" else torch.float32 | |
SAMPLING_RATE = 16000 | |
CHUNK_LENGTH_S = 20 # 30 seconds per chunk | |
# Load Whisper model and processor | |
processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to(device) | |
# Title of the app | |
st.title("Audio Player with Live Transcription") | |
# Sidebar for file uploader and submit button | |
st.sidebar.header("Upload Audio Files") | |
uploaded_files = st.sidebar.file_uploader("Choose audio files", type=["mp3", "wav"], accept_multiple_files=True) | |
submit_button = st.sidebar.button("Submit") | |
# Session state to hold data | |
if 'audio_files' not in st.session_state: | |
st.session_state.audio_files = [] | |
st.session_state.transcriptions = {} | |
st.session_state.translations = {} | |
st.session_state.detected_languages = [] | |
st.session_state.waveforms = [] | |
def detect_language(audio_file): | |
whisper_model = whisper.load_model("small") | |
trimmed_audio = whisper.pad_or_trim(audio_file.squeeze()) | |
mel = whisper.log_mel_spectrogram(trimmed_audio).to(whisper_model.device) | |
_, probs = whisper_model.detect_language(mel) | |
detected_lang = max(probs[0], key=probs[0].get) | |
print(f"Detected language: {detected_lang}") | |
return detected_lang | |
def process_long_audio(waveform, sampling_rate, task="transcribe", language=None): | |
input_length = waveform.shape[1] | |
chunk_length = int(CHUNK_LENGTH_S * sampling_rate) | |
chunks = [waveform[:, i:i + chunk_length] for i in range(0, input_length, chunk_length)] | |
results = [] | |
for chunk in chunks: | |
# import pdb;pdb.set_trace() | |
input_features = processor(chunk[0], sampling_rate=sampling_rate, return_tensors="pt").input_features.to(device) | |
with torch.no_grad(): | |
if task == "translate": | |
forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task="translate") | |
generated_ids = model.generate(input_features, forced_decoder_ids=forced_decoder_ids) | |
else: | |
generated_ids = model.generate(input_features) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True) | |
results.extend(transcription) | |
return " ".join(results) | |
# Process uploaded files | |
if submit_button and uploaded_files is not None: | |
st.session_state.audio_files = uploaded_files | |
st.session_state.detected_languages = [] | |
st.session_state.waveforms = [] | |
for uploaded_file in uploaded_files: | |
waveform, sampling_rate = ta.load(BytesIO(uploaded_file.read())) | |
if sampling_rate != SAMPLING_RATE: | |
waveform = ta.functional.resample(waveform, orig_freq=sampling_rate, new_freq=SAMPLING_RATE) | |
st.session_state.waveforms.append(waveform) | |
detected_language = detect_language(waveform) | |
st.session_state.detected_languages.append(detected_language) | |
# Display uploaded files and options | |
if 'audio_files' in st.session_state and st.session_state.audio_files: | |
for i, uploaded_file in enumerate(st.session_state.audio_files): | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
st.write(f"**File name**: {uploaded_file.name}") | |
st.audio(BytesIO(uploaded_file.read()), format=uploaded_file.type) | |
st.write(f"**Detected Language**: {st.session_state.detected_languages[i]}") | |
with col2: | |
if st.button(f"Transcribe {uploaded_file.name}"): | |
with st.spinner("Transcribing..."): | |
transcription = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE) | |
st.session_state.transcriptions[i] = transcription | |
if st.session_state.transcriptions.get(i): | |
st.write("**Transcription**:") | |
st.write(st.session_state.transcriptions[i]) | |
if st.button(f"Translate {uploaded_file.name}"): | |
with st.spinner("Translating..."): | |
with open('languages.pkl', 'rb') as f: | |
lang_dict = pickle.load(f) | |
detected_language_name = lang_dict[st.session_state.detected_languages[i]] | |
translation = process_long_audio(st.session_state.waveforms[i], SAMPLING_RATE, task="translate", | |
language=detected_language_name) | |
st.session_state.translations[i] = translation | |
if st.session_state.translations.get(i): | |
st.write("**Translation**:") | |
st.write(st.session_state.translations[i]) |