DWizard commited on
Commit
6808a65
·
2 Parent(s): a12b2b8 4f0065c

Merge branch 'oop-refactor' of https://github.com/project-kxkg/project-t into oop-refactor

Browse files
Files changed (2) hide show
  1. src/task.py +4 -3
  2. src/translators/translation.py +16 -6
src/task.py CHANGED
@@ -11,7 +11,7 @@ import subprocess
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
- from src.translators.translation import get_translation, translate
15
 
16
  import torch
17
  import stable_whisper
@@ -183,9 +183,10 @@ class Task:
183
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
184
 
185
  # Module 3: perform srt translation
186
- def translation(self):
187
  logging.info("---------------------Start Translation--------------------")
188
- get_translation(self.SRT_Script, self.model, self.task_id)
 
189
 
190
  # Module 4: perform srt post process steps
191
  def postprocess(self):
 
11
  from src.srt_util.srt import SrtScript
12
  from src.srt_util.srt2ass import srt2ass
13
  from time import time, strftime, gmtime, sleep
14
+ from src.translators.translation import get_translation, prompt_selector
15
 
16
  import torch
17
  import stable_whisper
 
183
  self.progress = TaskStatus.TRANSLATING.value[0], new_progress
184
 
185
  # Module 3: perform srt translation
186
+ def translation(self,task_cfg):
187
  logging.info("---------------------Start Translation--------------------")
188
+ prompt = prompt_selector(self.source_lang,self.target_lang,task_cfg['field'])
189
+ get_translation(self.SRT_Script, self.model, self.task_id, prompt, task_cfg['chunk_size'])
190
 
191
  # Module 4: perform srt post process steps
192
  def postprocess(self):
src/translators/translation.py CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
- def get_translation(srt, model, video_name):
9
- script_arr, range_arr = split_script(srt.get_source_only())
10
- translate(srt, script_arr, range_arr, model, video_name)
11
  pass
12
 
13
  def check_translation(sentence, translation):
@@ -26,8 +26,18 @@ def check_translation(sentence, translation):
26
 
27
  # TODO{david}: prompts selector
28
  def prompt_selector(src_lang, tgt_lang, domain):
29
-
30
- return ""
 
 
 
 
 
 
 
 
 
 
31
 
32
  def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
33
  """
@@ -51,7 +61,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count
51
  raise Exception("Warning! No Input have passed to LLM!")
52
  if task is None:
53
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
54
-
55
  previous_length = 0
56
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
57
  # update the range based on previous length
 
5
  from src.srt_util.srt import split_script
6
  from .LLM_task import LLM_task
7
 
8
+ def get_translation(srt, model, video_name, task, chunk_size = 1000):
9
+ script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
10
+ translate(srt, script_arr, range_arr, model, video_name, task)
11
  pass
12
 
13
  def check_translation(sentence, translation):
 
26
 
27
  # TODO{david}: prompts selector
28
  def prompt_selector(src_lang, tgt_lang, domain):
29
+ language_map = {
30
+ "EN": "English",
31
+ "ZH": "Chinese",
32
+ }
33
+ src_lang = language_map[src_lang]
34
+ tgt_lang = language_map[tgt_lang]
35
+ prompt = f"""
36
+ you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
37
+ you will be provided with a segement in {[src_lang]} parsed by line, where your translation text should keep the original
38
+ meaning and the number of lines.
39
+ """
40
+ return prompt
41
 
42
  def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
43
  """
 
61
  raise Exception("Warning! No Input have passed to LLM!")
62
  if task is None:
63
  task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
64
+ print(task)
65
  previous_length = 0
66
  for sentence, range_ in tqdm(zip(script_arr, range_arr)):
67
  # update the range based on previous length