Rapid-Textual-Adversarial-Defense
/
textattack
/goal_functions
/classification
/targeted_classification.py
""" | |
Determine if an attack has been successful in targeted Classification | |
----------------------------------------------------------------------- | |
""" | |
from .classification_goal_function import ClassificationGoalFunction | |
class TargetedClassification(ClassificationGoalFunction): | |
"""A targeted attack on classification models which attempts to maximize | |
the score of the target label. | |
Complete when the arget label is the predicted label. | |
""" | |
def __init__(self, *args, target_class=0, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.target_class = target_class | |
def _is_goal_complete(self, model_output, _): | |
return ( | |
self.target_class == model_output.argmax() | |
) or self.ground_truth_output == self.target_class | |
def _get_score(self, model_output, _): | |
if self.target_class < 0 or self.target_class >= len(model_output): | |
raise ValueError( | |
f"target class set to {self.target_class} with {len(model_output)} classes." | |
) | |
else: | |
return model_output[self.target_class] | |
def extra_repr_keys(self): | |
if self.maximizable: | |
return ["maximizable", "target_class"] | |
else: | |
return ["target_class"] | |