theodotus commited on
Commit
fcae46c
·
1 Parent(s): 6917fe9

Added m2m100 func

Browse files
Files changed (1) hide show
  1. pipeline.py +14 -1
pipeline.py CHANGED
@@ -12,6 +12,10 @@ class PreTrainedPipeline():
12
  dialogpt_path = os.path.join(path, "dialogpt")
13
  self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
14
  self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
 
 
 
 
15
 
16
  def __call__(self, inputs: str) -> List[Dict]:
17
 
@@ -32,4 +36,13 @@ class PreTrainedPipeline():
32
  eos_index = tokens.index(self.tokenizer.eos_token_id)
33
  answer_tokens = tokens[eos_index+1:]
34
  generated_text = self.tokenizer.decode(answer_tokens)
35
- return generated_text
 
 
 
 
 
 
 
 
 
 
12
  dialogpt_path = os.path.join(path, "dialogpt")
13
  self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
14
  self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
15
+ # Init M2M100
16
+ m2m100_path = os.path.join(path, "m2m100")
17
+ self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
18
+ self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
19
 
20
  def __call__(self, inputs: str) -> List[Dict]:
21
 
 
36
  eos_index = tokens.index(self.tokenizer.eos_token_id)
37
  answer_tokens = tokens[eos_index+1:]
38
  generated_text = self.tokenizer.decode(answer_tokens)
39
+ return generated_text
40
+
41
+ def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
42
+ self.m2m100_tokenizer.src_lang = from_lang
43
+ source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs))
44
+ target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]]
45
+ results = self.translator.translate_batch([source], target_prefix=[target_prefix])
46
+ target = results[0].hypotheses[0][1:]
47
+ translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target))
48
+ return translated_text