ReactSeq / onmt /tests /test_subword_marker.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
12 kB
import unittest
from onmt.transforms.bart import word_start_finder
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer
from onmt.constants import SubwordMarker
class TestWordStartFinder(unittest.TestCase):
def test_word_start_naive(self):
word_start_finder_fn = word_start_finder(ignore_subword=True)
data_in = [
"however",
",",
"according",
"to",
"the",
"logs",
",",
"she",
"is",
"hard",
"-",
"working",
".",
] # noqa: E501
true_out = [
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
True,
] # noqa: E501
out = word_start_finder_fn(data_in)
self.assertEqual(out, true_out)
def test_word_start_joiner(self):
word_start_finder_fn = word_start_finder(is_joiner=True)
data_in = [
"however",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"she",
"is",
"hard",
"■-■",
"working",
"■.",
] # noqa: E501
true_out = [
True,
False,
True,
True,
True,
True,
False,
True,
True,
True,
False,
False,
False,
] # noqa: E501
out = word_start_finder_fn(data_in)
self.assertEqual(out, true_out)
def test_word_start_spacer(self):
word_start_finder_fn = word_start_finder()
data_in = [
"▁however",
",",
"▁according",
"▁to",
"▁the",
"▁logs",
",",
"▁she",
"▁is",
"▁hard",
"-",
"working",
".",
] # noqa: E501
true_out = [
True,
False,
True,
True,
True,
True,
False,
True,
True,
True,
False,
False,
False,
] # noqa: E501
out = word_start_finder_fn(data_in)
self.assertEqual(out, true_out)
# no dummy prefix
no_dummy = [
"however",
",",
"▁according",
"▁to",
"▁the",
"▁logs",
",",
"▁she",
"▁is",
"▁hard",
"-",
"working",
".",
] # noqa: E501
no_dummy_out = word_start_finder_fn(no_dummy)
self.assertEqual(no_dummy_out, true_out)
class TestSubwordGroup(unittest.TestCase):
def test_subword_group_joiner(self):
data_in = [
"however",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"she",
"is",
"hard",
"■-■",
"working",
"■.",
] # noqa: E501
true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7]
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER)
self.assertEqual(out, true_out)
def test_subword_group_joiner_with_case_markup(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"however",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"⦅mrk_begin_case_region_U⦆",
"she",
"is",
"hard",
"■-■",
"working",
"⦅mrk_end_case_region_U⦆",
"■.",
] # noqa: E501
true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER)
self.assertEqual(out, true_out)
def test_subword_group_joiner_with_case_markup_advanced(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"dummy",
"text",
"⦅mrk_case_modifier_C⦆",
"1■",
"h■",
"k",
"⦅mrk_begin_case_region_U⦆",
"th■",
"⦅mrk_end_case_region_U⦆",
"n",
"more",
"dummy",
"text",
] # noqa: E501
true_out = [0, 0, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 5, 6]
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER)
self.assertEqual(out, true_out)
def test_subword_group_joiner_prior_tokenization(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"how■",
"ever",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"⦅mrk_begin_case_region_U⦆",
"she",
"is",
"hard",
"■-■",
"working",
"⦅mrk_end_case_region_U⦆",
"■.",
] # noqa: E501
original_data_in = [
"However",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"SHE",
"IS",
"HARD-WORKING",
"■.",
] # noqa: E501
true_out = [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 7, 8, 9, 9, 9, 9, 10] # noqa: E501
out = subword_map_by_joiner(
data_in, marker=SubwordMarker.JOINER, original_subwords=original_data_in
)
self.assertEqual(out, true_out)
def test_subword_group_joiner_prior_tokenization_harder(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"how■",
"ever",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"⦅mrk_begin_case_region_U⦆",
"she",
"is",
"hard",
"■-■",
"working",
"⦅mrk_end_case_region_U⦆",
"■.",
] # noqa: E501
original_data_in = [
"⦅mrk_case_modifier_C⦆",
"how■",
"ever",
"■,",
"according",
"to",
"the",
"logs",
"■,",
"⦅mrk_begin_case_region_U⦆",
"she",
"is",
"hard",
"■-■",
"working",
"⦅mrk_end_case_region_U⦆",
"■.",
] # noqa: E501
true_out = [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
] # noqa: E501
out = subword_map_by_joiner(
data_in, marker=SubwordMarker.JOINER, original_subwords=original_data_in
)
self.assertEqual(out, true_out)
def test_subword_group_joiner_with_new_joiner(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"however",
"■",
",",
"according",
"to",
"the",
"logs",
"■",
",",
"⦅mrk_begin_case_region_U⦆",
"she",
"is",
"hard",
"■",
"-",
"■",
"working",
"⦅mrk_end_case_region_U⦆",
"■",
".",
] # noqa: E501
true_out = [
0,
0,
0,
0,
1,
2,
3,
4,
4,
4,
5,
5,
6,
7,
7,
7,
7,
7,
7,
7,
7,
] # noqa: E501
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER)
self.assertEqual(out, true_out)
def test_subword_group_naive(self):
data_in = [
"however",
",",
"according",
"to",
"the",
"logs",
",",
"she",
"is",
"hard",
"-",
"working",
".",
] # noqa: E501
true_out = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
out = subword_map_by_joiner(data_in, marker=SubwordMarker.JOINER)
self.assertEqual(out, true_out)
def test_subword_group_spacer(self):
data_in = [
"however",
",",
"▁according",
"▁to",
"▁the",
"▁logs",
",",
"▁she",
"▁is",
"▁hard",
"-",
"working",
".",
] # noqa: E501
true_out = [0, 0, 1, 2, 3, 4, 4, 5, 6, 7, 7, 7, 7]
out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER)
self.assertEqual(out, true_out)
# no dummy prefix
no_dummy = [
"however",
",",
"▁according",
"▁to",
"▁the",
"▁logs",
",",
"▁she",
"▁is",
"▁hard",
"-",
"working",
".",
] # noqa: E501
no_dummy_out = subword_map_by_spacer(no_dummy, marker=SubwordMarker.SPACER)
self.assertEqual(no_dummy_out, true_out)
def test_subword_group_spacer_with_case_markup(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"▁however",
",",
"▁according",
"▁to",
"▁the",
"▁logs",
",",
"▁⦅mrk_begin_case_region_U⦆",
"▁she",
"▁is",
"▁hard",
"-",
"working",
".",
"▁⦅mrk_end_case_region_U⦆",
] # noqa: E501
true_out = [0, 0, 0, 1, 2, 3, 4, 4, 5, 5, 6, 7, 7, 7, 7, 7]
out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER)
self.assertEqual(out, true_out)
def test_subword_group_spacer_with_spacer_new(self):
data_in = [
"⦅mrk_case_modifier_C⦆",
"▁",
"however",
",",
"▁",
"according",
"▁",
"to",
"▁",
"the",
"▁",
"logs",
",",
"▁",
"⦅mrk_begin_case_region_U⦆",
"▁",
"she",
"▁",
"is",
"▁",
"hard",
"-",
"working",
".",
"▁",
"⦅mrk_end_case_region_U⦆",
] # noqa: E501
true_out = [
0,
0,
0,
0,
1,
1,
2,
2,
3,
3,
4,
4,
4,
5,
5,
5,
5,
6,
6,
7,
7,
7,
7,
7,
7,
7,
] # noqa: E501
out = subword_map_by_spacer(data_in, marker=SubwordMarker.SPACER)
self.assertEqual(out, true_out)
if __name__ == "__main__":
unittest.main()