File size: 7,895 Bytes
4943752 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
Greedy Word Swap with Word Importance Ranking
===================================================
When WIR method is set to ``unk``, this is a reimplementation of the search
method from the paper: Is BERT Really Robust?
A Strong Baseline for Natural Language Attack on Text Classification and
Entailment by Jin et. al, 2019. See https://arxiv.org/abs/1907.11932 and
https://github.com/jind11/TextFooler.
"""
import numpy as np
import torch
from torch.nn.functional import softmax
from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import SearchMethod
from textattack.shared.validators import (
transformation_consists_of_word_swaps_and_deletions,
)
class GreedyWordSwapWIR(SearchMethod):
"""An attack that greedily chooses from a list of possible perturbations in
order of index, after ranking indices by importance.
Args:
wir_method: method for ranking most important words
model_wrapper: model wrapper used for gradient-based ranking
"""
def __init__(self, wir_method="unk", unk_token="[UNK]"):
self.wir_method = wir_method
self.unk_token = unk_token
def _get_index_order(self, initial_text):
"""Returns word indices of ``initial_text`` in descending order of
importance."""
len_text, indices_to_order = self.get_indices_to_order(initial_text)
if self.wir_method == "unk":
leave_one_texts = [
initial_text.replace_word_at_index(i, self.unk_token)
for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "weighted-saliency":
# first, compute word saliency
leave_one_texts = [
initial_text.replace_word_at_index(i, self.unk_token)
for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
saliency_scores = np.array([result.score for result in leave_one_results])
softmax_saliency_scores = softmax(
torch.Tensor(saliency_scores), dim=0
).numpy()
# compute the largest change in score we can find by swapping each word
delta_ps = []
for idx in indices_to_order:
# Exit Loop when search_over is True - but we need to make sure delta_ps
# is the same size as softmax_saliency_scores
if search_over:
delta_ps = delta_ps + [0.0] * (
len(softmax_saliency_scores) - len(delta_ps)
)
break
transformed_text_candidates = self.get_transformations(
initial_text,
original_text=initial_text,
indices_to_modify=[idx],
)
if not transformed_text_candidates:
# no valid synonym substitutions for this word
delta_ps.append(0.0)
continue
swap_results, search_over = self.get_goal_results(
transformed_text_candidates
)
score_change = [result.score for result in swap_results]
if not score_change:
delta_ps.append(0.0)
continue
max_score_change = np.max(score_change)
delta_ps.append(max_score_change)
index_scores = softmax_saliency_scores * np.array(delta_ps)
elif self.wir_method == "delete":
leave_one_texts = [
initial_text.delete_word_at_index(i) for i in indices_to_order
]
leave_one_results, search_over = self.get_goal_results(leave_one_texts)
index_scores = np.array([result.score for result in leave_one_results])
elif self.wir_method == "gradient":
victim_model = self.get_victim_model()
index_scores = np.zeros(len_text)
grad_output = victim_model.get_grad(initial_text.tokenizer_input)
gradient = grad_output["gradient"]
word2token_mapping = initial_text.align_with_model_tokens(victim_model)
for i, index in enumerate(indices_to_order):
matched_tokens = word2token_mapping[index]
if not matched_tokens:
index_scores[i] = 0.0
else:
agg_grad = np.mean(gradient[matched_tokens], axis=0)
index_scores[i] = np.linalg.norm(agg_grad, ord=1)
search_over = False
elif self.wir_method == "random":
index_order = indices_to_order
np.random.shuffle(index_order)
search_over = False
else:
raise ValueError(f"Unsupported WIR method {self.wir_method}")
if self.wir_method != "random":
index_order = np.array(indices_to_order)[(-index_scores).argsort()]
return index_order, search_over
def perform_search(self, initial_result):
attacked_text = initial_result.attacked_text
# Sort words by order of importance
index_order, search_over = self._get_index_order(attacked_text)
i = 0
cur_result = initial_result
results = None
while i < len(index_order) and not search_over:
transformed_text_candidates = self.get_transformations(
cur_result.attacked_text,
original_text=initial_result.attacked_text,
indices_to_modify=[index_order[i]],
)
i += 1
if len(transformed_text_candidates) == 0:
continue
results, search_over = self.get_goal_results(transformed_text_candidates)
results = sorted(results, key=lambda x: -x.score)
# Skip swaps which don't improve the score
if results[0].score > cur_result.score:
cur_result = results[0]
else:
continue
# If we succeeded, return the index with best similarity.
if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED:
best_result = cur_result
# @TODO: Use vectorwise operations
max_similarity = -float("inf")
for result in results:
if result.goal_status != GoalFunctionResultStatus.SUCCEEDED:
break
candidate = result.attacked_text
try:
similarity_score = candidate.attack_attrs["similarity_score"]
except KeyError:
# If the attack was run without any similarity metrics,
# candidates won't have a similarity score. In this
# case, break and return the candidate that changed
# the original score the most.
break
if similarity_score > max_similarity:
max_similarity = similarity_score
best_result = result
return best_result
return cur_result
def check_transformation_compatibility(self, transformation):
"""Since it ranks words by their importance, GreedyWordSwapWIR is
limited to word swap and deletion transformations."""
return transformation_consists_of_word_swaps_and_deletions(transformation)
@property
def is_black_box(self):
if self.wir_method == "gradient":
return False
else:
return True
def extra_repr_keys(self):
return ["wir_method"]
|