|
""" |
|
Genetic Algorithm Word Swap |
|
==================================== |
|
""" |
|
from abc import ABC, abstractmethod |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from textattack.goal_function_results import GoalFunctionResultStatus |
|
from textattack.search_methods import PopulationBasedSearch, PopulationMember |
|
from textattack.shared.validators import transformation_consists_of_word_swaps |
|
|
|
|
|
class GeneticAlgorithm(PopulationBasedSearch, ABC): |
|
"""Base class for attacking a model with word substiutitions using a |
|
genetic algorithm. |
|
|
|
Args: |
|
pop_size (int): The population size. Defaults to 20. |
|
max_iters (int): The maximum number of iterations to use. Defaults to 50. |
|
temp (float): Temperature for softmax function used to normalize probability dist when sampling parents. |
|
Higher temperature increases the sensitivity to lower probability candidates. |
|
give_up_if_no_improvement (bool): If True, stop the search early if no candidate that improves the score is found. |
|
post_crossover_check (bool): If True, check if child produced from crossover step passes the constraints. |
|
max_crossover_retries (int): Maximum number of crossover retries if resulting child fails to pass the constraints. |
|
Applied only when `post_crossover_check` is set to `True`. |
|
Setting it to 0 means we immediately take one of the parents at random as the child upon failure. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
pop_size=60, |
|
max_iters=20, |
|
temp=0.3, |
|
give_up_if_no_improvement=False, |
|
post_crossover_check=True, |
|
max_crossover_retries=20, |
|
): |
|
self.max_iters = max_iters |
|
self.pop_size = pop_size |
|
self.temp = temp |
|
self.give_up_if_no_improvement = give_up_if_no_improvement |
|
self.post_crossover_check = post_crossover_check |
|
self.max_crossover_retries = max_crossover_retries |
|
|
|
|
|
self._search_over = False |
|
|
|
@abstractmethod |
|
def _modify_population_member(self, pop_member, new_text, new_result, word_idx): |
|
"""Modify `pop_member` by returning a new copy with `new_text`, |
|
`new_result`, and, `attributes` altered appropriately for given |
|
`word_idx`""" |
|
raise NotImplementedError() |
|
|
|
@abstractmethod |
|
def _get_word_select_prob_weights(self, pop_member): |
|
"""Get the attribute of `pop_member` that is used for determining |
|
probability of each word being selected for perturbation.""" |
|
raise NotImplementedError |
|
|
|
def _perturb(self, pop_member, original_result, index=None): |
|
"""Perturb `pop_member` and return it. Replaces a word at a random |
|
(unless `index` is specified) in `pop_member`. |
|
|
|
Args: |
|
pop_member (PopulationMember): The population member being perturbed. |
|
original_result (GoalFunctionResult): Result of original sample being attacked |
|
index (int): Index of word to perturb. |
|
Returns: |
|
Perturbed `PopulationMember` |
|
""" |
|
num_words = pop_member.attacked_text.num_words |
|
|
|
word_select_prob_weights = np.copy( |
|
self._get_word_select_prob_weights(pop_member) |
|
) |
|
non_zero_indices = np.count_nonzero(word_select_prob_weights) |
|
if non_zero_indices == 0: |
|
return pop_member |
|
iterations = 0 |
|
while iterations < non_zero_indices: |
|
if index: |
|
idx = index |
|
else: |
|
w_select_probs = word_select_prob_weights / np.sum( |
|
word_select_prob_weights |
|
) |
|
idx = np.random.choice(num_words, 1, p=w_select_probs)[0] |
|
|
|
transformed_texts = self.get_transformations( |
|
pop_member.attacked_text, |
|
original_text=original_result.attacked_text, |
|
indices_to_modify=[idx], |
|
) |
|
|
|
if not len(transformed_texts): |
|
iterations += 1 |
|
continue |
|
|
|
new_results, self._search_over = self.get_goal_results(transformed_texts) |
|
|
|
diff_scores = ( |
|
torch.Tensor([r.score for r in new_results]) - pop_member.result.score |
|
) |
|
if len(diff_scores) and diff_scores.max() > 0: |
|
idx_with_max_score = diff_scores.argmax() |
|
pop_member = self._modify_population_member( |
|
pop_member, |
|
transformed_texts[idx_with_max_score], |
|
new_results[idx_with_max_score], |
|
idx, |
|
) |
|
return pop_member |
|
|
|
word_select_prob_weights[idx] = 0 |
|
iterations += 1 |
|
|
|
if self._search_over: |
|
break |
|
|
|
return pop_member |
|
|
|
@abstractmethod |
|
def _crossover_operation(self, pop_member1, pop_member2): |
|
"""Actual operation that takes `pop_member1` text and `pop_member2` |
|
text and mixes the two to generate crossover between `pop_member1` and |
|
`pop_member2`. |
|
|
|
Args: |
|
pop_member1 (PopulationMember): The first population member. |
|
pop_member2 (PopulationMember): The second population member. |
|
Returns: |
|
Tuple of `AttackedText` and a dictionary of attributes. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def _post_crossover_check( |
|
self, new_text, parent_text1, parent_text2, original_text |
|
): |
|
"""Check if `new_text` that has been produced by performing crossover |
|
between `parent_text1` and `parent_text2` aligns with the constraints. |
|
|
|
Args: |
|
new_text (AttackedText): Text produced by crossover operation |
|
parent_text1 (AttackedText): Parent text of `new_text` |
|
parent_text2 (AttackedText): Second parent text of `new_text` |
|
original_text (AttackedText): Original text |
|
Returns: |
|
`True` if `new_text` meets the constraints. If otherwise, return `False`. |
|
""" |
|
if "last_transformation" in new_text.attack_attrs: |
|
previous_text = ( |
|
parent_text1 |
|
if "last_transformation" in parent_text1.attack_attrs |
|
else parent_text2 |
|
) |
|
passed_constraints = self._check_constraints( |
|
new_text, previous_text, original_text=original_text |
|
) |
|
return passed_constraints |
|
else: |
|
|
|
return True |
|
|
|
def _crossover(self, pop_member1, pop_member2, original_text): |
|
"""Generates a crossover between pop_member1 and pop_member2. |
|
|
|
If the child fails to satisfy the constraints, we re-try crossover for a fix number of times, |
|
before taking one of the parents at random as the resulting child. |
|
Args: |
|
pop_member1 (PopulationMember): The first population member. |
|
pop_member2 (PopulationMember): The second population member. |
|
original_text (AttackedText): Original text |
|
Returns: |
|
A population member containing the crossover. |
|
""" |
|
x1_text = pop_member1.attacked_text |
|
x2_text = pop_member2.attacked_text |
|
|
|
num_tries = 0 |
|
passed_constraints = False |
|
while num_tries < self.max_crossover_retries + 1: |
|
new_text, attributes = self._crossover_operation(pop_member1, pop_member2) |
|
|
|
replaced_indices = new_text.attack_attrs["newly_modified_indices"] |
|
new_text.attack_attrs["modified_indices"] = ( |
|
x1_text.attack_attrs["modified_indices"] - replaced_indices |
|
) | (x2_text.attack_attrs["modified_indices"] & replaced_indices) |
|
|
|
if "last_transformation" in x1_text.attack_attrs: |
|
new_text.attack_attrs["last_transformation"] = x1_text.attack_attrs[ |
|
"last_transformation" |
|
] |
|
elif "last_transformation" in x2_text.attack_attrs: |
|
new_text.attack_attrs["last_transformation"] = x2_text.attack_attrs[ |
|
"last_transformation" |
|
] |
|
|
|
if self.post_crossover_check: |
|
passed_constraints = self._post_crossover_check( |
|
new_text, x1_text, x2_text, original_text |
|
) |
|
|
|
if not self.post_crossover_check or passed_constraints: |
|
break |
|
|
|
num_tries += 1 |
|
|
|
if self.post_crossover_check and not passed_constraints: |
|
|
|
|
|
pop_mem = pop_member1 if np.random.uniform() < 0.5 else pop_member2 |
|
return pop_mem |
|
else: |
|
new_results, self._search_over = self.get_goal_results([new_text]) |
|
return PopulationMember( |
|
new_text, result=new_results[0], attributes=attributes |
|
) |
|
|
|
@abstractmethod |
|
def _initialize_population(self, initial_result, pop_size): |
|
""" |
|
Initialize a population of size `pop_size` with `initial_result` |
|
Args: |
|
initial_result (GoalFunctionResult): Original text |
|
pop_size (int): size of population |
|
Returns: |
|
population as `list[PopulationMember]` |
|
""" |
|
raise NotImplementedError() |
|
|
|
def perform_search(self, initial_result): |
|
self._search_over = False |
|
population = self._initialize_population(initial_result, self.pop_size) |
|
pop_size = len(population) |
|
current_score = initial_result.score |
|
|
|
for i in range(self.max_iters): |
|
population = sorted(population, key=lambda x: x.result.score, reverse=True) |
|
|
|
if ( |
|
self._search_over |
|
or population[0].result.goal_status |
|
== GoalFunctionResultStatus.SUCCEEDED |
|
): |
|
break |
|
|
|
if population[0].result.score > current_score: |
|
current_score = population[0].result.score |
|
elif self.give_up_if_no_improvement: |
|
break |
|
|
|
pop_scores = torch.Tensor([pm.result.score for pm in population]) |
|
logits = ((-pop_scores) / self.temp).exp() |
|
select_probs = (logits / logits.sum()).cpu().numpy() |
|
|
|
parent1_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs) |
|
parent2_idx = np.random.choice(pop_size, size=pop_size - 1, p=select_probs) |
|
|
|
children = [] |
|
for idx in range(pop_size - 1): |
|
child = self._crossover( |
|
population[parent1_idx[idx]], |
|
population[parent2_idx[idx]], |
|
initial_result.attacked_text, |
|
) |
|
if self._search_over: |
|
break |
|
|
|
child = self._perturb(child, initial_result) |
|
children.append(child) |
|
|
|
|
|
|
|
if self._search_over: |
|
break |
|
|
|
population = [population[0]] + children |
|
|
|
return population[0].result |
|
|
|
def check_transformation_compatibility(self, transformation): |
|
"""The genetic algorithm is specifically designed for word |
|
substitutions.""" |
|
return transformation_consists_of_word_swaps(transformation) |
|
|
|
@property |
|
def is_black_box(self): |
|
return True |
|
|
|
def extra_repr_keys(self): |
|
return [ |
|
"pop_size", |
|
"max_iters", |
|
"temp", |
|
"give_up_if_no_improvement", |
|
"post_crossover_check", |
|
"max_crossover_retries", |
|
] |
|
|