import transformers import ctranslate2 from typing import List, Dict import os class PreTrainedPipeline(): def __init__(self, path: str): # Init DialoGPT self.eos_token = "\n" dialogpt_path = os.path.join(path, "opt") self.generator = ctranslate2.Generator(dialogpt_path, device="cpu", compute_type="float") self.tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/opt-350m") # Init M2M100 m2m100_path = os.path.join(path, "m2m100") self.translator = ctranslate2.Translator(m2m100_path, device="cpu", compute_type="int8") self.m2m100_tokenizer = transformers.AutoTokenizer.from_pretrained("facebook/m2m100_418M") def __call__(self, inputs: str) -> List[Dict]: # to eng en_text = self.m2m100(inputs, "uk", "en") # Run dialogpt generated_text = self.dialogpt(en_text) # to ukr uk_text = self.m2m100(generated_text, "en", "uk") return [{"generated_text": uk_text}] def dialogpt(self, inputs: str) -> str: # Get input tokens text = inputs + self.eos_token start_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) # generate results = self.generator.generate_batch([start_tokens], max_length=50, repetition_penalty=1.2) output = results[0].sequences[0] # left only answers tokens = self.tokenizer.convert_tokens_to_ids(output) generated_text = self.tokenizer.decode(tokens) eos_index = self.index_last(generated_text, self.eos_token) answer_text = generated_text[eos_index+1:] return answer_text @staticmethod def index_last(li: str, char: str): idx = len(li) - 1 - li[::-1].index(char) return idx def m2m100(self, inputs: str, from_lang: str, to_lang: str) -> str: self.m2m100_tokenizer.src_lang = from_lang source = self.m2m100_tokenizer.convert_ids_to_tokens(self.m2m100_tokenizer.encode(inputs)) target_prefix = [self.m2m100_tokenizer.lang_code_to_token[to_lang]] results = self.translator.translate_batch([source], target_prefix=[target_prefix]) target = results[0].hypotheses[0][1:] translated_text = self.m2m100_tokenizer.decode(self.m2m100_tokenizer.convert_tokens_to_ids(target)) return translated_text