DialoGPT-uk / pipeline.py
theodotus's picture
Use opt for generation
90541db
import transformers
import ctranslate2
from typing import List, Dict
import os
class PreTrainedPipeline():
def __init__(self, path: str):
# Init DialoGPT
self.eos_token = "\n"
dialogpt_path = os.path.join(path, "opt")
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="float")
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m")
# Init M2M100
m2m100_path = os.path.join(path, "m2m100")
self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8")
self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M")
def __call__(self, inputs: str) -> List[Dict]:
# to eng
en_text = self.m2m100(inputs, "uk", "en")
# Run dialogpt
generated_text = self.dialogpt(en_text)
# to ukr
uk_text = self.m2m100(generated_text, "en", "uk")
return [{"generated_text": uk_text}]
def dialogpt(self, inputs: str) -> str:
# Get input tokens
text = inputs + self.eos_token
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
# generate
results = self.generator.generate_batch([start_tokens], max_length=50, repetition_penalty=1.2)
output = results[0].sequences[0]
# left only answers
tokens = self.tokenizer.convert_tokens_to_ids(output)
generated_text = self.tokenizer.decode(tokens)
eos_index = self.index_last(generated_text, self.eos_token)
answer_text = generated_text[eos_index+1:]
return answer_text
@staticmethod
def index_last(li: str, char: str):
idx = len(li) - 1 - li[::-1].index(char)
return idx
def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str:
self.m2m100_tokenizer.src_lang = from_lang
source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs))
target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]]
results = self.translator.translate_batch([source], target_prefix=[target_prefix])
target = results[0].hypotheses[0][1:]
translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target))
return translated_text