File size: 4,793 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import codecs
import os
from onmt.utils.parse import ArgumentParser
from onmt.translate import GNMTGlobalScorer, Translator
from onmt.opts import translate_opts
from onmt.constants import CorpusTask
from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter
from onmt.inputters.inputter import IterOnDevice
from onmt.transforms import get_transforms_cls, make_transforms, TransformPipe


class ScoringPreparator:
    """Allow the calculation of metrics via the Trainer's
    training_eval_handler method.
    """

    def __init__(self, vocabs, opt):
        self.vocabs = vocabs
        self.opt = opt
        if self.opt.dump_preds is not None:
            if not os.path.exists(self.opt.dump_preds):
                os.makedirs(self.opt.dump_preds)
        self.transforms = opt.transforms
        transforms_cls = get_transforms_cls(self.transforms)
        transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
        self.transform = TransformPipe.build_from(transforms.values())

    def warm_up(self, transforms):
        self.transforms = transforms
        transforms_cls = get_transforms_cls(self.transforms)
        transforms = make_transforms(self.opt, transforms_cls, self.vocabs)
        self.transform = TransformPipe.build_from(transforms.values())

    def translate(self, model, gpu_rank, step):
        """Compute and save the sentences predicted by the
        current model's state related to a batch.

        Args:
            model (:obj:`onmt.models.NMTModel`): The current model's state.
            transformed_batches(list of lists): A list of transformed batches.
            gpu_rank (int): Ordinal rank of the gpu where the
                translation is to be done.
            step: The current training step.
            mode: (string): 'train' or 'valid'.
        Returns:
            preds (list): Detokenized predictions
            texts_ref (list): Detokenized target sentences
        """
        # ########## #
        # Translator #
        # ########## #

        # Set translation options
        parser = ArgumentParser()
        translate_opts(parser)
        base_args = ["-model", "dummy"] + ["-src", "dummy"]
        opt = parser.parse_args(base_args)
        opt.gpu = gpu_rank
        ArgumentParser.validate_translate_opts(opt)

        # Build translator from options
        scorer = GNMTGlobalScorer.from_opt(opt)
        out_file = codecs.open(os.devnull, "w", "utf-8")
        model_opt = self.opt
        ArgumentParser.update_model_opts(model_opt)
        ArgumentParser.validate_model_opts(model_opt)
        translator = Translator.from_opt(
            model,
            self.vocabs,
            opt,
            model_opt,
            global_scorer=scorer,
            out_file=out_file,
            report_align=opt.report_align,
            report_score=False,
            logger=None,
        )

        # ################### #
        # Validation iterator #
        # ################### #

        # Reinstantiate the validation iterator

        transforms_cls = get_transforms_cls(model_opt._all_transform)
        model_opt.num_workers = 0
        model_opt.tgt = None

        valid_iter = build_dynamic_dataset_iter(
            model_opt,
            transforms_cls,
            translator.vocabs,
            task=CorpusTask.VALID,
            copy=model_opt.copy_attn,
        )

        # Retrieve raw references and sources
        with codecs.open(
            valid_iter.corpora_info["valid"]["path_tgt"], "r", encoding="utf-8"
        ) as f:
            raw_refs = [line.strip("\n") for line in f if line.strip("\n")]
        with codecs.open(
            valid_iter.corpora_info["valid"]["path_src"], "r", encoding="utf-8"
        ) as f:
            raw_srcs = [line.strip("\n") for line in f if line.strip("\n")]

        valid_iter = IterOnDevice(valid_iter, opt.gpu)

        # ########### #
        # Predictions #
        # ########### #

        _, preds = translator._translate(
            valid_iter,
            transform=valid_iter.transform,
            attn_debug=opt.attn_debug,
            align_debug=opt.align_debug,
        )

        # ####### #
        # Outputs #
        # ####### #

        # Flatten predictions
        preds = [x.lstrip() for sublist in preds for x in sublist]

        # Save results
        if len(preds) > 0 and self.opt.scoring_debug:
            path = os.path.join(self.opt.dump_preds, f"preds.valid_step_{step}.txt")
            with open(path, "a") as file:
                for i in range(len(preds)):
                    file.write("SOURCE: {}\n".format(raw_srcs[i]))
                    file.write("REF: {}\n".format(raw_refs[i]))
                    file.write("PRED: {}\n\n".format(preds[i]))
        return preds, raw_refs