Eason Lu commited on
Commit
b37d0d4
·
1 Parent(s): 6808a65

Former-commit-id: 2d7d950f54b1deb8b4dd9b68c98d65384954c47e

configs/task_config.yaml CHANGED
@@ -1,18 +1,35 @@
1
  # configuration for each task
2
- model: gpt-4
3
- # output type that user receive
4
- output_type:
5
- subtitle: srt
6
- video: False
7
- bilingal: False
8
  source_lang: EN
9
  target_lang: ZH
10
  field: SC2
11
- chunk_size: 1000
 
 
 
 
 
 
 
 
12
  pre_process:
13
- ON: True
14
  sentence_form: True
15
  spell_check: False
16
  term_correct: True
 
 
 
 
 
 
 
17
  post_process:
18
- ON: True
 
 
 
 
 
 
 
 
 
 
1
  # configuration for each task
 
 
 
 
 
 
2
  source_lang: EN
3
  target_lang: ZH
4
  field: SC2
5
+
6
+ # ASR config
7
+ ASR:
8
+ ASR_model: whisper
9
+ whisper_config:
10
+ whisper_model: tiny
11
+ method: stable
12
+
13
+ # pre-process module config
14
  pre_process:
 
15
  sentence_form: True
16
  spell_check: False
17
  term_correct: True
18
+
19
+ # Translation module config
20
+ translation:
21
+ model: gpt-4
22
+ chunk_size: 1000
23
+
24
+ # post-process module config
25
  post_process:
26
+ check_len_and_split: True
27
+ remove_trans_punctuation: True
28
+
29
+ # output type that user receive
30
+ output_type:
31
+ subtitle: srt
32
+ video: False
33
+ bilingal: False
34
+
35
+
entries/run.py CHANGED
@@ -10,6 +10,13 @@ from datetime import datetime
10
  import shutil
11
  from uuid import uuid4
12
 
 
 
 
 
 
 
 
13
  def parse_args():
14
  parser = argparse.ArgumentParser()
15
  parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
@@ -42,8 +49,9 @@ if __name__ == "__main__":
42
  task_dir.mkdir(parents=False, exist_ok=False)
43
  task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
44
 
45
- # logging
46
- logging.basicConfig(level=logging.INFO, handlers=[
 
47
  logging.FileHandler(
48
  "{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
49
  'w', encoding='utf-8')])
 
10
  import shutil
11
  from uuid import uuid4
12
 
13
+ """
14
+ Main entry for terminal environment.
15
+ Use it for debug and development purpose.
16
+ Usage: python3 entries/run.py [-h] [--link LINK] [--video_file VIDEO_FILE] [--audio_file AUDIO_FILE] [--srt_file SRT_FILE] [--continue CONTINUE]
17
+ [--launch_cfg LAUNCH_CFG] [--task_cfg TASK_CFG]
18
+ """
19
+
20
  def parse_args():
21
  parser = argparse.ArgumentParser()
22
  parser.add_argument("--link", help="youtube video link here", default=None, type=str, required=False)
 
49
  task_dir.mkdir(parents=False, exist_ok=False)
50
  task_dir.joinpath("results").mkdir(parents=False, exist_ok=False)
51
 
52
+ # logging setting
53
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
54
+ logging.basicConfig(level=logging.INFO, format=logfmt, handlers=[
55
  logging.FileHandler(
56
  "{}/{}_{}.log".format(task_dir, f"task_{task_id}", datetime.now().strftime("%m%d%Y_%H%M%S")),
57
  'w', encoding='utf-8')])
src/srt_util/srt.py CHANGED
@@ -185,7 +185,6 @@ class SrtScript(object):
185
 
186
  def inner_func(target, input_str):
