|
import transformers |
|
import ctranslate2 |
|
|
|
from typing import List, Dict |
|
import os |
|
|
|
|
|
|
|
class PreTrainedPipeline(): |
|
def __init__(self, path: str): |
|
|
|
self.eos_token = "\n" |
|
dialogpt_path = os.path.join(path, "opt") |
|
self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="float") |
|
self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m") |
|
|
|
m2m100_path = os.path.join(path, "m2m100") |
|
self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8") |
|
self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M") |
|
|
|
def __call__(self, inputs: str) -> List[Dict]: |
|
|
|
en_text = self.m2m100(inputs, "uk", "en") |
|
|
|
generated_text = self.dialogpt(en_text) |
|
|
|
uk_text = self.m2m100(generated_text, "en", "uk") |
|
|
|
return [{"generated_text": uk_text}] |
|
|
|
def dialogpt(self, inputs: str) -> str: |
|
|
|
text = inputs + self.eos_token |
|
start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) |
|
|
|
results = self.generator.generate_batch([start_tokens], max_length=50, repetition_penalty=1.2) |
|
output = results[0].sequences[0] |
|
|
|
tokens = self.tokenizer.convert_tokens_to_ids(output) |
|
generated_text = self.tokenizer.decode(tokens) |
|
eos_index = self.index_last(generated_text, self.eos_token) |
|
answer_text = generated_text[eos_index+1:] |
|
return answer_text |
|
|
|
@staticmethod |
|
def index_last(li: str, char: str): |
|
idx = len(li) - 1 - li[::-1].index(char) |
|
return idx |
|
|
|
def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str: |
|
self.m2m100_tokenizer.src_lang = from_lang |
|
source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs)) |
|
target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]] |
|
results = self.translator.translate_batch([source], target_prefix=[target_prefix]) |
|
target = results[0].hypotheses[0][1:] |
|
translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target)) |
|
return translated_text |