|
from onmt.utils.logging import logger |
|
from onmt.transforms import register_transform |
|
from .transform import Transform |
|
|
|
import spacy |
|
import ahocorasick |
|
import re |
|
|
|
|
|
class TermMatcher(object): |
|
def __init__( |
|
self, |
|
termbase_path, |
|
src_spacy_language_model, |
|
tgt_spacy_language_model, |
|
term_example_ratio, |
|
src_term_stoken, |
|
tgt_term_stoken, |
|
tgt_term_etoken, |
|
delimiter, |
|
term_corpus_ratio=0.2, |
|
): |
|
self.term_example_ratio = term_example_ratio |
|
self.src_nlp = spacy.load(src_spacy_language_model, disable=["parser", "ner"]) |
|
self.tgt_nlp = spacy.load(tgt_spacy_language_model, disable=["parser", "ner"]) |
|
|
|
|
|
|
|
|
|
self.src_nlp.tokenizer.rules = { |
|
key: value |
|
for key, value in self.src_nlp.tokenizer.rules.items() |
|
if "'" not in key and "’" not in key and "‘" not in key |
|
} |
|
self.tgt_nlp.tokenizer.rules = { |
|
key: value |
|
for key, value in self.tgt_nlp.tokenizer.rules.items() |
|
if "'" not in key and "’" not in key and "‘" not in key |
|
} |
|
self.internal_termbase = self._create_internal_termbase(termbase_path) |
|
self.automaton = self._create_automaton() |
|
self.term_corpus_ratio = term_corpus_ratio |
|
self.src_term_stoken = src_term_stoken |
|
self.tgt_term_stoken = tgt_term_stoken |
|
self.tgt_term_etoken = tgt_term_etoken |
|
self.delimiter = delimiter |
|
|
|
def _create_internal_termbase(self, termbase_path): |
|
logger.debug("Creating termbase with lemmas for Terminology transform") |
|
|
|
|
|
src_stopwords = self.src_nlp.Defaults.stop_words |
|
tgt_stopwords = self.tgt_nlp.Defaults.stop_words |
|
termbase = list() |
|
with open(termbase_path, mode="r", encoding="utf-8") as file: |
|
pairs = file.readlines() |
|
for pair in pairs: |
|
src_term, tgt_term = map(str, pair.split("\t")) |
|
src_lemma = " ".join( |
|
"∥".join(tok.lemma_.split()) for tok in self.src_nlp(src_term) |
|
).strip() |
|
tgt_lemma = " ".join( |
|
tok.lemma_ for tok in self.tgt_nlp(tgt_term) |
|
).strip() |
|
if ( |
|
src_lemma.lower() not in src_stopwords |
|
and tgt_lemma.lower() not in tgt_stopwords |
|
): |
|
termbase.append((src_lemma, tgt_lemma)) |
|
logger.debug( |
|
f"Created termbase with {len(termbase)} lemmas " |
|
f"for Terminology transform" |
|
) |
|
return termbase |
|
|
|
def _create_automaton(self): |
|
automaton = ahocorasick.Automaton() |
|
for term in self.internal_termbase: |
|
automaton.add_word(term[0], (term[0], term[1])) |
|
automaton.make_automaton() |
|
return automaton |
|
|
|
def _src_sentence_with_terms(self, source_string, target_string) -> tuple: |
|
|
|
maybe_augmented = source_string.split(self.delimiter) |
|
source_only = maybe_augmented[0].strip() |
|
augmented_part = ( |
|
maybe_augmented[1].strip() if len(maybe_augmented) > 1 else None |
|
) |
|
|
|
doc_src = self.src_nlp(source_only) |
|
doc_tgt = self.tgt_nlp(target_string) |
|
|
|
|
|
tokenized_source = [tok.text for tok in doc_src] |
|
lemmatized_source = ["∥".join(tok.lemma_.lower().split()) for tok in doc_src] |
|
lemmatized_target = [tok.lemma_.lower() for tok in doc_tgt] |
|
|
|
lemmatized_source_string = " ".join(lemmatized_source) |
|
|
|
offset = 0 |
|
source_with_terms = list() |
|
term_counter = 0 |
|
|
|
max_terms_allowed = int(len(tokenized_source) * self.term_example_ratio) |
|
is_match = False |
|
for match_end, (src_entry, tgt_entry) in self.automaton.iter_long( |
|
lemmatized_source_string |
|
): |
|
|
|
if term_counter == max_terms_allowed: |
|
break |
|
|
|
match_start = match_end - len(src_entry) + 1 |
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
(tgt_entry.lower() not in " ".join(lemmatized_target).lower()) |
|
or ( |
|
len(lemmatized_source_string) != match_end + 1 |
|
and not (lemmatized_source_string[match_end + 1].isspace()) |
|
) |
|
or ( |
|
not lemmatized_source_string[match_start - 1].isspace() |
|
and match_start != 0 |
|
) |
|
): |
|
continue |
|
else: |
|
term_counter += 1 |
|
|
|
|
|
|
|
lemma_list_index = 0 |
|
for i, w in enumerate(lemmatized_source): |
|
if lemma_list_index == match_start: |
|
lemma_list_index = i |
|
break |
|
else: |
|
lemma_list_index += len(w) + 1 |
|
|
|
|
|
num_words_in_src_term = len(src_entry.split()) |
|
src_term = " ".join( |
|
tokenized_source[ |
|
lemma_list_index : lemma_list_index + num_words_in_src_term |
|
] |
|
).strip() |
|
|
|
|
|
|
|
tgt_term = tgt_entry.replace(" ", "∥").rstrip().lower() |
|
source_with_terms.append( |
|
f"{lemmatized_source_string[offset: match_start]}" |
|
f"{self.src_term_stoken}∥{src_term}∥{self.tgt_term_stoken}∥" |
|
f"{tgt_term}∥{self.tgt_term_etoken}" |
|
) |
|
|
|
offset = match_end + 1 |
|
is_match = True |
|
|
|
if is_match: |
|
source_with_terms.append(lemmatized_source_string[offset:]) |
|
tokenized_source_with_terms = "".join(source_with_terms).split() |
|
|
|
if not ( |
|
len(tokenized_source) |
|
== len(lemmatized_source) |
|
== len(tokenized_source_with_terms) |
|
): |
|
final_string = " ".join(tokenized_source) |
|
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
|
return fixed_punct.split(), not is_match |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
completed_tokenized_source = list() |
|
for idx in range(len(tokenized_source_with_terms)): |
|
|
|
src_lemma = tokenized_source_with_terms[idx].replace("∥", " ") |
|
if lemmatized_source[idx].replace("∥", " ") == src_lemma: |
|
completed_tokenized_source.append(tokenized_source[idx]) |
|
else: |
|
completed_tokenized_source.append(src_lemma) |
|
|
|
if augmented_part is not None: |
|
final_string = " ".join( |
|
completed_tokenized_source |
|
+ [self.delimiter] |
|
+ augmented_part.split() |
|
) |
|
else: |
|
final_string = " ".join(completed_tokenized_source) |
|
|
|
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
|
return fixed_punct.split(), is_match |
|
else: |
|
final_string = " ".join(tokenized_source) |
|
fixed_punct = re.sub(r" ([^\w\s⦅\-\–])", r"\1", final_string) |
|
return fixed_punct.split(), not is_match |
|
|
|
|
|
@register_transform(name="terminology") |
|
class TerminologyTransform(Transform): |
|
def __init__(self, opts): |
|
super().__init__(opts) |
|
|
|
@classmethod |
|
def add_options(cls, parser): |
|
"""Available options for terminology matching.""" |
|
|
|
group = parser.add_argument_group("Transform/Terminology") |
|
group.add( |
|
"--termbase_path", |
|
"-termbase_path", |
|
type=str, |
|
help="Path to a dictionary file with terms.", |
|
) |
|
group.add( |
|
"--src_spacy_language_model", |
|
"-src_spacy_language_model", |
|
type=str, |
|
help="Name of the spacy language model for the source corpus.", |
|
) |
|
group.add( |
|
"--tgt_spacy_language_model", |
|
"-tgt_spacy_language_model", |
|
type=str, |
|
help="Name of the spacy language model for the target corpus.", |
|
) |
|
group.add( |
|
"--term_corpus_ratio", |
|
"-term_corpus_ratio", |
|
type=float, |
|
default=0.3, |
|
help="Ratio of corpus to augment with terms.", |
|
) |
|
group.add( |
|
"--term_example_ratio", |
|
"-term_example_ratio", |
|
type=float, |
|
default=0.2, |
|
help="Max terms allowed in an example.", |
|
) |
|
group.add( |
|
"--src_term_stoken", |
|
"-src_term_stoken", |
|
type=str, |
|
help="The source term start token.", |
|
default="⦅src_term_start⦆", |
|
) |
|
group.add( |
|
"--tgt_term_stoken", |
|
"-tgt_term_stoken", |
|
type=str, |
|
help="The target term start token.", |
|
default="⦅tgt_term_start⦆", |
|
) |
|
group.add( |
|
"--tgt_term_etoken", |
|
"-tgt_term_etoken", |
|
type=str, |
|
help="The target term end token.", |
|
default="⦅tgt_term_end⦆", |
|
) |
|
group.add( |
|
"--term_source_delimiter", |
|
"-term_source_delimiter", |
|
type=str, |
|
help="Any special token used for augmented source sentences. " |
|
"The default is the fuzzy token used in the " |
|
"FuzzyMatch transform.", |
|
default="⦅fuzzy⦆", |
|
) |
|
|
|
def _parse_opts(self): |
|
self.termbase_path = self.opts.termbase_path |
|
self.src_spacy_language_model = self.opts.src_spacy_language_model |
|
self.tgt_spacy_language_model = self.opts.tgt_spacy_language_model |
|
self.term_corpus_ratio = self.opts.term_corpus_ratio |
|
self.term_example_ratio = self.opts.term_example_ratio |
|
self.term_source_delimiter = self.opts.term_source_delimiter |
|
self.src_term_stoken = self.opts.src_term_stoken |
|
self.tgt_term_stoken = self.opts.tgt_term_stoken |
|
self.tgt_term_etoken = self.opts.tgt_term_etoken |
|
|
|
@classmethod |
|
def get_specials(cls, opts): |
|
"""Add the term tokens to the src vocab.""" |
|
src_specials = list() |
|
src_specials.extend( |
|
[opts.src_term_stoken, opts.tgt_term_stoken, opts.tgt_term_etoken] |
|
) |
|
return (src_specials, list()) |
|
|
|
def warm_up(self, vocabs=None): |
|
"""Create the terminology matcher.""" |
|
|
|
super().warm_up(None) |
|
self.termmatcher = TermMatcher( |
|
self.termbase_path, |
|
self.src_spacy_language_model, |
|
self.tgt_spacy_language_model, |
|
self.term_example_ratio, |
|
self.src_term_stoken, |
|
self.tgt_term_stoken, |
|
self.tgt_term_etoken, |
|
self.term_source_delimiter, |
|
self.term_corpus_ratio, |
|
) |
|
|
|
def batch_apply(self, batch, is_train=False, stats=None, **kwargs): |
|
bucket_size = len(batch) |
|
examples_with_terms = 0 |
|
|
|
for i, (ex, _, _) in enumerate(batch): |
|
|
|
|
|
|
|
if i % 2 == 0: |
|
original_src = ex["src"] |
|
augmented_example, is_match = self.apply(ex, is_train, stats, **kwargs) |
|
if is_match and ( |
|
examples_with_terms < bucket_size * self.term_corpus_ratio |
|
): |
|
examples_with_terms += 1 |
|
ex["src"] = augmented_example["src"] |
|
else: |
|
ex["src"] = original_src |
|
|
|
logger.debug(f"Added terms to {examples_with_terms}/{bucket_size} examples") |
|
return batch |
|
|
|
def apply(self, example, is_train=False, stats=None, **kwargs) -> tuple: |
|
"""Add terms to source examples.""" |
|
|
|
example["src"], is_match = self.termmatcher._src_sentence_with_terms( |
|
" ".join(example["src"]), " ".join(example["tgt"]) |
|
) |
|
return example, is_match |
|
|