Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
10.1 kB
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform, ObservableStats
class FilterTooLongStats(ObservableStats):
"""Runing statistics for FilterTooLongTransform."""
__slots__ = ["filtered"]
def __init__(self):
self.filtered = 1
def update(self, other: "FilterTooLongStats"):
self.filtered += other.filtered
@register_transform(name="filtertoolong")
class FilterTooLongTransform(Transform):
"""Filter out sentence that are too long."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""
Available options relate to this Transform.
For performance it is better to use multiple of 8
On target side, since we'll add BOS/EOS, we filter with minus 2
"""
group = parser.add_argument_group("Transform/Filter")
group.add(
"--src_seq_length",
"-src_seq_length",
type=int,
default=192,
help="Maximum source sequence length.",
)
group.add(
"--tgt_seq_length",
"-tgt_seq_length",
type=int,
default=192,
help="Maximum target sequence length.",
)
def _parse_opts(self):
self.src_seq_length = self.opts.src_seq_length
self.tgt_seq_length = self.opts.tgt_seq_length
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Return None if too long else return as is."""
if (
len(example["src"]) > self.src_seq_length
or len(example["tgt"]) > self.tgt_seq_length - 2
):
if stats is not None:
stats.update(FilterTooLongStats())
return None
else:
return example
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}, {}={}".format(
"src_seq_length", self.src_seq_length, "tgt_seq_length", self.tgt_seq_length
)
@register_transform(name="prefix")
class PrefixTransform(Transform):
"""Add Prefix to src (& tgt) sentence."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalailable options relate to this Transform."""
group = parser.add_argument_group("Transform/Prefix")
group.add(
"--src_prefix",
"-src_prefix",
type=str,
default="",
help="String to prepend to all source example.",
)
group.add(
"--tgt_prefix",
"-tgt_prefix",
type=str,
default="",
help="String to prepend to all target example.",
)
@staticmethod
def _get_prefix(corpus):
"""Get prefix string of a `corpus`."""
if "prefix" in corpus["transforms"]:
src_prefix = corpus.get("src_prefix", "")
tgt_prefix = corpus.get("tgt_prefix", "")
prefix = {"src": src_prefix, "tgt": tgt_prefix}
else:
prefix = None
return prefix
@classmethod
def get_prefix_dict(cls, opts):
"""Get all needed prefix correspond to corpus in `opts`."""
prefix_dict = {}
# prefix src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
prefix = cls._get_prefix(corpus)
if prefix is not None:
logger.debug(f"Get prefix for {c_name}: {prefix}")
prefix_dict[c_name] = prefix
# prefix as general option for inference
if hasattr(opts, "src_prefix"):
if "infer" not in prefix_dict.keys():
prefix_dict["infer"] = {}
prefix_dict["infer"]["src"] = opts.src_prefix
logger.debug(f"Get prefix for src infer: {opts.src_prefix}")
if hasattr(opts, "tgt_prefix"):
if "infer" not in prefix_dict.keys():
prefix_dict["infer"] = {}
prefix_dict["infer"]["tgt"] = opts.tgt_prefix
logger.debug(f"Get prefix for tgt infer: {opts.tgt_prefix}")
return prefix_dict
@classmethod
def get_specials(cls, opts):
"""Get special vocabs added by prefix transform."""
prefix_dict = cls.get_prefix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, prefix in prefix_dict.items():
src_specials.update(prefix["src"].split())
tgt_specials.update(prefix["tgt"].split())
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
"""Warm up to get prefix dictionary."""
super().warm_up(None)
self.prefix_dict = self.get_prefix_dict(self.opts)
def _prepend(self, example, prefix):
"""Prepend `prefix` to `tokens`."""
for side, side_prefix in prefix.items():
if example.get(side) is not None:
example[side] = side_prefix.split() + example[side]
elif len(side_prefix) > 0:
example[side] = side_prefix.split()
return example
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply prefix prepend to example.
Should provide `corpus_name` to get correspond prefix.
"""
corpus_name = kwargs.get("corpus_name", None)
if corpus_name is None:
raise ValueError("corpus_name is required.")
corpus_prefix = self.prefix_dict.get(corpus_name, None)
if corpus_prefix is None:
raise ValueError(f"prefix for {corpus_name} does not exist.")
return self._prepend(example, corpus_prefix)
def apply_reverse(self, translated):
def _removeprefix(s, prefix):
if s.startswith(prefix) and len(prefix) > 0:
return s[len(prefix) + 1 :]
else:
return s
corpus_prefix = self.prefix_dict.get("infer", None)
return _removeprefix(translated, corpus_prefix["tgt"])
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}".format("prefix_dict", self.prefix_dict)
@register_transform(name="suffix")
class SuffixTransform(Transform):
"""Add Suffix to src (& tgt) sentence."""
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Avalailable options relate to this Transform."""
group = parser.add_argument_group("Transform/Suffix")
group.add(
"--src_suffix",
"-src_suffix",
type=str,
default="",
help="String to append to all source example.",
)
group.add(
"--tgt_suffix",
"-tgt_suffix",
type=str,
default="",
help="String to append to all target example.",
)
@staticmethod
def _get_suffix(corpus):
"""Get suffix string of a `corpus`."""
if "suffix" in corpus["transforms"]:
src_suffix = corpus.get("src_suffix", "")
tgt_suffix = corpus.get("tgt_suffix", "")
suffix = {"src": src_suffix, "tgt": tgt_suffix}
else:
suffix = None
return suffix
@classmethod
def get_suffix_dict(cls, opts):
"""Get all needed suffix correspond to corpus in `opts`."""
suffix_dict = {}
# suffix src/tgt for each dataset
if hasattr(opts, "data"):
for c_name, corpus in opts.data.items():
suffix = cls._get_suffix(corpus)
if suffix is not None:
logger.debug(f"Get suffix for {c_name}: {suffix}")
suffix_dict[c_name] = suffix
# suffix as general option for inference
if hasattr(opts, "src_suffix"):
if "infer" not in suffix_dict.keys():
suffix_dict["infer"] = {}
suffix_dict["infer"]["src"] = opts.src_suffix
logger.debug(f"Get suffix for src infer: {opts.src_suffix}")
if hasattr(opts, "tgt_suffix"):
if "infer" not in suffix_dict.keys():
suffix_dict["infer"] = {}
suffix_dict["infer"]["tgt"] = opts.tgt_suffix
logger.debug(f"Get suffix for tgt infer: {opts.tgt_suffix}")
return suffix_dict
@classmethod
def get_specials(cls, opts):
"""Get special vocabs added by suffix transform."""
suffix_dict = cls.get_suffix_dict(opts)
src_specials, tgt_specials = set(), set()
for _, suffix in suffix_dict.items():
src_specials.update(suffix["src"].split())
tgt_specials.update(suffix["tgt"].split())
return (src_specials, tgt_specials)
def warm_up(self, vocabs=None):
"""Warm up to get suffix dictionary."""
super().warm_up(None)
self.suffix_dict = self.get_suffix_dict(self.opts)
def _append(self, example, suffix):
"""Prepend `suffix` to `tokens`."""
for side, side_suffix in suffix.items():
if example.get(side) is not None:
example[side] = example[side] + side_suffix.split()
elif len(side_suffix) > 0:
example[side] = side_suffix.split()
return example
def apply(self, example, is_train=False, stats=None, **kwargs):
"""Apply suffix append to example.
Should provide `corpus_name` to get correspond suffix.
"""
corpus_name = kwargs.get("corpus_name", None)
if corpus_name is None:
raise ValueError("corpus_name is required.")
corpus_suffix = self.suffix_dict.get(corpus_name, None)
if corpus_suffix is None:
raise ValueError(f"suffix for {corpus_name} does not exist.")
return self._append(example, corpus_suffix)
def _repr_args(self):
"""Return str represent key arguments for class."""
return "{}={}".format("suffix_dict", self.suffix_dict)