from datetime import timedelta import os import whisper from csv import reader import re class SRT_segment(object): def __init__(self, *args) -> None: if isinstance(args[0], dict): segment = args[0] start_ms = int((segment['start']*100)%100*10) end_ms = int((segment['end']*100)%100*10) start_time = str(timedelta(seconds=int(segment['start']), milliseconds=start_ms)) end_time = str(timedelta(seconds=int(segment['end']), milliseconds=end_ms)) if start_ms == 0: self.start_time_str = str(0)+start_time.split('.')[0]+',000' else: self.start_time_str = str(0)+start_time.split('.')[0]+','+start_time.split('.')[1][:3] if end_ms == 0: self.end_time_str = str(0)+end_time.split('.')[0]+',000' else: self.end_time_str = str(0)+end_time.split('.')[0]+','+end_time.split('.')[1][:3] self.source_text = segment['text'] self.duration = f"{self.start_time_str} --> {self.end_time_str}" self.translation = "" elif isinstance(args[0], list): self.source_text = args[0][2] self.duration = args[0][1] self.start_time_str = self.duration.split(" --> ")[0] self.end_time_str = self.duration.split(" --> ")[1] self.translation = "" def merge_seg(self, seg): self.source_text += seg.source_text self.translation += seg.translation self.end_time_str = seg.end_time_str self.duration = f"{self.start_time_str} --> {self.end_time_str}" pass def __str__(self) -> str: return f'{self.duration}\n{self.source_text}\n\n' def get_trans_str(self) -> str: return f'{self.duration}\n{self.translation}\n\n' def get_bilingual_str(self) -> str: return f'{self.duration}\n{self.source_text}\n{self.translation}\n\n' class SRT_script(): def __init__(self, segments) -> None: self.segments = [] for seg in segments: srt_seg = SRT_segment(seg) self.segments.append(srt_seg) @classmethod def parse_from_srt_file(cls, path:str): with open(path, 'r', encoding="utf-8") as f: script_lines = f.read().splitlines() segments = [] for i in range(len(script_lines)): if i % 4 == 0: segments.append(list(script_lines[i:i+4])) return cls(segments) def merge_segs(self, idx_list) -> SRT_segment: final_seg = self.segments[idx_list[0]] if len(idx_list) == 1: return final_seg for idx in range(1, len(idx_list)): final_seg.merge_seg(self.segments[idx_list[idx]]) return final_seg def form_whole_sentence(self): merge_list = [] # a list of indices that should be merged e.g. [[0], [1, 2, 3, 4], [5, 6], [7]] sentence = [] for i, seg in enumerate(self.segments): if seg.source_text[-1] == '.': sentence.append(i) merge_list.append(sentence) sentence = [] else: sentence.append(i) segments = [] for idx_list in merge_list: segments.append(self.merge_segs(idx_list)) self.segments = segments # need memory release? def set_translation(self, translate:str, id_range:tuple): start_seg_id = id_range[0] end_seg_id = id_range[1] lines = translate.split('\n\n') if len(lines) != (end_seg_id - start_seg_id + 1): print(id_range) for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): print(seg.source_text) print(translate) for i, seg in enumerate(self.segments[start_seg_id-1:end_seg_id]): # naive way to due with merge translation problem # TODO: need a smarter solution if i < len(lines): if "(Note:" in lines[i]: # to avoid note lines.remove(lines[i]) if i == len(lines) - 1: break seg.translation = lines[i].split(":")[1] pass def split_seg(self, seg_id): # TODO: evenly split seg to 2 parts and add new seg into self.segments pass def check_len_and_split(self, threshold): # TODO: if sentence length >= threshold, split this segments to two pass def get_source_only(self): # return a string with pure source text result = "" for i, seg in enumerate(self.segments): result+=f'SENTENCE {i+1}: {seg.source_text}\n\n\n' return result def reform_src_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += str(seg) return result def reform_trans_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += seg.get_trans_str() return result def form_bilingual_str(self): result = "" for i, seg in enumerate(self.segments): result += f'{i+1}\n' result += seg.get_bilingual_str() return result def write_srt_file_src(self, path:str): # write srt file to path with open(path, "w", encoding='utf-8') as f: f.write(self.reform_src_str()) pass def write_srt_file_translate(self, path:str): with open(path, "w", encoding='utf-8') as f: f.write(self.reform_trans_str()) pass def write_srt_file_bilingual(self, path:str): with open(path, "w", encoding='utf-8') as f: f.write(self.form_bilingual_str()) pass def correct_with_force_term(self): ## force term correction # TODO: shortcut translation i.e. VA, ob # TODO: variety of translation # load term dictionary with open("finetune_data/dict.csv",'r', encoding='utf-8') as f: csv_reader = reader(f) term_dict = {rows[0]:rows[1] for rows in csv_reader} # change term for seg in self.segments: ready_words = re.sub('\n', '\n ', seg.source_text).split(" ") for i in range(len(ready_words)): word = ready_words[i] if word[-2:] == ".\n" : if word[:-2].lower() in term_dict : new_word = word.replace(word[:-2], term_dict.get(word[:-2].lower())) + ' ' ready_words[i] = new_word else: ready_words[i] = word + ' ' elif word.lower() in term_dict : new_word = word.replace(word,term_dict.get(word.lower())) + ' ' ready_words[i] = new_word else : ready_words[i]= word + ' ' seg.source_text = re.sub('\n ', '\n', "".join(ready_words)) print(self) pass