Spaces:
Running
Running
import openai | |
from pytube import YouTube | |
import argparse | |
import os | |
from pathlib import Path | |
from tqdm import tqdm | |
from src.srt_util.srt import SrtScript | |
from src.Pigeon import Pigeon | |
import stable_whisper | |
import whisper | |
from src.srt_util import srt2ass | |
import logging | |
from datetime import datetime | |
import torch | |
import subprocess | |
import time | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False) | |
parser.add_argument("--video_file", help="local video path here", default=None, type=str, required=False) | |
parser.add_argument("--audio_file", help="local audio path here", default=None, type=str, required=False) | |
parser.add_argument("--srt_file", help="srt file input path here", default=None, type=str, | |
required=False) # New argument | |
parser.add_argument("--download", help="download path", default='./downloads', type=str, required=False) | |
parser.add_argument("--output_dir", help="translate result path", default='./results', type=str, required=False) | |
parser.add_argument("--video_name", | |
help="video name, if use video link as input, the name will auto-filled by youtube video name", | |
default='placeholder', type=str, required=False) | |
parser.add_argument("--model_name", help="model name only support gpt-4 and gpt-3.5-turbo", type=str, | |
required=False, default="gpt-4") # default change to gpt-4 | |
parser.add_argument("--log_dir", help="log path", default='./logs', type=str, required=False) | |
parser.add_argument("-only_srt", help="set script output to only .srt file", action='store_true') | |
parser.add_argument("-v", help="auto encode script with video", action='store_true') | |
args = parser.parse_args() | |
return args | |
def get_sources(args, download_path, result_path, video_name): | |
# get source audio | |
audio_path = None | |
audio_file = None | |
video_path = None | |
if args.link is not None and args.video_file is None: | |
# Download audio from YouTube | |
video_link = args.link | |
video = None | |
audio = None | |
try: | |
yt = YouTube(video_link,use_oauth=True, allow_oauth_cache=True) | |
video = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first() | |
if video: | |
video.download(f'{download_path}/video') | |
print('Video download completed!') | |
else: | |
print("Error: Video stream not found") | |
audio = yt.streams.filter(only_audio=True, file_extension='mp4').first() | |
if audio: | |
audio.download(f'{download_path}/audio') | |
print('Audio download completed!') | |
else: | |
print("Error: Audio stream not found") | |
except Exception as e: | |
print("Connection Error") | |
print(e) | |
exit() | |
video_path = f'{download_path}/video/{video.default_filename}' | |
audio_path = '{}/audio/{}'.format(download_path, audio.default_filename) | |
audio_file = open(audio_path, "rb") | |
if video_name == 'placeholder': | |
video_name = audio.default_filename.split('.')[0] | |
elif args.video_file is not None: | |
# Read from local | |
video_path = args.video_file | |
if args.audio_file is not None: | |
audio_file = open(args.audio_file, "rb") | |
audio_path = args.audio_file | |
else: | |
output_audio_path = f'{download_path}/audio/{video_name}.mp3' | |
subprocess.run(['ffmpeg', '-i', video_path, '-f', 'mp3', '-ab', '192000', '-vn', output_audio_path]) | |
audio_file = open(output_audio_path, "rb") | |
audio_path = output_audio_path | |
if not os.path.exists(f'{result_path}/{video_name}'): | |
os.mkdir(f'{result_path}/{video_name}') | |
if args.audio_file is not None: | |
audio_file = open(args.audio_file, "rb") | |
audio_path = args.audio_file | |
pass | |
return audio_path, audio_file, video_path, video_name | |
def get_srt_class(srt_file_en, result_path, video_name, audio_path, audio_file=None, whisper_model='large', | |
method="stable"): | |
# Instead of using the script_en variable directly, we'll use script_input | |
if srt_file_en is not None: | |
srt = SrtScript.parse_from_srt_file(srt_file_en) | |
else: | |
# using whisper to perform speech-to-text and save it in <video name>_en.txt under RESULT PATH. | |
srt_file_en = "{}/{}/{}_en.srt".format(result_path, video_name, video_name) | |
if not os.path.exists(srt_file_en): | |
devices = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# use OpenAI API for transcribe | |
if method == "api": | |
transcript = openai.Audio.transcribe("whisper-1", audio_file) | |
# use local whisper model | |
elif method == "basic": | |
model = whisper.load_model(whisper_model, | |
device=devices) # using base model in local machine (may use large model on our server) | |
transcript = model.transcribe(audio_path) | |
srt = SRT_script(transcript['segments']) # read segments to SRT class | |
# use stable-whisper | |
elif method == "stable": | |
# use cuda if available | |
model = stable_whisper.load_model(whisper_model, device=devices) | |
transcript = model.transcribe(audio_path, regroup=False, | |
initial_prompt="Hello, welcome to my lecture. Are you good my friend?") | |
( | |
transcript | |
.split_by_punctuation(['.', '。', '?']) | |
.merge_by_gap(.15, max_words=3) | |
.merge_by_punctuation([' ']) | |
.split_by_punctuation(['.', '。', '?']) | |
) | |
transcript = transcript.to_dict() | |
srt = SRT_script(transcript['segments']) # read segments to SRT class | |
else: | |
raise ValueError("invalid speech to text method") | |
srt = SrtScript(transcript['segments']) # read segments to SRT class | |
else: | |
srt = SrtScript.parse_from_srt_file(srt_file_en) | |
return srt_file_en, srt | |
# Split the video script by sentences and create chunks within the token limit | |
def script_split(script_in, chunk_size=1000): | |
script_split = script_in.split('\n\n') | |
script_arr = [] | |
range_arr = [] | |
start = 1 | |
end = 0 | |
script = "" | |
for sentence in script_split: | |
if len(script) + len(sentence) + 1 <= chunk_size: | |
script += sentence + '\n\n' | |
end += 1 | |
else: | |
range_arr.append((start, end)) | |
start = end + 1 | |
end += 1 | |
script_arr.append(script.strip()) | |
script = sentence + '\n\n' | |
if script.strip(): | |
script_arr.append(script.strip()) | |
range_arr.append((start, len(script_split) - 1)) | |
assert len(script_arr) == len(range_arr) | |
return script_arr, range_arr | |
def check_translation(sentence, translation): | |
""" | |
check merge sentence issue from openai translation | |
""" | |
sentence_count = sentence.count('\n\n') + 1 | |
translation_count = translation.count('\n\n') + 1 | |
if sentence_count != translation_count: | |
# print("sentence length: ", len(sentence), sentence_count) | |
# print("translation length: ", len(translation), translation_count) | |
return False | |
else: | |
return True | |
def get_response(model_name, sentence): | |
""" | |
Generates a translated response for a given sentence using a specified OpenAI model. | |
Args: | |
model_name (str): The name of the OpenAI model to be used for translation, either "gpt-3.5-turbo" or "gpt-4". | |
sentence (str): The English sentence related to StarCraft 2 videos that needs to be translated into Chinese. | |
Returns: | |
str: The translated Chinese sentence, maintaining the original format, meaning, and number of lines. | |
""" | |
if model_name == "gpt-3.5-turbo" or model_name == "gpt-4": | |
response = openai.ChatCompletion.create( | |
model=model_name, | |
messages=[ | |
# {"role": "system", "content": "You are a helpful assistant that translates English to Chinese and have decent background in starcraft2."}, | |
# {"role": "system", "content": "Your translation has to keep the orginal format and be as accurate as possible."}, | |
# {"role": "system", "content": "Your translation needs to be consistent with the number of sentences in the original."}, | |
# {"role": "system", "content": "There is no need for you to add any comments or notes."}, | |
# {"role": "user", "content": 'Translate the following English text to Chinese: "{}"'.format(sentence)} | |
{"role": "system", | |
"content": "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"}, | |
{"role": "user", "content": sentence} | |
], | |
temperature=0.15 | |
) | |
return response['choices'][0]['message']['content'].strip() | |
# Translate and save | |
def translate(srt, script_arr, range_arr, model_name, video_name, video_link, attempts_count=5): | |
""" | |
Translates the given script array into another language using the chatgpt and writes to the SRT file. | |
This function takes a script array, a range array, a model name, a video name, and a video link as input. It iterates | |
through sentences and range in the script and range arrays. If the translation check fails for five times, the function | |
will attempt to resolve merge sentence issues and split the sentence into smaller tokens for a better translation. | |
Args: | |
srt (Subtitle): An instance of the Subtitle class representing the SRT file. | |
script_arr (list): A list of strings representing the original script sentences to be translated. | |
range_arr (list): A list of tuples representing the start and end positions of sentences in the script. | |
model_name (str): The name of the translation model to be used. | |
video_name (str): The name of the video. | |
video_link (str): The link to the video. | |
attempts_count (int): Number of attemps of failures for unmatched sentences. | |
""" | |
logging.info("start translating...") | |
previous_length = 0 | |
for sentence, range in tqdm(zip(script_arr, range_arr)): | |
# update the range based on previous length | |
range = (range[0] + previous_length, range[1] + previous_length) | |
# using chatgpt model | |
print(f"now translating sentences {range}") | |
logging.info(f"now translating sentences {range}, time: {datetime.now()}") | |
flag = True | |
while flag: | |
flag = False | |
try: | |
translate = get_response(model_name, sentence) | |
# detect merge sentence issue and try to solve for five times: | |
while not check_translation(sentence, translate) and attempts_count > 0: | |
translate = get_response(model_name, sentence) | |
attempts_count -= 1 | |
# if failure still happen, split into smaller tokens | |
if attempts_count == 0: | |
single_sentences = sentence.split("\n\n") | |
logging.info("merge sentence issue found for range", range) | |
translate = "" | |
for i, single_sentence in enumerate(single_sentences): | |
if i == len(single_sentences) - 1: | |
translate += get_response(model_name, single_sentence) | |
else: | |
translate += get_response(model_name, single_sentence) + "\n\n" | |
# print(single_sentence, translate.split("\n\n")[-2]) | |
logging.info("solved by individually translation!") | |
except Exception as e: | |
logging.debug("An error has occurred during translation:", e) | |
print("An error has occurred during translation:", e) | |
print("Retrying... the script will continue after 30 seconds.") | |
time.sleep(30) | |
flag = True | |
srt.set_translation(translate, range, model_name, video_name, video_link) | |
def main_old(): | |
args = parse_args() | |
# input check: input should be either video file or youtube video link. | |
if args.link is None and args.video_file is None and args.srt_file is None and args.audio_file is None: | |
raise TypeError("need video source or srt file") | |
# set up | |
start_time = time.time() | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
DOWNLOAD_PATH = Path(args.download) | |
if not DOWNLOAD_PATH.exists(): | |
DOWNLOAD_PATH.mkdir(parents=False, exist_ok=False) | |
DOWNLOAD_PATH.joinpath('audio').mkdir(parents=False, exist_ok=False) | |
DOWNLOAD_PATH.joinpath('video').mkdir(parents=False, exist_ok=False) | |
RESULT_PATH = Path(args.output_dir) | |
if not RESULT_PATH.exists(): | |
RESULT_PATH.mkdir(parents=False, exist_ok=False) | |
# set video name as the input file name if not specified | |
if args.video_name == 'placeholder': | |
# set video name to upload file name | |
if args.video_file is not None: | |
VIDEO_NAME = args.video_file.split('/')[-1].split('.')[0] | |
elif args.audio_file is not None: | |
VIDEO_NAME = args.audio_file.split('/')[-1].split('.')[0] | |
elif args.srt_file is not None: | |
VIDEO_NAME = args.srt_file.split('/')[-1].split('.')[0].split("_")[0] | |
else: | |
VIDEO_NAME = args.video_name | |
else: | |
VIDEO_NAME = args.video_name | |
audio_path, audio_file, video_path, VIDEO_NAME = get_sources(args, DOWNLOAD_PATH, RESULT_PATH, VIDEO_NAME) | |
if not os.path.exists(args.log_dir): | |
os.makedirs(args.log_dir) | |
logging.basicConfig(level=logging.INFO, handlers=[ | |
logging.FileHandler("{}/{}_{}.log".format(args.log_dir, VIDEO_NAME, datetime.now().strftime("%m%d%Y_%H%M%S")), | |
'w', encoding='utf-8')]) | |
logging.info("---------------------Video Info---------------------") | |
logging.info("Video name: {}, translation model: {}, video link: {}".format(VIDEO_NAME, args.model_name, args.link)) | |
srt_file_en, srt = get_srt_class(args.srt_file, RESULT_PATH, VIDEO_NAME, audio_path, audio_file, method="api") | |
# SRT class preprocess | |
logging.info("---------------------Start Preprocessing SRT class---------------------") | |
srt.write_srt_file_src(srt_file_en) | |
srt.form_whole_sentence() | |
# srt.spell_check_term() | |
# srt.correct_with_force_term() | |
processed_srt_file_en = srt_file_en.split('.srt')[0] + '_processed.srt' | |
srt.write_srt_file_src(processed_srt_file_en) | |
script_input = srt.get_source_only() | |
# write ass | |
if not args.only_srt: | |
logging.info("write English .srt file to .ass") | |
assSub_en = srt2ass(processed_srt_file_en, "default", "No", "Modest") | |
logging.info('ASS subtitle saved as: ' + assSub_en) | |
script_arr, range_arr = script_split(script_input) | |
logging.info("---------------------Start Translation--------------------") | |
translate(srt, script_arr, range_arr, args.model_name, VIDEO_NAME, args.link) | |
# SRT post-processing | |
logging.info("---------------------Start Post-processing SRT class---------------------") | |
srt.check_len_and_split() | |
srt.remove_trans_punctuation() | |
srt.write_srt_file_translate(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt") | |
srt.write_srt_file_bilingual(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_bi.srt") | |
# write ass | |
if not args.only_srt: | |
logging.info("write Chinese .srt file to .ass") | |
assSub_zh = srt2ass(f"{RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt", "default", "No", "Modest") | |
logging.info('ASS subtitle saved as: ' + assSub_zh) | |
# encode to .mp4 video file | |
if args.v: | |
logging.info("encoding video file") | |
if args.only_srt: | |
os.system( | |
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.srt" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4') | |
else: | |
os.system( | |
f'ffmpeg -i {video_path} -vf "subtitles={RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}_zh.ass" {RESULT_PATH}/{VIDEO_NAME}/{VIDEO_NAME}.mp4') | |
end_time = time.time() | |
logging.info( | |
"Pipeline finished, time duration:{}".format(time.strftime("%H:%M:%S", time.gmtime(end_time - start_time)))) | |
def main(): | |
pigeon = Pigeon() | |
pigeon.run() | |
if __name__ == "__main__": | |
main() | |