|
from onmt.utils.logging import logger |
|
from onmt.transforms import register_transform |
|
from .transform import Transform, ObservableStats |
|
|
|
|
|
class FilterTooLongStats(ObservableStats): |
|
"""Runing statistics for FilterTooLongTransform.""" |
|
|
|
__slots__ = ["filtered"] |
|
|
|
def __init__(self): |
|
self.filtered = 1 |
|
|
|
def update(self, other: "FilterTooLongStats"): |
|
self.filtered += other.filtered |
|
|
|
|
|
@register_transform(name="filtertoolong") |
|
class FilterTooLongTransform(Transform): |
|
"""Filter out sentence that are too long.""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
""" |
|
Available options relate to this Transform. |
|
For performance it is better to use multiple of 8 |
|
On target side, since we'll add BOS/EOS, we filter with minus 2 |
|
""" |
|
group = parser.add_argument_group("Transform/Filter") |
|
group.add( |
|
"--src_seq_length", |
|
"-src_seq_length", |
|
type=int, |
|
default=192, |
|
help="Maximum source sequence length.", |
|
) |
|
group.add( |
|
"--tgt_seq_length", |
|
"-tgt_seq_length", |
|
type=int, |
|
default=192, |
|
help="Maximum target sequence length.", |
|
) |
|
|
|
def _parse_opts(self): |
|
self.src_seq_length = self.opts.src_seq_length |
|
self.tgt_seq_length = self.opts.tgt_seq_length |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
"""Return None if too long else return as is.""" |
|
if ( |
|
len(example["src"]) > self.src_seq_length |
|
or len(example["tgt"]) > self.tgt_seq_length - 2 |
|
): |
|
if stats is not None: |
|
stats.update(FilterTooLongStats()) |
|
return None |
|
else: |
|
return example |
|
|
|
def _repr_args(self): |
|
"""Return str represent key arguments for class.""" |
|
return "{}={}, {}={}".format( |
|
"src_seq_length", self.src_seq_length, "tgt_seq_length", self.tgt_seq_length |
|
) |
|
|
|
|
|
@register_transform(name="prefix") |
|
class PrefixTransform(Transform): |
|
"""Add Prefix to src (& tgt) sentence.""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Avalailable options relate to this Transform.""" |
|
group = parser.add_argument_group("Transform/Prefix") |
|
group.add( |
|
"--src_prefix", |
|
"-src_prefix", |
|
type=str, |
|
default="", |
|
help="String to prepend to all source example.", |
|
) |
|
group.add( |
|
"--tgt_prefix", |
|
"-tgt_prefix", |
|
type=str, |
|
default="", |
|
help="String to prepend to all target example.", |
|
) |
|
|
|
@staticmethod |
|
def _get_prefix(corpus): |
|
"""Get prefix string of a `corpus`.""" |
|
if "prefix" in corpus["transforms"]: |
|
src_prefix = corpus.get("src_prefix", "") |
|
tgt_prefix = corpus.get("tgt_prefix", "") |
|
prefix = {"src": src_prefix, "tgt": tgt_prefix} |
|
else: |
|
prefix = None |
|
return prefix |
|
|
|
@classmethod |
|
def get_prefix_dict(cls, opts): |
|
"""Get all needed prefix correspond to corpus in `opts`.""" |
|
prefix_dict = {} |
|
|
|
if hasattr(opts, "data"): |
|
for c_name, corpus in opts.data.items(): |
|
prefix = cls._get_prefix(corpus) |
|
if prefix is not None: |
|
logger.debug(f"Get prefix for {c_name}: {prefix}") |
|
prefix_dict[c_name] = prefix |
|
|
|
if hasattr(opts, "src_prefix"): |
|
if "infer" not in prefix_dict.keys(): |
|
prefix_dict["infer"] = {} |
|
prefix_dict["infer"]["src"] = opts.src_prefix |
|
logger.debug(f"Get prefix for src infer: {opts.src_prefix}") |
|
if hasattr(opts, "tgt_prefix"): |
|
if "infer" not in prefix_dict.keys(): |
|
prefix_dict["infer"] = {} |
|
prefix_dict["infer"]["tgt"] = opts.tgt_prefix |
|
logger.debug(f"Get prefix for tgt infer: {opts.tgt_prefix}") |
|
|
|
return prefix_dict |
|
|
|
@classmethod |
|
def get_specials(cls, opts): |
|
"""Get special vocabs added by prefix transform.""" |
|
prefix_dict = cls.get_prefix_dict(opts) |
|
src_specials, tgt_specials = set(), set() |
|
for _, prefix in prefix_dict.items(): |
|
src_specials.update(prefix["src"].split()) |
|
tgt_specials.update(prefix["tgt"].split()) |
|
return (src_specials, tgt_specials) |
|
|
|
def warm_up(self, vocabs=None): |
|
"""Warm up to get prefix dictionary.""" |
|
super().warm_up(None) |
|
self.prefix_dict = self.get_prefix_dict(self.opts) |
|
|
|
def _prepend(self, example, prefix): |
|
"""Prepend `prefix` to `tokens`.""" |
|
for side, side_prefix in prefix.items(): |
|
if example.get(side) is not None: |
|
example[side] = side_prefix.split() + example[side] |
|
elif len(side_prefix) > 0: |
|
example[side] = side_prefix.split() |
|
return example |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
"""Apply prefix prepend to example. |
|
|
|
Should provide `corpus_name` to get correspond prefix. |
|
""" |
|
corpus_name = kwargs.get("corpus_name", None) |
|
if corpus_name is None: |
|
raise ValueError("corpus_name is required.") |
|
corpus_prefix = self.prefix_dict.get(corpus_name, None) |
|
if corpus_prefix is None: |
|
raise ValueError(f"prefix for {corpus_name} does not exist.") |
|
return self._prepend(example, corpus_prefix) |
|
|
|
def apply_reverse(self, translated): |
|
def _removeprefix(s, prefix): |
|
if s.startswith(prefix) and len(prefix) > 0: |
|
return s[len(prefix) + 1 :] |
|
else: |
|
return s |
|
|
|
corpus_prefix = self.prefix_dict.get("infer", None) |
|
return _removeprefix(translated, corpus_prefix["tgt"]) |
|
|
|
def _repr_args(self): |
|
"""Return str represent key arguments for class.""" |
|
return "{}={}".format("prefix_dict", self.prefix_dict) |
|
|
|
|
|
@register_transform(name="suffix") |
|
class SuffixTransform(Transform): |
|
"""Add Suffix to src (& tgt) sentence.""" |
|
|
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Avalailable options relate to this Transform.""" |
|
group = parser.add_argument_group("Transform/Suffix") |
|
group.add( |
|
"--src_suffix", |
|
"-src_suffix", |
|
type=str, |
|
default="", |
|
help="String to append to all source example.", |
|
) |
|
group.add( |
|
"--tgt_suffix", |
|
"-tgt_suffix", |
|
type=str, |
|
default="", |
|
help="String to append to all target example.", |
|
) |
|
|
|
@staticmethod |
|
def _get_suffix(corpus): |
|
"""Get suffix string of a `corpus`.""" |
|
if "suffix" in corpus["transforms"]: |
|
src_suffix = corpus.get("src_suffix", "") |
|
tgt_suffix = corpus.get("tgt_suffix", "") |
|
suffix = {"src": src_suffix, "tgt": tgt_suffix} |
|
else: |
|
suffix = None |
|
return suffix |
|
|
|
@classmethod |
|
def get_suffix_dict(cls, opts): |
|
"""Get all needed suffix correspond to corpus in `opts`.""" |
|
suffix_dict = {} |
|
|
|
if hasattr(opts, "data"): |
|
for c_name, corpus in opts.data.items(): |
|
suffix = cls._get_suffix(corpus) |
|
if suffix is not None: |
|
logger.debug(f"Get suffix for {c_name}: {suffix}") |
|
suffix_dict[c_name] = suffix |
|
|
|
if hasattr(opts, "src_suffix"): |
|
if "infer" not in suffix_dict.keys(): |
|
suffix_dict["infer"] = {} |
|
suffix_dict["infer"]["src"] = opts.src_suffix |
|
logger.debug(f"Get suffix for src infer: {opts.src_suffix}") |
|
if hasattr(opts, "tgt_suffix"): |
|
if "infer" not in suffix_dict.keys(): |
|
suffix_dict["infer"] = {} |
|
suffix_dict["infer"]["tgt"] = opts.tgt_suffix |
|
logger.debug(f"Get suffix for tgt infer: {opts.tgt_suffix}") |
|
|
|
return suffix_dict |
|
|
|
@classmethod |
|
def get_specials(cls, opts): |
|
"""Get special vocabs added by suffix transform.""" |
|
suffix_dict = cls.get_suffix_dict(opts) |
|
src_specials, tgt_specials = set(), set() |
|
for _, suffix in suffix_dict.items(): |
|
src_specials.update(suffix["src"].split()) |
|
tgt_specials.update(suffix["tgt"].split()) |
|
return (src_specials, tgt_specials) |
|
|
|
def warm_up(self, vocabs=None): |
|
"""Warm up to get suffix dictionary.""" |
|
super().warm_up(None) |
|
self.suffix_dict = self.get_suffix_dict(self.opts) |
|
|
|
def _append(self, example, suffix): |
|
"""Prepend `suffix` to `tokens`.""" |
|
for side, side_suffix in suffix.items(): |
|
if example.get(side) is not None: |
|
example[side] = example[side] + side_suffix.split() |
|
elif len(side_suffix) > 0: |
|
example[side] = side_suffix.split() |
|
return example |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs): |
|
"""Apply suffix append to example. |
|
|
|
Should provide `corpus_name` to get correspond suffix. |
|
""" |
|
corpus_name = kwargs.get("corpus_name", None) |
|
if corpus_name is None: |
|
raise ValueError("corpus_name is required.") |
|
corpus_suffix = self.suffix_dict.get(corpus_name, None) |
|
if corpus_suffix is None: |
|
raise ValueError(f"suffix for {corpus_name} does not exist.") |
|
return self._append(example, corpus_suffix) |
|
|
|
def _repr_args(self): |
|
"""Return str represent key arguments for class.""" |
|
return "{}={}".format("suffix_dict", self.suffix_dict) |
|
|