|
""" |
|
Word Embedding Distance |
|
-------------------------- |
|
""" |
|
|
|
from textattack.constraints import Constraint |
|
from textattack.shared import AbstractWordEmbedding, WordEmbedding |
|
from textattack.shared.validators import transformation_consists_of_word_swaps |
|
|
|
|
|
class WordEmbeddingDistance(Constraint): |
|
"""A constraint on word substitutions which places a maximum distance |
|
between the embedding of the word being deleted and the word being |
|
inserted. |
|
|
|
Args: |
|
embedding (obj): Wrapper for word embedding. |
|
include_unknown_words (bool): Whether or not the constraint is fulfilled if the embedding of x or x_adv is unknown. |
|
min_cos_sim (:obj:`float`, optional): The minimum cosine similarity between word embeddings. |
|
max_mse_dist (:obj:`float`, optional): The maximum euclidean distance between word embeddings. |
|
cased (bool): Whether embedding supports uppercase & lowercase (defaults to False, or just lowercase). |
|
compare_against_original (bool): If `True`, compare new `x_adv` against the original `x`. Otherwise, compare it against the previous `x_adv`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embedding=None, |
|
include_unknown_words=True, |
|
min_cos_sim=None, |
|
max_mse_dist=None, |
|
cased=False, |
|
compare_against_original=True, |
|
): |
|
super().__init__(compare_against_original) |
|
if embedding is None: |
|
embedding = WordEmbedding.counterfitted_GLOVE_embedding() |
|
self.include_unknown_words = include_unknown_words |
|
self.cased = cased |
|
|
|
if bool(min_cos_sim) == bool(max_mse_dist): |
|
raise ValueError("You must choose either `min_cos_sim` or `max_mse_dist`.") |
|
self.min_cos_sim = min_cos_sim |
|
self.max_mse_dist = max_mse_dist |
|
|
|
if not isinstance(embedding, AbstractWordEmbedding): |
|
raise ValueError( |
|
"`embedding` object must be of type `textattack.shared.AbstractWordEmbedding`." |
|
) |
|
self.embedding = embedding |
|
|
|
def get_cos_sim(self, a, b): |
|
"""Returns the cosine similarity of words with IDs a and b.""" |
|
return self.embedding.get_cos_sim(a, b) |
|
|
|
def get_mse_dist(self, a, b): |
|
"""Returns the MSE distance of words with IDs a and b.""" |
|
return self.embedding.get_mse_dist(a, b) |
|
|
|
def _check_constraint(self, transformed_text, reference_text): |
|
"""Returns true if (``transformed_text`` and ``reference_text``) are |
|
closer than ``self.min_cos_sim`` or ``self.max_mse_dist``.""" |
|
try: |
|
indices = transformed_text.attack_attrs["newly_modified_indices"] |
|
except KeyError: |
|
raise KeyError( |
|
"Cannot apply part-of-speech constraint without `newly_modified_indices`" |
|
) |
|
|
|
|
|
if any( |
|
i >= len(reference_text.words) or i >= len(transformed_text.words) |
|
for i in indices |
|
): |
|
return False |
|
|
|
for i in indices: |
|
ref_word = reference_text.words[i] |
|
transformed_word = transformed_text.words[i] |
|
|
|
if not self.cased: |
|
|
|
ref_word = ref_word.lower() |
|
transformed_word = transformed_word.lower() |
|
|
|
try: |
|
ref_id = self.embedding.word2index(ref_word) |
|
transformed_id = self.embedding.word2index(transformed_word) |
|
except KeyError: |
|
|
|
if self.include_unknown_words: |
|
continue |
|
return False |
|
|
|
|
|
if self.min_cos_sim: |
|
cos_sim = self.get_cos_sim(ref_id, transformed_id) |
|
if cos_sim < self.min_cos_sim: |
|
return False |
|
|
|
if self.max_mse_dist: |
|
mse_dist = self.get_mse_dist(ref_id, transformed_id) |
|
if mse_dist > self.max_mse_dist: |
|
return False |
|
|
|
return True |
|
|
|
def check_compatibility(self, transformation): |
|
"""WordEmbeddingDistance requires a word being both deleted and |
|
inserted at the same index in order to compare their embeddings, |
|
therefore it's restricted to word swaps.""" |
|
return transformation_consists_of_word_swaps(transformation) |
|
|
|
def extra_repr_keys(self): |
|
"""Set the extra representation of the constraint using these keys. |
|
|
|
To print customized extra information, you should reimplement |
|
this method in your own constraint. Both single-line and multi- |
|
line strings are acceptable. |
|
""" |
|
if self.min_cos_sim is None: |
|
metric = "max_mse_dist" |
|
else: |
|
metric = "min_cos_sim" |
|
return [ |
|
"embedding", |
|
metric, |
|
"cased", |
|
"include_unknown_words", |
|
] + super().extra_repr_keys() |
|
|