import os import random import zipfile from difflib import Differ import gradio as gr import nltk import pandas as pd from findfile import find_files from anonymous_demo import TADCheckpointManager from textattack import Attacker from textattack.attack_recipes import ( BAEGarg2019, PWWSRen2019, TextFoolerJin2019, PSOZang2020, IGAWang2019, GeneticAlgorithmAlzantot2018, DeepWordBugGao2018, CLARE2020, ) from textattack.attack_results import SuccessfulAttackResult from textattack.datasets import Dataset from textattack.models.wrappers import HuggingFaceModelWrapper z = zipfile.ZipFile("checkpoints.zip", "r") z.extractall(os.getcwd()) class ModelWrapper(HuggingFaceModelWrapper): def __init__(self, model): self.model = model # pipeline = pipeline def __call__(self, text_inputs, **kwargs): outputs = [] for text_input in text_inputs: raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) outputs.append(raw_outputs["probs"]) return outputs class SentAttacker: def __init__(self, model, recipe_class=BAEGarg2019): model = model model_wrapper = ModelWrapper(model) recipe = recipe_class.build(model_wrapper) # WordNet defaults to english. Set the default language to French ('fra') # recipe.transformation.language = "en" _dataset = [("", 0)] _dataset = Dataset(_dataset) self.attacker = Attacker(recipe, _dataset) def diff_texts(text1, text2): d = Differ() return [ (token[2:], token[0] if token[0] != " " else None) for token in d.compare(text1, text2) ] def get_ensembled_tad_results(results): target_dict = {} for r in results: target_dict[r["label"]] = ( target_dict.get(r["label"]) + 1 if r["label"] in target_dict else 1 ) return dict(zip(target_dict.values(), target_dict.keys()))[ max(target_dict.values()) ] nltk.download("omw-1.4") sent_attackers = {} tad_classifiers = {} attack_recipes = { "bae": BAEGarg2019, "pwws": PWWSRen2019, "textfooler": TextFoolerJin2019, "pso": PSOZang2020, "iga": IGAWang2019, "GA": GeneticAlgorithmAlzantot2018, "wordbugger": DeepWordBugGao2018, 'clare': CLARE2020, } for attacker in ["pwws", "bae", "textfooler", "pso", "wordbugger", 'clare']: for dataset in [ "agnews10k", "amazon", "sst2", # 'imdb' ]: if "tad-{}".format(dataset) not in tad_classifiers: tad_classifiers[ "tad-{}".format(dataset) ] = TADCheckpointManager.get_tad_text_classifier( "tad-{}".format(dataset).upper() ) sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker( tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker] ) tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[ "tad-{}pwws".format(dataset) ] def get_sst2_example(): filter_key_words = [ ".py", ".md", "readme", "log", "result", "zip", ".state_dict", ".model", ".png", "acc_", "f1_", ".origin", ".adv", ".csv", ] dataset_file = {"train": [], "test": [], "valid": []} dataset = "sst2" search_path = "./" task = "text_defense" dataset_file["test"] += find_files( search_path, [dataset, "test", task], exclude_key=[".adv", ".org", ".defense", ".inference", "train."] + filter_key_words, ) for dat_type in ["test"]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode="r", encoding="utf8") as fin: lines = fin.readlines() for line in lines: text, label = line.split("$LABEL$") text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def get_agnews_example(): filter_key_words = [ ".py", ".md", "readme", "log", "result", "zip", ".state_dict", ".model", ".png", "acc_", "f1_", ".origin", ".adv", ".csv", ] dataset_file = {"train": [], "test": [], "valid": []} dataset = "agnews" search_path = "./" task = "text_defense" dataset_file["test"] += find_files( search_path, [dataset, "test", task], exclude_key=[".adv", ".org", ".defense", ".inference", "train."] + filter_key_words, ) for dat_type in ["test"]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode="r", encoding="utf8") as fin: lines = fin.readlines() for line in lines: text, label = line.split("$LABEL$") text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def get_amazon_example(): filter_key_words = [ ".py", ".md", "readme", "log", "result", "zip", ".state_dict", ".model", ".png", "acc_", "f1_", ".origin", ".adv", ".csv", ] dataset_file = {"train": [], "test": [], "valid": []} dataset = "amazon" search_path = "./" task = "text_defense" dataset_file["test"] += find_files( search_path, [dataset, "test", task], exclude_key=[".adv", ".org", ".defense", ".inference", "train."] + filter_key_words, ) for dat_type in ["test"]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode="r", encoding="utf8") as fin: lines = fin.readlines() for line in lines: text, label = line.split("$LABEL$") text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] def get_imdb_example(): filter_key_words = [ ".py", ".md", "readme", "log", "result", "zip", ".state_dict", ".model", ".png", "acc_", "f1_", ".origin", ".adv", ".csv", ] dataset_file = {"train": [], "test": [], "valid": []} dataset = "imdb" search_path = "./" task = "text_defense" dataset_file["test"] += find_files( search_path, [dataset, "test", task], exclude_key=[".adv", ".org", ".defense", ".inference", "train."] + filter_key_words, ) for dat_type in ["test"]: data = [] label_set = set() for data_file in dataset_file[dat_type]: with open(data_file, mode="r", encoding="utf8") as fin: lines = fin.readlines() for line in lines: text, label = line.split("$LABEL$") text = text.strip() label = int(label.strip()) data.append((text, label)) label_set.add(label) return data[random.randint(0, len(data))] cache = set() def generate_adversarial_example(dataset, attacker, text=None, label=None): if not text or text in cache: if "agnews" in dataset.lower(): text, label = get_agnews_example() elif "sst2" in dataset.lower(): text, label = get_sst2_example() elif "amazon" in dataset.lower(): text, label = get_amazon_example() elif "imdb" in dataset.lower(): text, label = get_imdb_example() cache.add(text) result = None attack_result = sent_attackers[ "tad-{}{}".format(dataset.lower(), attacker.lower()) ].attacker.simple_attack(text, int(label)) if isinstance(attack_result, SuccessfulAttackResult): if ( attack_result.perturbed_result.output != attack_result.original_result.ground_truth_output ) and ( attack_result.original_result.output == attack_result.original_result.ground_truth_output ): # with defense result = tad_classifiers["tad-{}".format(dataset.lower())].infer( attack_result.perturbed_result.attacked_text.text + "!ref!{},{},{}".format( attack_result.original_result.ground_truth_output, 1, attack_result.perturbed_result.output, ), print_result=True, defense="pwws", ) if result: classification_df = {} classification_df["is_repaired"] = result["is_fixed"] classification_df["pred_label"] = result["label"] classification_df["confidence"] = round(result["confidence"], 3) classification_df["is_correct"] = result["ref_label_check"] advdetection_df = {} if result["is_adv_label"] != "0": advdetection_df["is_adversarial"] = { "0": False, "1": True, 0: False, 1: True, }[result["is_adv_label"]] advdetection_df["perturbed_label"] = result["perturbed_label"] advdetection_df["confidence"] = round(result["is_adv_confidence"], 3) # advdetection_df['ref_is_attack'] = result['ref_is_adv_label'] # advdetection_df['is_correct'] = result['ref_is_adv_check'] else: return generate_adversarial_example(dataset, attacker) return ( text, label, result["restored_text"], result["label"], attack_result.perturbed_result.attacked_text.text, diff_texts(text, text), diff_texts(text, attack_result.perturbed_result.attacked_text.text), diff_texts(text, result["restored_text"]), attack_result.perturbed_result.output, pd.DataFrame(classification_df, index=[0]), pd.DataFrame(advdetection_df, index=[0]), ) demo = gr.Blocks() with demo: gr.Markdown( "#

