|
""" |
|
|
|
TextAttack Constraint Class |
|
===================================== |
|
""" |
|
|
|
from abc import ABC, abstractmethod |
|
|
|
import textattack |
|
from textattack.shared.utils import ReprMixin |
|
|
|
|
|
class Constraint(ReprMixin, ABC): |
|
"""An abstract class that represents constraints on adversial text |
|
examples. Constraints evaluate whether transformations from a |
|
``AttackedText`` to another ``AttackedText`` meet certain conditions. |
|
|
|
Args: |
|
compare_against_original (bool): If `True`, the reference text should be the original text under attack. |
|
If `False`, the reference text is the most recent text from which the transformed text was generated. |
|
All constraints must have this attribute. |
|
""" |
|
|
|
def __init__(self, compare_against_original): |
|
self.compare_against_original = compare_against_original |
|
|
|
def call_many(self, transformed_texts, reference_text): |
|
"""Filters ``transformed_texts`` based on which transformations fulfill |
|
the constraint. First checks compatibility with latest |
|
``Transformation``, then calls ``_check_constraint_many`` |
|
|
|
Args: |
|
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s. |
|
reference_text (AttackedText): The ``AttackedText`` to compare against. |
|
""" |
|
incompatible_transformed_texts = [] |
|
compatible_transformed_texts = [] |
|
for transformed_text in transformed_texts: |
|
try: |
|
if self.check_compatibility( |
|
transformed_text.attack_attrs["last_transformation"] |
|
): |
|
compatible_transformed_texts.append(transformed_text) |
|
else: |
|
incompatible_transformed_texts.append(transformed_text) |
|
except KeyError: |
|
raise KeyError( |
|
"transformed_text must have `last_transformation` attack_attr to apply constraint" |
|
) |
|
filtered_texts = self._check_constraint_many( |
|
compatible_transformed_texts, reference_text |
|
) |
|
return list(filtered_texts) + incompatible_transformed_texts |
|
|
|
def _check_constraint_many(self, transformed_texts, reference_text): |
|
"""Filters ``transformed_texts`` based on which transformations fulfill |
|
the constraint. Calls ``check_constraint`` |
|
|
|
Args: |
|
transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText`` |
|
reference_texts (AttackedText): The ``AttackedText`` to compare against. |
|
""" |
|
return [ |
|
transformed_text |
|
for transformed_text in transformed_texts |
|
if self._check_constraint(transformed_text, reference_text) |
|
] |
|
|
|
def __call__(self, transformed_text, reference_text): |
|
"""Returns True if the constraint is fulfilled, False otherwise. First |
|
checks compatibility with latest ``Transformation``, then calls |
|
``_check_constraint`` |
|
|
|
Args: |
|
transformed_text (AttackedText): The candidate transformed ``AttackedText``. |
|
reference_text (AttackedText): The ``AttackedText`` to compare against. |
|
""" |
|
if not isinstance(transformed_text, textattack.shared.AttackedText): |
|
raise TypeError("transformed_text must be of type AttackedText") |
|
if not isinstance(reference_text, textattack.shared.AttackedText): |
|
raise TypeError("reference_text must be of type AttackedText") |
|
|
|
try: |
|
if not self.check_compatibility( |
|
transformed_text.attack_attrs["last_transformation"] |
|
): |
|
return True |
|
except KeyError: |
|
raise KeyError( |
|
"`transformed_text` must have `last_transformation` attack_attr to apply constraint." |
|
) |
|
return self._check_constraint(transformed_text, reference_text) |
|
|
|
@abstractmethod |
|
def _check_constraint(self, transformed_text, reference_text): |
|
"""Returns True if the constraint is fulfilled, False otherwise. Must |
|
be overridden by the specific constraint. |
|
|
|
Args: |
|
transformed_text: The candidate transformed ``AttackedText``. |
|
reference_text (AttackedText): The ``AttackedText`` to compare against. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def check_compatibility(self, transformation): |
|
"""Checks if this constraint is compatible with the given |
|
transformation. For example, the ``WordEmbeddingDistance`` constraint |
|
compares the embedding of the word inserted with that of the word |
|
deleted. Therefore it can only be applied in the case of word swaps, |
|
and not for transformations which involve only one of insertion or |
|
deletion. |
|
|
|
Args: |
|
transformation: The ``Transformation`` to check compatibility with. |
|
""" |
|
return True |
|
|
|
def extra_repr_keys(self): |
|
"""Set the extra representation of the constraint using these keys. |
|
|
|
To print customized extra information, you should reimplement |
|
this method in your own constraint. Both single-line and multi- |
|
line strings are acceptable. |
|
""" |
|
return ["compare_against_original"] |
|
|