File size: 4,493 Bytes
0e73e91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import copy
import time
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import nltk
import string
from copy import deepcopy
from torchprofile import profile_macs
from datetime import datetime

from transformers import BertTokenizer, BertModel, BertForMaskedLM
from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer
from blackbox_utils.Attack_base import MyAttack


class CharacterAttack(MyAttack):
    # TODO: 存储一个list每次只修改不同的token位置
    def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
        super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)

    def compute_importance(self, text):
        current_tensor = self.preprocess_function(text)["input_ids"][0]
        # print(current_tensor)
        word_losses = {}
        for idx in range(1,len(current_tensor)-1):
            # print(current_tensor[:idx])
            # print(current_tensor[idx+1:])
            sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]])
            sentence_without = self.tokenizer.decode(sentence_tokens_without)
            sentence_without = [sentence_without,text[1]]
            word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without)
        word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)] 
        return word_losses 

    def compute_loss(self, text):
        inputs = self.preprocess_function(text)
        shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids'])
        # toc = datetime.now()
        macs = profile_macs(self.model, shift_inputs)
        # tic = datetime.now()
        # print((tic-toc).total_seconds())
        result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True)
        token_length = len(result["input_ids"])
        macs_per_token = macs/(token_length*10**8)

        return self.predict(macs_per_token)

    def mutation(self, current_adv_text):
        current_tensor = self.preprocess_function(current_adv_text)
        # print(current_tensor)
        current_tensor = current_tensor["input_ids"][0]
        # print(current_tensor)
        new_strings = self.character_replace_mutation(current_adv_text, current_tensor)
        return new_strings

    @staticmethod
    def transfer(c: str):
        if c in string.ascii_lowercase:
            return c.upper()
        elif c in string.ascii_uppercase:
            return c.lower()
        return c

    def character_replace_mutation(self, current_text, current_tensor):
        important_tensor = self.compute_importance(current_text)
        # current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor]
        new_strings = [current_text]
        # 遍历每个vocabulary,查找文本有的第一个token
        # print(current_tensor)
        for t in important_tensor:
            if int(t) not in current_tensor:
                continue
            ori_decode_token = self.tokenizer.decode([int(t)])
            # print(ori_decode_token)
            # if self.space_token in ori_decode_token:
            #     ori_token = ori_decode_token.replace(self.space_token, '')
            # else:
            ori_token = ori_decode_token
            # 如果只有一个长度
            if len(ori_token) == 1 or ori_token not in current_text[0]:  #todo
                continue
            # 随意插入一个字符
            candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character]
            # 随意更换一个大小写
            candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))]
            # print(candidate)
            # 最多只替换一次
            new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate]

            # ori_tensor_pos = current_tensor.eq(int(t)).nonzero()
            #
            # for p in ori_tensor_pos:
            #     new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate]
            # 存在一个有效的改动就返回
            if len(new_strings) > 1:
                return new_strings
        return new_strings