Spaces:
Runtime error
Runtime error
import gradio as gr | |
import json | |
from difflib import Differ | |
import ffmpeg | |
import os | |
from pathlib import Path | |
import time | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import MarianMTModel, MarianTokenizer | |
import pandas as pd | |
import re | |
import time | |
import os | |
from fuzzywuzzy import fuzz | |
from fastT5 import export_and_get_onnx_model | |
import torch | |
from transformers import pipeline | |
MODEL = "Finnish-NLP/wav2vec2-base-fi-voxpopuli-v2-finetuned" | |
marian_nmt_model = "Helsinki-NLP/opus-mt-tc-big-fi-en" | |
tokenizer_marian = MarianTokenizer.from_pretrained(marian_nmt_model) | |
model = MarianMTModel.from_pretrained(marian_nmt_model) | |
cuda = torch.device( | |
'cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
sr_pipeline_device = 0 if torch.cuda.is_available() else -1 | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
speech_recognizer = pipeline( | |
task="automatic-speech-recognition", | |
model=f'{MODEL}', | |
tokenizer=f'{MODEL}', | |
framework="pt", | |
device=sr_pipeline_device, | |
) | |
model_checkpoint = 'Finnish-NLP/t5-small-nl24-casing-punctuation-correction' | |
tokenizer_t5 = AutoTokenizer.from_pretrained(model_checkpoint) | |
model_t5 = export_and_get_onnx_model(model_checkpoint) | |
#model_t5 = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_flax=False, torch_dtype=torch.float32).to(device) | |
videos_out_path = Path("./videos_out") | |
videos_out_path.mkdir(parents=True, exist_ok=True) | |
samples_data = sorted(Path('examples').glob('*.json')) | |
SAMPLES = [] | |
for file in samples_data: | |
with open(file) as f: | |
sample = json.load(f) | |
SAMPLES.append(sample) | |
VIDEOS = list(map(lambda x: [x['video']], SAMPLES)) | |
total_inferences_since_reboot = 0 | |
total_cuts_since_reboot = 0 | |
async def speech_to_text(video_file_path): | |
""" | |
Takes a video path to convert to audio, transcribe audio channel to text timestamps | |
Using https://huggingface.co/tasks/automatic-speech-recognition pipeline | |
""" | |
global total_inferences_since_reboot | |
if(video_file_path == None): | |
raise ValueError("Error no video input") | |
video_path = Path(video_file_path) | |
try: | |
# convert video to audio 16k using PIPE to audio_memory | |
audio_memory, _ = ffmpeg.input(video_path).output( | |
'-', format="wav", ac=1, ar='16k').overwrite_output().global_args('-loglevel', 'quiet').run(capture_stdout=True) | |
except Exception as e: | |
raise RuntimeError("Error converting video to audio") | |
last_time = time.time() | |
try: | |
output = speech_recognizer( | |
audio_memory, return_timestamps="word", chunk_length_s=10, stride_length_s=(4, 2)) | |
transcription = output["text"].lower() | |
timestamps = [[chunk["text"].lower(), chunk["timestamp"][0], chunk["timestamp"][1]] | |
for chunk in output['chunks']] | |
input_ids = tokenizer_t5(transcription, return_tensors="pt").input_ids.to(device) | |
outputs = model_t5.generate(input_ids, max_length=128) | |
case_corrected_text = tokenizer_t5.decode(outputs[0], skip_special_tokens=True) | |
translated = model.generate(**tokenizer_marian([case_corrected_text], return_tensors="pt", padding=True)) | |
translated_plain = "".join([tokenizer_marian.decode(t, skip_special_tokens=True) for t in translated]) | |
for timestamp in timestamps: | |
total_inferences_since_reboot += 1 | |
df = pd.DataFrame(timestamps, columns = ['word', 'start','stop']) | |
df['start'] = df['start'].astype('float16') | |
df['stop'] = df['stop'].astype('float16') | |
print("\n\ntotal_inferences_since_reboot: ", | |
total_inferences_since_reboot, "\n\n") | |
return (transcription, transcription, timestamps,df, case_corrected_text, translated_plain) | |
except Exception as e: | |
raise RuntimeError("Error Running inference with local model", e) | |
def create_srt(text_out_t5, df): | |
df.columns = ['word', 'start', 'stop'] | |
df_sentences = pd.DataFrame(columns=['sentence','start','stop','translated']) | |
found_match_value = 0 | |
found_match_word = "" | |
t5_sentences = re.split('[.]|[?]|[!]', text_out_t5) | |
t5_sentences = [sentence.replace('.','').replace('?','').replace('!','') for sentence in t5_sentences if sentence] | |
for i, sentence in enumerate(t5_sentences): | |
sentence = sentence.lower().split(" ") | |
if i == 0: | |
df_subset = df[df['stop'] <10] | |
start = df.iloc[0]['start'] | |
for j, word in enumerate(df_subset['word']): | |
temp_value = fuzz.partial_ratio((word), sentence[-1]) | |
if temp_value > found_match_value: | |
found_match_value = temp_value | |
found_match_word = word | |
stop = df_subset[df_subset['word'] == found_match_word] | |
translated = model.generate(**tokenizer_marian(t5_sentences[i], return_tensors="pt", padding=True)) | |
translated_plain = [tokenizer_marian.decode(t, skip_special_tokens=True) for t in translated] | |
dict_to_add = { | |
'sentence': t5_sentences[i], | |
'start': start, | |
'stop': stop.iloc[0]['stop'], | |
'translated': translated_plain[0] | |
} | |
df_sentences = df_sentences.append(dict_to_add, ignore_index=True) | |
new_start = df.iloc[stop.index.values[0]+1]['start'] | |
new_stop = new_start + 10 | |
else: | |
found_match_value = 0 | |
found_match_word = "" | |
df_subset = df[(df['start'] >= new_start) & (df['stop'] <= new_stop)] | |
start = df_subset.iloc[0]['start'] | |
for j, word in enumerate(df_subset['word']): | |
temp_value = fuzz.partial_ratio((word), sentence[-1]) | |
if temp_value > found_match_value: | |
found_match_value = temp_value | |
found_match_word = word | |
stop = df_subset[df_subset['word'] == found_match_word] | |
translated = model.generate(**tokenizer_marian(t5_sentences[i], return_tensors="pt", padding=True)) | |
translated_plain = [tokenizer_marian.decode(t, skip_special_tokens=True) for t in translated] | |
dict_to_add = { | |
'sentence': t5_sentences[i], | |
'start': start, | |
'stop': stop.iloc[0]['stop'], | |
'translated': translated_plain[0] | |
} | |
df_sentences = df_sentences.append(dict_to_add, ignore_index=True) | |
try: | |
new_start = df.iloc[stop.index.values[0]+1]['start'] | |
new_stop = new_start + 10 | |
except Exception as e: | |
df_sentences = df_sentences.iloc[0:i+1] | |
return df_sentences | |
def create_srt_and_burn(video_in, srt_sentences): | |
srt_sentences.columns = ['sentence', 'start', 'stop','translated'] | |
srt_sentences.dropna(inplace=True) | |
srt_sentences['start'] = srt_sentences['start'].astype('float') | |
srt_sentences['stop'] = srt_sentences['stop'].astype('float') | |
with open('testi.srt','w') as file: | |
for i in range(len(srt_sentences)): | |
file.write(str(i+1)) | |
file.write('\n') | |
start = (time.strftime('%H:%M:%S', time.gmtime(srt_sentences.iloc[i]['start']))) | |
if "." in str(srt_sentences.iloc[i]['start']): | |
if len(str(srt_sentences.iloc[i]['start']).split('.')[1]) > 3: | |
start = start + '.' + str(srt_sentences.iloc[i]['start']).split('.')[1][:3] | |
else: | |
start = start + '.' + str(srt_sentences.iloc[i]['start']).split('.')[1] | |
file.write(start) | |
stop = (time.strftime('%H:%M:%S', time.gmtime(srt_sentences.iloc[i]['stop']))) | |
if len(str(srt_sentences.iloc[i]['stop']).split('.')[1]) > 3: | |
stop = stop + '.' + str(srt_sentences.iloc[i]['stop']).split('.')[1][:3] | |
else: | |
stop = stop + '.' + str(srt_sentences.iloc[i]['stop']).split('.')[1] | |
file.write(' --> ') | |
file.write(stop) | |
file.write('\n') | |
file.writelines(srt_sentences.iloc[i]['translated']) | |
if int(i) != len(srt_sentences)-1: | |
file.write('\n\n') | |
try: | |
file1 = open('./testi.srt', 'r') | |
Lines = file1.readlines() | |
count = 0 | |
# Strips the newline character | |
for line in Lines: | |
count += 1 | |
video_out = str(Path(video_in)).replace('.mp4', '_out.mp4') | |
command = "ffmpeg -i {} -y -vf subtitles=./testi.srt {}".format(Path(video_in), Path(video_out)) | |
os.system(command) | |
return video_out | |
except Exception as e: | |
print(e) | |
return video_out | |
# ---- Gradio Layout ----- | |
video_in = gr.Video(label="Video file", interactive=True) | |
text_in = gr.Textbox(label="Transcription", lines=10, interactive=True) | |
text_out_t5 = gr.Textbox(label="Transcription T5", lines=10, interactive=True) | |
translation_out = gr.Textbox(label="Translation", lines=10, interactive=True) | |
text_out_timestamps = gr.Textbox(label="Word level timestamps", lines=10, interactive=True) | |
srt_sentences = gr.DataFrame(label="Srt lines", row_count=(0, "dynamic")) | |
video_out = gr.Video(label="Video Out") | |
diff_out = gr.HighlightedText(label="Cuts Diffs", combine_adjacent=True) | |
examples = gr.components.Dataset( | |
components=[video_in], samples=VIDEOS, type="index") | |
demo = gr.Blocks(enable_queue=True, css=''' | |
#cut_btn, #reset_btn { align-self:stretch; } | |
#\\31 3 { max-width: 540px; } | |
.output-markdown {max-width: 65ch !important;} | |
''') | |
demo.encrypt = False | |
with demo: | |
transcription_var = gr.Variable() | |
timestamps_var = gr.Variable() | |
timestamps_df = gr.Dataframe(visible=False, row_count=(0, "dynamic")) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(''' | |
# Create videos with English subtitles from videos spoken in Finnish | |
This project is a quick proof of concept of a simple video editor where you can add English subtitles to Finnish videos. | |
This space currently only works for short videos (Up to 128 tokens) but will be improved in next versions. | |
Space uses our finetuned Finnish ASR models, Our pretrained + finetuned Finnish T5 model for casing+punctuation correction and Opus-MT models from Helsinki University for Finnish --> English translation. | |
This space was inspired by https://huggingface.co/spaces/radames/edit-video-by-editing-text | |
''') | |
with gr.Row(): | |
examples.render() | |
def load_example(id): | |
video = SAMPLES[id]['video'] | |
transcription = '' | |
timestamps = SAMPLES[id]['timestamps'] | |
return (video, transcription, transcription, timestamps) | |
examples.click( | |
load_example, | |
inputs=[examples], | |
outputs=[video_in, text_in, transcription_var, timestamps_var], | |
queue=False) | |
with gr.Row(): | |
with gr.Column(): | |
video_in.render() | |
transcribe_btn = gr.Button("1. Press here to transcribe Audio") | |
transcribe_btn.click(speech_to_text, [video_in], [ | |
text_in, transcription_var, text_out_timestamps,timestamps_df, text_out_t5, translation_out]) | |
with gr.Row(): | |
gr.Markdown(''' | |
### Here you will get varying outputs from different parts of the processing | |
ASR model output, T5 model output which corrects casing + hyphenation, sentence level translations and word level timestamps''') | |
with gr.Row(): | |
with gr.Column(): | |
text_in.render() | |
with gr.Column(): | |
text_out_t5.render() | |
with gr.Column(): | |
translation_out.render() | |
with gr.Column(): | |
text_out_timestamps.render() | |
with gr.Row(): | |
with gr.Column(): | |
translate_and_make_srt_btn = gr.Button("2. Press here to create rows for subtitles") | |
translate_and_make_srt_btn.click(create_srt, [text_out_t5, timestamps_df], [ | |
srt_sentences]) | |
with gr.Row(): | |
with gr.Column(): | |
srt_sentences.render() | |
with gr.Row(): | |
with gr.Column(): | |
translate_and_make_srt_btn = gr.Button("3. Press here to create subtitle file and insert translations to video") | |
translate_and_make_srt_btn.click(create_srt_and_burn, [video_in, srt_sentences], [ | |
video_out]) | |
video_out.render() | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |