Added m2m100 func
Browse files- pipeline.py +14 -1
pipeline.py
CHANGED
@@ -12,6 +12,10 @@ class PreTrainedPipeline():
|
|
12 |
dialogpt_path = os.path.join(path, "dialogpt")
|
13 |
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
|
14 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def __call__(self, inputs: str) -> List[Dict]:
|
17 |
|
@@ -32,4 +36,13 @@ class PreTrainedPipeline():
|
|
32 |
eos_index = tokens.index(self.tokenizer.eos_token_id)
|
33 |
answer_tokens = tokens[eos_index+1:]
|
34 |
generated_text = self.tokenizer.decode(answer_tokens)
|
35 |
-
return generated_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
dialogpt_path = os.path.join(path, "dialogpt")
|
13 |
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
|
14 |
self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
|
15 |
+
# Init M2M100
|
16 |
+
m2m100_path = os.path.join(path, "m2m100")
|
17 |
+
self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
|
18 |
+
self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
|
19 |
|
20 |
def __call__(self, inputs: str) -> List[Dict]:
|
21 |
|
|
|
36 |
eos_index = tokens.index(self.tokenizer.eos_token_id)
|
37 |
answer_tokens = tokens[eos_index+1:]
|
38 |
generated_text = self.tokenizer.decode(answer_tokens)
|
39 |
+
return generated_text
|
40 |
+
|
41 |
+
def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
|
42 |
+
self.m2m100_tokenizer.src_lang = from_lang
|
43 |
+
source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs))
|
44 |
+
target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]]
|
45 |
+
results = self.translator.translate_batch([source], target_prefix=[target_prefix])
|
46 |
+
target = results[0].hypotheses[0][1:]
|
47 |
+
translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target))
|
48 |
+
return translated_text
|