import codecs import os from onmt.utils.parse import ArgumentParser from onmt.translate import GNMTGlobalScorer, Translator from onmt.opts import translate_opts from onmt.constants import CorpusTask from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.inputters.inputter import IterOnDevice from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe class ScoringPreparator: """Allow the calculation of metrics via the Trainer's training_eval_handler method. """ def __init__(self, vocabs, opt): self.vocabs = vocabs self.opt = opt if self.opt.dump_preds is not None: if not os.path.exists(self.opt.dump_preds): os.makedirs(self.opt.dump_preds) self.transforms = opt.transforms transforms_cls = get_transforms_cls(self.transforms) transforms = make_transforms(self.opt, transforms_cls, self.vocabs) self.transform = TransformPipe.build_from(transforms.values()) def warm_up(self, transforms): self.transforms = transforms transforms_cls = get_transforms_cls(self.transforms) transforms = make_transforms(self.opt, transforms_cls, self.vocabs) self.transform = TransformPipe.build_from(transforms.values()) def translate(self, model, gpu_rank, step): """Compute and save the sentences predicted by the current model's state related to a batch. Args: model (:obj:`onmt.models.NMTModel`): The current model's state. transformed_batches(list of lists): A list of transformed batches. gpu_rank (int): Ordinal rank of the gpu where the translation is to be done. step: The current training step. mode: (string): 'train' or 'valid'. Returns: preds (list): Detokenized predictions texts_ref (list): Detokenized target sentences """ # ########## # # Translator # # ########## # # Set translation options parser = ArgumentParser() translate_opts(parser) base_args = ["-model", "dummy"] + ["-src", "dummy"] opt = parser.parse_args(base_args) opt.gpu = gpu_rank ArgumentParser.validate_translate_opts(opt) # Build translator from options scorer = GNMTGlobalScorer.from_opt(opt) out_file = codecs.open(os.devnull, "w", "utf-8") model_opt = self.opt ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) translator = Translator.from_opt( model, self.vocabs, opt, model_opt, global_scorer=scorer, out_file=out_file, report_align=opt.report_align, report_score=False, logger=None, ) # ################### # # Validation iterator # # ################### # # Reinstantiate the validation iterator transforms_cls = get_transforms_cls(model_opt._all_transform) model_opt.num_workers = 0 model_opt.tgt = None valid_iter = build_dynamic_dataset_iter( model_opt, transforms_cls, translator.vocabs, task=CorpusTask.VALID, copy=model_opt.copy_attn, ) # Retrieve raw references and sources with codecs.open( valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8" ) as f: raw_refs = [line.strip("\n") for line in f if line.strip("\n")] with codecs.open( valid_iter.corpora_info["valid"]["path_src"], "r", encoding="utf-8" ) as f: raw_srcs = [line.strip("\n") for line in f if line.strip("\n")] valid_iter = IterOnDevice(valid_iter, opt.gpu) # ########### # # Predictions # # ########### # _, preds = translator._translate( valid_iter, transform=valid_iter.transform, attn_debug=opt.attn_debug, align_debug=opt.align_debug, ) # ####### # # Outputs # # ####### # # Flatten predictions preds = [x.lstrip() for sublist in preds for x in sublist] # Save results if len(preds) > 0 and self.opt.scoring_debug: path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt") with open(path, "a") as file: for i in range(len(preds)): file.write("SOURCE: {}\n".format(raw_srcs[i])) file.write("REF: {}\n".format(raw_refs[i])) file.write("PRED: {}\n\n".format(preds[i])) return preds, raw_refs