|
|
|
|
|
|
|
import copy |
|
import unittest |
|
import glob |
|
import os |
|
|
|
from onmt.utils.parse import ArgumentParser |
|
from onmt.opts import dynamic_prepare_opts |
|
from onmt.train_single import prepare_transforms_vocabs |
|
from onmt.constants import CorpusName |
|
|
|
|
|
SAVE_DATA_PREFIX = "data/test_data_prepare" |
|
|
|
|
|
def get_default_opts(): |
|
parser = ArgumentParser(description="data sample prepare") |
|
dynamic_prepare_opts(parser) |
|
|
|
default_opts = [ |
|
"-config", |
|
"data/data.yaml", |
|
"-src_vocab", |
|
"data/vocab-train.src", |
|
"-tgt_vocab", |
|
"data/vocab-train.tgt", |
|
] |
|
|
|
opt = parser.parse_known_args(default_opts)[0] |
|
|
|
opt.copy_attn = False |
|
ArgumentParser.validate_prepare_opts(opt) |
|
return opt |
|
|
|
|
|
default_opts = get_default_opts() |
|
|
|
|
|
class TestData(unittest.TestCase): |
|
def __init__(self, *args, **kwargs): |
|
super(TestData, self).__init__(*args, **kwargs) |
|
self.opt = default_opts |
|
|
|
def dataset_build(self, opt): |
|
try: |
|
prepare_transforms_vocabs(opt, {}) |
|
except SystemExit as err: |
|
print(err) |
|
except IOError as err: |
|
if opt.skip_empty_level != "error": |
|
raise err |
|
else: |
|
print(f"Catched IOError: {err}") |
|
finally: |
|
|
|
for pt in glob.glob(SAVE_DATA_PREFIX + "*.pt"): |
|
os.remove(pt) |
|
if self.opt.save_data: |
|
|
|
sample_path = os.path.join( |
|
os.path.dirname(self.opt.save_data), CorpusName.SAMPLE |
|
) |
|
if os.path.exists(sample_path): |
|
for f in glob.glob(sample_path + "/*"): |
|
os.remove(f) |
|
os.rmdir(sample_path) |
|
|
|
|
|
def _add_test(param_setting, methodname): |
|
""" |
|
Adds a Test to TestData according to settings |
|
|
|
Args: |
|
param_setting: list of tuples of (param, setting) |
|
methodname: name of the method that gets called |
|
""" |
|
|
|
def test_method(self): |
|
if param_setting: |
|
opt = copy.deepcopy(self.opt) |
|
for param, setting in param_setting: |
|
setattr(opt, param, setting) |
|
else: |
|
opt = self.opt |
|
getattr(self, methodname)(opt) |
|
|
|
if param_setting: |
|
name = "test_" + methodname + "_" + "_".join(str(param_setting).split()) |
|
else: |
|
name = "test_" + methodname + "_standard" |
|
setattr(TestData, name, test_method) |
|
test_method.__name__ = name |
|
|
|
|
|
test_databuild = [ |
|
[], |
|
[("src_vocab_size", 1), ("tgt_vocab_size", 1)], |
|
[("src_vocab_size", 10000), ("tgt_vocab_size", 10000)], |
|
[("src_seq_len", 1)], |
|
[("src_seq_len", 5000)], |
|
[("src_seq_length_trunc", 1)], |
|
[("src_seq_length_trunc", 5000)], |
|
[("tgt_seq_len", 1)], |
|
[("tgt_seq_len", 5000)], |
|
[("tgt_seq_length_trunc", 1)], |
|
[("tgt_seq_length_trunc", 5000)], |
|
[("copy_attn", True)], |
|
[("share_vocab", True)], |
|
[("n_sample", 30), ("save_data", SAVE_DATA_PREFIX)], |
|
[("n_sample", 30), ("save_data", SAVE_DATA_PREFIX), ("skip_empty_level", "error")], |
|
] |
|
|
|
for p in test_databuild: |
|
_add_test(p, "dataset_build") |
|
|