File size: 14,487 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import configargparse as cfargparse
import os
import torch

import onmt.opts as opts
from onmt.utils.logging import logger
from onmt.constants import CorpusName, ModelTask
from onmt.transforms import AVAILABLE_TRANSFORMS


class DataOptsCheckerMixin(object):
    """Checker with methods for validate data related options."""

    @staticmethod
    def _validate_file(file_path, info):
        """Check `file_path` is valid or raise `IOError`."""
        if not os.path.isfile(file_path):
            raise IOError(f"Please check path of your {info} file!")

    @classmethod
    def _validate_data(cls, opt):
        """Parse corpora specified in data field of YAML file."""
        import yaml

        default_transforms = opt.transforms
        if len(default_transforms) != 0:
            logger.info(f"Default transforms: {default_transforms}.")
        corpora = yaml.safe_load(opt.data)

        for cname, corpus in corpora.items():
            # Check Transforms
            _transforms = corpus.get("transforms", None)
            if _transforms is None:
                logger.info(
                    f"Missing transforms field for {cname} data, "
                    f"set to default: {default_transforms}."
                )
                corpus["transforms"] = default_transforms
            # Check path
            path_src = corpus.get("path_src", None)
            path_tgt = corpus.get("path_tgt", None)
            if path_src is None:
                raise ValueError(
                    f"Corpus {cname} src path is required."
                    "tgt path is also required for non language"
                    " modeling tasks."
                )
            else:
                opt.data_task = ModelTask.SEQ2SEQ
                if path_tgt is None:
                    logger.debug(
                        "path_tgt is None, it should be set unless the task"
                        " is language modeling"
                    )
                    opt.data_task = ModelTask.LANGUAGE_MODEL
                    # tgt is src for LM task
                    corpus["path_tgt"] = path_src
                    corpora[cname] = corpus
                    path_tgt = path_src
                cls._validate_file(path_src, info=f"{cname}/path_src")
                cls._validate_file(path_tgt, info=f"{cname}/path_tgt")
            path_align = corpus.get("path_align", None)
            if path_align is None:
                if hasattr(opt, "lambda_align") and opt.lambda_align > 0.0:
                    raise ValueError(
                        f"Corpus {cname} alignment file path are "
                        "required when lambda_align > 0.0"
                    )
                corpus["path_align"] = None
            else:
                cls._validate_file(path_align, info=f"{cname}/path_align")
            # Check weight
            weight = corpus.get("weight", None)
            if weight is None:
                if cname != CorpusName.VALID:
                    logger.warning(
                        f"Corpus {cname}'s weight should be given."
                        " We default it to 1 for you."
                    )
                corpus["weight"] = 1

            # Check features
            if opt.n_src_feats > 0:
                if "inferfeats" not in corpus["transforms"]:
                    raise ValueError(
                        "'inferfeats' transform is required "
                        "when setting source features"
                    )

        logger.info(f"Parsed {len(corpora)} corpora from -data.")
        opt.data = corpora

    @classmethod
    def _validate_transforms_opts(cls, opt):
        """Check options used by transforms."""
        for name, transform_cls in AVAILABLE_TRANSFORMS.items():
            if name in opt._all_transform:
                transform_cls._validate_options(opt)

    @classmethod
    def _get_all_transform(cls, opt):
        """Should only called after `_validate_data`."""
        all_transforms = set(opt.transforms)
        for cname, corpus in opt.data.items():
            _transforms = set(corpus["transforms"])
            if len(_transforms) != 0:
                all_transforms.update(_transforms)
        if hasattr(opt, "lambda_align") and opt.lambda_align > 0.0:
            if not all_transforms.isdisjoint({"sentencepiece", "bpe", "onmt_tokenize"}):
                raise ValueError(
                    "lambda_align is not compatible with" " on-the-fly tokenization."
                )
            if not all_transforms.isdisjoint({"tokendrop", "prefix", "bart"}):
                raise ValueError(
                    "lambda_align is not compatible yet with"
                    " potentiel token deletion/addition."
                )
        opt._all_transform = all_transforms

    @classmethod
    def _get_all_transform_translate(cls, opt):
        opt._all_transform = opt.transforms

    @classmethod
    def _validate_vocab_opts(cls, opt, build_vocab_only=False):
        """Check options relate to vocab."""

        if build_vocab_only:
            if not opt.share_vocab:
                assert opt.tgt_vocab, "-tgt_vocab is required if not -share_vocab."
            return
        # validation when train:
        cls._validate_file(opt.src_vocab, info="src vocab")
        if not opt.share_vocab:
            cls._validate_file(opt.tgt_vocab, info="tgt vocab")

        if opt.dump_transforms:
            assert (
                opt.save_data
            ), "-save_data should be set if set \
                -dump_transforms."
        # Check embeddings stuff
        if opt.both_embeddings is not None:
            assert (
                opt.src_embeddings is None and opt.tgt_embeddings is None
            ), "You don't need -src_embeddings or -tgt_embeddings \
                if -both_embeddings is set."

        if any(
            [
                opt.both_embeddings is not None,
                opt.src_embeddings is not None,
                opt.tgt_embeddings is not None,
            ]
        ):
            assert (
                opt.embeddings_type is not None
            ), "You need to specify an -embedding_type!"
            assert (
                opt.save_data
            ), "-save_data should be set if use \
                pretrained embeddings."

    @classmethod
    def _validate_language_model_compatibilities_opts(cls, opt):
        if opt.model_task != ModelTask.LANGUAGE_MODEL:
            return

        logger.info("encoder is not used for LM task")

        assert opt.share_vocab and (
            opt.tgt_vocab is None
        ), "vocab must be shared for LM task"

        assert (
            opt.decoder_type == "transformer"
        ), "Only transformer decoder is supported for LM task"

    @classmethod
    def _validate_source_features_opts(cls, opt):
        if opt.src_feats_defaults is not None:
            assert opt.n_src_feats == len(
                opt.src_feats_defaults.split("│")
            ), "The number source features defaults does not match \
                -n_src_feats"

    @classmethod
    def validate_prepare_opts(cls, opt, build_vocab_only=False):
        """Validate all options relate to prepare (data/transform/vocab)."""
        if opt.n_sample != 0:
            assert (
                opt.save_data
            ), "-save_data should be set if \
                want save samples."
        cls._validate_data(opt)
        cls._get_all_transform(opt)
        cls._validate_transforms_opts(opt)
        cls._validate_vocab_opts(opt, build_vocab_only=build_vocab_only)
        cls._validate_source_features_opts(opt)

    @classmethod
    def validate_model_opts(cls, opt):
        cls._validate_language_model_compatibilities_opts(opt)


