ReactSeq / onmt /transforms /terminology.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
13.5 kB
from onmt.utils.logging import logger
from onmt.transforms import register_transform
from .transform import Transform
import spacy
import ahocorasick
import re
class TermMatcher(object):
def __init__(
self,
termbase_path,
src_spacy_language_model,
tgt_spacy_language_model,
term_example_ratio,
src_term_stoken,
tgt_term_stoken,
tgt_term_etoken,
delimiter,
term_corpus_ratio=0.2,
):
self.term_example_ratio = term_example_ratio
self.src_nlp = spacy.load(src_spacy_language_model, disable=["parser", "ner"])
self.tgt_nlp = spacy.load(tgt_spacy_language_model, disable=["parser", "ner"])
# We exclude tokenization for contractions in
# order to avoid inconsistencies with pyonmtok's tokenization.
# (e.g. "I ca n't" with spacy, "I can ' t" with pyonmttok)
self.src_nlp.tokenizer.rules = {
key: value
for key, value in self.src_nlp.tokenizer.rules.items()
if "'" not in key and "’" not in key and "‘" not in key
}
self.tgt_nlp.tokenizer.rules = {
key: value
for key, value in self.tgt_nlp.tokenizer.rules.items()
if "'" not in key and "’" not in key and "‘" not in key
}
self.internal_termbase = self._create_internal_termbase(termbase_path)
self.automaton = self._create_automaton()
self.term_corpus_ratio = term_corpus_ratio
self.src_term_stoken = src_term_stoken
self.tgt_term_stoken = tgt_term_stoken
self.tgt_term_etoken = tgt_term_etoken
self.delimiter = delimiter
def _create_internal_termbase(self, termbase_path):
logger.debug("Creating termbase with lemmas for Terminology transform")
# Use Spacy's stopwords to get rid of junk entries
src_stopwords = self.src_nlp.Defaults.stop_words
tgt_stopwords = self.tgt_nlp.Defaults.stop_words
termbase = list()
with open(termbase_path, mode="r", encoding="utf-8") as file:
pairs = file.readlines()
for pair in pairs:
src_term, tgt_term = map(str, pair.split("\t"))
src_lemma = " ".join(
"∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term)
).strip()
tgt_lemma = " ".join(
tok.lemma_ for tok in self.tgt_nlp(tgt_term)
).strip()
if (
src_lemma.lower() not in src_stopwords
and tgt_lemma.lower() not in tgt_stopwords
):
termbase.append((src_lemma, tgt_lemma))
logger.debug(
f"Created termbase with {len(termbase)} lemmas "
f"for Terminology transform"
)
return termbase
def _create_automaton(self):
automaton = ahocorasick.Automaton()
for term in self.internal_termbase:
automaton.add_word(term[0], (term[0], term[1]))
automaton.make_automaton()
return automaton
def _src_sentence_with_terms(self, source_string, target_string) -> tuple:
maybe_augmented = source_string.split(self.delimiter)
source_only = maybe_augmented[0].strip()
augmented_part = (
maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None
)
doc_src = self.src_nlp(source_only)
doc_tgt = self.tgt_nlp(target_string)
# Perform tokenization with spacy for consistency.
tokenized_source = [tok.text for tok in doc_src]
lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src]
lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt]
lemmatized_source_string = " ".join(lemmatized_source)
offset = 0
source_with_terms = list()
term_counter = 0
max_terms_allowed = int(len(tokenized_source) * self.term_example_ratio)
is_match = False
for match_end, (src_entry, tgt_entry) in self.automaton.iter_long(
lemmatized_source_string
):
if term_counter == max_terms_allowed:
break
match_start = match_end - len(src_entry) + 1
# We ensure that the target lemma is present in the lemmatized
# target string, that the match is an exact match (there is
# whitespace before or after the term)
# and we perform some bound checking.
if (
(tgt_entry.lower() not in " ".join(lemmatized_target).lower())
or (
len(lemmatized_source_string) != match_end + 1
and not (lemmatized_source_string[match_end + 1].isspace())
)
or (
not lemmatized_source_string[match_start - 1].isspace()
and match_start != 0
)
):
continue
else:
term_counter += 1
# Map the lemmatized string match index to
# the lemmatized list index
lemma_list_index = 0
for i, w in enumerate(lemmatized_source):
if lemma_list_index == match_start:
lemma_list_index = i
break
else:
lemma_list_index += len(w) + 1
# We need to know if the term is multiword
num_words_in_src_term = len(src_entry.split())
src_term = " ".join(
tokenized_source[
lemma_list_index : lemma_list_index + num_words_in_src_term
]
).strip()
# Join multiword target lemmas with a unique separator so
# we can treat them as single word and not change the indices.
tgt_term = tgt_entry.replace(" ", "∥").rstrip().lower()
source_with_terms.append(
f"{lemmatized_source_string[offset: match_start]}"
f"{self.src_term_stoken}{src_term}{self.tgt_term_stoken}∥"
f"{tgt_term}{self.tgt_term_etoken}"
)
offset = match_end + 1
is_match = True
if is_match:
source_with_terms.append(lemmatized_source_string[offset:])
tokenized_source_with_terms = "".join(source_with_terms).split()
if not (
len(tokenized_source)
== len(lemmatized_source)
== len(tokenized_source_with_terms)
):
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
# Construct the final source from the lemmatized list
# that contains the terms. We compare the tokens in the
# term-augmented lemma list with the tokens in the original
# lemma list. If the lemma is the same, then we replace with
# the token from the original tokenized source list. If they
# are not the same, it means the lemma has been augemented
# with a term, so we inject this in the final list.
completed_tokenized_source = list()
for idx in range(len(tokenized_source_with_terms)):
# Restore the spaces in multi-word terms
src_lemma = tokenized_source_with_terms[idx].replace("∥", " ")
if lemmatized_source[idx].replace("∥", " ") == src_lemma:
completed_tokenized_source.append(tokenized_source[idx])
else:
completed_tokenized_source.append(src_lemma)
if augmented_part is not None:
final_string = " ".join(
completed_tokenized_source
+ [self.delimiter]
+ augmented_part.split()
)
else:
final_string = " ".join(completed_tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), is_match
else:
final_string = " ".join(tokenized_source)
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string)
return fixed_punct.split(), not is_match
@register_transform(name="terminology")
class TerminologyTransform(Transform):
def __init__(self, opts):
super().__init__(opts)
@classmethod
def add_options(cls, parser):
"""Available options for terminology matching."""
group = parser.add_argument_group("Transform/Terminology")
group.add(
"--termbase_path",
"-termbase_path",
type=str,
help="Path to a dictionary file with terms.",
)
group.add(
"--src_spacy_language_model",
"-src_spacy_language_model",
type=str,
help="Name of the spacy language model for the source corpus.",
)
group.add(
"--tgt_spacy_language_model",
"-tgt_spacy_language_model",
type=str,
help="Name of the spacy language model for the target corpus.",
)
group.add(
"--term_corpus_ratio",
"-term_corpus_ratio",
type=float,
default=0.3,
help="Ratio of corpus to augment with terms.",
)
group.add(
"--term_example_ratio",
"-term_example_ratio",
type=float,
default=0.2,
help="Max terms allowed in an example.",
)
group.add(
"--src_term_stoken",
"-src_term_stoken",
type=str,
help="The source term start token.",
default="⦅src_term_start⦆",
)
group.add(
"--tgt_term_stoken",
"-tgt_term_stoken",
type=str,
help="The target term start token.",
default="⦅tgt_term_start⦆",
)
group.add(
"--tgt_term_etoken",
"-tgt_term_etoken",
type=str,
help="The target term end token.",
default="⦅tgt_term_end⦆",
)
group.add(
"--term_source_delimiter",
"-term_source_delimiter",
type=str,
help="Any special token used for augmented source sentences. "
"The default is the fuzzy token used in the "
"FuzzyMatch transform.",
default="⦅fuzzy⦆",
)
def _parse_opts(self):
self.termbase_path = self.opts.termbase_path
self.src_spacy_language_model = self.opts.src_spacy_language_model
self.tgt_spacy_language_model = self.opts.tgt_spacy_language_model
self.term_corpus_ratio = self.opts.term_corpus_ratio
self.term_example_ratio = self.opts.term_example_ratio
self.term_source_delimiter = self.opts.term_source_delimiter
self.src_term_stoken = self.opts.src_term_stoken
self.tgt_term_stoken = self.opts.tgt_term_stoken
self.tgt_term_etoken = self.opts.tgt_term_etoken
@classmethod
def get_specials(cls, opts):
"""Add the term tokens to the src vocab."""
src_specials = list()
src_specials.extend(
[opts.src_term_stoken, opts.tgt_term_stoken, opts.tgt_term_etoken]
)
return (src_specials, list())
def warm_up(self, vocabs=None):
"""Create the terminology matcher."""
super().warm_up(None)
self.termmatcher = TermMatcher(
self.termbase_path,
self.src_spacy_language_model,
self.tgt_spacy_language_model,
self.term_example_ratio,
self.src_term_stoken,
self.tgt_term_stoken,
self.tgt_term_etoken,
self.term_source_delimiter,
self.term_corpus_ratio,
)
def batch_apply(self, batch, is_train=False, stats=None, **kwargs):
bucket_size = len(batch)
examples_with_terms = 0
for i, (ex, _, _) in enumerate(batch):
# Skip half examples to improve performance. This means we set
# a hard limit for the `term_corpus_ratio` to 0.5, which is actually
# quite high. TODO: We can add this (skipping examples) as an option
if i % 2 == 0:
original_src = ex["src"]
augmented_example, is_match = self.apply(ex, is_train, stats, **kwargs)
if is_match and (
examples_with_terms < bucket_size * self.term_corpus_ratio
):
examples_with_terms += 1
ex["src"] = augmented_example["src"]
else:
ex["src"] = original_src
logger.debug(f"Added terms to {examples_with_terms}/{bucket_size} examples")
return batch
def apply(self, example, is_train=False, stats=None, **kwargs) -> tuple:
"""Add terms to source examples."""
example["src"], is_match = self.termmatcher._src_sentence_with_terms(
" ".join(example["src"]), " ".join(example["tgt"])
)
return example, is_match