187
  response = openai.ChatCompletion.create(
188
- # model=model,
189
  model="gpt-4",
190
  messages=[
191
  {"role": "system",
@@ -208,19 +207,13 @@ class SrtScript(object):
208
  flag = True
209
  while flag:
210
  flag = False
211
- # print("translate:")
212
- # print(translate)
213
  try:
214
- # print("target")
215
- # print(end_seg_id - start_seg_id + 1)
216
  translate = inner_func(end_seg_id - start_seg_id + 1, translate)
217
  except Exception as e:
218
  print("An error has occurred during solving unmatched lines:", e)
219
  print("Retrying...")
220
  flag = True
221
  lines = translate.split('\n')
222
- # print("result")
223
- # print(len(lines))
224
 
225
  if len(lines) < (end_seg_id - start_seg_id + 1):
226
  solved = False
@@ -264,6 +257,7 @@ class SrtScript(object):
264
  # evenly split seg to 2 parts and add new seg into self.segments
265
 
266
  # ignore the initial comma to solve the recursion problem
 
267
  if len(seg.source_text) > 2:
268
  if seg.source_text[:2] == ', ':
269
  seg.source_text = seg.source_text[2:]
 
185
 
186
  def inner_func(target, input_str):
187
  response = openai.ChatCompletion.create(
 
188
  model="gpt-4",
189
  messages=[
190
  {"role": "system",
 
207
  flag = True
208
  while flag:
209
  flag = False
 
 
210
  try:
 
 
211
  translate = inner_func(end_seg_id - start_seg_id + 1, translate)
212
  except Exception as e:
213
  print("An error has occurred during solving unmatched lines:", e)
214
  print("Retrying...")
215
  flag = True
216
  lines = translate.split('\n')
 
 
217
 
218
  if len(lines) < (end_seg_id - start_seg_id + 1):
219
  solved = False
 
257
  # evenly split seg to 2 parts and add new seg into self.segments
258
 
259
  # ignore the initial comma to solve the recursion problem
260
+ # FIXME: accomodate multilingual setting
261
  if len(seg.source_text) > 2:
262
  if seg.source_text[:2] == ', ':
263
  seg.source_text = seg.source_text[2:]
src/task.py CHANGED
@@ -55,7 +55,6 @@ class TaskStatus(str, Enum):
55
  OUTPUT_MODULE = 'OUTPUT_MODULE'
56
 
57
 
58
-
59
  class Task:
60
  @property
61
  def status(self):
@@ -70,69 +69,74 @@ class Task:
70
  def __init__(self, task_id, task_local_dir, task_cfg):
71
  self.__status_lock = threading.Lock()
72
  self.__status = TaskStatus.CREATED
 
73
  openai.api_key = getenv("OPENAI_API_KEY")
74
- self.launch_info = task_cfg # do not use, just for fallback
 
75
  self.task_local_dir = task_local_dir
76
- self.model = task_cfg["model"]
77
- self.gpu_status = 0
 
 
78
  self.output_type = task_cfg["output_type"]
79
  self.target_lang = task_cfg["target_lang"]
80
  self.source_lang = task_cfg["source_lang"]
81
  self.field = task_cfg["field"]
82
  self.pre_setting = task_cfg["pre_process"]
83
  self.post_setting = task_cfg["post_process"]
84
- self.task_id = task_id
85
  self.audio_path = None
86
  self.SRT_Script = None
87
  self.result = None
88
  self.s_t = None
89
  self.t_e = None
90
 
91
- print(f" Task ID: {self.task_id}")
92
- logging.info(f" Task ID: {self.task_id}")
93
- logging.info(f" {self.source_lang} -> {self.target_lang} task in {self.field}")
94
- logging.info(f" Model: \t\t\t{self.model}")
95
- logging.info(f" subtitle_type: \t\t{self.output_type['subtitle']}")
96
- logging.info(f" video_ouput: \t\t{self.output_type['video']}")
97
- logging.info(f" bilingal_ouput: \t{self.output_type['bilingal']}")
98
- logging.info(" PREprocess setting:")
99
- for key, value in self.pre_setting:
100
- logging.info(f" {key}: {value}")
101
- logging.info(" POSTprocess setting:")
102
- for key, value in self.post_setting:
103
- logging.info(f" {key}: {value}")
104
 
105
  @staticmethod
106
  def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
107
  # convert to audio
108
- logging.info(" Task Creation method: Youtube Link")
109
  return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
110
 
111
  @staticmethod
112
  def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
113
  # get audio path
114
- logging.info(" Task Creation method: Audio File")
115
  return AudioTask(task_id, task_dir, task_cfg, audio_path)
116
 
117
  @staticmethod
118
  def fromVideoFile(video_path, task_id, task_dir, task_cfg):
119
  # get audio path
120
- logging.info(" Task Creation method: Video File")
121
  return VideoTask(task_id, task_dir, task_cfg, video_path)
122
 
123
  # Module 1 ASR: audio --> SRT_script
124
- def get_srt_class(self, whisper_model='tiny', method="stable"):
125
  # Instead of using the script_en variable directly, we'll use script_input
 
126
  self.status = TaskStatus.INITIALIZING_ASR
127
  self.t_s = time()
128
  # self.SRT_Script = SrtScript
129
-
 
130
  src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
131
  if not Path.exists(src_srt_path):
132
  # extract script from audio
133
  logging.info("extract script from audio")
134
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
135
- # logging.info("device: ", device)
136
 
137
  if method == "api":
138
  with open(self.audio_path, 'rb') as audio_file:
@@ -158,7 +162,6 @@ class Task:
158
  self.SRT_Script.write_srt_file_src(src_srt_path)
159
 
160
  # Module 2: SRT preprocess: perform preprocess steps
161
- # TODO: multi-lang and multi-field support according to task_cfg
162
  def preprocess(self):
163
  self.status = TaskStatus.PRE_PROCESSING
164
  logging.info("--------------------Start Preprocessing SRT class--------------------")
@@ -183,18 +186,20 @@ class Task:
183
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
184
 
185
  # Module 3: perform srt translation
186
- def translation(self,task_cfg):
187
  logging.info("---------------------Start Translation--------------------")
188
- prompt = prompt_selector(self.source_lang,self.target_lang,task_cfg['field'])
189
- get_translation(self.SRT_Script, self.model, self.task_id, prompt, task_cfg['chunk_size'])
190
 
191
  # Module 4: perform srt post process steps
192
  def postprocess(self):
193
  self.status = TaskStatus.POST_PROCESSING
194
 
195
  logging.info("---------------------Start Post-processing SRT class---------------------")
196
- self.SRT_Script.check_len_and_split()
197
- self.SRT_Script.remove_trans_punctuation()
 
 
198
  logging.info("---------------------Post-processing SRT class finished---------------------")
199
 
200
  # Module 5: output module
@@ -233,11 +238,9 @@ class Task:
233
 
234
  def run_pipeline(self):
235
  self.get_srt_class()
236
- if self.pre_setting["ON"]:
237
- self.preprocess()
238
  self.translation()
239
- if self.post_setting["ON"]:
240
- self.postprocess()
241
  self.result = self.output_render()
242
  print(self.result)
243
 
@@ -259,7 +262,6 @@ class YoutubeTask(Task):
259
  audio = yt.streams.filter(only_audio=True).first()
260
  if audio:
261
  audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
262
- # logging.info(f'Audio download completed to {self.task_local_dir}!')
263
  else:
264
  logging.info(" download audio failed, using ffmpeg to extract audio")
265
  subprocess.run(
 
55
  OUTPUT_MODULE = 'OUTPUT_MODULE'
56
 
57
 
 
58
  class Task:
59
  @property
60
  def status(self):
 
69
  def __init__(self, task_id, task_local_dir, task_cfg):
70
  self.__status_lock = threading.Lock()
71
  self.__status = TaskStatus.CREATED
72
+ self.gpu_status = 0
73
  openai.api_key = getenv("OPENAI_API_KEY")
74
+ self.task_id = task_id
75
+
76
  self.task_local_dir = task_local_dir
77
+ self.ASR_setting = task_cfg["ASR"]
78
+ self.translation_setting = task_cfg["translation"]
79
+ self.translation_model = self.translation_setting["model"]
80
+
81
  self.output_type = task_cfg["output_type"]
82
  self.target_lang = task_cfg["target_lang"]
83
  self.source_lang = task_cfg["source_lang"]
84
  self.field = task_cfg["field"]
85
  self.pre_setting = task_cfg["pre_process"]
86
  self.post_setting = task_cfg["post_process"]
87
+
88
  self.audio_path = None
89
  self.SRT_Script = None
90
  self.result = None
91
  self.s_t = None
92
  self.t_e = None
93
 
94
+ print(f"Task ID: {self.task_id}")
95
+ logging.info(f"Task ID: {self.task_id}")
96
+ logging.info(f"{self.source_lang} -> {self.target_lang} task in {self.field}")
97
+ logging.info(f"Translation Model: {self.translation_model}")
98
+ logging.info(f"subtitle_type: {self.output_type['subtitle']}")
99
+ logging.info(f"video_ouput: {self.output_type['video']}")
100
+ logging.info(f"bilingal_ouput: {self.output_type['bilingal']}")
101
+ logging.info("Pre-process setting:")
102
+ for key in self.pre_setting:
103
+ logging.info(f"{key}: {self.pre_setting[key]}")
104
+ logging.info("Post-process setting:")
105
+ for key in self.post_setting:
106
+ logging.info(f"{key}: {self.post_setting[key]}")
107
 
108
  @staticmethod
109
  def fromYoutubeLink(youtube_url, task_id, task_dir, task_cfg):
110
  # convert to audio
111
+ logging.info("Task Creation method: Youtube Link")
112
  return YoutubeTask(task_id, task_dir, task_cfg, youtube_url)
113
 
114
  @staticmethod
115
  def fromAudioFile(audio_path, task_id, task_dir, task_cfg):
116
  # get audio path
117
+ logging.info("Task Creation method: Audio File")
118
  return AudioTask(task_id, task_dir, task_cfg, audio_path)
119
 
120
  @staticmethod
121
  def fromVideoFile(video_path, task_id, task_dir, task_cfg):
122
  # get audio path
123
+ logging.info("Task Creation method: Video File")
124
  return VideoTask(task_id, task_dir, task_cfg, video_path)
125
 
126
  # Module 1 ASR: audio --> SRT_script
127
+ def get_srt_class(self):
128
  # Instead of using the script_en variable directly, we'll use script_input
129
+ # TODO: setup ASR module like translator
130
  self.status = TaskStatus.INITIALIZING_ASR
131
  self.t_s = time()
132
  # self.SRT_Script = SrtScript
133
+ method = self.ASR_setting["whisper_config"]["method"]
134
+ whisper_model = self.ASR_setting["whisper_config"]["whisper_model"]
135
  src_srt_path = self.task_local_dir.joinpath(f"task_{self.task_id})_{self.source_lang}.srt")
136
  if not Path.exists(src_srt_path):
137
  # extract script from audio
138
  logging.info("extract script from audio")
139
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
140
 
141
  if method == "api":
142
  with open(self.audio_path, 'rb') as audio_file:
 
162
  self.SRT_Script.write_srt_file_src(src_srt_path)
163
 
164
  # Module 2: SRT preprocess: perform preprocess steps
 
165
  def preprocess(self):
166
  self.status = TaskStatus.PRE_PROCESSING
167
  logging.info("--------------------Start Preprocessing SRT class--------------------")
 
186
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
187
 
188
  # Module 3: perform srt translation
189
+ def translation(self):
190
  logging.info("---------------------Start Translation--------------------")
191
+ prompt = prompt_selector(self.source_lang, self.target_lang, self.field)
192
+ get_translation(self.SRT_Script, self.translation_model, self.task_id, prompt, self.translation_setting['chunk_size'])
193
 
194
  # Module 4: perform srt post process steps
195
  def postprocess(self):
196
  self.status = TaskStatus.POST_PROCESSING
197
 
198
  logging.info("---------------------Start Post-processing SRT class---------------------")
199
+ if self.post_setting["check_len_and_split"]:
200
+ self.SRT_Script.check_len_and_split()
201
+ if self.post_setting["remove_trans_punctuation"]:
202
+ self.SRT_Script.remove_trans_punctuation()
203
  logging.info("---------------------Post-processing SRT class finished---------------------")
204
 
205
  # Module 5: output module
 
238
 
239
  def run_pipeline(self):
240
  self.get_srt_class()
241
+ self.preprocess()
 
242
  self.translation()
243
+ self.postprocess()
 
244
  self.result = self.output_render()
245
  print(self.result)
246
 
 
262
  audio = yt.streams.filter(only_audio=True).first()
263
  if audio:
264
  audio.download(str(self.task_local_dir), filename=f"task_{self.task_id}.mp3")
 
265
  else:
266
  logging.info(" download audio failed, using ffmpeg to extract audio")
267
  subprocess.run(
src/translators/translation.py CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
- def get_translation(srt, model, video_name, task, chunk_size = 1000):
9
  script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
10
- translate(srt, script_arr, range_arr, model, video_name, task)
11
  pass
12
 
13
  def check_translation(sentence, translation):
@@ -39,7 +39,7 @@ def prompt_selector(src_lang, tgt_lang, domain):
39
  """
40
  return prompt
41
 
42
- def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
43
  """
44
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
45
 
@@ -61,14 +61,14 @@ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count
61
  raise Exception("Warning! No Input have passed to LLM!")
62
  if task is None:
63
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
64
- print(task)
65
  previous_length = 0
66
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
67
  # update the range based on previous length
68
  range_ = (range_[0] + previous_length, range_[1] + previous_length)
69
  # using chatgpt model
70
  print(f"now translating sentences {range_}")
71
- #logging.info(f"now translating sentences {range_}, time: {datetime.now()}")
72
  flag = True
73
  while flag:
74
  flag = False
 
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
+ def get_translation(srt, model, video_name, prompt, chunk_size = 1000):
9
  script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
10
+ translate(srt, script_arr, range_arr, model, video_name, task=prompt)
11
  pass
12
 
13
  def check_translation(sentence, translation):
 
39
  """
40
  return prompt
41
 
42
+ def translate(srt, script_arr, range_arr, model_name, video_name=None, attempts_count=5, task=None, temp = 0.15):
43
  """
44
  Translates the given script array into another language using the chatgpt and writes to the SRT file.
45
 
 
61
  raise Exception("Warning! No Input have passed to LLM!")
62
  if task is None:
63
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
64
+ logging.info(f"translation prompt: {task}")
65
  previous_length = 0
66
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
67
  # update the range based on previous length
68
  range_ = (range_[0] + previous_length, range_[1] + previous_length)
69
  # using chatgpt model
70
  print(f"now translating sentences {range_}")
71
+ logging.info(f"now translating sentences {range_}")
72
  flag = True
73
  while flag:
74
  flag = False