File size: 10,498 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#!/usr/bin/env python
"""Get vocabulary coutings from transformed corpora samples."""
import os
import copy
import multiprocessing as mp
import pyonmttok
from functools import partial
from onmt.utils.logging import init_logger, logger
from onmt.utils.misc import set_random_seed, check_path
from onmt.utils.parse import ArgumentParser
from onmt.opts import dynamic_prepare_opts
from onmt.inputters.text_corpus import build_corpora_iters, get_corpora
from onmt.inputters.text_utils import process, append_features_to_text
from onmt.transforms import make_transforms, get_transforms_cls
from onmt.constants import CorpusName, CorpusTask
from collections import Counter


MAXBUCKETSIZE = 256000


def write_files_from_queues(sample_path, queues):
    """
    Standalone process that reads data from
    queues in order and write to sample files.
    """
    os.makedirs(sample_path, exist_ok=True)
    for c_name in queues.keys():
        dest_base = os.path.join(sample_path, "{}.{}".format(c_name, CorpusName.SAMPLE))
        with open(dest_base + ".src", "w", encoding="utf-8") as f_src, open(
            dest_base + ".tgt", "w", encoding="utf-8"
        ) as f_tgt:
            while True:
                _next = False
                for q in queues[c_name]:
                    item = q.get()
                    if item == "blank":
                        continue
                    if item == "break":
                        _next = True
                        break
                    _, src_line, tgt_line = item
                    f_src.write(src_line + "\n")
                    f_tgt.write(tgt_line + "\n")
                if _next:
                    break


def build_sub_vocab(corpora, transforms, opts, n_sample, stride, offset):
    """Build vocab on (strided) subpart of the data."""
    sub_counter_src = Counter()
    sub_counter_tgt = Counter()
    sub_counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]
    datasets_iterables = build_corpora_iters(
        corpora,
        transforms,
        opts.data,
        skip_empty_level=opts.skip_empty_level,
        stride=stride,
        offset=offset,
    )
    for c_name, c_iter in datasets_iterables.items():
        for i, item in enumerate(c_iter):
            maybe_example = process(CorpusTask.TRAIN, [item])
            if maybe_example is not None:
                maybe_example = maybe_example[0]
            else:
                if opts.dump_samples:
                    build_sub_vocab.queues[c_name][offset].put("blank")
                continue
            src_line, tgt_line = (
                maybe_example["src"]["src"],
                maybe_example["tgt"]["tgt"],
            )
            sub_counter_src.update(src_line.split(" "))
            sub_counter_tgt.update(tgt_line.split(" "))

            if "feats" in maybe_example["src"]:
                src_feats_lines = maybe_example["src"]["feats"]
                for k in range(opts.n_src_feats):
                    sub_counter_src_feats[k].update(src_feats_lines[k].split(" "))
            else:
                src_feats_lines = []

            if opts.dump_samples:
                src_pretty_line = append_features_to_text(src_line, src_feats_lines)
                build_sub_vocab.queues[c_name][offset].put(
                    (i, src_pretty_line, tgt_line)
                )
            if n_sample > 0 and ((i + 1) * stride + offset) >= n_sample:
                if opts.dump_samples:
                    build_sub_vocab.queues[c_name][offset].put("break")
                break
        if opts.dump_samples:
            build_sub_vocab.queues[c_name][offset].put("break")
    return sub_counter_src, sub_counter_tgt, sub_counter_src_feats


def init_pool(queues):
    """Add the queues as attribute of the pooled function."""
    build_sub_vocab.queues = queues


def build_vocab(opts, transforms, n_sample=3):
    """Build vocabulary from data."""

    if n_sample == -1:
        logger.info(f"n_sample={n_sample}: Build vocab on full datasets.")
    elif n_sample > 0:
        logger.info(f"Build vocab on {n_sample} transformed examples/corpus.")
    else:
        raise ValueError(f"n_sample should > 0 or == -1, get {n_sample}.")

    if opts.dump_samples:
        logger.info(
            "The samples on which the vocab is built will be "
            "dumped to disk. It may slow down the process."
        )
    corpora = get_corpora(opts, task=CorpusTask.TRAIN)
    counter_src = Counter()
    counter_tgt = Counter()
    counter_src_feats = [Counter() for _ in range(opts.n_src_feats)]

    queues = {
        c_name: [
            mp.Queue(opts.vocab_sample_queue_size) for i in range(opts.num_threads)
        ]
        for c_name in corpora.keys()
    }
    sample_path = os.path.join(os.path.dirname(opts.save_data), CorpusName.SAMPLE)
    if opts.dump_samples:
        write_process = mp.Process(
            target=write_files_from_queues, args=(sample_path, queues), daemon=True
        )
        write_process.start()
    with mp.Pool(opts.num_threads, init_pool, [queues]) as p:
        func = partial(
            build_sub_vocab, corpora, transforms, opts, n_sample, opts.num_threads
        )
        for sub_counter_src, sub_counter_tgt, sub_counter_src_feats in p.imap(
            func, range(0, opts.num_threads)
        ):
            counter_src.update(sub_counter_src)
            counter_tgt.update(sub_counter_tgt)
            for i in range(opts.n_src_feats):
                counter_src_feats[i].update(sub_counter_src_feats[i])
    if opts.dump_samples:
        write_process.join()
    return counter_src, counter_tgt, counter_src_feats


