|
""" |
|
Greedy Word Swap with Word Importance Ranking |
|
=================================================== |
|
|
|
|
|
When WIR method is set to ``unk``, this is a reimplementation of the search |
|
method from the paper: Is BERT Really Robust? |
|
|
|
A Strong Baseline for Natural Language Attack on Text Classification and |
|
Entailment by Jin et. al, 2019. See https://arxiv.org/abs/1907.11932 and |
|
https://github.com/jind11/TextFooler. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn.functional import softmax |
|
|
|
from textattack.goal_function_results import GoalFunctionResultStatus |
|
from textattack.search_methods import SearchMethod |
|
from textattack.shared.validators import ( |
|
transformation_consists_of_word_swaps_and_deletions, |
|
) |
|
|
|
|
|
class GreedyWordSwapWIR(SearchMethod): |
|
"""An attack that greedily chooses from a list of possible perturbations in |
|
order of index, after ranking indices by importance. |
|
|
|
Args: |
|
wir_method: method for ranking most important words |
|
model_wrapper: model wrapper used for gradient-based ranking |
|
""" |
|
|
|
def __init__(self, wir_method="unk", unk_token="[UNK]"): |
|
self.wir_method = wir_method |
|
self.unk_token = unk_token |
|
|
|
def _get_index_order(self, initial_text): |
|
"""Returns word indices of ``initial_text`` in descending order of |
|
importance.""" |
|
|
|
len_text, indices_to_order = self.get_indices_to_order(initial_text) |
|
|
|
if self.wir_method == "unk": |
|
leave_one_texts = [ |
|
initial_text.replace_word_at_index(i, self.unk_token) |
|
for i in indices_to_order |
|
] |
|
leave_one_results, search_over = self.get_goal_results(leave_one_texts) |
|
index_scores = np.array([result.score for result in leave_one_results]) |
|
|
|
elif self.wir_method == "weighted-saliency": |
|
|
|
leave_one_texts = [ |
|
initial_text.replace_word_at_index(i, self.unk_token) |
|
for i in indices_to_order |
|
] |
|
leave_one_results, search_over = self.get_goal_results(leave_one_texts) |
|
saliency_scores = np.array([result.score for result in leave_one_results]) |
|
|
|
softmax_saliency_scores = softmax( |
|
torch.Tensor(saliency_scores), dim=0 |
|
).numpy() |
|
|
|
|
|
delta_ps = [] |
|
for idx in indices_to_order: |
|
|
|
|
|
if search_over: |
|
delta_ps = delta_ps + [0.0] * ( |
|
len(softmax_saliency_scores) - len(delta_ps) |
|
) |
|
break |
|
|
|
transformed_text_candidates = self.get_transformations( |
|
initial_text, |
|
original_text=initial_text, |
|
indices_to_modify=[idx], |
|
) |
|
if not transformed_text_candidates: |
|
|
|
delta_ps.append(0.0) |
|
continue |
|
swap_results, search_over = self.get_goal_results( |
|
transformed_text_candidates |
|
) |
|
score_change = [result.score for result in swap_results] |
|
if not score_change: |
|
delta_ps.append(0.0) |
|
continue |
|
max_score_change = np.max(score_change) |
|
delta_ps.append(max_score_change) |
|
|
|
index_scores = softmax_saliency_scores * np.array(delta_ps) |
|
|
|
elif self.wir_method == "delete": |
|
leave_one_texts = [ |
|
initial_text.delete_word_at_index(i) for i in indices_to_order |
|
] |
|
leave_one_results, search_over = self.get_goal_results(leave_one_texts) |
|
index_scores = np.array([result.score for result in leave_one_results]) |
|
|
|
elif self.wir_method == "gradient": |
|
victim_model = self.get_victim_model() |
|
index_scores = np.zeros(len_text) |
|
grad_output = victim_model.get_grad(initial_text.tokenizer_input) |
|
gradient = grad_output["gradient"] |
|
word2token_mapping = initial_text.align_with_model_tokens(victim_model) |
|
for i, index in enumerate(indices_to_order): |
|
matched_tokens = word2token_mapping[index] |
|
if not matched_tokens: |
|
index_scores[i] = 0.0 |
|
else: |
|
agg_grad = np.mean(gradient[matched_tokens], axis=0) |
|
index_scores[i] = np.linalg.norm(agg_grad, ord=1) |
|
|
|
search_over = False |
|
|
|
elif self.wir_method == "random": |
|
index_order = indices_to_order |
|
np.random.shuffle(index_order) |
|
search_over = False |
|
else: |
|
raise ValueError(f"Unsupported WIR method {self.wir_method}") |
|
|
|
if self.wir_method != "random": |
|
index_order = np.array(indices_to_order)[(-index_scores).argsort()] |
|
|
|
return index_order, search_over |
|
|
|
def perform_search(self, initial_result): |
|
attacked_text = initial_result.attacked_text |
|
|
|
|
|
index_order, search_over = self._get_index_order(attacked_text) |
|
i = 0 |
|
cur_result = initial_result |
|
results = None |
|
while i < len(index_order) and not search_over: |
|
transformed_text_candidates = self.get_transformations( |
|
cur_result.attacked_text, |
|
original_text=initial_result.attacked_text, |
|
indices_to_modify=[index_order[i]], |
|
) |
|
i += 1 |
|
if len(transformed_text_candidates) == 0: |
|
continue |
|
results, search_over = self.get_goal_results(transformed_text_candidates) |
|
results = sorted(results, key=lambda x: -x.score) |
|
|
|
if results[0].score > cur_result.score: |
|
cur_result = results[0] |
|
else: |
|
continue |
|
|
|
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: |
|
best_result = cur_result |
|
|
|
max_similarity = -float("inf") |
|
for result in results: |
|
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED: |
|
break |
|
candidate = result.attacked_text |
|
try: |
|
similarity_score = candidate.attack_attrs["similarity_score"] |
|
except KeyError: |
|
|
|
|
|
|
|
|
|
break |
|
if similarity_score > max_similarity: |
|
max_similarity = similarity_score |
|
best_result = result |
|
return best_result |
|
|
|
return cur_result |
|
|
|
def check_transformation_compatibility(self, transformation): |
|
"""Since it ranks words by their importance, GreedyWordSwapWIR is |
|
limited to word swap and deletion transformations.""" |
|
return transformation_consists_of_word_swaps_and_deletions(transformation) |
|
|
|
@property |
|
def is_black_box(self): |
|
if self.wir_method == "gradient": |
|
return False |
|
else: |
|
return True |
|
|
|
def extra_repr_keys(self): |
|
return ["wir_method"] |
|
|