Spaces:
Running
Running
Merge branch 'oop-refactor' of https://github.com/project-kxkg/project-t into oop-refactor
Browse files- src/task.py +4 -3
- src/translators/translation.py +16 -6
src/task.py
CHANGED
@@ -11,7 +11,7 @@ import subprocess
|
|
11 |
from src.srt_util.srt import SrtScript
|
12 |
from src.srt_util.srt2ass import srt2ass
|
13 |
from time import time, strftime, gmtime, sleep
|
14 |
-
from src.translators.translation import get_translation,
|
15 |
|
16 |
import torch
|
17 |
import stable_whisper
|
@@ -183,9 +183,10 @@ class Task:
|
|
183 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
184 |
|
185 |
# Module 3: perform srt translation
|
186 |
-
def translation(self):
|
187 |
logging.info("---------------------Start Translation--------------------")
|
188 |
-
|
|
|
189 |
|
190 |
# Module 4: perform srt post process steps
|
191 |
def postprocess(self):
|
|
|
11 |
from src.srt_util.srt import SrtScript
|
12 |
from src.srt_util.srt2ass import srt2ass
|
13 |
from time import time, strftime, gmtime, sleep
|
14 |
+
from src.translators.translation import get_translation, prompt_selector
|
15 |
|
16 |
import torch
|
17 |
import stable_whisper
|
|
|
183 |
self.progress = TaskStatus.TRANSLATING.value[0], new_progress
|
184 |
|
185 |
# Module 3: perform srt translation
|
186 |
+
def translation(self,task_cfg):
|
187 |
logging.info("---------------------Start Translation--------------------")
|
188 |
+
prompt = prompt_selector(self.source_lang,self.target_lang,task_cfg['field'])
|
189 |
+
get_translation(self.SRT_Script, self.model, self.task_id, prompt, task_cfg['chunk_size'])
|
190 |
|
191 |
# Module 4: perform srt post process steps
|
192 |
def postprocess(self):
|
src/translators/translation.py
CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
|
|
5 |
from src.srt_util.srt import split_script
|
6 |
from .LLM_task import LLM_task
|
7 |
|
8 |
-
def get_translation(srt, model, video_name):
|
9 |
-
script_arr, range_arr = split_script(srt.get_source_only())
|
10 |
-
translate(srt, script_arr, range_arr, model, video_name)
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
@@ -26,8 +26,18 @@ def check_translation(sentence, translation):
|
|
26 |
|
27 |
# TODO{david}: prompts selector
|
28 |
def prompt_selector(src_lang, tgt_lang, domain):
|
29 |
-
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
|
33 |
"""
|
@@ -51,7 +61,7 @@ def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count
|
|
51 |
raise Exception("Warning! No Input have passed to LLM!")
|
52 |
if task is None:
|
53 |
task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
54 |
-
|
55 |
previous_length = 0
|
56 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
57 |
# update the range based on previous length
|
|
|
5 |
from src.srt_util.srt import split_script
|
6 |
from .LLM_task import LLM_task
|
7 |
|
8 |
+
def get_translation(srt, model, video_name, task, chunk_size = 1000):
|
9 |
+
script_arr, range_arr = split_script(srt.get_source_only(),chunk_size)
|
10 |
+
translate(srt, script_arr, range_arr, model, video_name, task)
|
11 |
pass
|
12 |
|
13 |
def check_translation(sentence, translation):
|
|
|
26 |
|
27 |
# TODO{david}: prompts selector
|
28 |
def prompt_selector(src_lang, tgt_lang, domain):
|
29 |
+
language_map = {
|
30 |
+
"EN": "English",
|
31 |
+
"ZH": "Chinese",
|
32 |
+
}
|
33 |
+
src_lang = language_map[src_lang]
|
34 |
+
tgt_lang = language_map[tgt_lang]
|
35 |
+
prompt = f"""
|
36 |
+
you are a translation assistant, your job is to translate a video in domain of {domain} from {src_lang} to {tgt_lang},
|
37 |
+
you will be provided with a segement in {[src_lang]} parsed by line, where your translation text should keep the original
|
38 |
+
meaning and the number of lines.
|
39 |
+
"""
|
40 |
+
return prompt
|
41 |
|
42 |
def translate(srt, script_arr, range_arr, model_name, video_name, attempts_count=5, task=None, temp = 0.15):
|
43 |
"""
|
|
|
61 |
raise Exception("Warning! No Input have passed to LLM!")
|
62 |
if task is None:
|
63 |
task = "你是一个翻译助理,你的任务是翻译星际争霸视频,你会被提供一个按行分割的英文段落,你需要在保证句意和行数的情况下输出翻译后的文本。"
|
64 |
+
print(task)
|
65 |
previous_length = 0
|
66 |
for sentence, range_ in tqdm(zip(script_arr, range_arr)):
|
67 |
# update the range based on previous length
|