""" Greedy Word Swap with Word Importance Ranking =================================================== When WIR method is set to ``unk``, this is a reimplementation of the search method from the paper: Is BERT Really Robust? A Strong Baseline for Natural Language Attack on Text Classification and Entailment by Jin et. al, 2019. See https://arxiv.org/abs/1907.11932 and https://github.com/jind11/TextFooler. """ import numpy as np import torch from torch.nn.functional import softmax from textattack.goal_function_results import GoalFunctionResultStatus from textattack.search_methods import SearchMethod from textattack.shared.validators import ( transformation_consists_of_word_swaps_and_deletions, ) class GreedyWordSwapWIR(SearchMethod): """An attack that greedily chooses from a list of possible perturbations in order of index, after ranking indices by importance. Args: wir_method: method for ranking most important words model_wrapper: model wrapper used for gradient-based ranking """ def __init__(self, wir_method="unk", unk_token="[UNK]"): self.wir_method = wir_method self.unk_token = unk_token def _get_index_order(self, initial_text): """Returns word indices of ``initial_text`` in descending order of importance.""" len_text, indices_to_order = self.get_indices_to_order(initial_text) if self.wir_method == "unk": leave_one_texts = [ initial_text.replace_word_at_index(i, self.unk_token) for i in indices_to_order ] leave_one_results, search_over = self.get_goal_results(leave_one_texts) index_scores = np.array([result.score for result in leave_one_results]) elif self.wir_method == "weighted-saliency": # first, compute word saliency leave_one_texts = [ initial_text.replace_word_at_index(i, self.unk_token) for i in indices_to_order ] leave_one_results, search_over = self.get_goal_results(leave_one_texts) saliency_scores = np.array([result.score for result in leave_one_results]) softmax_saliency_scores = softmax( torch.Tensor(saliency_scores), dim=0 ).numpy() # compute the largest change in score we can find by swapping each word delta_ps = [] for idx in indices_to_order: # Exit Loop when search_over is True - but we need to make sure delta_ps # is the same size as softmax_saliency_scores if search_over: delta_ps = delta_ps + [0.0] * ( len(softmax_saliency_scores) - len(delta_ps) ) break transformed_text_candidates = self.get_transformations( initial_text, original_text=initial_text, indices_to_modify=[idx], ) if not transformed_text_candidates: # no valid synonym substitutions for this word delta_ps.append(0.0) continue swap_results, search_over = self.get_goal_results( transformed_text_candidates ) score_change = [result.score for result in swap_results] if not score_change: delta_ps.append(0.0) continue max_score_change = np.max(score_change) delta_ps.append(max_score_change) index_scores = softmax_saliency_scores * np.array(delta_ps) elif self.wir_method == "delete": leave_one_texts = [ initial_text.delete_word_at_index(i) for i in indices_to_order ] leave_one_results, search_over = self.get_goal_results(leave_one_texts) index_scores = np.array([result.score for result in leave_one_results]) elif self.wir_method == "gradient": victim_model = self.get_victim_model() index_scores = np.zeros(len_text) grad_output = victim_model.get_grad(initial_text.tokenizer_input) gradient = grad_output["gradient"] word2token_mapping = initial_text.align_with_model_tokens(victim_model) for i, index in enumerate(indices_to_order): matched_tokens = word2token_mapping[index] if not matched_tokens: index_scores[i] = 0.0 else: agg_grad = np.mean(gradient[matched_tokens], axis=0) index_scores[i] = np.linalg.norm(agg_grad, ord=1) search_over = False elif self.wir_method == "random": index_order = indices_to_order np.random.shuffle(index_order) search_over = False else: raise ValueError(f"Unsupported WIR method {self.wir_method}") if self.wir_method != "random": index_order = np.array(indices_to_order)[(-index_scores).argsort()] return index_order, search_over def perform_search(self, initial_result): attacked_text = initial_result.attacked_text # Sort words by order of importance index_order, search_over = self._get_index_order(attacked_text) i = 0 cur_result = initial_result results = None while i < len(index_order) and not search_over: transformed_text_candidates = self.get_transformations( cur_result.attacked_text, original_text=initial_result.attacked_text, indices_to_modify=[index_order[i]], ) i += 1 if len(transformed_text_candidates) == 0: continue results, search_over = self.get_goal_results(transformed_text_candidates) results = sorted(results, key=lambda x: -x.score) # Skip swaps which don't improve the score if results[0].score > cur_result.score: cur_result = results[0] else: continue # If we succeeded, return the index with best similarity. if cur_result.goal_status == GoalFunctionResultStatus.SUCCEEDED: best_result = cur_result # @TODO: Use vectorwise operations max_similarity = -float("inf") for result in results: if result.goal_status != GoalFunctionResultStatus.SUCCEEDED: break candidate = result.attacked_text try: similarity_score = candidate.attack_attrs["similarity_score"] except KeyError: # If the attack was run without any similarity metrics, # candidates won't have a similarity score. In this # case, break and return the candidate that changed # the original score the most. break if similarity_score > max_similarity: max_similarity = similarity_score best_result = result return best_result return cur_result def check_transformation_compatibility(self, transformation): """Since it ranks words by their importance, GreedyWordSwapWIR is limited to word swap and deletion transformations.""" return transformation_consists_of_word_swaps_and_deletions(transformation) @property def is_black_box(self): if self.wir_method == "gradient": return False else: return True def extra_repr_keys(self): return ["wir_method"]