|
from onmt.transforms import register_transform |
|
from .transform import Transform |
|
from onmt.utils.alignment import subword_map_by_joiner, subword_map_by_spacer |
|
|
|
|
|
@register_transform(name="inferfeats") |
|
class InferFeatsTransform(Transform): |
|
"""Infer features for subword tokenization.""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Avalilable options related to this Transform.""" |
|
group = parser.add_argument_group("Transform/InferFeats") |
|
group.add( |
|
"--reversible_tokenization", |
|
"-reversible_tokenization", |
|
default="joiner", |
|
choices=["joiner", "spacer"], |
|
help="Type of reversible tokenization " "applied on the tokenizer.", |
|
) |
|
|
|
def _parse_opts(self): |
|
super()._parse_opts() |
|
self.reversible_tokenization = self.opts.reversible_tokenization |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
if "src_feats" not in example: |
|
|
|
return example |
|
|
|
if self.reversible_tokenization == "joiner": |
|
original_src = example["src_original"] |
|
word_to_subword_mapping = subword_map_by_joiner( |
|
example["src"], original_subwords=original_src |
|
) |
|
else: |
|
word_to_subword_mapping = subword_map_by_spacer(example["src"]) |
|
|
|
new_src_feats = [[] for _ in range(len(example["src_feats"]))] |
|
for subword, word_id in zip(example["src"], word_to_subword_mapping): |
|
for i, feat_values in enumerate(example["src_feats"]): |
|
inferred_feat = feat_values[word_id] |
|
new_src_feats[i].append(inferred_feat) |
|
example["src_feats"] = new_src_feats |
|
|
|
return example |
|
|
|
def _repr_args(self): |
|
return "" |
|
|