theodotus commited on
Commit
90541db
·
1 Parent(s): b9571f6

Use opt for generation

Browse files
Files changed (1) hide show
  1. pipeline.py +15 -9
pipeline.py CHANGED
@@ -9,9 +9,10 @@ import os
9
  class PreTrainedPipeline():
10
  def __init__(self, path: str):
11
  # Init DialoGPT
12
- dialogpt_path = os.path.join(path, "dialogpt")
13
- self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="int8")
14
- self.tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
 
15
  # Init M2M100
16
  m2m100_path = os.path.join(path, "m2m100")
17
  self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
@@ -29,17 +30,22 @@ class PreTrainedPipeline():
29
 
30
  def dialogpt(self, inputs: str) -> str:
31
  # Get input tokens
32
- text = inputs + self.tokenizer.eos_token
33
  start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
34
  # generate
35
- results = self.generator.generate_batch([start_tokens])
36
  output = results[0].sequences[0]
37
  # left only answers
38
  tokens = self.tokenizer.convert_tokens_to_ids(output)
39
- eos_index = tokens.index(self.tokenizer.eos_token_id)
40
- answer_tokens = tokens[eos_index+1:]
41
- generated_text = self.tokenizer.decode(answer_tokens)
42
- return generated_text
 
 
 
 
 
43
 
44
  def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
45
  self.m2m100_tokenizer.src_lang = from_lang
 
9
  class PreTrainedPipeline():
10
  def __init__(self, path: str):
11
  # Init DialoGPT
12
+ self.eos_token = "\n"
13
+ dialogpt_path = os.path.join(path, "opt")
14
+ self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="float")
15
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")
16
  # Init M2M100
17
  m2m100_path = os.path.join(path, "m2m100")
18
  self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
 
30
 
31
  def dialogpt(self, inputs: str) -> str:
32
  # Get input tokens
33
+ text = inputs + self.eos_token
34
  start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
35
  # generate
36
+ results = self.generator.generate_batch([start_tokens], max_length=50, repetition_penalty=1.2)
37
  output = results[0].sequences[0]
38
  # left only answers
39
  tokens = self.tokenizer.convert_tokens_to_ids(output)
40
+ generated_text = self.tokenizer.decode(tokens)
41
+ eos_index = self.index_last(generated_text, self.eos_token)
42
+ answer_text = generated_text[eos_index+1:]
43
+ return answer_text
44
+
45
+ @staticmethod
46
+ def index_last(li: str, char: str):
47
+ idx = len(li) - 1 - li[::-1].index(char)
48
+ return idx
49
 
50
  def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
51
  self.m2m100_tokenizer.src_lang = from_lang