Spaces:
Sleeping
Sleeping
File size: 4,493 Bytes
0e73e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import copy
import time
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import nltk
import string
from copy import deepcopy
from torchprofile import profile_macs
from datetime import datetime
from transformers import BertTokenizer, BertModel, BertForMaskedLM
from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer
from blackbox_utils.Attack_base import MyAttack
class CharacterAttack(MyAttack):
# TODO: 存储一个list每次只修改不同的token位置
def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)
def compute_importance(self, text):
current_tensor = self.preprocess_function(text)["input_ids"][0]
# print(current_tensor)
word_losses = {}
for idx in range(1,len(current_tensor)-1):
# print(current_tensor[:idx])
# print(current_tensor[idx+1:])
sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]])
sentence_without = self.tokenizer.decode(sentence_tokens_without)
sentence_without = [sentence_without,text[1]]
word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without)
word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)]
return word_losses
def compute_loss(self, text):
inputs = self.preprocess_function(text)
shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids'])
# toc = datetime.now()
macs = profile_macs(self.model, shift_inputs)
# tic = datetime.now()
# print((tic-toc).total_seconds())
result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True)
token_length = len(result["input_ids"])
macs_per_token = macs/(token_length*10**8)
return self.predict(macs_per_token)
def mutation(self, current_adv_text):
current_tensor = self.preprocess_function(current_adv_text)
# print(current_tensor)
current_tensor = current_tensor["input_ids"][0]
# print(current_tensor)
new_strings = self.character_replace_mutation(current_adv_text, current_tensor)
return new_strings
@staticmethod
def transfer(c: str):
if c in string.ascii_lowercase:
return c.upper()
elif c in string.ascii_uppercase:
return c.lower()
return c
def character_replace_mutation(self, current_text, current_tensor):
important_tensor = self.compute_importance(current_text)
# current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor]
new_strings = [current_text]
# 遍历每个vocabulary,查找文本有的第一个token
# print(current_tensor)
for t in important_tensor:
if int(t) not in current_tensor:
continue
ori_decode_token = self.tokenizer.decode([int(t)])
# print(ori_decode_token)
# if self.space_token in ori_decode_token:
# ori_token = ori_decode_token.replace(self.space_token, '')
# else:
ori_token = ori_decode_token
# 如果只有一个长度
if len(ori_token) == 1 or ori_token not in current_text[0]: #todo
continue
# 随意插入一个字符
candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character]
# 随意更换一个大小写
candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))]
# print(candidate)
# 最多只替换一次
new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate]
# ori_tensor_pos = current_tensor.eq(int(t)).nonzero()
#
# for p in ori_tensor_pos:
# new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate]
# 存在一个有效的改动就返回
if len(new_strings) > 1:
return new_strings
return new_strings |