def ingest_tokens(opts, transforms, n_sample, learner, stride, offset):
    def _mp_ingest(data):
        func = partial(process, CorpusName.TRAIN)
        chunk = len(data) // opts.num_threads
        with mp.Pool(opts.num_threads) as pool:
            buckets = pool.map(
                func,
                [data[i * chunk : (i + 1) * chunk] for i in range(0, opts.num_threads)],
            )
        for bucket in buckets:
            for ex in bucket:
                if ex is not None:
                    src_line, tgt_line = (ex["src"]["src"], ex["tgt"]["tgt"])
                    learner.ingest(src_line)
                    learner.ingest(tgt_line)

    corpora = get_corpora(opts, task=CorpusTask.TRAIN)
    datasets_iterables = build_corpora_iters(
        corpora,
        transforms,
        opts.data,
        skip_empty_level=opts.skip_empty_level,
        stride=stride,
        offset=offset,
    )
    to_ingest = []
    for c_name, c_iter in datasets_iterables.items():
        for i, item in enumerate(c_iter):
            if n_sample >= 0 and i >= n_sample:
                break
            if len(to_ingest) >= MAXBUCKETSIZE:
                _mp_ingest(to_ingest)
                to_ingest = []
            to_ingest.append(item)
        _mp_ingest(to_ingest)


def make_learner(tokenization_type, symbols):
    if tokenization_type == "bpe":
        # BPE training
        learner = pyonmttok.BPELearner(tokenizer=None, symbols=symbols)
    elif tokenization_type == "sentencepiece":
        # SentencePiece training
        learner = pyonmttok.SentencePieceLearner(
            vocab_size=symbols, character_coverage=0.98
        )
    return learner


def build_vocab_main(opts):
    """Apply transforms to samples of specified data and build vocab from it.

    Transforms that need vocab will be disabled in this.
    Built vocab is saved in plain text format as following and can be pass as
    `-src_vocab` (and `-tgt_vocab`) when training:
    ```
    <tok_0>\t<count_0>
    <tok_1>\t<count_1>
    ```
    """

    ArgumentParser.validate_prepare_opts(opts, build_vocab_only=True)
    assert (
        opts.n_sample == -1 or opts.n_sample > 1
    ), f"Illegal argument n_sample={opts.n_sample}."

    logger = init_logger()
    set_random_seed(opts.seed, False)
    transforms_cls = get_transforms_cls(opts._all_transform)

    if opts.learn_subwords:
        logger.info(f"Ingesting {opts.src_subword_type} model from corpus")
        learner = make_learner(opts.src_subword_type, opts.learn_subwords_size)
        if opts.src_subword_model is not None:
            tok_path = opts.src_subword_model
        else:
            data_dir = os.path.split(opts.save_data)[0]
            if not os.path.exists(data_dir):
                os.makedirs(data_dir)
            tok_path = os.path.join(data_dir, f"{opts.src_subword_type}.model")
        save_opts = copy.deepcopy(opts)
        opts.src_subword_type = "none"
        opts.tgt_subword_type = "none"
        opts.src_onmttok_kwargs["joiner_annotate"] = False
        opts.tgt_onmttok_kwargs["joiner_annotate"] = False
        transforms = make_transforms(opts, transforms_cls, None)
        ingest_tokens(opts, transforms, opts.n_sample, learner, 1, 0)
        logger.info(f"Learning {tok_path} model, patience")
        learner.learn(tok_path)
        opts = save_opts

    transforms = make_transforms(opts, transforms_cls, None)

    logger.info(f"Counter vocab from {opts.n_sample} samples.")
    src_counter, tgt_counter, src_feats_counter = build_vocab(
        opts, transforms, n_sample=opts.n_sample
    )

    logger.info(f"Counters src: {len(src_counter)}")
    logger.info(f"Counters tgt: {len(tgt_counter)}")
    for i, feat_counter in enumerate(src_feats_counter):
        logger.info(f"Counters src feat_{i}: {len(feat_counter)}")

    def save_counter(counter, save_path):
        check_path(save_path, exist_ok=opts.overwrite, log=logger.warning)
        with open(save_path, "w", encoding="utf8") as fo:
            for tok, count in counter.most_common():
                fo.write(tok + "\t" + str(count) + "\n")

    if opts.share_vocab:
        src_counter += tgt_counter
        tgt_counter = src_counter
        logger.info(f"Counters after share:{len(src_counter)}")
        save_counter(src_counter, opts.src_vocab)
    else:
        save_counter(src_counter, opts.src_vocab)
        save_counter(tgt_counter, opts.tgt_vocab)

    for i, c in enumerate(src_feats_counter):
        save_counter(c, f"{opts.src_vocab}_feat{i}")


def _get_parser():
    parser = ArgumentParser(description="build_vocab.py")
    dynamic_prepare_opts(parser, build_vocab_only=True)
    return parser


def main():
    parser = _get_parser()
    opts, unknown = parser.parse_known_args()
    build_vocab_main(opts)


if __name__ == "__main__":
    main()