Reactive Perturbation Defocusing for Textual Adversarial Defense

" ) gr.Markdown("##

Clarifications

") gr.Markdown( "- This demo has no mechanism to ensure the adversarial example will be correctly repaired by RPD." " The repair success rate is actually the performance reported in the paper (approximately up to 97%.)" ) gr.Markdown( "- The red (+) and green (-) colors in the character edition indicate the character is added " "or deleted in the adversarial example compared to the original input natural example." ) gr.Markdown( "- The adversarial example and repaired adversarial example may be unnatural to read, " "while it is because the attackers usually generate unnatural perturbations." "RPD does not introduce additional unnatural perturbations." ) gr.Markdown( "- To our best knowledge, Reactive Perturbation Defocusing is a novel approach in adversarial defense " ". RPD significantly (>10% defense accuracy improvement) outperforms the state-of-the-art methods." ) gr.Markdown( "- The DeepWordBug, IGA, GA, PSO, and CLARE attackers are very slow on CPU Devices." " And they are unknown attackers to RPD's adversarial detector. " ) gr.Markdown("##

Natural Example Input

") with gr.Group(): with gr.Row(): input_dataset = gr.Radio( choices=["SST2", "AGNews10K", "Amazon"], value="SST2", label="Select a testing dataset and an adversarial attacker to generate an adversarial example.", ) input_attacker = gr.Radio( choices=[ "BAE", "PWWS", "TextFooler", "WordBugger", "PSO", "CLARE", ], value="TextFooler", label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.", ) with gr.Group(): with gr.Row(): input_sentence = gr.Textbox( placeholder="Input a natural example...", label="Alternatively, input a natural example and its original label to generate an adversarial example.", ) input_label = gr.Textbox( placeholder="Original label...", label="Original Label" ) button_gen = gr.Button( "Generate an adversarial example and repair using RPD (No GPU, Time:3-10 mins )", variant="primary", ) gr.Markdown( "##

Generated Adversarial Example and Repaired Adversarial Example

" ) with gr.Group(): with gr.Column(): with gr.Row(): output_original_example = gr.Textbox(label="Original Example") output_original_label = gr.Textbox(label="Original Label") with gr.Row(): output_adv_example = gr.Textbox(label="Adversarial Example") output_adv_label = gr.Textbox(label="Perturbed Label") with gr.Row(): output_repaired_example = gr.Textbox( label="Repaired Adversarial Example by RPD" ) output_repaired_label = gr.Textbox(label="Repaired Label") gr.Markdown( "##

The Output of Reactive Perturbation Defocusing

" ) with gr.Group(): output_is_adv_df = gr.DataFrame(label="Adversarial Example Detection Result") gr.Markdown( "The is_adversarial field indicates an adversarial example is detected. " "The perturbed_label is the predicted label of the adversarial example. " "The confidence field represents the confidence of the predicted adversarial example detection. " ) output_df = gr.DataFrame(label="Repaired Standard Classification Result") gr.Markdown( "If is_repaired=true, it has been repaired by RPD. " "The pred_label field indicates the standard classification result. " "The confidence field represents the confidence of the predicted label. " "The is_correct field indicates whether the predicted label is correct." ) gr.Markdown("##

Example Comparisons

") ori_text_diff = gr.HighlightedText( label="The Original Natural Example", combine_adjacent=True, ) adv_text_diff = gr.HighlightedText( label="Character Editions of Adversarial Example Compared to the Natural Example", combine_adjacent=True, ) restored_text_diff = gr.HighlightedText( label="Character Editions of Repaired Adversarial Example Compared to the Natural Example", combine_adjacent=True, ) # Bind functions to buttons button_gen.click( fn=generate_adversarial_example, inputs=[input_dataset, input_attacker, input_sentence, input_label], outputs=[ output_original_example, output_original_label, output_repaired_example, output_repaired_label, output_adv_example, ori_text_diff, adv_text_diff, restored_text_diff, output_adv_label, output_df, output_is_adv_df, ], ) demo.launch()