File size: 4,793 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import codecs
import os
from onmt.utils.parse import ArgumentParser
from onmt.translate import GNMTGlobalScorer, Translator
from onmt.opts import translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe
class ScoringPreparator:
"""Allow the calculation of metrics via the Trainer's
training_eval_handler method.
"""
def __init__(self, vocabs, opt):
self.vocabs = vocabs
self.opt = opt
if self.opt.dump_preds is not None:
if not os.path.exists(self.opt.dump_preds):
os.makedirs(self.opt.dump_preds)
self.transforms = opt.transforms
transforms_cls = get_transforms_cls(self.transforms)
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
self.transform = TransformPipe.build_from(transforms.values())
def warm_up(self, transforms):
self.transforms = transforms
transforms_cls = get_transforms_cls(self.transforms)
transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
self.transform = TransformPipe.build_from(transforms.values())
def translate(self, model, gpu_rank, step):
"""Compute and save the sentences predicted by the
current model's state related to a batch.
Args:
model (:obj:`onmt.models.NMTModel`): The current model's state.
transformed_batches(list of lists): A list of transformed batches.
gpu_rank (int): Ordinal rank of the gpu where the
translation is to be done.
step: The current training step.
mode: (string): 'train' or 'valid'.
Returns:
preds (list): Detokenized predictions
texts_ref (list): Detokenized target sentences
"""
# ########## #
# Translator #
# ########## #
# Set translation options
parser = ArgumentParser()
translate_opts(parser)
base_args = ["-model", "dummy"] + ["-src", "dummy"]
opt = parser.parse_args(base_args)
opt.gpu = gpu_rank
ArgumentParser.validate_translate_opts(opt)
# Build translator from options
scorer = GNMTGlobalScorer.from_opt(opt)
out_file = codecs.open(os.devnull, "w", "utf-8")
model_opt = self.opt
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
translator = Translator.from_opt(
model,
self.vocabs,
opt,
model_opt,
global_scorer=scorer,
out_file=out_file,
report_align=opt.report_align,
report_score=False,
logger=None,
)
# ################### #
# Validation iterator #
# ################### #
# Reinstantiate the validation iterator
transforms_cls = get_transforms_cls(model_opt._all_transform)
model_opt.num_workers = 0
model_opt.tgt = None
valid_iter = build_dynamic_dataset_iter(
model_opt,
transforms_cls,
translator.vocabs,
task=CorpusTask.VALID,
copy=model_opt.copy_attn,
)
# Retrieve raw references and sources
with codecs.open(
valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8"
) as f:
raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
with codecs.open(
valid_iter.corpora_info["valid"]["path_src"], "r", encoding="utf-8"
) as f:
raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]
valid_iter = IterOnDevice(valid_iter, opt.gpu)
# ########### #
# Predictions #
# ########### #
_, preds = translator._translate(
valid_iter,
transform=valid_iter.transform,
attn_debug=opt.attn_debug,
align_debug=opt.align_debug,
)
# ####### #
# Outputs #
# ####### #
# Flatten predictions
preds = [x.lstrip() for sublist in preds for x in sublist]
# Save results
if len(preds) > 0 and self.opt.scoring_debug:
path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt")
with open(path, "a") as file:
for i in range(len(preds)):
file.write("SOURCE: {}\n".format(raw_srcs[i]))
file.write("REF: {}\n".format(raw_refs[i]))
file.write("PRED: {}\n\n".format(preds[i]))
return preds, raw_refs
|