class ArgumentParser(cfargparse.ArgumentParser, DataOptsCheckerMixin):
    """OpenNMT option parser powered with option check methods."""

    def __init__(
        self,
        config_file_parser_class=cfargparse.YAMLConfigFileParser,
        formatter_class=cfargparse.ArgumentDefaultsHelpFormatter,
        **kwargs,
    ):
        super(ArgumentParser, self).__init__(
            config_file_parser_class=config_file_parser_class,
            formatter_class=formatter_class,
            **kwargs,
        )

    @classmethod
    def defaults(cls, *args):
        """Get default arguments added to a parser by all ``*args``."""
        dummy_parser = cls()
        for callback in args:
            callback(dummy_parser)
        defaults = dummy_parser.parse_known_args([])[0]
        return defaults

    @classmethod
    def update_model_opts(cls, model_opt):
        if model_opt.word_vec_size > 0:
            model_opt.src_word_vec_size = model_opt.word_vec_size
            model_opt.tgt_word_vec_size = model_opt.word_vec_size

        # Backward compatibility with "fix_word_vecs_*" opts
        if hasattr(model_opt, "fix_word_vecs_enc"):
            model_opt.freeze_word_vecs_enc = model_opt.fix_word_vecs_enc
        if hasattr(model_opt, "fix_word_vecs_dec"):
            model_opt.freeze_word_vecs_dec = model_opt.fix_word_vecs_dec

        if model_opt.layers > 0:
            model_opt.enc_layers = model_opt.layers
            model_opt.dec_layers = model_opt.layers

        if model_opt.hidden_size > 0:
            model_opt.enc_hid_size = model_opt.hidden_size
            model_opt.dec_hid_size = model_opt.hidden_size

        model_opt.brnn = model_opt.encoder_type == "brnn"

        if model_opt.copy_attn_type is None:
            model_opt.copy_attn_type = model_opt.global_attention

        if model_opt.alignment_layer is None:
            model_opt.alignment_layer = -2
            model_opt.lambda_align = 0.0
            model_opt.full_context_alignment = False

    @classmethod
    def validate_model_opts(cls, model_opt):
        assert model_opt.model_type in ["text"], (
            "Unsupported model type %s" % model_opt.model_type
        )

        # encoder and decoder should be same sizes
        same_size = model_opt.enc_hid_size == model_opt.dec_hid_size
        assert same_size, "The encoder and decoder rnns must be the same size for now"

        assert (
            model_opt.rnn_type != "SRU" or model_opt.gpu_ranks
        ), "Using SRU requires -gpu_ranks set."
        if model_opt.share_embeddings:
            if model_opt.model_type != "text":
                raise AssertionError("--share_embeddings requires --model_type text.")
        if model_opt.lambda_align > 0.0:
            assert (
                model_opt.decoder_type == "transformer"
            ), "Only transformer is supported to joint learn alignment."
            assert (
                model_opt.alignment_layer < model_opt.dec_layers
                and model_opt.alignment_layer >= -model_opt.dec_layers
            ), "N° alignment_layer should be smaller than number of layers."
            logger.info(
                "Joint learn alignment at layer [{}] "
                "with {} heads in full_context '{}'.".format(
                    model_opt.alignment_layer,
                    model_opt.alignment_heads,
                    model_opt.full_context_alignment,
                )
            )

        if model_opt.feat_merge == "concat" and model_opt.feat_vec_size > 0:
            assert (
                model_opt.feat_vec_size * model_opt.n_src_feats
            ) + model_opt.src_word_vec_size == model_opt.hidden_size, (
                "(feat_vec_size * n_src_feats) + "
                "src_word_vec_size should be equal to hidden_size with "
                "-feat_merge concat mode."
            )

        if model_opt.position_encoding and model_opt.max_relative_positions != 0:
            raise ValueError(
                "Cannot use absolute and relative position encoding at the"
                "same time. Use either --position_encoding=true for legacy"
                "absolute position encoding or --max_realtive_positions with"
                " -1 for Rotary, or > 0 for Relative Position Representations"
                "as in https://arxiv.org/pdf/1803.02155.pdf"
            )
        if model_opt.multiquery and model_opt.num_kv == 0:
            model_opt.num_kv = 1

    @classmethod
    def ckpt_model_opts(cls, ckpt_opt):
        # Load default opt values, then overwrite with the opts in
        # the checkpoint. That way, if there are new options added,
        # the defaults are used.
        opt = cls.defaults(opts.model_opts)
        opt.__dict__.update(ckpt_opt.__dict__)
        return opt

    @classmethod
    def validate_train_opts(cls, opt):
        if torch.cuda.is_available() and not opt.gpu_ranks:
            logger.warn("You have a CUDA device, should run with -gpu_ranks")
        if opt.world_size < len(opt.gpu_ranks):
            raise AssertionError(
                "parameter counts of -gpu_ranks must be less or equal "
                "than -world_size."
            )
        if opt.world_size == len(opt.gpu_ranks) and min(opt.gpu_ranks) > 0:
            raise AssertionError(
                "-gpu_ranks should have master(=0) rank "
                "unless -world_size is greater than len(gpu_ranks)."
            )

        assert len(opt.dropout) == len(
            opt.dropout_steps
        ), "Number of dropout values must match accum_steps values"

        assert len(opt.attention_dropout) == len(
            opt.dropout_steps
        ), "Number of attention_dropout values must match accum_steps values"

        assert len(opt.accum_count) == len(
            opt.accum_steps
        ), "Number of accum_count values must match number of accum_steps"

        if opt.update_vocab:
            assert opt.train_from, "-update_vocab needs -train_from option"
            assert opt.reset_optim in [
                "states",
                "all",
            ], '-update_vocab needs -reset_optim "states" or "all"'

    @classmethod
    def validate_translate_opts(cls, opt):
        if opt.gold_align:
            assert opt.report_align, "-report_align should be enabled with -gold_align"
            assert (
                not opt.replace_unk
            ), "-replace_unk option can not be used with -gold_align enabled"
            assert opt.tgt, "-tgt should be specified with -gold_align"

    @classmethod
    def validate_translate_opts_dynamic(cls, opt):
        # It comes from training
        # TODO: needs to be added as inference opt
        opt.share_vocab = False