adamtayzzz commited on
Commit
0e73e91
·
verified ·
1 Parent(s): 28b6ab1

Upload 21 files

Browse files
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ import random
6
+
7
+ import datasets
8
+ from datasets import load_dataset, load_metric
9
+ from torch.utils.data import DataLoader
10
+ from tqdm.auto import tqdm
11
+ import gradio as gr
12
+
13
+ import transformers
14
+ from accelerate import Accelerator # huggingface package
15
+ from transformers import (
16
+ AdamW,
17
+ AutoConfig,
18
+ AutoModelForSequenceClassification,
19
+ AutoTokenizer,
20
+ DataCollatorWithPadding,
21
+ PretrainedConfig,
22
+ SchedulerType,
23
+ default_data_collator,
24
+ get_scheduler,
25
+ set_seed,
26
+ BertTokenizer,
27
+ )
28
+ from transformers.utils.versions import require_version
29
+
30
+ import torch
31
+ from test_module.modeling_transkimer import BertForSequenceClassification as TranskimerForSequenceClassification
32
+ from test_module.modeling_transkimer_roberta import RobertaForSequenceClassification as TranskimerRobertaForSequenceClassification
33
+ from test_module.modeling_utils import convert_softmax_mask_to_digit
34
+
35
+ from blackbox_utils.my_attack import CharacterAttack
36
+ from transformers import glue_processors as processors
37
+
38
+
39
+ task_to_keys = {
40
+ "cola": ("sentence", None),
41
+ "mnli": ("premise", "hypothesis"),
42
+ "mrpc": ("sentence1", "sentence2"),
43
+ "qnli": ("question", "sentence"),
44
+ "qqp": ("question1", "question2"),
45
+ "rte": ("sentence1", "sentence2"),
46
+ "sst2": ("sentence", None),
47
+ "stsb": ("sentence1", "sentence2"),
48
+ "wnli": ("sentence1", "sentence2"),
49
+ "imdb": ("text", None),
50
+ }
51
+
52
+ model_path_dict = {
53
+ "transkimer_sst2_not_pad":'./not_pad_0.5',
54
+ }
55
+
56
+
57
+ datasets.utils.logging.set_verbosity_error()
58
+ transformers.utils.logging.set_verbosity_error()
59
+
60
+
61
+ task_name = 'sst2'
62
+ model_type = 'transkimer'
63
+
64
+ # Load pretrained model and tokenizer
65
+ model_path_key = f'{model_type}_{task_name}_not_pad'
66
+ model_path = model_path_dict[model_path_key]
67
+ config = AutoConfig.from_pretrained(model_path, num_labels=num_labels, finetuning_task=task_name)
68
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', use_fast=True)
69
+ model = TranskimerForSequenceClassification.from_pretrained(model_path,from_tf=bool(".ckpt" in model_path),config=config,)
70
+
71
+ # Preprocessing the datasets
72
+ sentence1_key, sentence2_key = task_to_keys[task_name]
73
+
74
+ processor = processors[task_name]()
75
+ label_list = processor.get_labels()
76
+
77
+ label_to_id = {v: i for i, v in enumerate(label_list)}
78
+
79
+ padding = False
80
+
81
+ attack = CharacterAttack(f'{model_type}_{task_name}',model,tokenizer,device='cpu',max_per=10,padding=padding,max_length=128,label_to_id=label_to_id,sentence1_key=sentence1_key,sentence2_key=sentence2_key)
82
+
83
+
84
+ def greet(text):
85
+ text_input = [(text,None)]
86
+ outputs,time = attack.get_prob(text_input)
87
+ _,token_remained,_ = attack.output_analysis(outputs)
88
+ return time,token_remained.item()
89
+
90
+ iface = gr.Interface(fn=greet, inputs=["text","text"], outputs=["number","number"])
91
+ iface.launch()
blackbox_utils/Attack_base.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ # import jieba
4
+ import string
5
+ import numpy as np
6
+ from copy import deepcopy
7
+ from tqdm import tqdm
8
+ import time
9
+ from datetime import datetime
10
+ import os
11
+ from sklearn.linear_model import LinearRegression
12
+ from torch.multiprocessing import Process,Pool
13
+
14
+ from transformers import BertTokenizer
15
+
16
+ os.environ['TOKENIZERS_PARALLELISM']='True'
17
+ # torch.autograd.set_detect_anomaly(True)
18
+
19
+ class BaseAttack:
20
+ def __init__(self, name, model, tokenizer, device, max_per, padding,max_length,label_to_id,sentence1_key,sentence2_key):
21
+ self.name = name
22
+ self.model = model
23
+ self.tokenizer = tokenizer
24
+ self.device = device
25
+ self.model = self.model.to(self.device)
26
+ self.model.eval()
27
+ self.padding = padding
28
+ self.max_length = max_length
29
+ self.label_to_id = label_to_id
30
+ self.sentence1_key = sentence1_key
31
+ self.sentence2_key = sentence2_key
32
+ # 修改token个数的最大值
33
+ self.max_per = max_per
34
+ # linear regression model initialization
35
+ self.linear_regression()
36
+ self.random_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
37
+
38
+ def run_attack(self, x):
39
+ pass
40
+
41
+ def compute_loss(self, x):
42
+ pass
43
+
44
+ def preprocess_function(self,examples,to_device=True):
45
+ # Tokenize the texts
46
+ texts = ((examples[0],) if self.sentence2_key is None else (examples[0], examples[1]))
47
+ result = self.tokenizer(*texts, padding=self.padding, max_length=self.max_length, truncation=True)
48
+ new_result = {}
49
+ for key,item in result.items():
50
+ if to_device:
51
+ new_result[key] = torch.tensor(item).unsqueeze(0).to(self.device)
52
+ else:
53
+ new_result[key] = torch.tensor(item).unsqueeze(0)
54
+ return new_result
55
+
56
+ def get_pred(self,input_):
57
+ return self.get_prob(input_).logits.argmax(dim=-1)
58
+
59
+ def get_prob(self,input_):
60
+ toc = datetime.now()
61
+ batch = self.preprocess_function(input_)
62
+ # batch['gumbel_softmax']=gradient
63
+ # print(batch)
64
+ outputs = self.model(**batch) # get all logits
65
+ tic = datetime.now()
66
+ running_time = (tic-toc).total_seconds()
67
+ return outputs,running_time
68
+
69
+ def output_analysis(self,outputs):
70
+ # print(outputs)
71
+
72
+ all_skim_loss, all_tokens_remained = list(), list()
73
+ all_layer_tokens_remained = [[] for _ in range(len(outputs.layer_tokens_remained))]
74
+
75
+ all_skim_loss.append(outputs.skim_loss)
76
+ all_tokens_remained.append(outputs.tokens_remained)
77
+ for layer_idx,mac in enumerate(outputs.layer_tokens_remained):
78
+ all_layer_tokens_remained[layer_idx].append(mac)
79
+
80
+ skim_loss = torch.mean(torch.stack(all_skim_loss))
81
+ tokens_remained = torch.mean(torch.stack(all_tokens_remained))
82
+ layers_result = [torch.mean(torch.stack(macs)) for i,macs in enumerate(all_layer_tokens_remained)]
83
+
84
+ return skim_loss,tokens_remained,layers_result
85
+
86
+ def load_data(self,model_path_key,mode='train'):
87
+ path = f'flops_count/{model_path_key}/{mode}'
88
+ if os.path.exists(f'{path}/process_data.pth'):
89
+ print(f'loading data from {path}')
90
+ data = torch.load(f'{path}/process_data.pth')
91
+ else:
92
+ time_list = torch.load(f'{path}/time_list.pth')
93
+ ratio_list = torch.load(f'{path}/ratio_list.pth')
94
+ token_num_list = torch.load(f'{path}/text_len_list_tokenizer.pth')
95
+
96
+ ratio_list_ = []
97
+ for ratio in ratio_list:
98
+ ratio_list_.append(ratio.item())
99
+ y = np.expand_dims(np.array(ratio_list_),axis=1)
100
+ # print(x.shape)
101
+
102
+ time_list_ = []
103
+ for time,token_num in zip(time_list,token_num_list):
104
+ time_list_.append((time/(token_num*(10**8))))
105
+ x = np.expand_dims(np.array(time_list_),axis=1)
106
+ # print(y.shape)
107
+
108
+ data = dict()
109
+ data['x']=x
110
+ data['y']=y
111
+ torch.save(data,f'{path}/process_data.pth')
112
+
113
+ return data
114
+
115
+ def predict(self,x):
116
+ return self.w*x+self.b
117
+
118
+ def linear_regression(self):
119
+ print("="*20)
120
+ print('Linear Regression Generation')
121
+ data_train = self.load_data(self.name,mode='train')
122
+ data_test = self.load_data(self.name,mode='test')
123
+ # print(data_train,data_test)
124
+
125
+ reg = LinearRegression().fit(data_train['x'],data_train['y'])
126
+ train_score = reg.score(data_train['x'],data_train['y'])
127
+ test_score = reg.score(data_test['x'],data_test['y'])
128
+ print(f'train set score: {train_score}')
129
+ print(f'test set score: {test_score}')
130
+
131
+ self.w = reg.coef_[0][0]
132
+ self.b = reg.intercept_[0]
133
+ print("w:",self.w)
134
+ print("b:",self.b)
135
+
136
+ print(self.predict(0.8))
137
+
138
+
139
+ class MyAttack(BaseAttack):
140
+ def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
141
+ super(MyAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)
142
+ # self.insert_character = string.punctuation
143
+ self.insert_character = string.digits
144
+ self.insert_character += string.ascii_letters
145
+ # self.insert_character -= """"'/\\"""
146
+ # print(self.insert_character)
147
+
148
+ self.origin_ratio = []
149
+ self.attack_ratio = []
150
+ self.layer_result = []
151
+ self.origin_layer_result = []
152
+
153
+ # @torch.no_grad()
154
+ # def select_best(self, new_strings):
155
+ # best_string = None
156
+ # best_loss = 0
157
+ # for new_string in new_strings:
158
+ # new_predicted_loss = self.compute_loss(new_string)
159
+ # if new_predicted_loss>best_loss:
160
+ # best_loss = new_predicted_loss
161
+ # best_string = new_string
162
+
163
+ # assert best_string is not None
164
+ # return best_string,best_loss
165
+
166
+ @torch.no_grad()
167
+ def select_best(self, new_strings):
168
+ # self.model.to('cpu')
169
+ best_string = None
170
+ best_loss = 0
171
+ with Pool(processes=4) as pool:
172
+ loss_list = pool.map(self.compute_loss,new_strings)
173
+ idx = np.argmax(np.array(loss_list))
174
+ best_loss = loss_list[idx]
175
+ best_string = new_strings[idx]
176
+ # self.model.to(self.device)
177
+ # for new_string in new_strings:
178
+ # new_predicted_loss = self.compute_loss(new_string)
179
+ # if new_predicted_loss>best_loss:
180
+ # best_loss = new_predicted_loss
181
+ # best_string = new_string
182
+
183
+ assert best_string is not None
184
+ # self.model.to(self.device)
185
+ return best_string,best_loss
186
+
187
+ def compute_loss(self, xxx):
188
+ raise NotImplementedError
189
+
190
+ def mutation(self, current_adv_text, grad, modify_pos):
191
+ raise NotImplementedError
192
+
193
+ def run_attack(self, text):
194
+ # assert len(text) == 1
195
+ # print(text)
196
+ text[0] = text[0].strip(" .")
197
+ text[1] = text[1].strip(" .")
198
+ print(f'Origin Text: {text}')
199
+ current_adv_text = deepcopy(text)
200
+ # max_per 最多扰动单词的个数
201
+ # pbar = tqdm(range(self.max_per))
202
+
203
+ best_loss = 0
204
+ best_tokens_remained = 0
205
+ best_layer_result = None
206
+
207
+ output,_ = self.get_prob(current_adv_text)
208
+ origin_skim_loss,origin_ratio_,origin_layer_result_ = self.output_analysis(output)
209
+ print(origin_skim_loss,origin_ratio_)
210
+ self.origin_ratio.append(origin_ratio_.item())
211
+ self.origin_layer_result.append(origin_layer_result_)
212
+
213
+
214
+ # for it in pbar:
215
+ for _ in range(self.max_per):
216
+ # 得到每个修改的位置
217
+ new_strings = self.mutation(current_adv_text)
218
+ #print(new_strings)
219
+ current_adv_text,current_loss = self.select_best(new_strings)
220
+ # print(new_strings)
221
+ # print(current_adv_text,current_loss,current_tokens_remained)
222
+ if current_loss > best_loss:
223
+ best_adv_text = deepcopy(current_adv_text)
224
+ best_loss = current_loss
225
+ print(best_adv_text)
226
+
227
+ output,_ = self.get_prob(best_adv_text)
228
+ _,best_tokens_remained,best_layer_result = self.output_analysis(output)
229
+
230
+ self.attack_ratio.append(best_tokens_remained.item())
231
+ self.layer_result.append(best_layer_result)
232
+ print(f'Malicious Text: {best_adv_text}')
233
+ print(f'Origin Ratio: {self.origin_ratio[-1]} Attack Ratio: {self.attack_ratio[-1]}')
234
+ print(f'Layer Result: {self.layer_result[-1]}')
235
+
236
+ return best_adv_text,best_loss,best_tokens_remained,best_layer_result
237
+
238
+
blackbox_utils/__pycache__/Attack_base.cpython-37.pyc ADDED
Binary file (6.68 kB). View file
 
blackbox_utils/__pycache__/my_attack.cpython-37.pyc ADDED
Binary file (3.89 kB). View file
 
blackbox_utils/my_attack.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ import nltk
8
+ import string
9
+ from copy import deepcopy
10
+ from torchprofile import profile_macs
11
+ from datetime import datetime
12
+
13
+ from transformers import BertTokenizer, BertModel, BertForMaskedLM
14
+ from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer
15
+ from blackbox_utils.Attack_base import MyAttack
16
+
17
+
18
+ class CharacterAttack(MyAttack):
19
+ # TODO: 存储一个list每次只修改不同的token位置
20
+ def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key):
21
+ super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key)
22
+
23
+ def compute_importance(self, text):
24
+ current_tensor = self.preprocess_function(text)["input_ids"][0]
25
+ # print(current_tensor)
26
+ word_losses = {}
27
+ for idx in range(1,len(current_tensor)-1):
28
+ # print(current_tensor[:idx])
29
+ # print(current_tensor[idx+1:])
30
+ sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]])
31
+ sentence_without = self.tokenizer.decode(sentence_tokens_without)
32
+ sentence_without = [sentence_without,text[1]]
33
+ word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without)
34
+ word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)]
35
+ return word_losses
36
+
37
+ def compute_loss(self, text):
38
+ inputs = self.preprocess_function(text)
39
+ shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids'])
40
+ # toc = datetime.now()
41
+ macs = profile_macs(self.model, shift_inputs)
42
+ # tic = datetime.now()
43
+ # print((tic-toc).total_seconds())
44
+ result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True)
45
+ token_length = len(result["input_ids"])
46
+ macs_per_token = macs/(token_length*10**8)
47
+
48
+ return self.predict(macs_per_token)
49
+
50
+ def mutation(self, current_adv_text):
51
+ current_tensor = self.preprocess_function(current_adv_text)
52
+ # print(current_tensor)
53
+ current_tensor = current_tensor["input_ids"][0]
54
+ # print(current_tensor)
55
+ new_strings = self.character_replace_mutation(current_adv_text, current_tensor)
56
+ return new_strings
57
+
58
+ @staticmethod
59
+ def transfer(c: str):
60
+ if c in string.ascii_lowercase:
61
+ return c.upper()
62
+ elif c in string.ascii_uppercase:
63
+ return c.lower()
64
+ return c
65
+
66
+ def character_replace_mutation(self, current_text, current_tensor):
67
+ important_tensor = self.compute_importance(current_text)
68
+ # current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor]
69
+ new_strings = [current_text]
70
+ # 遍历每个vocabulary,查找文本有的第一个token
71
+ # print(current_tensor)
72
+ for t in important_tensor:
73
+ if int(t) not in current_tensor:
74
+ continue
75
+ ori_decode_token = self.tokenizer.decode([int(t)])
76
+ # print(ori_decode_token)
77
+ # if self.space_token in ori_decode_token:
78
+ # ori_token = ori_decode_token.replace(self.space_token, '')
79
+ # else:
80
+ ori_token = ori_decode_token
81
+ # 如果只有一个长度
82
+ if len(ori_token) == 1 or ori_token not in current_text[0]: #todo
83
+ continue
84
+ # 随意插入一个字符
85
+ candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character]
86
+ # 随意更换一个大小写
87
+ candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))]
88
+ # print(candidate)
89
+ # 最多只替换一次
90
+ new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate]
91
+
92
+ # ori_tensor_pos = current_tensor.eq(int(t)).nonzero()
93
+ #
94
+ # for p in ori_tensor_pos:
95
+ # new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate]
96
+ # 存在一个有效的改动就返回
97
+ if len(new_strings) > 1:
98
+ return new_strings
99
+ return new_strings
not_pad_0.5/config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "bert-base-uncased",
3
+ "architectures": [
4
+ "BertForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "classifier_dropout": null,
8
+ "finetuning_task": "sst2",
9
+ "gradient_checkpointing": false,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 768,
13
+ "id2label": {
14
+ "0": "negative",
15
+ "1": "positive"
16
+ },
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "label2id": {
20
+ "negative": 0,
21
+ "positive": 1
22
+ },
23
+ "layer_norm_eps": 1e-12,
24
+ "max_position_embeddings": 512,
25
+ "model_type": "bert",
26
+ "num_attention_heads": 12,
27
+ "num_hidden_layers": 12,
28
+ "pad_token_id": 0,
29
+ "position_embedding_type": "absolute",
30
+ "problem_type": "single_label_classification",
31
+ "skim_coefficient": 0.5,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.26.1",
34
+ "type_vocab_size": 2,
35
+ "use_cache": true,
36
+ "vocab_size": 30522
37
+ }
not_pad_0.5/log.log ADDED
The diff for this file is too large to render. See raw diff
 
not_pad_0.5/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b49ee701f19e2ad22efc89e07452b0aed5c3075800004d2a10b63308a29f4363
3
+ size 466610485
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ argparse
2
+ logging
3
+ math
4
+ os
5
+ random
6
+ tqdm
7
+ gradio
8
+ transformers
9
+ accelerate
10
+ torch
11
+ datetime
12
+ time
13
+ copy
14
+ numpy
15
+ string
16
+ sklearn
17
+ nltk
18
+ torchprofile
test_module/__init__.py ADDED
File without changes
test_module/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (151 Bytes). View file
 
test_module/__pycache__/modeling_skim_predictor.cpython-37.pyc ADDED
Binary file (1.9 kB). View file
 
test_module/__pycache__/modeling_transkimer.cpython-37.pyc ADDED
Binary file (57 kB). View file
 
test_module/__pycache__/modeling_transkimer_roberta.cpython-37.pyc ADDED
Binary file (45.5 kB). View file
 
test_module/__pycache__/modeling_utils.cpython-37.pyc ADDED
Binary file (3.71 kB). View file
 
test_module/modeling_skim_predictor.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+ def init_skim_predictor(module_list, mean_bias=5.0):
5
+ for module in module_list:
6
+ if not isinstance(module, torch.nn.Linear):
7
+ raise ValueError("only support initialization of linear skim predictor")
8
+
9
+ # module.bias.data[1].fill_(5.0)
10
+ # module.bias.data[0].fill_(-5.0)
11
+ # module.weight.data.zero_()
12
+ module.bias.data[1].normal_(mean=mean_bias, std=0.02)
13
+ module.bias.data[0].normal_(mean=-mean_bias, std=0.02)
14
+ module.weight.data.normal_(mean=0.0, std=0.02)
15
+
16
+ module._skim_initialized = True
17
+
18
+ class SkimPredictor(nn.Module):
19
+ def __init__(self, input_size, output_size, hidden_size=None):
20
+ super().__init__()
21
+
22
+ self.hidden_size = hidden_size if hidden_size else input_size
23
+
24
+ self.predictor = nn.Sequential(
25
+ nn.LayerNorm(input_size),
26
+ nn.Linear(input_size, self.hidden_size),
27
+ # nn.GELU(),
28
+ # nn.Linear(self.hidden_size, self.hidden_size),
29
+ nn.LayerNorm(self.hidden_size),
30
+ nn.GELU(),
31
+ nn.Linear(self.hidden_size, output_size),
32
+ )
33
+
34
+ init_skim_predictor([self.predictor[-1]])
35
+
36
+ def forward(self, hidden_states):
37
+ return self.predictor(hidden_states)
38
+
39
+ def test_init_skim_predictor():
40
+ num_layers = 12
41
+
42
+ skim_predictors = torch.nn.ModuleList([torch.nn.Linear(768,2) for _ in range(num_layers)])
43
+ init_skim_predictor(skim_predictors)
44
+
45
+ print(skim_predictors[0].weight, skim_predictors[0].bias)
46
+
47
+ rand_input = torch.rand((4, 16, 768))
48
+ print(skim_predictors[0](rand_input))
49
+
50
+ if __name__ == "__main__":
51
+ test_init_skim_predictor()
test_module/modeling_transkimer.py ADDED
@@ -0,0 +1,2002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+
19
+ import math
20
+ import os
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from packaging import version
28
+ from torch import nn
29
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.file_utils import (
33
+ ModelOutput,
34
+ add_code_sample_docstrings,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.modeling_outputs import (
40
+ BaseModelOutputWithPastAndCrossAttentions,
41
+ BaseModelOutputWithPoolingAndCrossAttentions,
42
+ CausalLMOutputWithCrossAttentions,
43
+ MaskedLMOutput,
44
+ MultipleChoiceModelOutput,
45
+ NextSentencePredictorOutput,
46
+ QuestionAnsweringModelOutput,
47
+ SequenceClassifierOutput,
48
+ TokenClassifierOutput,
49
+ )
50
+ from transformers.modeling_utils import (
51
+ PreTrainedModel,
52
+ apply_chunking_to_forward,
53
+ find_pruneable_heads_and_indices,
54
+ prune_linear_layer,
55
+ )
56
+ from transformers.utils import logging
57
+ from transformers.models.bert.configuration_bert import BertConfig
58
+
59
+ from .modeling_skim_predictor import SkimPredictor
60
+ from .modeling_utils import BaseModelOutputWithPastAndCrossAttentionsSkim, BaseModelOutputWithPoolingAndCrossAttentionsSkim, MaskedLMOutputSkim, QuestionAnsweringModelOutputSkim, SequenceClassifierOutputSkim, convert_softmax_mask_to_digit, trunc_with_mask_batched, masked_softmax
61
+
62
+ logger = logging.get_logger(__name__)
63
+
64
+ _CHECKPOINT_FOR_DOC = "bert-base-uncased"
65
+ _CONFIG_FOR_DOC = "BertConfig"
66
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
67
+
68
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
69
+ "bert-base-uncased",
70
+ "bert-large-uncased",
71
+ "bert-base-cased",
72
+ "bert-large-cased",
73
+ "bert-base-multilingual-uncased",
74
+ "bert-base-multilingual-cased",
75
+ "bert-base-chinese",
76
+ "bert-base-german-cased",
77
+ "bert-large-uncased-whole-word-masking",
78
+ "bert-large-cased-whole-word-masking",
79
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
80
+ "bert-large-cased-whole-word-masking-finetuned-squad",
81
+ "bert-base-cased-finetuned-mrpc",
82
+ "bert-base-german-dbmdz-cased",
83
+ "bert-base-german-dbmdz-uncased",
84
+ "cl-tohoku/bert-base-japanese",
85
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
86
+ "cl-tohoku/bert-base-japanese-char",
87
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
88
+ "TurkuNLP/bert-base-finnish-cased-v1",
89
+ "TurkuNLP/bert-base-finnish-uncased-v1",
90
+ "wietsedv/bert-base-dutch-cased",
91
+ # See all BERT models at https://huggingface.co/models?filter=bert
92
+ ]
93
+
94
+
95
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
96
+ """Load tf checkpoints in a pytorch model."""
97
+ try:
98
+ import re
99
+
100
+ import numpy as np
101
+ import tensorflow as tf
102
+ except ImportError:
103
+ logger.error(
104
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
105
+ "https://www.tensorflow.org/install/ for installation instructions."
106
+ )
107
+ raise
108
+ tf_path = os.path.abspath(tf_checkpoint_path)
109
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
110
+ # Load weights from TF model
111
+ init_vars = tf.train.list_variables(tf_path)
112
+ names = []
113
+ arrays = []
114
+ for name, shape in init_vars:
115
+ logger.info(f"Loading TF weight {name} with shape {shape}")
116
+ array = tf.train.load_variable(tf_path, name)
117
+ names.append(name)
118
+ arrays.append(array)
119
+
120
+ for name, array in zip(names, arrays):
121
+ name = name.split("/")
122
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
123
+ # which are not required for using pretrained model
124
+ if any(
125
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
126
+ for n in name
127
+ ):
128
+ logger.info(f"Skipping {'/'.join(name)}")
129
+ continue
130
+ pointer = model
131
+ for m_name in name:
132
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
133
+ scope_names = re.split(r"_(\d+)", m_name)
134
+ else:
135
+ scope_names = [m_name]
136
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
137
+ pointer = getattr(pointer, "weight")
138
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
139
+ pointer = getattr(pointer, "bias")
140
+ elif scope_names[0] == "output_weights":
141
+ pointer = getattr(pointer, "weight")
142
+ elif scope_names[0] == "squad":
143
+ pointer = getattr(pointer, "classifier")
144
+ else:
145
+ try:
146
+ pointer = getattr(pointer, scope_names[0])
147
+ except AttributeError:
148
+ logger.info(f"Skipping {'/'.join(name)}")
149
+ continue
150
+ if len(scope_names) >= 2:
151
+ num = int(scope_names[1])
152
+ pointer = pointer[num]
153
+ if m_name[-11:] == "_embeddings":
154
+ pointer = getattr(pointer, "weight")
155
+ elif m_name == "kernel":
156
+ array = np.transpose(array)
157
+ try:
158
+ assert (
159
+ pointer.shape == array.shape
160
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
161
+ except AssertionError as e:
162
+ e.args += (pointer.shape, array.shape)
163
+ raise
164
+ logger.info(f"Initialize PyTorch weight {name}")
165
+ pointer.data = torch.from_numpy(array)
166
+ return model
167
+
168
+
169
+ class BertEmbeddings(nn.Module):
170
+ """Construct the embeddings from word, position and token_type embeddings."""
171
+
172
+ def __init__(self, config):
173
+ super().__init__()
174
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
175
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
176
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
177
+
178
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
179
+ # any TensorFlow checkpoint file
180
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
181
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
182
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
183
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
184
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
185
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
186
+ self.register_buffer(
187
+ "token_type_ids",
188
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
189
+ persistent=False,
190
+ )
191
+
192
+ def forward(
193
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
194
+ ):
195
+ if input_ids is not None:
196
+ input_shape = input_ids.size()
197
+ else:
198
+ input_shape = inputs_embeds.size()[:-1]
199
+
200
+ seq_length = input_shape[1]
201
+
202
+ if position_ids is None:
203
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
204
+
205
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
206
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
207
+ # issue #5664
208
+ if token_type_ids is None:
209
+ if hasattr(self, "token_type_ids"):
210
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
211
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
212
+ token_type_ids = buffered_token_type_ids_expanded
213
+ else:
214
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
215
+
216
+ if inputs_embeds is None:
217
+ inputs_embeds = self.word_embeddings(input_ids)
218
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
219
+
220
+ embeddings = inputs_embeds + token_type_embeddings
221
+ if self.position_embedding_type == "absolute":
222
+ position_embeddings = self.position_embeddings(position_ids)
223
+ embeddings += position_embeddings
224
+ embeddings = self.LayerNorm(embeddings)
225
+ embeddings = self.dropout(embeddings)
226
+ return embeddings
227
+
228
+
229
+ class BertSelfAttention(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
233
+ raise ValueError(
234
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
235
+ f"heads ({config.num_attention_heads})"
236
+ )
237
+
238
+ self.num_attention_heads = config.num_attention_heads
239
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
240
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
241
+
242
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
243
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
244
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
245
+
246
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
247
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
248
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
249
+ self.max_position_embeddings = config.max_position_embeddings
250
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
251
+
252
+ self.is_decoder = config.is_decoder
253
+
254
+ def transpose_for_scores(self, x):
255
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
256
+ x = x.view(*new_x_shape)
257
+ return x.permute(0, 2, 1, 3)
258
+
259
+ def forward(
260
+ self,
261
+ hidden_states,
262
+ attention_mask=None,
263
+ head_mask=None,
264
+ encoder_hidden_states=None,
265
+ encoder_attention_mask=None,
266
+ past_key_value=None,
267
+ output_attentions=False,
268
+ skim_mask=None,
269
+ ):
270
+ mixed_query_layer = self.query(hidden_states)
271
+
272
+ # If this is instantiated as a cross-attention module, the keys
273
+ # and values come from an encoder; the attention mask needs to be
274
+ # such that the encoder's padding tokens are not attended to.
275
+ is_cross_attention = encoder_hidden_states is not None
276
+
277
+ if is_cross_attention and past_key_value is not None:
278
+ # reuse k,v, cross_attentions
279
+ key_layer = past_key_value[0]
280
+ value_layer = past_key_value[1]
281
+ attention_mask = encoder_attention_mask
282
+ elif is_cross_attention:
283
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
284
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
285
+ attention_mask = encoder_attention_mask
286
+ elif past_key_value is not None:
287
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
288
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
289
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
290
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
291
+ else:
292
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
293
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
294
+
295
+ query_layer = self.transpose_for_scores(mixed_query_layer)
296
+
297
+ if self.is_decoder:
298
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
299
+ # Further calls to cross_attention layer can then reuse all cross-attention
300
+ # key/value_states (first "if" case)
301
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
302
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
303
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
304
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
305
+ past_key_value = (key_layer, value_layer)
306
+
307
+ # Take the dot product between "query" and "key" to get the raw attention scores.
308
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
309
+
310
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
311
+ seq_length = hidden_states.size()[1]
312
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
313
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
314
+ distance = position_ids_l - position_ids_r
315
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
316
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
317
+
318
+ if self.position_embedding_type == "relative_key":
319
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
320
+ attention_scores = attention_scores + relative_position_scores
321
+ elif self.position_embedding_type == "relative_key_query":
322
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
323
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
324
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
325
+
326
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
327
+ if attention_mask is not None:
328
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
329
+ attention_scores = attention_scores + attention_mask
330
+
331
+ # Normalize the attention scores to probabilities.
332
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
333
+ # attention_probs = masked_softmax(attention_scores, skim_mask, dim=3)
334
+
335
+ # mask attention probs during training for skimming
336
+ attention_probs = attention_probs * skim_mask[:, None, None, :]
337
+
338
+ # This is actually dropping out entire tokens to attend to, which might
339
+ # seem a bit unusual, but is taken from the original Transformer paper.
340
+ attention_probs = self.dropout(attention_probs)
341
+
342
+ # Mask heads if we want to
343
+ if head_mask is not None:
344
+ attention_probs = attention_probs * head_mask
345
+
346
+ context_layer = torch.matmul(attention_probs, value_layer)
347
+
348
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
349
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
350
+ context_layer = context_layer.view(*new_context_layer_shape)
351
+
352
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
353
+
354
+ if self.is_decoder:
355
+ outputs = outputs + (past_key_value,)
356
+ return outputs
357
+
358
+
359
+ class BertSelfOutput(nn.Module):
360
+ def __init__(self, config):
361
+ super().__init__()
362
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
363
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
364
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
365
+
366
+ def forward(self, hidden_states, input_tensor):
367
+ hidden_states = self.dense(hidden_states)
368
+ hidden_states = self.dropout(hidden_states)
369
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
370
+ return hidden_states
371
+
372
+
373
+ class BertAttention(nn.Module):
374
+ def __init__(self, config):
375
+ super().__init__()
376
+ self.self = BertSelfAttention(config)
377
+ self.output = BertSelfOutput(config)
378
+ self.pruned_heads = set()
379
+
380
+ def prune_heads(self, heads):
381
+ if len(heads) == 0:
382
+ return
383
+ heads, index = find_pruneable_heads_and_indices(
384
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
385
+ )
386
+
387
+ # Prune linear layers
388
+ self.self.query = prune_linear_layer(self.self.query, index)
389
+ self.self.key = prune_linear_layer(self.self.key, index)
390
+ self.self.value = prune_linear_layer(self.self.value, index)
391
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
392
+
393
+ # Update hyper params and store pruned heads
394
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
395
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
396
+ self.pruned_heads = self.pruned_heads.union(heads)
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ attention_mask=None,
402
+ head_mask=None,
403
+ encoder_hidden_states=None,
404
+ encoder_attention_mask=None,
405
+ past_key_value=None,
406
+ output_attentions=False,
407
+ skim_mask=None,
408
+ ):
409
+ self_outputs = self.self(
410
+ hidden_states,
411
+ attention_mask,
412
+ head_mask,
413
+ encoder_hidden_states,
414
+ encoder_attention_mask,
415
+ past_key_value,
416
+ output_attentions,
417
+ skim_mask,
418
+ )
419
+ attention_output = self.output(self_outputs[0], hidden_states)
420
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
421
+ return outputs
422
+
423
+
424
+ class BertIntermediate(nn.Module):
425
+ def __init__(self, config):
426
+ super().__init__()
427
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
428
+ if isinstance(config.hidden_act, str):
429
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
430
+ else:
431
+ self.intermediate_act_fn = config.hidden_act
432
+
433
+ def forward(self, hidden_states):
434
+ hidden_states = self.dense(hidden_states)
435
+ hidden_states = self.intermediate_act_fn(hidden_states)
436
+ return hidden_states
437
+
438
+
439
+ class BertOutput(nn.Module):
440
+ def __init__(self, config):
441
+ super().__init__()
442
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
443
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
444
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
445
+
446
+ def forward(self, hidden_states, input_tensor):
447
+ hidden_states = self.dense(hidden_states)
448
+ hidden_states = self.dropout(hidden_states)
449
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
450
+ return hidden_states
451
+
452
+
453
+ class BertLayer(nn.Module):
454
+ def __init__(self, config):
455
+ super().__init__()
456
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
457
+ self.seq_len_dim = 1
458
+ self.attention = BertAttention(config)
459
+ self.is_decoder = config.is_decoder
460
+ self.add_cross_attention = config.add_cross_attention
461
+ if self.add_cross_attention:
462
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
463
+ self.crossattention = BertAttention(config)
464
+ self.intermediate = BertIntermediate(config)
465
+ self.output = BertOutput(config)
466
+
467
+ def forward(
468
+ self,
469
+ hidden_states,
470
+ attention_mask=None,
471
+ head_mask=None,
472
+ encoder_hidden_states=None,
473
+ encoder_attention_mask=None,
474
+ past_key_value=None,
475
+ output_attentions=False,
476
+ skim_mask=None,
477
+ ):
478
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
479
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
480
+ self_attention_outputs = self.attention(
481
+ hidden_states,
482
+ attention_mask,
483
+ head_mask,
484
+ output_attentions=output_attentions,
485
+ past_key_value=self_attn_past_key_value,
486
+ skim_mask=skim_mask,
487
+ )
488
+ attention_output = self_attention_outputs[0]
489
+
490
+ # if decoder, the last output is tuple of self-attn cache
491
+ if self.is_decoder:
492
+ outputs = self_attention_outputs[1:-1]
493
+ present_key_value = self_attention_outputs[-1]
494
+ else:
495
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
496
+
497
+ cross_attn_present_key_value = None
498
+ if self.is_decoder and encoder_hidden_states is not None:
499
+ assert hasattr(
500
+ self, "crossattention"
501
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
502
+
503
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
504
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
505
+ cross_attention_outputs = self.crossattention(
506
+ attention_output,
507
+ attention_mask,
508
+ head_mask,
509
+ encoder_hidden_states,
510
+ encoder_attention_mask,
511
+ cross_attn_past_key_value,
512
+ output_attentions,
513
+ )
514
+ attention_output = cross_attention_outputs[0]
515
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
516
+
517
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
518
+ cross_attn_present_key_value = cross_attention_outputs[-1]
519
+ present_key_value = present_key_value + cross_attn_present_key_value
520
+
521
+ layer_output = apply_chunking_to_forward(
522
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
523
+ )
524
+ outputs = (layer_output,) + outputs
525
+
526
+ # if decoder, return the attn key/values as the last output
527
+ if self.is_decoder:
528
+ outputs = outputs + (present_key_value,)
529
+
530
+ return outputs
531
+
532
+ def feed_forward_chunk(self, attention_output):
533
+ intermediate_output = self.intermediate(attention_output)
534
+ layer_output = self.output(intermediate_output, attention_output)
535
+ return layer_output
536
+
537
+
538
+ class BertEncoder(nn.Module):
539
+ def __init__(self, config):
540
+ super().__init__()
541
+ self.config = config
542
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
543
+
544
+ # skim predictors for each layer
545
+ self.skim_predictors = nn.ModuleList([SkimPredictor(config.hidden_size, 2) for _ in range(config.num_hidden_layers)])
546
+ # init_skim_predictor(self.skim_predictors)
547
+
548
+ def forward(
549
+ self,
550
+ hidden_states,
551
+ attention_mask=None,
552
+ head_mask=None,
553
+ encoder_hidden_states=None,
554
+ encoder_attention_mask=None,
555
+ past_key_values=None,
556
+ use_cache=None,
557
+ output_attentions=False,
558
+ output_hidden_states=False,
559
+ return_dict=True,
560
+ ):
561
+ all_hidden_states = () if output_hidden_states else None
562
+ all_self_attentions = () if output_attentions else None
563
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
564
+ all_skim_mask = ()
565
+
566
+ next_decoder_cache = () if use_cache else None
567
+
568
+ forward_hidden_states = hidden_states.clone()
569
+ forward_skim_mask = None
570
+
571
+
572
+ for i, layer_module in enumerate(self.layer):
573
+ if output_hidden_states:
574
+ all_hidden_states = all_hidden_states + (hidden_states,)
575
+
576
+ # if gumbel_softmax:
577
+ # # print('gradient')
578
+ # skim_mask = nn.functional.gumbel_softmax(self.skim_predictors[i](hidden_states[:,1:,:]), hard=True, tau=1)
579
+ # else:
580
+ # print('not gradient')
581
+ logits = nn.functional.softmax(self.skim_predictors[i](hidden_states[:,1:,:]),dim=-1)
582
+ # print(logits)
583
+ index = logits.max(dim=-1, keepdim=True)[1]
584
+ # print(index)
585
+ skim_mask = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
586
+ # print(skim_mask)
587
+
588
+ skim_mask = skim_mask[:,:,1]
589
+ skim_mask_with_cls = torch.ones(skim_mask.shape[0], skim_mask.shape[1]+1, device=skim_mask.device)
590
+ skim_mask_with_cls[:,1:] = skim_mask
591
+ skim_mask = skim_mask_with_cls
592
+ # multiple current layer skim mask with last layer skim mask
593
+ # to gurantee skimmed tokens are never recovered
594
+ if all_skim_mask and hidden_states.shape[0] != 1:
595
+ skim_mask = skim_mask * all_skim_mask[-1]
596
+ all_skim_mask += (skim_mask, )
597
+
598
+ # 最大的一个不同之处:就是这个trunc掉了
599
+ if hidden_states.shape[0] == 1:
600
+ bool_skim_mask = skim_mask.to(dtype=torch.bool)
601
+ hidden_states = trunc_with_mask_batched(hidden_states, bool_skim_mask, 1)
602
+ attention_mask = trunc_with_mask_batched(attention_mask, bool_skim_mask, 3)
603
+ skim_mask = trunc_with_mask_batched(skim_mask, bool_skim_mask, 1)
604
+ if forward_skim_mask is None:
605
+ forward_skim_mask = torch.ones_like(bool_skim_mask).to(dtype=torch.bool)
606
+
607
+ layer_head_mask = head_mask[i] if head_mask is not None else None
608
+ past_key_value = past_key_values[i] if past_key_values is not None else None
609
+
610
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
611
+
612
+ if use_cache:
613
+ logger.warning(
614
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
615
+ "`use_cache=False`..."
616
+ )
617
+ use_cache = False
618
+
619
+ def create_custom_forward(module):
620
+ def custom_forward(*inputs):
621
+ return module(*inputs, past_key_value, output_attentions)
622
+
623
+ return custom_forward
624
+
625
+ layer_outputs = torch.utils.checkpoint.checkpoint(
626
+ create_custom_forward(layer_module),
627
+ hidden_states,
628
+ attention_mask,
629
+ layer_head_mask,
630
+ encoder_hidden_states,
631
+ encoder_attention_mask,
632
+ )
633
+ else:
634
+ layer_outputs = layer_module(
635
+ hidden_states,
636
+ attention_mask,
637
+ layer_head_mask,
638
+ encoder_hidden_states,
639
+ encoder_attention_mask,
640
+ past_key_value,
641
+ output_attentions,
642
+ skim_mask,
643
+ )
644
+
645
+ hidden_states = layer_outputs[0]
646
+ # print(hidden_states.shape)
647
+ if hidden_states.shape[0] == 1:
648
+ forward_skim_mask[forward_skim_mask.clone()] = bool_skim_mask
649
+ forward_hidden_states[forward_skim_mask] = hidden_states
650
+ else:
651
+ forward_hidden_states = forward_hidden_states * (1-skim_mask.view(*skim_mask.shape,1)) + hidden_states * skim_mask.view(*skim_mask.shape,1)
652
+
653
+ if use_cache:
654
+ next_decoder_cache += (layer_outputs[-1],)
655
+ if output_attentions:
656
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
657
+ if self.config.add_cross_attention:
658
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
659
+
660
+ if output_hidden_states:
661
+ all_hidden_states = all_hidden_states + (hidden_states,)
662
+
663
+
664
+ if not return_dict:
665
+ return tuple(
666
+ v
667
+ for v in [
668
+ forward_hidden_states,
669
+ next_decoder_cache,
670
+ all_hidden_states,
671
+ all_self_attentions,
672
+ all_cross_attentions,
673
+ ]
674
+ if v is not None
675
+ )
676
+ return BaseModelOutputWithPastAndCrossAttentionsSkim(
677
+ last_hidden_state=forward_hidden_states,
678
+ past_key_values=next_decoder_cache,
679
+ hidden_states=all_hidden_states,
680
+ attentions=all_self_attentions,
681
+ cross_attentions=all_cross_attentions,
682
+ attention_mask=attention_mask,
683
+ skim_mask=all_skim_mask,
684
+ )
685
+
686
+
687
+ class BertPooler(nn.Module):
688
+ def __init__(self, config):
689
+ super().__init__()
690
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
691
+ self.activation = nn.Tanh()
692
+
693
+ def forward(self, hidden_states):
694
+ # We "pool" the model by simply taking the hidden state corresponding
695
+ # to the first token.
696
+ first_token_tensor = hidden_states[:, 0]
697
+ pooled_output = self.dense(first_token_tensor)
698
+ pooled_output = self.activation(pooled_output)
699
+ return pooled_output
700
+
701
+
702
+ class BertPredictionHeadTransform(nn.Module):
703
+ def __init__(self, config):
704
+ super().__init__()
705
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
706
+ if isinstance(config.hidden_act, str):
707
+ self.transform_act_fn = ACT2FN[config.hidden_act]
708
+ else:
709
+ self.transform_act_fn = config.hidden_act
710
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
711
+
712
+ def forward(self, hidden_states):
713
+ hidden_states = self.dense(hidden_states)
714
+ hidden_states = self.transform_act_fn(hidden_states)
715
+ hidden_states = self.LayerNorm(hidden_states)
716
+ return hidden_states
717
+
718
+
719
+ class BertLMPredictionHead(nn.Module):
720
+ def __init__(self, config):
721
+ super().__init__()
722
+ self.transform = BertPredictionHeadTransform(config)
723
+
724
+ # The output weights are the same as the input embeddings, but there is
725
+ # an output-only bias for each token.
726
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
727
+
728
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
729
+
730
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
731
+ self.decoder.bias = self.bias
732
+
733
+ def forward(self, hidden_states):
734
+ hidden_states = self.transform(hidden_states)
735
+ hidden_states = self.decoder(hidden_states)
736
+ return hidden_states
737
+
738
+
739
+ class BertOnlyMLMHead(nn.Module):
740
+ def __init__(self, config):
741
+ super().__init__()
742
+ self.predictions = BertLMPredictionHead(config)
743
+
744
+ def forward(self, sequence_output):
745
+ prediction_scores = self.predictions(sequence_output)
746
+ return prediction_scores
747
+
748
+
749
+ class BertOnlyNSPHead(nn.Module):
750
+ def __init__(self, config):
751
+ super().__init__()
752
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
753
+
754
+ def forward(self, pooled_output):
755
+ seq_relationship_score = self.seq_relationship(pooled_output)
756
+ return seq_relationship_score
757
+
758
+
759
+ class BertPreTrainingHeads(nn.Module):
760
+ def __init__(self, config):
761
+ super().__init__()
762
+ self.predictions = BertLMPredictionHead(config)
763
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
764
+
765
+ def forward(self, sequence_output, pooled_output):
766
+ prediction_scores = self.predictions(sequence_output)
767
+ seq_relationship_score = self.seq_relationship(pooled_output)
768
+ return prediction_scores, seq_relationship_score
769
+
770
+
771
+ class BertPreTrainedModel(PreTrainedModel):
772
+ """
773
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
774
+ models.
775
+ """
776
+
777
+ config_class = BertConfig
778
+ load_tf_weights = load_tf_weights_in_bert
779
+ base_model_prefix = "bert"
780
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
781
+
782
+ def _init_weights(self, module):
783
+ """Initialize the weights"""
784
+ if hasattr(module, '_skim_initialized') and module._skim_initialized:
785
+ return
786
+ if isinstance(module, nn.Linear):
787
+ # Slightly different from the TF version which uses truncated_normal for initialization
788
+ # cf https://github.com/pytorch/pytorch/pull/5617
789
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
790
+ if module.bias is not None:
791
+ module.bias.data.zero_()
792
+ elif isinstance(module, nn.Embedding):
793
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
794
+ if module.padding_idx is not None:
795
+ module.weight.data[module.padding_idx].zero_()
796
+ elif isinstance(module, nn.LayerNorm):
797
+ module.bias.data.zero_()
798
+ module.weight.data.fill_(1.0)
799
+
800
+
801
+ @dataclass
802
+ class BertForPreTrainingOutput(ModelOutput):
803
+ """
804
+ Output type of :class:`~transformers.BertForPreTraining`.
805
+
806
+ Args:
807
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
808
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
809
+ (classification) loss.
810
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
811
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
812
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
813
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
814
+ before SoftMax).
815
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
816
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
817
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
818
+
819
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
820
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
821
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
822
+ sequence_length, sequence_length)`.
823
+
824
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
825
+ heads.
826
+ """
827
+
828
+ loss: Optional[torch.FloatTensor] = None
829
+ prediction_logits: torch.FloatTensor = None
830
+ seq_relationship_logits: torch.FloatTensor = None
831
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
832
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
833
+
834
+
835
+ BERT_START_DOCSTRING = r"""
836
+
837
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
838
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
839
+ pruning heads etc.)
840
+
841
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
842
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
843
+ general usage and behavior.
844
+
845
+ Parameters:
846
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
847
+ Initializing with a config file does not load the weights associated with the model, only the
848
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
849
+ weights.
850
+ """
851
+
852
+ BERT_INPUTS_DOCSTRING = r"""
853
+ Args:
854
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
855
+ Indices of input sequence tokens in the vocabulary.
856
+
857
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
858
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
859
+ details.
860
+
861
+ `What are input IDs? <../glossary.html#input-ids>`__
862
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
863
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
864
+
865
+ - 1 for tokens that are **not masked**,
866
+ - 0 for tokens that are **masked**.
867
+
868
+ `What are attention masks? <../glossary.html#attention-mask>`__
869
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
870
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
871
+ 1]``:
872
+
873
+ - 0 corresponds to a `sentence A` token,
874
+ - 1 corresponds to a `sentence B` token.
875
+
876
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
877
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
878
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
879
+ config.max_position_embeddings - 1]``.
880
+
881
+ `What are position IDs? <../glossary.html#position-ids>`_
882
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
883
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
884
+
885
+ - 1 indicates the head is **not masked**,
886
+ - 0 indicates the head is **masked**.
887
+
888
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
889
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
890
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
891
+ vectors than the model's internal embedding lookup matrix.
892
+ output_attentions (:obj:`bool`, `optional`):
893
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
894
+ tensors for more detail.
895
+ output_hidden_states (:obj:`bool`, `optional`):
896
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
897
+ more detail.
898
+ return_dict (:obj:`bool`, `optional`):
899
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
900
+ """
901
+
902
+
903
+ @add_start_docstrings(
904
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
905
+ BERT_START_DOCSTRING,
906
+ )
907
+ class BertModel(BertPreTrainedModel):
908
+ """
909
+
910
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
911
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
912
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
913
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
914
+
915
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
916
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
917
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
918
+ input to the forward pass.
919
+ """
920
+
921
+ def __init__(self, config, add_pooling_layer=True):
922
+ super().__init__(config)
923
+ self.config = config
924
+
925
+ self.embeddings = BertEmbeddings(config)
926
+ self.encoder = BertEncoder(config)
927
+
928
+ self.pooler = BertPooler(config) if add_pooling_layer else None
929
+
930
+ self.init_weights()
931
+
932
+ def get_input_embeddings(self):
933
+ return self.embeddings.word_embeddings
934
+
935
+ def set_input_embeddings(self, value):
936
+ self.embeddings.word_embeddings = value
937
+
938
+ def _prune_heads(self, heads_to_prune):
939
+ """
940
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
941
+ class PreTrainedModel
942
+ """
943
+ for layer, heads in heads_to_prune.items():
944
+ self.encoder.layer[layer].attention.prune_heads(heads)
945
+
946
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
947
+ @add_code_sample_docstrings(
948
+ processor_class=_TOKENIZER_FOR_DOC,
949
+ checkpoint=_CHECKPOINT_FOR_DOC,
950
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
951
+ config_class=_CONFIG_FOR_DOC,
952
+ )
953
+ def forward(
954
+ self,
955
+ input_ids=None,
956
+ attention_mask=None,
957
+ token_type_ids=None,
958
+ position_ids=None,
959
+ head_mask=None,
960
+ inputs_embeds=None,
961
+ encoder_hidden_states=None,
962
+ encoder_attention_mask=None,
963
+ past_key_values=None,
964
+ use_cache=None,
965
+ output_attentions=None,
966
+ output_hidden_states=None,
967
+ return_dict=None,
968
+ ):
969
+ r"""
970
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
971
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
972
+ the model is configured as a decoder.
973
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
974
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
975
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
976
+
977
+ - 1 for tokens that are **not masked**,
978
+ - 0 for tokens that are **masked**.
979
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
980
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
981
+
982
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
983
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
984
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
985
+ use_cache (:obj:`bool`, `optional`):
986
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
987
+ decoding (see :obj:`past_key_values`).
988
+ """
989
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
990
+ output_hidden_states = (
991
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
992
+ )
993
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
994
+
995
+ if self.config.is_decoder:
996
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
997
+ else:
998
+ use_cache = False
999
+ if input_ids is not None and inputs_embeds is not None:
1000
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1001
+ elif input_ids is not None:
1002
+ input_shape = input_ids.size()
1003
+ elif inputs_embeds is not None:
1004
+ input_shape = inputs_embeds.size()[:-1]
1005
+ else:
1006
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1007
+
1008
+ batch_size, seq_length = input_shape
1009
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1010
+
1011
+ # past_key_values_length
1012
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1013
+
1014
+ if attention_mask is None:
1015
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1016
+
1017
+ if token_type_ids is None:
1018
+ if hasattr(self.embeddings, "token_type_ids"):
1019
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
1020
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
1021
+ token_type_ids = buffered_token_type_ids_expanded
1022
+ else:
1023
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1024
+
1025
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1026
+ # ourselves in which case we just need to make it broadcastable to all heads.
1027
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
1028
+
1029
+ # If a 2D or 3D attention mask is provided for the cross-attention
1030
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1031
+ if self.config.is_decoder and encoder_hidden_states is not None:
1032
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1033
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1034
+ if encoder_attention_mask is None:
1035
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1036
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1037
+ else:
1038
+ encoder_extended_attention_mask = None
1039
+
1040
+ # Prepare head mask if needed
1041
+ # 1.0 in head_mask indicate we keep the head
1042
+ # attention_probs has shape bsz x n_heads x N x N
1043
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1044
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1045
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1046
+
1047
+ embedding_output = self.embeddings(
1048
+ input_ids=input_ids,
1049
+ position_ids=position_ids,
1050
+ token_type_ids=token_type_ids,
1051
+ inputs_embeds=inputs_embeds,
1052
+ past_key_values_length=past_key_values_length,
1053
+ )
1054
+ encoder_outputs = self.encoder(
1055
+ embedding_output,
1056
+ attention_mask=extended_attention_mask,
1057
+ head_mask=head_mask,
1058
+ encoder_hidden_states=encoder_hidden_states,
1059
+ encoder_attention_mask=encoder_extended_attention_mask,
1060
+ past_key_values=past_key_values,
1061
+ use_cache=use_cache,
1062
+ output_attentions=output_attentions,
1063
+ output_hidden_states=output_hidden_states,
1064
+ return_dict=return_dict,
1065
+ )
1066
+ sequence_output = encoder_outputs[0]
1067
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1068
+
1069
+ if not return_dict:
1070
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1071
+
1072
+ return BaseModelOutputWithPoolingAndCrossAttentionsSkim(
1073
+ last_hidden_state=sequence_output,
1074
+ pooler_output=pooled_output,
1075
+ past_key_values=encoder_outputs.past_key_values,
1076
+ hidden_states=encoder_outputs.hidden_states,
1077
+ attentions=encoder_outputs.attentions,
1078
+ cross_attentions=encoder_outputs.cross_attentions,
1079
+ attention_mask=encoder_outputs.attention_mask,
1080
+ skim_mask=encoder_outputs.skim_mask,
1081
+ )
1082
+
1083
+
1084
+ @add_start_docstrings(
1085
+ """
1086
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1087
+ sentence prediction (classification)` head.
1088
+ """,
1089
+ BERT_START_DOCSTRING,
1090
+ )
1091
+ class BertForPreTraining(BertPreTrainedModel):
1092
+ def __init__(self, config):
1093
+ super().__init__(config)
1094
+
1095
+ self.bert = BertModel(config)
1096
+ self.cls = BertPreTrainingHeads(config)
1097
+
1098
+ self.init_weights()
1099
+
1100
+ def get_output_embeddings(self):
1101
+ return self.cls.predictions.decoder
1102
+
1103
+ def set_output_embeddings(self, new_embeddings):
1104
+ self.cls.predictions.decoder = new_embeddings
1105
+
1106
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1107
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1108
+ def forward(
1109
+ self,
1110
+ input_ids=None,
1111
+ attention_mask=None,
1112
+ token_type_ids=None,
1113
+ position_ids=None,
1114
+ head_mask=None,
1115
+ inputs_embeds=None,
1116
+ labels=None,
1117
+ next_sentence_label=None,
1118
+ output_attentions=None,
1119
+ output_hidden_states=None,
1120
+ return_dict=None,
1121
+ ):
1122
+ r"""
1123
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
1124
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1125
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1126
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1127
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1128
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1129
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1130
+
1131
+ - 0 indicates sequence B is a continuation of sequence A,
1132
+ - 1 indicates sequence B is a random sequence.
1133
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1134
+ Used to hide legacy arguments that have been deprecated.
1135
+
1136
+ Returns:
1137
+
1138
+ Example::
1139
+
1140
+ >>> from transformers import BertTokenizer, BertForPreTraining
1141
+ >>> import torch
1142
+
1143
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1144
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
1145
+
1146
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1147
+ >>> outputs = model(**inputs)
1148
+
1149
+ >>> prediction_logits = outputs.prediction_logits
1150
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1151
+ """
1152
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1153
+
1154
+ outputs = self.bert(
1155
+ input_ids,
1156
+ attention_mask=attention_mask,
1157
+ token_type_ids=token_type_ids,
1158
+ position_ids=position_ids,
1159
+ head_mask=head_mask,
1160
+ inputs_embeds=inputs_embeds,
1161
+ output_attentions=output_attentions,
1162
+ output_hidden_states=output_hidden_states,
1163
+ return_dict=return_dict,
1164
+ )
1165
+
1166
+ sequence_output, pooled_output = outputs[:2]
1167
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1168
+
1169
+ total_loss = None
1170
+ if labels is not None and next_sentence_label is not None:
1171
+ loss_fct = CrossEntropyLoss()
1172
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1173
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1174
+ total_loss = masked_lm_loss + next_sentence_loss
1175
+
1176
+ if not return_dict:
1177
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1178
+ return ((total_loss,) + output) if total_loss is not None else output
1179
+
1180
+ return BertForPreTrainingOutput(
1181
+ loss=total_loss,
1182
+ prediction_logits=prediction_scores,
1183
+ seq_relationship_logits=seq_relationship_score,
1184
+ hidden_states=outputs.hidden_states,
1185
+ attentions=outputs.attentions,
1186
+ )
1187
+
1188
+
1189
+ @add_start_docstrings(
1190
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
1191
+ )
1192
+ class BertLMHeadModel(BertPreTrainedModel):
1193
+
1194
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1195
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1196
+
1197
+ def __init__(self, config):
1198
+ super().__init__(config)
1199
+
1200
+ if not config.is_decoder:
1201
+ logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
1202
+
1203
+ self.bert = BertModel(config, add_pooling_layer=False)
1204
+ self.cls = BertOnlyMLMHead(config)
1205
+
1206
+ self.init_weights()
1207
+
1208
+ def get_output_embeddings(self):
1209
+ return self.cls.predictions.decoder
1210
+
1211
+ def set_output_embeddings(self, new_embeddings):
1212
+ self.cls.predictions.decoder = new_embeddings
1213
+
1214
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1215
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1216
+ def forward(
1217
+ self,
1218
+ input_ids=None,
1219
+ attention_mask=None,
1220
+ token_type_ids=None,
1221
+ position_ids=None,
1222
+ head_mask=None,
1223
+ inputs_embeds=None,
1224
+ encoder_hidden_states=None,
1225
+ encoder_attention_mask=None,
1226
+ labels=None,
1227
+ past_key_values=None,
1228
+ use_cache=None,
1229
+ output_attentions=None,
1230
+ output_hidden_states=None,
1231
+ return_dict=None,
1232
+ ):
1233
+ r"""
1234
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1235
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1236
+ the model is configured as a decoder.
1237
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1238
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1239
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1240
+
1241
+ - 1 for tokens that are **not masked**,
1242
+ - 0 for tokens that are **masked**.
1243
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1244
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1245
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1246
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1247
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1248
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1249
+
1250
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1251
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1252
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1253
+ use_cache (:obj:`bool`, `optional`):
1254
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1255
+ decoding (see :obj:`past_key_values`).
1256
+
1257
+ Returns:
1258
+
1259
+ Example::
1260
+
1261
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1262
+ >>> import torch
1263
+
1264
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1265
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1266
+ >>> config.is_decoder = True
1267
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1268
+
1269
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1270
+ >>> outputs = model(**inputs)
1271
+
1272
+ >>> prediction_logits = outputs.logits
1273
+ """
1274
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1275
+ if labels is not None:
1276
+ use_cache = False
1277
+
1278
+ outputs = self.bert(
1279
+ input_ids,
1280
+ attention_mask=attention_mask,
1281
+ token_type_ids=token_type_ids,
1282
+ position_ids=position_ids,
1283
+ head_mask=head_mask,
1284
+ inputs_embeds=inputs_embeds,
1285
+ encoder_hidden_states=encoder_hidden_states,
1286
+ encoder_attention_mask=encoder_attention_mask,
1287
+ past_key_values=past_key_values,
1288
+ use_cache=use_cache,
1289
+ output_attentions=output_attentions,
1290
+ output_hidden_states=output_hidden_states,
1291
+ return_dict=return_dict,
1292
+ )
1293
+
1294
+ sequence_output = outputs[0]
1295
+ prediction_scores = self.cls(sequence_output)
1296
+
1297
+ lm_loss = None
1298
+ if labels is not None:
1299
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1300
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1301
+ labels = labels[:, 1:].contiguous()
1302
+ loss_fct = CrossEntropyLoss()
1303
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1304
+
1305
+ if not return_dict:
1306
+ output = (prediction_scores,) + outputs[2:]
1307
+ return ((lm_loss,) + output) if lm_loss is not None else output
1308
+
1309
+ return CausalLMOutputWithCrossAttentions(
1310
+ loss=lm_loss,
1311
+ logits=prediction_scores,
1312
+ past_key_values=outputs.past_key_values,
1313
+ hidden_states=outputs.hidden_states,
1314
+ attentions=outputs.attentions,
1315
+ cross_attentions=outputs.cross_attentions,
1316
+ )
1317
+
1318
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1319
+ input_shape = input_ids.shape
1320
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1321
+ if attention_mask is None:
1322
+ attention_mask = input_ids.new_ones(input_shape)
1323
+
1324
+ # cut decoder_input_ids if past is used
1325
+ if past is not None:
1326
+ input_ids = input_ids[:, -1:]
1327
+
1328
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1329
+
1330
+ def _reorder_cache(self, past, beam_idx):
1331
+ reordered_past = ()
1332
+ for layer_past in past:
1333
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1334
+ return reordered_past
1335
+
1336
+
1337
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1338
+ class BertForMaskedLM(BertPreTrainedModel):
1339
+
1340
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1341
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1342
+
1343
+ def __init__(self, config):
1344
+ super().__init__(config)
1345
+
1346
+ if config.is_decoder:
1347
+ logger.warning(
1348
+ "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
1349
+ "bi-directional self-attention."
1350
+ )
1351
+
1352
+ self.bert = BertModel(config, add_pooling_layer=False)
1353
+ self.cls = BertOnlyMLMHead(config)
1354
+
1355
+ self.init_weights()
1356
+
1357
+ def get_output_embeddings(self):
1358
+ return self.cls.predictions.decoder
1359
+
1360
+ def set_output_embeddings(self, new_embeddings):
1361
+ self.cls.predictions.decoder = new_embeddings
1362
+
1363
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1364
+ @add_code_sample_docstrings(
1365
+ processor_class=_TOKENIZER_FOR_DOC,
1366
+ checkpoint=_CHECKPOINT_FOR_DOC,
1367
+ output_type=MaskedLMOutput,
1368
+ config_class=_CONFIG_FOR_DOC,
1369
+ )
1370
+ def forward(
1371
+ self,
1372
+ input_ids=None,
1373
+ attention_mask=None,
1374
+ token_type_ids=None,
1375
+ position_ids=None,
1376
+ head_mask=None,
1377
+ inputs_embeds=None,
1378
+ encoder_hidden_states=None,
1379
+ encoder_attention_mask=None,
1380
+ labels=None,
1381
+ output_attentions=None,
1382
+ output_hidden_states=None,
1383
+ return_dict=None,
1384
+ ):
1385
+ r"""
1386
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1387
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1388
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1389
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1390
+ """
1391
+
1392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1393
+
1394
+ outputs = self.bert(
1395
+ input_ids,
1396
+ attention_mask=attention_mask,
1397
+ token_type_ids=token_type_ids,
1398
+ position_ids=position_ids,
1399
+ head_mask=head_mask,
1400
+ inputs_embeds=inputs_embeds,
1401
+ encoder_hidden_states=encoder_hidden_states,
1402
+ encoder_attention_mask=encoder_attention_mask,
1403
+ output_attentions=output_attentions,
1404
+ output_hidden_states=output_hidden_states,
1405
+ return_dict=return_dict,
1406
+ )
1407
+
1408
+ sequence_output = outputs[0]
1409
+ prediction_scores = self.cls(sequence_output)
1410
+
1411
+ masked_lm_loss = None
1412
+ if labels is not None:
1413
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1414
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1415
+
1416
+ if not return_dict:
1417
+ output = (prediction_scores,) + outputs[2:]
1418
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1419
+
1420
+ return MaskedLMOutput(
1421
+ loss=masked_lm_loss,
1422
+ logits=prediction_scores,
1423
+ hidden_states=outputs.hidden_states,
1424
+ attentions=outputs.attentions,
1425
+ )
1426
+
1427
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1428
+ input_shape = input_ids.shape
1429
+ effective_batch_size = input_shape[0]
1430
+
1431
+ # add a dummy token
1432
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1433
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1434
+ dummy_token = torch.full(
1435
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1436
+ )
1437
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1438
+
1439
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1440
+
1441
+
1442
+ @add_start_docstrings(
1443
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
1444
+ BERT_START_DOCSTRING,
1445
+ )
1446
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1447
+ def __init__(self, config):
1448
+ super().__init__(config)
1449
+
1450
+ self.bert = BertModel(config)
1451
+ self.cls = BertOnlyNSPHead(config)
1452
+
1453
+ self.init_weights()
1454
+
1455
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1456
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1457
+ def forward(
1458
+ self,
1459
+ input_ids=None,
1460
+ attention_mask=None,
1461
+ token_type_ids=None,
1462
+ position_ids=None,
1463
+ head_mask=None,
1464
+ inputs_embeds=None,
1465
+ labels=None,
1466
+ output_attentions=None,
1467
+ output_hidden_states=None,
1468
+ return_dict=None,
1469
+ **kwargs,
1470
+ ):
1471
+ r"""
1472
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1473
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1474
+ (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
1475
+
1476
+ - 0 indicates sequence B is a continuation of sequence A,
1477
+ - 1 indicates sequence B is a random sequence.
1478
+
1479
+ Returns:
1480
+
1481
+ Example::
1482
+
1483
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1484
+ >>> import torch
1485
+
1486
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1487
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1488
+
1489
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1490
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1491
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1492
+
1493
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1494
+ >>> logits = outputs.logits
1495
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1496
+ """
1497
+
1498
+ if "next_sentence_label" in kwargs:
1499
+ warnings.warn(
1500
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1501
+ FutureWarning,
1502
+ )
1503
+ labels = kwargs.pop("next_sentence_label")
1504
+
1505
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1506
+
1507
+ outputs = self.bert(
1508
+ input_ids,
1509
+ attention_mask=attention_mask,
1510
+ token_type_ids=token_type_ids,
1511
+ position_ids=position_ids,
1512
+ head_mask=head_mask,
1513
+ inputs_embeds=inputs_embeds,
1514
+ output_attentions=output_attentions,
1515
+ output_hidden_states=output_hidden_states,
1516
+ return_dict=return_dict,
1517
+ )
1518
+
1519
+ pooled_output = outputs[1]
1520
+
1521
+ seq_relationship_scores = self.cls(pooled_output)
1522
+
1523
+ next_sentence_loss = None
1524
+ if labels is not None:
1525
+ loss_fct = CrossEntropyLoss()
1526
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1527
+
1528
+ if not return_dict:
1529
+ output = (seq_relationship_scores,) + outputs[2:]
1530
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1531
+
1532
+ return NextSentencePredictorOutput(
1533
+ loss=next_sentence_loss,
1534
+ logits=seq_relationship_scores,
1535
+ hidden_states=outputs.hidden_states,
1536
+ attentions=outputs.attentions,
1537
+ )
1538
+
1539
+
1540
+ @add_start_docstrings(
1541
+ """
1542
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1543
+ output) e.g. for GLUE tasks.
1544
+ """,
1545
+ BERT_START_DOCSTRING,
1546
+ )
1547
+ class BertForSequenceClassification(BertPreTrainedModel):
1548
+ def __init__(self, config):
1549
+ super().__init__(config)
1550
+ self.num_labels = config.num_labels
1551
+ self.config = config
1552
+
1553
+ self.bert = BertModel(config)
1554
+ classifier_dropout = (
1555
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1556
+ )
1557
+ self.dropout = nn.Dropout(classifier_dropout)
1558
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1559
+
1560
+ self.skim_coefficient = config.skim_coefficient if hasattr(config, 'skim_coefficient') else 1
1561
+
1562
+ self.init_weights()
1563
+
1564
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1565
+ @add_code_sample_docstrings(
1566
+ processor_class=_TOKENIZER_FOR_DOC,
1567
+ checkpoint=_CHECKPOINT_FOR_DOC,
1568
+ output_type=SequenceClassifierOutput,
1569
+ config_class=_CONFIG_FOR_DOC,
1570
+ )
1571
+ def forward(
1572
+ self,
1573
+ input_ids=None,
1574
+ attention_mask=None,
1575
+ token_type_ids=None,
1576
+ position_ids=None,
1577
+ head_mask=None,
1578
+ inputs_embeds=None,
1579
+ labels=None,
1580
+ output_attentions=None,
1581
+ output_hidden_states=None,
1582
+ return_dict=None,
1583
+ ):
1584
+ r"""
1585
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1586
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1587
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1588
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1589
+ """
1590
+ # assert gumbel_softmax is not None
1591
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1592
+
1593
+ outputs = self.bert(
1594
+ input_ids,
1595
+ attention_mask=attention_mask,
1596
+ token_type_ids=token_type_ids,
1597
+ position_ids=position_ids,
1598
+ head_mask=head_mask,
1599
+ inputs_embeds=inputs_embeds,
1600
+ output_attentions=output_attentions,
1601
+ output_hidden_states=output_hidden_states,
1602
+ return_dict=return_dict,
1603
+ )
1604
+
1605
+ pooled_output = outputs[1]
1606
+
1607
+ pooled_output = self.dropout(pooled_output)
1608
+ logits = self.classifier(pooled_output)
1609
+
1610
+ loss = None
1611
+ if labels is not None:
1612
+ if self.config.problem_type is None:
1613
+ if self.num_labels == 1:
1614
+ self.config.problem_type = "regression"
1615
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1616
+ self.config.problem_type = "single_label_classification"
1617
+ else:
1618
+ self.config.problem_type = "multi_label_classification"
1619
+
1620
+ if self.config.problem_type == "regression":
1621
+ loss_fct = MSELoss()
1622
+ if self.num_labels == 1:
1623
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1624
+ else:
1625
+ loss = loss_fct(logits, labels)
1626
+ elif self.config.problem_type == "single_label_classification":
1627
+ loss_fct = CrossEntropyLoss()
1628
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1629
+ elif self.config.problem_type == "multi_label_classification":
1630
+ loss_fct = BCEWithLogitsLoss()
1631
+ loss = loss_fct(logits, labels)
1632
+ if not return_dict:
1633
+ output = (logits,) + outputs[2:]
1634
+ return ((loss,) + output) if loss is not None else output
1635
+
1636
+ skim_loss, neat_mac = 0.0, 0.0
1637
+ layer_neat_mac = list()
1638
+ all_tokens_length = torch.mean(torch.sum(attention_mask.to(torch.float32),dim=-1))
1639
+ for mask in outputs.skim_mask:
1640
+ accumulated_skim_mask = torch.mean(torch.sum(mask,dim=1))
1641
+ skim_loss += accumulated_skim_mask/mask.shape[1]
1642
+ layer_neat_mac.append(accumulated_skim_mask/all_tokens_length)
1643
+ neat_mac += accumulated_skim_mask/all_tokens_length
1644
+ skim_loss /= self.config.num_hidden_layers
1645
+ neat_mac /= self.config.num_hidden_layers
1646
+ classification_loss = loss
1647
+ # print(skim_loss, neat_mac, loss)
1648
+ # loss = skim_loss
1649
+ if labels is not None:
1650
+ loss = self.skim_coefficient * skim_loss + loss
1651
+
1652
+ return SequenceClassifierOutputSkim(
1653
+ loss=loss,
1654
+ logits=logits,
1655
+ hidden_states=outputs.hidden_states,
1656
+ attentions=outputs.attentions,
1657
+ attention_mask=outputs.attention_mask,
1658
+ skim_mask=outputs.skim_mask,
1659
+ skim_loss=skim_loss,
1660
+ classification_loss=classification_loss,
1661
+ tokens_remained=neat_mac,
1662
+ layer_tokens_remained=layer_neat_mac,
1663
+ )
1664
+
1665
+
1666
+ @add_start_docstrings(
1667
+ """
1668
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1669
+ softmax) e.g. for RocStories/SWAG tasks.
1670
+ """,
1671
+ BERT_START_DOCSTRING,
1672
+ )
1673
+ class BertForMultipleChoice(BertPreTrainedModel):
1674
+ def __init__(self, config):
1675
+ super().__init__(config)
1676
+
1677
+ self.bert = BertModel(config)
1678
+ classifier_dropout = (
1679
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1680
+ )
1681
+ self.dropout = nn.Dropout(classifier_dropout)
1682
+ self.classifier = nn.Linear(config.hidden_size, 1)
1683
+
1684
+ self.init_weights()
1685
+
1686
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1687
+ @add_code_sample_docstrings(
1688
+ processor_class=_TOKENIZER_FOR_DOC,
1689
+ checkpoint=_CHECKPOINT_FOR_DOC,
1690
+ output_type=MultipleChoiceModelOutput,
1691
+ config_class=_CONFIG_FOR_DOC,
1692
+ )
1693
+ def forward(
1694
+ self,
1695
+ input_ids=None,
1696
+ attention_mask=None,
1697
+ token_type_ids=None,
1698
+ position_ids=None,
1699
+ head_mask=None,
1700
+ inputs_embeds=None,
1701
+ labels=None,
1702
+ output_attentions=None,
1703
+ output_hidden_states=None,
1704
+ return_dict=None,
1705
+ ):
1706
+ r"""
1707
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1708
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1709
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1710
+ :obj:`input_ids` above)
1711
+ """
1712
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1713
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1714
+
1715
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1716
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1717
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1718
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1719
+ inputs_embeds = (
1720
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1721
+ if inputs_embeds is not None
1722
+ else None
1723
+ )
1724
+
1725
+ outputs = self.bert(
1726
+ input_ids,
1727
+ attention_mask=attention_mask,
1728
+ token_type_ids=token_type_ids,
1729
+ position_ids=position_ids,
1730
+ head_mask=head_mask,
1731
+ inputs_embeds=inputs_embeds,
1732
+ output_attentions=output_attentions,
1733
+ output_hidden_states=output_hidden_states,
1734
+ return_dict=return_dict,
1735
+ )
1736
+
1737
+ pooled_output = outputs[1]
1738
+
1739
+ pooled_output = self.dropout(pooled_output)
1740
+ logits = self.classifier(pooled_output)
1741
+ reshaped_logits = logits.view(-1, num_choices)
1742
+
1743
+ loss = None
1744
+ if labels is not None:
1745
+ loss_fct = CrossEntropyLoss()
1746
+ loss = loss_fct(reshaped_logits, labels)
1747
+
1748
+ if not return_dict:
1749
+ output = (reshaped_logits,) + outputs[2:]
1750
+ return ((loss,) + output) if loss is not None else output
1751
+
1752
+ return MultipleChoiceModelOutput(
1753
+ loss=loss,
1754
+ logits=reshaped_logits,
1755
+ hidden_states=outputs.hidden_states,
1756
+ attentions=outputs.attentions,
1757
+ )
1758
+
1759
+
1760
+ @add_start_docstrings(
1761
+ """
1762
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1763
+ Named-Entity-Recognition (NER) tasks.
1764
+ """,
1765
+ BERT_START_DOCSTRING,
1766
+ )
1767
+ class BertForTokenClassification(BertPreTrainedModel):
1768
+
1769
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1770
+
1771
+ def __init__(self, config):
1772
+ super().__init__(config)
1773
+ self.num_labels = config.num_labels
1774
+
1775
+ self.bert = BertModel(config, add_pooling_layer=False)
1776
+ classifier_dropout = (
1777
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1778
+ )
1779
+ self.dropout = nn.Dropout(classifier_dropout)
1780
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1781
+
1782
+ self.init_weights()
1783
+
1784
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1785
+ @add_code_sample_docstrings(
1786
+ processor_class=_TOKENIZER_FOR_DOC,
1787
+ checkpoint=_CHECKPOINT_FOR_DOC,
1788
+ output_type=TokenClassifierOutput,
1789
+ config_class=_CONFIG_FOR_DOC,
1790
+ )
1791
+ def forward(
1792
+ self,
1793
+ input_ids=None,
1794
+ attention_mask=None,
1795
+ token_type_ids=None,
1796
+ position_ids=None,
1797
+ head_mask=None,
1798
+ inputs_embeds=None,
1799
+ labels=None,
1800
+ output_attentions=None,
1801
+ output_hidden_states=None,
1802
+ return_dict=None,
1803
+ ):
1804
+ r"""
1805
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1806
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1807
+ 1]``.
1808
+ """
1809
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1810
+
1811
+ outputs = self.bert(
1812
+ input_ids,
1813
+ attention_mask=attention_mask,
1814
+ token_type_ids=token_type_ids,
1815
+ position_ids=position_ids,
1816
+ head_mask=head_mask,
1817
+ inputs_embeds=inputs_embeds,
1818
+ output_attentions=output_attentions,
1819
+ output_hidden_states=output_hidden_states,
1820
+ return_dict=return_dict,
1821
+ )
1822
+
1823
+ sequence_output = outputs[0]
1824
+
1825
+ sequence_output = self.dropout(sequence_output)
1826
+ logits = self.classifier(sequence_output)
1827
+
1828
+ loss = None
1829
+ if labels is not None:
1830
+ loss_fct = CrossEntropyLoss()
1831
+ # Only keep active parts of the loss
1832
+ if attention_mask is not None:
1833
+ active_loss = attention_mask.view(-1) == 1
1834
+ active_logits = logits.view(-1, self.num_labels)
1835
+ active_labels = torch.where(
1836
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1837
+ )
1838
+ loss = loss_fct(active_logits, active_labels)
1839
+ else:
1840
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1841
+
1842
+ if not return_dict:
1843
+ output = (logits,) + outputs[2:]
1844
+ return ((loss,) + output) if loss is not None else output
1845
+
1846
+ return TokenClassifierOutput(
1847
+ loss=loss,
1848
+ logits=logits,
1849
+ hidden_states=outputs.hidden_states,
1850
+ attentions=outputs.attentions,
1851
+ )
1852
+
1853
+
1854
+ @add_start_docstrings(
1855
+ """
1856
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1857
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1858
+ """,
1859
+ BERT_START_DOCSTRING,
1860
+ )
1861
+ class BertForQuestionAnswering(BertPreTrainedModel):
1862
+
1863
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1864
+
1865
+ def __init__(self, config):
1866
+ super().__init__(config)
1867
+ self.num_labels = config.num_labels
1868
+
1869
+ self.bert = BertModel(config, add_pooling_layer=False)
1870
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1871
+
1872
+ self.skim_coefficient = config.skim_coefficient if hasattr(config, 'skim_coefficient') else 1
1873
+
1874
+ self.init_weights()
1875
+
1876
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1877
+ @add_code_sample_docstrings(
1878
+ processor_class=_TOKENIZER_FOR_DOC,
1879
+ checkpoint=_CHECKPOINT_FOR_DOC,
1880
+ output_type=QuestionAnsweringModelOutput,
1881
+ config_class=_CONFIG_FOR_DOC,
1882
+ )
1883
+ def forward(
1884
+ self,
1885
+ input_ids=None,
1886
+ attention_mask=None,
1887
+ token_type_ids=None,
1888
+ position_ids=None,
1889
+ head_mask=None,
1890
+ inputs_embeds=None,
1891
+ start_positions=None,
1892
+ end_positions=None,
1893
+ output_attentions=None,
1894
+ output_hidden_states=None,
1895
+ return_dict=None,
1896
+ ):
1897
+ r"""
1898
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1899
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1900
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1901
+ sequence are not taken into account for computing the loss.
1902
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1903
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1904
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1905
+ sequence are not taken into account for computing the loss.
1906
+ """
1907
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1908
+
1909
+ outputs = self.bert(
1910
+ input_ids,
1911
+ attention_mask=attention_mask,
1912
+ token_type_ids=token_type_ids,
1913
+ position_ids=position_ids,
1914
+ head_mask=head_mask,
1915
+ inputs_embeds=inputs_embeds,
1916
+ output_attentions=output_attentions,
1917
+ output_hidden_states=output_hidden_states,
1918
+ return_dict=return_dict,
1919
+ )
1920
+
1921
+ sequence_output = outputs[0]
1922
+
1923
+ logits = self.qa_outputs(sequence_output)
1924
+ start_logits, end_logits = logits.split(1, dim=-1)
1925
+ start_logits = start_logits.squeeze(-1).contiguous()
1926
+ end_logits = end_logits.squeeze(-1).contiguous()
1927
+
1928
+ total_loss = None
1929
+ if start_positions is not None and end_positions is not None:
1930
+ # If we are on multi-GPU, split add a dimension
1931
+ if len(start_positions.size()) > 1:
1932
+ start_positions = start_positions.squeeze(-1)
1933
+ if len(end_positions.size()) > 1:
1934
+ end_positions = end_positions.squeeze(-1)
1935
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1936
+ ignored_index = start_logits.size(1)
1937
+ start_positions = start_positions.clamp(0, ignored_index)
1938
+ end_positions = end_positions.clamp(0, ignored_index)
1939
+
1940
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1941
+ start_loss = loss_fct(start_logits, start_positions)
1942
+ end_loss = loss_fct(end_logits, end_positions)
1943
+ total_loss = (start_loss + end_loss) / 2
1944
+
1945
+ skim_loss, neat_mac = 0.0, 0.0
1946
+ layer_neat_mac = list()
1947
+ all_tokens_length = torch.mean(torch.sum(attention_mask.to(torch.float32),dim=-1))
1948
+ for mask in outputs.skim_mask:
1949
+ accumulated_skim_mask = torch.mean(torch.sum(mask,dim=1))
1950
+ skim_loss += accumulated_skim_mask/mask.shape[1]
1951
+ layer_neat_mac.append(accumulated_skim_mask/all_tokens_length)
1952
+ neat_mac += accumulated_skim_mask/all_tokens_length
1953
+ skim_loss /= self.config.num_hidden_layers
1954
+ neat_mac /= self.config.num_hidden_layers
1955
+ qa_loss = total_loss
1956
+ if start_positions is not None and end_positions is not None:
1957
+ # print(skim_loss, neat_mac, loss)
1958
+ # loss = skim_loss
1959
+ total_loss = self.skim_coefficient * skim_loss + qa_loss
1960
+
1961
+ if not return_dict:
1962
+ output = (start_logits, end_logits) + outputs[2:]
1963
+ return ((total_loss,) + output) if total_loss is not None else output
1964
+
1965
+
1966
+ return QuestionAnsweringModelOutputSkim(
1967
+ loss=total_loss,
1968
+ start_logits=start_logits,
1969
+ end_logits=end_logits,
1970
+ hidden_states=outputs.hidden_states,
1971
+ attentions=outputs.attentions,
1972
+ attention_mask=outputs.attention_mask,
1973
+ skim_mask=outputs.skim_mask,
1974
+ skim_loss=skim_loss,
1975
+ classification_loss=qa_loss,
1976
+ tokens_remained=neat_mac,
1977
+ layer_tokens_remained=layer_neat_mac,
1978
+ )
1979
+
1980
+
1981
+ def test_BertEncoder():
1982
+ import transformers
1983
+
1984
+ logging.debug(f'Start unit test for BertEncoder')
1985
+
1986
+ config = transformers.BertConfig.from_pretrained('bert-base-uncased')
1987
+ # config.output_attentions = False
1988
+ encoder = BertEncoder(config)
1989
+
1990
+ rand_hidden_states = torch.rand((1,8,768))
1991
+ # rand_hidden_states = torch.rand((4,128,768))
1992
+
1993
+ encoder_outputs = encoder(rand_hidden_states)
1994
+
1995
+ logging.debug(f'output attention: {config.output_attentions}, {encoder_outputs[-1][0].shape}')
1996
+
1997
+ if __name__ == "__main__":
1998
+ import logging
1999
+
2000
+ logging.basicConfig(level=logging.DEBUG)
2001
+
2002
+ test_BertEncoder()
test_module/modeling_transkimer_roberta.py ADDED
@@ -0,0 +1,1624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa model. """
17
+
18
+ import math
19
+
20
+ import torch
21
+ import torch.utils.checkpoint
22
+ from packaging import version
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN, gelu
27
+ from transformers.file_utils import (
28
+ add_code_sample_docstrings,
29
+ add_start_docstrings,
30
+ add_start_docstrings_to_model_forward,
31
+ replace_return_docstrings,
32
+ )
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ BaseModelOutputWithPoolingAndCrossAttentions,
36
+ CausalLMOutputWithCrossAttentions,
37
+ MaskedLMOutput,
38
+ MultipleChoiceModelOutput,
39
+ QuestionAnsweringModelOutput,
40
+ SequenceClassifierOutput,
41
+ TokenClassifierOutput,
42
+ )
43
+ from transformers.modeling_utils import (
44
+ PreTrainedModel,
45
+ apply_chunking_to_forward,
46
+ find_pruneable_heads_and_indices,
47
+ prune_linear_layer,
48
+ )
49
+ from transformers.utils import logging
50
+ from transformers.models.roberta.configuration_roberta import RobertaConfig
51
+
52
+ from module.modeling_skim_predictor import SkimPredictor
53
+ from module.modeling_utils import BaseModelOutputWithPastAndCrossAttentionsSkim, BaseModelOutputWithPoolingAndCrossAttentionsSkim, SequenceClassifierOutputSkim
54
+
55
+
56
+ logger = logging.get_logger(__name__)
57
+
58
+ _CHECKPOINT_FOR_DOC = "roberta-base"
59
+ _CONFIG_FOR_DOC = "RobertaConfig"
60
+ _TOKENIZER_FOR_DOC = "RobertaTokenizer"
61
+
62
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
63
+ "roberta-base",
64
+ "roberta-large",
65
+ "roberta-large-mnli",
66
+ "distilroberta-base",
67
+ "roberta-base-openai-detector",
68
+ "roberta-large-openai-detector",
69
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
70
+ ]
71
+
72
+
73
+ class RobertaEmbeddings(nn.Module):
74
+ """
75
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
76
+ """
77
+
78
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
82
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
83
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
84
+
85
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
86
+ # any TensorFlow checkpoint file
87
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
88
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
89
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
90
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
91
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
92
+ if version.parse(torch.__version__) > version.parse("1.6.0"):
93
+ self.register_buffer(
94
+ "token_type_ids",
95
+ torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
96
+ persistent=False,
97
+ )
98
+
99
+ # End copy
100
+ self.padding_idx = config.pad_token_id
101
+ self.position_embeddings = nn.Embedding(
102
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
103
+ )
104
+
105
+ def forward(
106
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
107
+ ):
108
+ if position_ids is None:
109
+ if input_ids is not None:
110
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
111
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
112
+ else:
113
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
114
+
115
+ if input_ids is not None:
116
+ input_shape = input_ids.size()
117
+ else:
118
+ input_shape = inputs_embeds.size()[:-1]
119
+
120
+ seq_length = input_shape[1]
121
+
122
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
123
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
124
+ # issue #5664
125
+ if token_type_ids is None:
126
+ if hasattr(self, "token_type_ids"):
127
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
128
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
129
+ token_type_ids = buffered_token_type_ids_expanded
130
+ else:
131
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
132
+
133
+ if inputs_embeds is None:
134
+ inputs_embeds = self.word_embeddings(input_ids)
135
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
136
+
137
+ embeddings = inputs_embeds + token_type_embeddings
138
+ if self.position_embedding_type == "absolute":
139
+ position_embeddings = self.position_embeddings(position_ids)
140
+ embeddings += position_embeddings
141
+ embeddings = self.LayerNorm(embeddings)
142
+ embeddings = self.dropout(embeddings)
143
+ return embeddings
144
+
145
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
146
+ """
147
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
148
+
149
+ Args:
150
+ inputs_embeds: torch.Tensor
151
+
152
+ Returns: torch.Tensor
153
+ """
154
+ input_shape = inputs_embeds.size()[:-1]
155
+ sequence_length = input_shape[1]
156
+
157
+ position_ids = torch.arange(
158
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
159
+ )
160
+ return position_ids.unsqueeze(0).expand(input_shape)
161
+
162
+
163
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
164
+ class RobertaSelfAttention(nn.Module):
165
+ def __init__(self, config):
166
+ super().__init__()
167
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
168
+ raise ValueError(
169
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
170
+ f"heads ({config.num_attention_heads})"
171
+ )
172
+
173
+ self.num_attention_heads = config.num_attention_heads
174
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
175
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
176
+
177
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
178
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
179
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
180
+
181
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
182
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
183
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
184
+ self.max_position_embeddings = config.max_position_embeddings
185
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
186
+
187
+ self.is_decoder = config.is_decoder
188
+
189
+ def transpose_for_scores(self, x):
190
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
191
+ x = x.view(*new_x_shape)
192
+ return x.permute(0, 2, 1, 3)
193
+
194
+ def forward(
195
+ self,
196
+ hidden_states,
197
+ attention_mask=None,
198
+ head_mask=None,
199
+ encoder_hidden_states=None,
200
+ encoder_attention_mask=None,
201
+ past_key_value=None,
202
+ output_attentions=False,
203
+ skim_mask=None,
204
+ ):
205
+ mixed_query_layer = self.query(hidden_states)
206
+
207
+ # If this is instantiated as a cross-attention module, the keys
208
+ # and values come from an encoder; the attention mask needs to be
209
+ # such that the encoder's padding tokens are not attended to.
210
+ is_cross_attention = encoder_hidden_states is not None
211
+
212
+ if is_cross_attention and past_key_value is not None:
213
+ # reuse k,v, cross_attentions
214
+ key_layer = past_key_value[0]
215
+ value_layer = past_key_value[1]
216
+ attention_mask = encoder_attention_mask
217
+ elif is_cross_attention:
218
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
219
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
220
+ attention_mask = encoder_attention_mask
221
+ elif past_key_value is not None:
222
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
223
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
224
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
225
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
226
+ else:
227
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
228
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
229
+
230
+ query_layer = self.transpose_for_scores(mixed_query_layer)
231
+
232
+ if self.is_decoder:
233
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
234
+ # Further calls to cross_attention layer can then reuse all cross-attention
235
+ # key/value_states (first "if" case)
236
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
237
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
238
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
239
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
240
+ past_key_value = (key_layer, value_layer)
241
+
242
+ # Take the dot product between "query" and "key" to get the raw attention scores.
243
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
244
+
245
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
246
+ seq_length = hidden_states.size()[1]
247
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
248
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
249
+ distance = position_ids_l - position_ids_r
250
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
251
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
252
+
253
+ if self.position_embedding_type == "relative_key":
254
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
255
+ attention_scores = attention_scores + relative_position_scores
256
+ elif self.position_embedding_type == "relative_key_query":
257
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
258
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
259
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
260
+
261
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
262
+ if attention_mask is not None:
263
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
264
+ attention_scores = attention_scores + attention_mask
265
+
266
+ # Normalize the attention scores to probabilities.
267
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
268
+
269
+ # mask attention probs during training for skimming
270
+ attention_probs = attention_probs * skim_mask[:, None, None, :]
271
+
272
+ # This is actually dropping out entire tokens to attend to, which might
273
+ # seem a bit unusual, but is taken from the original Transformer paper.
274
+ attention_probs = self.dropout(attention_probs)
275
+
276
+ # Mask heads if we want to
277
+ if head_mask is not None:
278
+ attention_probs = attention_probs * head_mask
279
+
280
+ context_layer = torch.matmul(attention_probs, value_layer)
281
+
282
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
283
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
284
+ context_layer = context_layer.view(*new_context_layer_shape)
285
+
286
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
287
+
288
+ if self.is_decoder:
289
+ outputs = outputs + (past_key_value,)
290
+ return outputs
291
+
292
+
293
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
294
+ class RobertaSelfOutput(nn.Module):
295
+ def __init__(self, config):
296
+ super().__init__()
297
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
298
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
299
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
300
+
301
+ def forward(self, hidden_states, input_tensor):
302
+ hidden_states = self.dense(hidden_states)
303
+ hidden_states = self.dropout(hidden_states)
304
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
305
+ return hidden_states
306
+
307
+
308
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
309
+ class RobertaAttention(nn.Module):
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.self = RobertaSelfAttention(config)
313
+ self.output = RobertaSelfOutput(config)
314
+ self.pruned_heads = set()
315
+
316
+ def prune_heads(self, heads):
317
+ if len(heads) == 0:
318
+ return
319
+ heads, index = find_pruneable_heads_and_indices(
320
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
321
+ )
322
+
323
+ # Prune linear layers
324
+ self.self.query = prune_linear_layer(self.self.query, index)
325
+ self.self.key = prune_linear_layer(self.self.key, index)
326
+ self.self.value = prune_linear_layer(self.self.value, index)
327
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
328
+
329
+ # Update hyper params and store pruned heads
330
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
331
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
332
+ self.pruned_heads = self.pruned_heads.union(heads)
333
+
334
+ def forward(
335
+ self,
336
+ hidden_states,
337
+ attention_mask=None,
338
+ head_mask=None,
339
+ encoder_hidden_states=None,
340
+ encoder_attention_mask=None,
341
+ past_key_value=None,
342
+ output_attentions=False,
343
+ skim_mask=None,
344
+ ):
345
+ self_outputs = self.self(
346
+ hidden_states,
347
+ attention_mask,
348
+ head_mask,
349
+ encoder_hidden_states,
350
+ encoder_attention_mask,
351
+ past_key_value,
352
+ output_attentions,
353
+ skim_mask,
354
+ )
355
+ attention_output = self.output(self_outputs[0], hidden_states)
356
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
357
+ return outputs
358
+
359
+
360
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
361
+ class RobertaIntermediate(nn.Module):
362
+ def __init__(self, config):
363
+ super().__init__()
364
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
365
+ if isinstance(config.hidden_act, str):
366
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
367
+ else:
368
+ self.intermediate_act_fn = config.hidden_act
369
+
370
+ def forward(self, hidden_states):
371
+ hidden_states = self.dense(hidden_states)
372
+ hidden_states = self.intermediate_act_fn(hidden_states)
373
+ return hidden_states
374
+
375
+
376
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
377
+ class RobertaOutput(nn.Module):
378
+ def __init__(self, config):
379
+ super().__init__()
380
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
381
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
382
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
383
+
384
+ def forward(self, hidden_states, input_tensor):
385
+ hidden_states = self.dense(hidden_states)
386
+ hidden_states = self.dropout(hidden_states)
387
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
388
+ return hidden_states
389
+
390
+
391
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
392
+ class RobertaLayer(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
396
+ self.seq_len_dim = 1
397
+ self.attention = RobertaAttention(config)
398
+ self.is_decoder = config.is_decoder
399
+ self.add_cross_attention = config.add_cross_attention
400
+ if self.add_cross_attention:
401
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
402
+ self.crossattention = RobertaAttention(config)
403
+ self.intermediate = RobertaIntermediate(config)
404
+ self.output = RobertaOutput(config)
405
+
406
+ def forward(
407
+ self,
408
+ hidden_states,
409
+ attention_mask=None,
410
+ head_mask=None,
411
+ encoder_hidden_states=None,
412
+ encoder_attention_mask=None,
413
+ past_key_value=None,
414
+ output_attentions=False,
415
+ skim_mask=None,
416
+ ):
417
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
418
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
419
+ self_attention_outputs = self.attention(
420
+ hidden_states,
421
+ attention_mask,
422
+ head_mask,
423
+ output_attentions=output_attentions,
424
+ past_key_value=self_attn_past_key_value,
425
+ skim_mask=skim_mask,
426
+ )
427
+ attention_output = self_attention_outputs[0]
428
+
429
+ # if decoder, the last output is tuple of self-attn cache
430
+ if self.is_decoder:
431
+ outputs = self_attention_outputs[1:-1]
432
+ present_key_value = self_attention_outputs[-1]
433
+ else:
434
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
435
+
436
+ cross_attn_present_key_value = None
437
+ if self.is_decoder and encoder_hidden_states is not None:
438
+ assert hasattr(
439
+ self, "crossattention"
440
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
441
+
442
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
443
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
444
+ cross_attention_outputs = self.crossattention(
445
+ attention_output,
446
+ attention_mask,
447
+ head_mask,
448
+ encoder_hidden_states,
449
+ encoder_attention_mask,
450
+ cross_attn_past_key_value,
451
+ output_attentions,
452
+ )
453
+ attention_output = cross_attention_outputs[0]
454
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
455
+
456
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
457
+ cross_attn_present_key_value = cross_attention_outputs[-1]
458
+ present_key_value = present_key_value + cross_attn_present_key_value
459
+
460
+ layer_output = apply_chunking_to_forward(
461
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
462
+ )
463
+ outputs = (layer_output,) + outputs
464
+
465
+ # if decoder, return the attn key/values as the last output
466
+ if self.is_decoder:
467
+ outputs = outputs + (present_key_value,)
468
+
469
+ return outputs
470
+
471
+ def feed_forward_chunk(self, attention_output):
472
+ intermediate_output = self.intermediate(attention_output)
473
+ layer_output = self.output(intermediate_output, attention_output)
474
+ return layer_output
475
+
476
+
477
+
478
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
479
+ class RobertaEncoder(nn.Module):
480
+ def __init__(self, config):
481
+ super().__init__()
482
+ self.config = config
483
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
484
+
485
+ # skim predictors for each layer
486
+ self.skim_predictors = nn.ModuleList([SkimPredictor(config.hidden_size, 2) for _ in range(config.num_hidden_layers)])
487
+
488
+ def forward(
489
+ self,
490
+ hidden_states,
491
+ attention_mask=None,
492
+ head_mask=None,
493
+ encoder_hidden_states=None,
494
+ encoder_attention_mask=None,
495
+ past_key_values=None,
496
+ use_cache=None,
497
+ output_attentions=False,
498
+ output_hidden_states=False,
499
+ return_dict=True,
500
+ ):
501
+ all_hidden_states = () if output_hidden_states else None
502
+ all_self_attentions = () if output_attentions else None
503
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
504
+ all_skim_mask = ()
505
+ forward_hidden_states = hidden_states.clone()
506
+
507
+ next_decoder_cache = () if use_cache else None
508
+ for i, layer_module in enumerate(self.layer):
509
+ if output_hidden_states:
510
+ all_hidden_states = all_hidden_states + (hidden_states,)
511
+
512
+ skim_mask = nn.functional.gumbel_softmax(self.skim_predictors[i](hidden_states[:,1:,:]), hard=True, tau=1)
513
+ skim_mask = skim_mask[:,:,1]
514
+ skim_mask_with_cls = torch.ones(skim_mask.shape[0], skim_mask.shape[1]+1, device=skim_mask.device)
515
+ skim_mask_with_cls[:,1:] = skim_mask
516
+ skim_mask = skim_mask_with_cls
517
+ # multiple current layer skim mask with last layer skim mask
518
+ # to gurantee skimmed tokens are never recovered
519
+ if all_skim_mask:
520
+ skim_mask = skim_mask * all_skim_mask[-1]
521
+ all_skim_mask += (skim_mask, )
522
+
523
+ layer_head_mask = head_mask[i] if head_mask is not None else None
524
+ past_key_value = past_key_values[i] if past_key_values is not None else None
525
+
526
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
527
+
528
+ if use_cache:
529
+ logger.warning(
530
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
531
+ "`use_cache=False`..."
532
+ )
533
+ use_cache = False
534
+
535
+ def create_custom_forward(module):
536
+ def custom_forward(*inputs):
537
+ return module(*inputs, past_key_value, output_attentions)
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ skim_mask,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ forward_hidden_states = forward_hidden_states * (1-skim_mask.view(*skim_mask.shape,1)) + hidden_states * skim_mask.view(*skim_mask.shape,1)
563
+ if use_cache:
564
+ next_decoder_cache += (layer_outputs[-1],)
565
+ if output_attentions:
566
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
567
+ if self.config.add_cross_attention:
568
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
569
+
570
+ if output_hidden_states:
571
+ all_hidden_states = all_hidden_states + (hidden_states,)
572
+
573
+ if not return_dict:
574
+ return tuple(
575
+ v
576
+ for v in [
577
+ forward_hidden_states,
578
+ next_decoder_cache,
579
+ all_hidden_states,
580
+ all_self_attentions,
581
+ all_cross_attentions,
582
+ ]
583
+ if v is not None
584
+ )
585
+ return BaseModelOutputWithPastAndCrossAttentionsSkim(
586
+ last_hidden_state=forward_hidden_states,
587
+ past_key_values=next_decoder_cache,
588
+ hidden_states=all_hidden_states,
589
+ attentions=all_self_attentions,
590
+ cross_attentions=all_cross_attentions,
591
+ attention_mask=attention_mask,
592
+ skim_mask=all_skim_mask,
593
+ )
594
+
595
+
596
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
597
+ class RobertaPooler(nn.Module):
598
+ def __init__(self, config):
599
+ super().__init__()
600
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
601
+ self.activation = nn.Tanh()
602
+
603
+ def forward(self, hidden_states):
604
+ # We "pool" the model by simply taking the hidden state corresponding
605
+ # to the first token.
606
+ first_token_tensor = hidden_states[:, 0]
607
+ pooled_output = self.dense(first_token_tensor)
608
+ pooled_output = self.activation(pooled_output)
609
+ return pooled_output
610
+
611
+
612
+ class RobertaPreTrainedModel(PreTrainedModel):
613
+ """
614
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
615
+ models.
616
+ """
617
+
618
+ config_class = RobertaConfig
619
+ base_model_prefix = "roberta"
620
+
621
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
622
+ def _init_weights(self, module):
623
+ """Initialize the weights"""
624
+ if hasattr(module, '_skim_initialized') and module._skim_initialized:
625
+ return
626
+ if isinstance(module, nn.Linear):
627
+ # Slightly different from the TF version which uses truncated_normal for initialization
628
+ # cf https://github.com/pytorch/pytorch/pull/5617
629
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
630
+ if module.bias is not None:
631
+ module.bias.data.zero_()
632
+ elif isinstance(module, nn.Embedding):
633
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
634
+ if module.padding_idx is not None:
635
+ module.weight.data[module.padding_idx].zero_()
636
+ elif isinstance(module, nn.LayerNorm):
637
+ module.bias.data.zero_()
638
+ module.weight.data.fill_(1.0)
639
+
640
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
641
+ """Remove some keys from ignore list"""
642
+ if not config.tie_word_embeddings:
643
+ # must make a new list, or the class variable gets modified!
644
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
645
+ self._keys_to_ignore_on_load_missing = [
646
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
647
+ ]
648
+
649
+
650
+ ROBERTA_START_DOCSTRING = r"""
651
+
652
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
653
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
654
+ pruning heads etc.)
655
+
656
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
657
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
658
+ general usage and behavior.
659
+
660
+ Parameters:
661
+ config (:class:`~transformers.RobertaConfig`): Model configuration class with all the parameters of the
662
+ model. Initializing with a config file does not load the weights associated with the model, only the
663
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
664
+ weights.
665
+ """
666
+
667
+ ROBERTA_INPUTS_DOCSTRING = r"""
668
+ Args:
669
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
670
+ Indices of input sequence tokens in the vocabulary.
671
+
672
+ Indices can be obtained using :class:`~transformers.RobertaTokenizer`. See
673
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
674
+ details.
675
+
676
+ `What are input IDs? <../glossary.html#input-ids>`__
677
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
678
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
679
+
680
+ - 1 for tokens that are **not masked**,
681
+ - 0 for tokens that are **masked**.
682
+
683
+ `What are attention masks? <../glossary.html#attention-mask>`__
684
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
685
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
686
+ 1]``:
687
+
688
+ - 0 corresponds to a `sentence A` token,
689
+ - 1 corresponds to a `sentence B` token.
690
+
691
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
692
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
693
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
694
+ config.max_position_embeddings - 1]``.
695
+
696
+ `What are position IDs? <../glossary.html#position-ids>`_
697
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
698
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
699
+
700
+ - 1 indicates the head is **not masked**,
701
+ - 0 indicates the head is **masked**.
702
+
703
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
704
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
705
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
706
+ vectors than the model's internal embedding lookup matrix.
707
+ output_attentions (:obj:`bool`, `optional`):
708
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
709
+ tensors for more detail.
710
+ output_hidden_states (:obj:`bool`, `optional`):
711
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
712
+ more detail.
713
+ return_dict (:obj:`bool`, `optional`):
714
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
715
+ """
716
+
717
+
718
+ @add_start_docstrings(
719
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
720
+ ROBERTA_START_DOCSTRING,
721
+ )
722
+ class RobertaModel(RobertaPreTrainedModel):
723
+ """
724
+
725
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
726
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
727
+ all you need`_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
728
+ Kaiser and Illia Polosukhin.
729
+
730
+ To behave as an decoder the model needs to be initialized with the :obj:`is_decoder` argument of the configuration
731
+ set to :obj:`True`. To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
732
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
733
+ input to the forward pass.
734
+
735
+ .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762
736
+
737
+ """
738
+
739
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
740
+
741
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
742
+ def __init__(self, config, add_pooling_layer=True):
743
+ super().__init__(config)
744
+ self.config = config
745
+
746
+ self.embeddings = RobertaEmbeddings(config)
747
+ self.encoder = RobertaEncoder(config)
748
+
749
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
750
+
751
+ self.init_weights()
752
+
753
+ def get_input_embeddings(self):
754
+ return self.embeddings.word_embeddings
755
+
756
+ def set_input_embeddings(self, value):
757
+ self.embeddings.word_embeddings = value
758
+
759
+ def _prune_heads(self, heads_to_prune):
760
+ """
761
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
762
+ class PreTrainedModel
763
+ """
764
+ for layer, heads in heads_to_prune.items():
765
+ self.encoder.layer[layer].attention.prune_heads(heads)
766
+
767
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
768
+ @add_code_sample_docstrings(
769
+ processor_class=_TOKENIZER_FOR_DOC,
770
+ checkpoint=_CHECKPOINT_FOR_DOC,
771
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
772
+ config_class=_CONFIG_FOR_DOC,
773
+ )
774
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
775
+ def forward(
776
+ self,
777
+ input_ids=None,
778
+ attention_mask=None,
779
+ token_type_ids=None,
780
+ position_ids=None,
781
+ head_mask=None,
782
+ inputs_embeds=None,
783
+ encoder_hidden_states=None,
784
+ encoder_attention_mask=None,
785
+ past_key_values=None,
786
+ use_cache=None,
787
+ output_attentions=None,
788
+ output_hidden_states=None,
789
+ return_dict=None,
790
+ ):
791
+ r"""
792
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
793
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
794
+ the model is configured as a decoder.
795
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
796
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
797
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
798
+
799
+ - 1 for tokens that are **not masked**,
800
+ - 0 for tokens that are **masked**.
801
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
802
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
803
+
804
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
805
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
806
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
807
+ use_cache (:obj:`bool`, `optional`):
808
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
809
+ decoding (see :obj:`past_key_values`).
810
+ """
811
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
812
+ output_hidden_states = (
813
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
814
+ )
815
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
816
+
817
+ if self.config.is_decoder:
818
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
819
+ else:
820
+ use_cache = False
821
+
822
+ if input_ids is not None and inputs_embeds is not None:
823
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
824
+ elif input_ids is not None:
825
+ input_shape = input_ids.size()
826
+ elif inputs_embeds is not None:
827
+ input_shape = inputs_embeds.size()[:-1]
828
+ else:
829
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
830
+
831
+ batch_size, seq_length = input_shape
832
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
833
+
834
+ # past_key_values_length
835
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
836
+
837
+ if attention_mask is None:
838
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
839
+
840
+ if token_type_ids is None:
841
+ if hasattr(self.embeddings, "token_type_ids"):
842
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
843
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
844
+ token_type_ids = buffered_token_type_ids_expanded
845
+ else:
846
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
847
+
848
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
849
+ # ourselves in which case we just need to make it broadcastable to all heads.
850
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
851
+
852
+ # If a 2D or 3D attention mask is provided for the cross-attention
853
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
854
+ if self.config.is_decoder and encoder_hidden_states is not None:
855
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
856
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
857
+ if encoder_attention_mask is None:
858
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
859
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
860
+ else:
861
+ encoder_extended_attention_mask = None
862
+
863
+ # Prepare head mask if needed
864
+ # 1.0 in head_mask indicate we keep the head
865
+ # attention_probs has shape bsz x n_heads x N x N
866
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
867
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
868
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
869
+
870
+ embedding_output = self.embeddings(
871
+ input_ids=input_ids,
872
+ position_ids=position_ids,
873
+ token_type_ids=token_type_ids,
874
+ inputs_embeds=inputs_embeds,
875
+ past_key_values_length=past_key_values_length,
876
+ )
877
+ encoder_outputs = self.encoder(
878
+ embedding_output,
879
+ attention_mask=extended_attention_mask,
880
+ head_mask=head_mask,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ encoder_attention_mask=encoder_extended_attention_mask,
883
+ past_key_values=past_key_values,
884
+ use_cache=use_cache,
885
+ output_attentions=output_attentions,
886
+ output_hidden_states=output_hidden_states,
887
+ return_dict=return_dict,
888
+ )
889
+ sequence_output = encoder_outputs[0]
890
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
891
+
892
+ if not return_dict:
893
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
894
+
895
+ return BaseModelOutputWithPoolingAndCrossAttentionsSkim(
896
+ last_hidden_state=sequence_output,
897
+ pooler_output=pooled_output,
898
+ past_key_values=encoder_outputs.past_key_values,
899
+ hidden_states=encoder_outputs.hidden_states,
900
+ attentions=encoder_outputs.attentions,
901
+ cross_attentions=encoder_outputs.cross_attentions,
902
+ attention_mask=encoder_outputs.attention_mask,
903
+ skim_mask=encoder_outputs.skim_mask,
904
+ )
905
+
906
+
907
+ @add_start_docstrings(
908
+ """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
909
+ )
910
+ class RobertaForCausalLM(RobertaPreTrainedModel):
911
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
912
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
913
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
914
+
915
+ def __init__(self, config):
916
+ super().__init__(config)
917
+
918
+ if not config.is_decoder:
919
+ logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
920
+
921
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
922
+ self.lm_head = RobertaLMHead(config)
923
+
924
+ # The LM head weights require special treatment only when they are tied with the word embeddings
925
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
926
+
927
+ self.init_weights()
928
+
929
+ def get_output_embeddings(self):
930
+ return self.lm_head.decoder
931
+
932
+ def set_output_embeddings(self, new_embeddings):
933
+ self.lm_head.decoder = new_embeddings
934
+
935
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
936
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
937
+ def forward(
938
+ self,
939
+ input_ids=None,
940
+ attention_mask=None,
941
+ token_type_ids=None,
942
+ position_ids=None,
943
+ head_mask=None,
944
+ inputs_embeds=None,
945
+ encoder_hidden_states=None,
946
+ encoder_attention_mask=None,
947
+ labels=None,
948
+ past_key_values=None,
949
+ use_cache=None,
950
+ output_attentions=None,
951
+ output_hidden_states=None,
952
+ return_dict=None,
953
+ ):
954
+ r"""
955
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
956
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
957
+ the model is configured as a decoder.
958
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
959
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
960
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
961
+
962
+ - 1 for tokens that are **not masked**,
963
+ - 0 for tokens that are **masked**.
964
+
965
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
966
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
967
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
968
+ ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
969
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
970
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
971
+
972
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
973
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
974
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
975
+ use_cache (:obj:`bool`, `optional`):
976
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
977
+ decoding (see :obj:`past_key_values`).
978
+
979
+ Returns:
980
+
981
+ Example::
982
+
983
+ >>> from transformers import RobertaTokenizer, RobertaForCausalLM, RobertaConfig
984
+ >>> import torch
985
+
986
+ >>> tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
987
+ >>> config = RobertaConfig.from_pretrained("roberta-base")
988
+ >>> config.is_decoder = True
989
+ >>> model = RobertaForCausalLM.from_pretrained('roberta-base', config=config)
990
+
991
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
992
+ >>> outputs = model(**inputs)
993
+
994
+ >>> prediction_logits = outputs.logits
995
+ """
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+ if labels is not None:
998
+ use_cache = False
999
+
1000
+ outputs = self.roberta(
1001
+ input_ids,
1002
+ attention_mask=attention_mask,
1003
+ token_type_ids=token_type_ids,
1004
+ position_ids=position_ids,
1005
+ head_mask=head_mask,
1006
+ inputs_embeds=inputs_embeds,
1007
+ encoder_hidden_states=encoder_hidden_states,
1008
+ encoder_attention_mask=encoder_attention_mask,
1009
+ past_key_values=past_key_values,
1010
+ use_cache=use_cache,
1011
+ output_attentions=output_attentions,
1012
+ output_hidden_states=output_hidden_states,
1013
+ return_dict=return_dict,
1014
+ )
1015
+
1016
+ sequence_output = outputs[0]
1017
+ prediction_scores = self.lm_head(sequence_output)
1018
+
1019
+ lm_loss = None
1020
+ if labels is not None:
1021
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1022
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1023
+ labels = labels[:, 1:].contiguous()
1024
+ loss_fct = CrossEntropyLoss()
1025
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1026
+
1027
+ if not return_dict:
1028
+ output = (prediction_scores,) + outputs[2:]
1029
+ return ((lm_loss,) + output) if lm_loss is not None else output
1030
+
1031
+ return CausalLMOutputWithCrossAttentions(
1032
+ loss=lm_loss,
1033
+ logits=prediction_scores,
1034
+ past_key_values=outputs.past_key_values,
1035
+ hidden_states=outputs.hidden_states,
1036
+ attentions=outputs.attentions,
1037
+ cross_attentions=outputs.cross_attentions,
1038
+ )
1039
+
1040
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1041
+ input_shape = input_ids.shape
1042
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1043
+ if attention_mask is None:
1044
+ attention_mask = input_ids.new_ones(input_shape)
1045
+
1046
+ # cut decoder_input_ids if past is used
1047
+ if past is not None:
1048
+ input_ids = input_ids[:, -1:]
1049
+
1050
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past}
1051
+
1052
+ def _reorder_cache(self, past, beam_idx):
1053
+ reordered_past = ()
1054
+ for layer_past in past:
1055
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1056
+ return reordered_past
1057
+
1058
+
1059
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
1060
+ class RobertaForMaskedLM(RobertaPreTrainedModel):
1061
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1062
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1063
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1064
+
1065
+ def __init__(self, config):
1066
+ super().__init__(config)
1067
+
1068
+ if config.is_decoder:
1069
+ logger.warning(
1070
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
1071
+ "bi-directional self-attention."
1072
+ )
1073
+
1074
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1075
+ self.lm_head = RobertaLMHead(config)
1076
+
1077
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1078
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1079
+
1080
+ self.init_weights()
1081
+
1082
+ def get_output_embeddings(self):
1083
+ return self.lm_head.decoder
1084
+
1085
+ def set_output_embeddings(self, new_embeddings):
1086
+ self.lm_head.decoder = new_embeddings
1087
+
1088
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1089
+ @add_code_sample_docstrings(
1090
+ processor_class=_TOKENIZER_FOR_DOC,
1091
+ checkpoint=_CHECKPOINT_FOR_DOC,
1092
+ output_type=MaskedLMOutput,
1093
+ config_class=_CONFIG_FOR_DOC,
1094
+ mask="<mask>",
1095
+ )
1096
+ def forward(
1097
+ self,
1098
+ input_ids=None,
1099
+ attention_mask=None,
1100
+ token_type_ids=None,
1101
+ position_ids=None,
1102
+ head_mask=None,
1103
+ inputs_embeds=None,
1104
+ encoder_hidden_states=None,
1105
+ encoder_attention_mask=None,
1106
+ labels=None,
1107
+ output_attentions=None,
1108
+ output_hidden_states=None,
1109
+ return_dict=None,
1110
+ ):
1111
+ r"""
1112
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1113
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1114
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1115
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1116
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1117
+ Used to hide legacy arguments that have been deprecated.
1118
+ """
1119
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1120
+
1121
+ outputs = self.roberta(
1122
+ input_ids,
1123
+ attention_mask=attention_mask,
1124
+ token_type_ids=token_type_ids,
1125
+ position_ids=position_ids,
1126
+ head_mask=head_mask,
1127
+ inputs_embeds=inputs_embeds,
1128
+ encoder_hidden_states=encoder_hidden_states,
1129
+ encoder_attention_mask=encoder_attention_mask,
1130
+ output_attentions=output_attentions,
1131
+ output_hidden_states=output_hidden_states,
1132
+ return_dict=return_dict,
1133
+ )
1134
+ sequence_output = outputs[0]
1135
+ prediction_scores = self.lm_head(sequence_output)
1136
+
1137
+ masked_lm_loss = None
1138
+ if labels is not None:
1139
+ loss_fct = CrossEntropyLoss()
1140
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1141
+
1142
+ if not return_dict:
1143
+ output = (prediction_scores,) + outputs[2:]
1144
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1145
+
1146
+ return MaskedLMOutput(
1147
+ loss=masked_lm_loss,
1148
+ logits=prediction_scores,
1149
+ hidden_states=outputs.hidden_states,
1150
+ attentions=outputs.attentions,
1151
+ )
1152
+
1153
+
1154
+ class RobertaLMHead(nn.Module):
1155
+ """Roberta Head for masked language modeling."""
1156
+
1157
+ def __init__(self, config):
1158
+ super().__init__()
1159
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1160
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1161
+
1162
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1163
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1164
+ self.decoder.bias = self.bias
1165
+
1166
+ def forward(self, features, **kwargs):
1167
+ x = self.dense(features)
1168
+ x = gelu(x)
1169
+ x = self.layer_norm(x)
1170
+
1171
+ # project back to size of vocabulary with bias
1172
+ x = self.decoder(x)
1173
+
1174
+ return x
1175
+
1176
+ def _tie_weights(self):
1177
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1178
+ self.bias = self.decoder.bias
1179
+
1180
+
1181
+ @add_start_docstrings(
1182
+ """
1183
+ RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1184
+ pooled output) e.g. for GLUE tasks.
1185
+ """,
1186
+ ROBERTA_START_DOCSTRING,
1187
+ )
1188
+ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1189
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1190
+
1191
+ def __init__(self, config):
1192
+ super().__init__(config)
1193
+ self.num_labels = config.num_labels
1194
+ self.config = config
1195
+
1196
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1197
+ self.classifier = RobertaClassificationHead(config)
1198
+
1199
+ self.skim_coefficient = config.skim_coefficient if hasattr(config, 'skim_coefficient') else 1
1200
+
1201
+ self.init_weights()
1202
+
1203
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1204
+ @add_code_sample_docstrings(
1205
+ processor_class=_TOKENIZER_FOR_DOC,
1206
+ checkpoint=_CHECKPOINT_FOR_DOC,
1207
+ output_type=SequenceClassifierOutput,
1208
+ config_class=_CONFIG_FOR_DOC,
1209
+ )
1210
+ def forward(
1211
+ self,
1212
+ input_ids=None,
1213
+ attention_mask=None,
1214
+ token_type_ids=None,
1215
+ position_ids=None,
1216
+ head_mask=None,
1217
+ inputs_embeds=None,
1218
+ labels=None,
1219
+ output_attentions=None,
1220
+ output_hidden_states=None,
1221
+ return_dict=None,
1222
+ ):
1223
+ r"""
1224
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1225
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1226
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1227
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1228
+ """
1229
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1230
+
1231
+ outputs = self.roberta(
1232
+ input_ids,
1233
+ attention_mask=attention_mask,
1234
+ token_type_ids=token_type_ids,
1235
+ position_ids=position_ids,
1236
+ head_mask=head_mask,
1237
+ inputs_embeds=inputs_embeds,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+ sequence_output = outputs[0]
1243
+ logits = self.classifier(sequence_output)
1244
+
1245
+ loss = None
1246
+ if labels is not None:
1247
+ if self.config.problem_type is None:
1248
+ if self.num_labels == 1:
1249
+ self.config.problem_type = "regression"
1250
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1251
+ self.config.problem_type = "single_label_classification"
1252
+ else:
1253
+ self.config.problem_type = "multi_label_classification"
1254
+
1255
+ if self.config.problem_type == "regression":
1256
+ loss_fct = MSELoss()
1257
+ if self.num_labels == 1:
1258
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1259
+ else:
1260
+ loss = loss_fct(logits, labels)
1261
+ elif self.config.problem_type == "single_label_classification":
1262
+ loss_fct = CrossEntropyLoss()
1263
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1264
+ elif self.config.problem_type == "multi_label_classification":
1265
+ loss_fct = BCEWithLogitsLoss()
1266
+ loss = loss_fct(logits, labels)
1267
+
1268
+ if not return_dict:
1269
+ output = (logits,) + outputs[2:]
1270
+ return ((loss,) + output) if loss is not None else output
1271
+
1272
+ skim_loss, neat_mac = 0.0, 0.0
1273
+ layer_neat_mac = list()
1274
+ all_tokens_length = torch.mean(torch.sum(attention_mask.to(torch.float32),dim=-1))
1275
+ for mask in outputs.skim_mask:
1276
+ accumulated_skim_mask = torch.mean(torch.sum(mask,dim=1))
1277
+ skim_loss += accumulated_skim_mask/mask.shape[1]
1278
+ layer_neat_mac.append(accumulated_skim_mask/all_tokens_length)
1279
+ neat_mac += accumulated_skim_mask/all_tokens_length
1280
+ skim_loss /= self.config.num_hidden_layers
1281
+ neat_mac /= self.config.num_hidden_layers
1282
+ classification_loss = loss
1283
+ # print(skim_loss, neat_mac, loss)
1284
+ # loss = skim_loss
1285
+ loss = self.skim_coefficient * skim_loss + loss
1286
+
1287
+ return SequenceClassifierOutputSkim(
1288
+ loss=loss,
1289
+ logits=logits,
1290
+ hidden_states=outputs.hidden_states,
1291
+ attentions=outputs.attentions,
1292
+ attention_mask=outputs.attention_mask,
1293
+ skim_mask=outputs.skim_mask,
1294
+ skim_loss=skim_loss,
1295
+ classification_loss=classification_loss,
1296
+ tokens_remained=neat_mac,
1297
+ layer_tokens_remained=layer_neat_mac,
1298
+ )
1299
+
1300
+
1301
+ @add_start_docstrings(
1302
+ """
1303
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1304
+ softmax) e.g. for RocStories/SWAG tasks.
1305
+ """,
1306
+ ROBERTA_START_DOCSTRING,
1307
+ )
1308
+ class RobertaForMultipleChoice(RobertaPreTrainedModel):
1309
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1310
+
1311
+ def __init__(self, config):
1312
+ super().__init__(config)
1313
+
1314
+ self.roberta = RobertaModel(config)
1315
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1316
+ self.classifier = nn.Linear(config.hidden_size, 1)
1317
+
1318
+ self.init_weights()
1319
+
1320
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1321
+ @add_code_sample_docstrings(
1322
+ processor_class=_TOKENIZER_FOR_DOC,
1323
+ checkpoint=_CHECKPOINT_FOR_DOC,
1324
+ output_type=MultipleChoiceModelOutput,
1325
+ config_class=_CONFIG_FOR_DOC,
1326
+ )
1327
+ def forward(
1328
+ self,
1329
+ input_ids=None,
1330
+ token_type_ids=None,
1331
+ attention_mask=None,
1332
+ labels=None,
1333
+ position_ids=None,
1334
+ head_mask=None,
1335
+ inputs_embeds=None,
1336
+ output_attentions=None,
1337
+ output_hidden_states=None,
1338
+ return_dict=None,
1339
+ ):
1340
+ r"""
1341
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1342
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1343
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1344
+ :obj:`input_ids` above)
1345
+ """
1346
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1347
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1348
+
1349
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1350
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1351
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1352
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1353
+ flat_inputs_embeds = (
1354
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1355
+ if inputs_embeds is not None
1356
+ else None
1357
+ )
1358
+
1359
+ outputs = self.roberta(
1360
+ flat_input_ids,
1361
+ position_ids=flat_position_ids,
1362
+ token_type_ids=flat_token_type_ids,
1363
+ attention_mask=flat_attention_mask,
1364
+ head_mask=head_mask,
1365
+ inputs_embeds=flat_inputs_embeds,
1366
+ output_attentions=output_attentions,
1367
+ output_hidden_states=output_hidden_states,
1368
+ return_dict=return_dict,
1369
+ )
1370
+ pooled_output = outputs[1]
1371
+
1372
+ pooled_output = self.dropout(pooled_output)
1373
+ logits = self.classifier(pooled_output)
1374
+ reshaped_logits = logits.view(-1, num_choices)
1375
+
1376
+ loss = None
1377
+ if labels is not None:
1378
+ loss_fct = CrossEntropyLoss()
1379
+ loss = loss_fct(reshaped_logits, labels)
1380
+
1381
+ if not return_dict:
1382
+ output = (reshaped_logits,) + outputs[2:]
1383
+ return ((loss,) + output) if loss is not None else output
1384
+
1385
+ return MultipleChoiceModelOutput(
1386
+ loss=loss,
1387
+ logits=reshaped_logits,
1388
+ hidden_states=outputs.hidden_states,
1389
+ attentions=outputs.attentions,
1390
+ )
1391
+
1392
+
1393
+ @add_start_docstrings(
1394
+ """
1395
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1396
+ Named-Entity-Recognition (NER) tasks.
1397
+ """,
1398
+ ROBERTA_START_DOCSTRING,
1399
+ )
1400
+ class RobertaForTokenClassification(RobertaPreTrainedModel):
1401
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1402
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1403
+
1404
+ def __init__(self, config):
1405
+ super().__init__(config)
1406
+ self.num_labels = config.num_labels
1407
+
1408
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1409
+ classifier_dropout = (
1410
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1411
+ )
1412
+ self.dropout = nn.Dropout(classifier_dropout)
1413
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1414
+
1415
+ self.init_weights()
1416
+
1417
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1418
+ @add_code_sample_docstrings(
1419
+ processor_class=_TOKENIZER_FOR_DOC,
1420
+ checkpoint=_CHECKPOINT_FOR_DOC,
1421
+ output_type=TokenClassifierOutput,
1422
+ config_class=_CONFIG_FOR_DOC,
1423
+ )
1424
+ def forward(
1425
+ self,
1426
+ input_ids=None,
1427
+ attention_mask=None,
1428
+ token_type_ids=None,
1429
+ position_ids=None,
1430
+ head_mask=None,
1431
+ inputs_embeds=None,
1432
+ labels=None,
1433
+ output_attentions=None,
1434
+ output_hidden_states=None,
1435
+ return_dict=None,
1436
+ ):
1437
+ r"""
1438
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1439
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1440
+ 1]``.
1441
+ """
1442
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1443
+
1444
+ outputs = self.roberta(
1445
+ input_ids,
1446
+ attention_mask=attention_mask,
1447
+ token_type_ids=token_type_ids,
1448
+ position_ids=position_ids,
1449
+ head_mask=head_mask,
1450
+ inputs_embeds=inputs_embeds,
1451
+ output_attentions=output_attentions,
1452
+ output_hidden_states=output_hidden_states,
1453
+ return_dict=return_dict,
1454
+ )
1455
+
1456
+ sequence_output = outputs[0]
1457
+
1458
+ sequence_output = self.dropout(sequence_output)
1459
+ logits = self.classifier(sequence_output)
1460
+
1461
+ loss = None
1462
+ if labels is not None:
1463
+ loss_fct = CrossEntropyLoss()
1464
+ # Only keep active parts of the loss
1465
+ if attention_mask is not None:
1466
+ active_loss = attention_mask.view(-1) == 1
1467
+ active_logits = logits.view(-1, self.num_labels)
1468
+ active_labels = torch.where(
1469
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1470
+ )
1471
+ loss = loss_fct(active_logits, active_labels)
1472
+ else:
1473
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1474
+
1475
+ if not return_dict:
1476
+ output = (logits,) + outputs[2:]
1477
+ return ((loss,) + output) if loss is not None else output
1478
+
1479
+ return TokenClassifierOutput(
1480
+ loss=loss,
1481
+ logits=logits,
1482
+ hidden_states=outputs.hidden_states,
1483
+ attentions=outputs.attentions,
1484
+ )
1485
+
1486
+
1487
+ class RobertaClassificationHead(nn.Module):
1488
+ """Head for sentence-level classification tasks."""
1489
+
1490
+ def __init__(self, config):
1491
+ super().__init__()
1492
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1493
+ classifier_dropout = (
1494
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1495
+ )
1496
+ self.dropout = nn.Dropout(classifier_dropout)
1497
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1498
+
1499
+ def forward(self, features, **kwargs):
1500
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1501
+ x = self.dropout(x)
1502
+ x = self.dense(x)
1503
+ x = torch.tanh(x)
1504
+ x = self.dropout(x)
1505
+ x = self.out_proj(x)
1506
+ return x
1507
+
1508
+
1509
+ @add_start_docstrings(
1510
+ """
1511
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1512
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1513
+ """,
1514
+ ROBERTA_START_DOCSTRING,
1515
+ )
1516
+ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
1517
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1518
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1519
+
1520
+ def __init__(self, config):
1521
+ super().__init__(config)
1522
+ self.num_labels = config.num_labels
1523
+
1524
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1525
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1526
+
1527
+ self.init_weights()
1528
+
1529
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1530
+ @add_code_sample_docstrings(
1531
+ processor_class=_TOKENIZER_FOR_DOC,
1532
+ checkpoint=_CHECKPOINT_FOR_DOC,
1533
+ output_type=QuestionAnsweringModelOutput,
1534
+ config_class=_CONFIG_FOR_DOC,
1535
+ )
1536
+ def forward(
1537
+ self,
1538
+ input_ids=None,
1539
+ attention_mask=None,
1540
+ token_type_ids=None,
1541
+ position_ids=None,
1542
+ head_mask=None,
1543
+ inputs_embeds=None,
1544
+ start_positions=None,
1545
+ end_positions=None,
1546
+ output_attentions=None,
1547
+ output_hidden_states=None,
1548
+ return_dict=None,
1549
+ ):
1550
+ r"""
1551
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1552
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1553
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1554
+ sequence are not taken into account for computing the loss.
1555
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1556
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1557
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1558
+ sequence are not taken into account for computing the loss.
1559
+ """
1560
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1561
+
1562
+ outputs = self.roberta(
1563
+ input_ids,
1564
+ attention_mask=attention_mask,
1565
+ token_type_ids=token_type_ids,
1566
+ position_ids=position_ids,
1567
+ head_mask=head_mask,
1568
+ inputs_embeds=inputs_embeds,
1569
+ output_attentions=output_attentions,
1570
+ output_hidden_states=output_hidden_states,
1571
+ return_dict=return_dict,
1572
+ )
1573
+
1574
+ sequence_output = outputs[0]
1575
+
1576
+ logits = self.qa_outputs(sequence_output)
1577
+ start_logits, end_logits = logits.split(1, dim=-1)
1578
+ start_logits = start_logits.squeeze(-1).contiguous()
1579
+ end_logits = end_logits.squeeze(-1).contiguous()
1580
+
1581
+ total_loss = None
1582
+ if start_positions is not None and end_positions is not None:
1583
+ # If we are on multi-GPU, split add a dimension
1584
+ if len(start_positions.size()) > 1:
1585
+ start_positions = start_positions.squeeze(-1)
1586
+ if len(end_positions.size()) > 1:
1587
+ end_positions = end_positions.squeeze(-1)
1588
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1589
+ ignored_index = start_logits.size(1)
1590
+ start_positions = start_positions.clamp(0, ignored_index)
1591
+ end_positions = end_positions.clamp(0, ignored_index)
1592
+
1593
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1594
+ start_loss = loss_fct(start_logits, start_positions)
1595
+ end_loss = loss_fct(end_logits, end_positions)
1596
+ total_loss = (start_loss + end_loss) / 2
1597
+
1598
+ if not return_dict:
1599
+ output = (start_logits, end_logits) + outputs[2:]
1600
+ return ((total_loss,) + output) if total_loss is not None else output
1601
+
1602
+ return QuestionAnsweringModelOutput(
1603
+ loss=total_loss,
1604
+ start_logits=start_logits,
1605
+ end_logits=end_logits,
1606
+ hidden_states=outputs.hidden_states,
1607
+ attentions=outputs.attentions,
1608
+ )
1609
+
1610
+
1611
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1612
+ """
1613
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1614
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1615
+
1616
+ Args:
1617
+ x: torch.Tensor x:
1618
+
1619
+ Returns: torch.Tensor
1620
+ """
1621
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1622
+ mask = input_ids.ne(padding_idx).int()
1623
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1624
+ return incremental_indices.long() + padding_idx
test_module/modeling_utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import torch
3
+ from transformers.file_utils import ModelOutput
4
+ from typing import Optional, Tuple
5
+
6
+ from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput
7
+
8
+ @dataclass
9
+ class BaseModelOutputWithPastAndCrossAttentionsSkim(ModelOutput):
10
+ last_hidden_state: torch.FloatTensor = None
11
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
12
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
13
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
14
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
15
+ attention_mask: Optional[torch.FloatTensor] = None
16
+ skim_mask: Optional[torch.FloatTensor] = None
17
+
18
+ @dataclass
19
+ class BaseModelOutputWithPoolingAndCrossAttentionsSkim(ModelOutput):
20
+ last_hidden_state: torch.FloatTensor = None
21
+ pooler_output: torch.FloatTensor = None
22
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
23
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
25
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
26
+ attention_mask: Optional[torch.FloatTensor] = None
27
+ skim_mask: Optional[torch.FloatTensor] = None
28
+
29
+
30
+ @dataclass
31
+ class SequenceClassifierOutputSkim(ModelOutput):
32
+ loss: Optional[torch.FloatTensor] = None
33
+ logits: torch.FloatTensor = None
34
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
35
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
36
+ attention_mask: Optional[torch.FloatTensor] = None
37
+ skim_mask: Optional[torch.FloatTensor] = None
38
+ skim_loss: Optional[torch.FloatTensor] = None
39
+ classification_loss: Optional[torch.FloatTensor] = None
40
+ tokens_remained: Optional[torch.FloatTensor] = None
41
+ layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
42
+
43
+ @dataclass
44
+ class QuestionAnsweringModelOutputSkim(QuestionAnsweringModelOutput):
45
+ attention_mask: Optional[torch.FloatTensor] = None
46
+ skim_mask: Optional[torch.FloatTensor] = None
47
+ skim_loss: Optional[torch.FloatTensor] = None
48
+ classification_loss: Optional[torch.FloatTensor] = None
49
+ tokens_remained: Optional[torch.FloatTensor] = None
50
+ layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
51
+
52
+ @dataclass
53
+ class MaskedLMOutputSkim(MaskedLMOutput):
54
+ attention_mask: Optional[torch.FloatTensor] = None
55
+ skim_mask: Optional[torch.FloatTensor] = None
56
+ skim_loss: Optional[torch.FloatTensor] = None
57
+ classification_loss: Optional[torch.FloatTensor] = None
58
+ tokens_remained: Optional[torch.FloatTensor] = None
59
+ layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
60
+
61
+ def masked_softmax(vec, mask, dim=1, eps=1e-6):
62
+ mask = mask[:,None,None,:]
63
+ exps = torch.exp(vec)
64
+ masked_exps = exps * mask.float() + eps
65
+ masked_sums = masked_exps.sum(dim, keepdim=True)
66
+ return (masked_exps/masked_sums)
67
+
68
+ def convert_softmax_mask_to_digit(skim_mask):
69
+ # skim_mask [batch, from, to, seq_len]
70
+ return (skim_mask == 0).to(dtype=torch.int64).unsqueeze(1).unsqueeze(1)
71
+
72
+ def trunc_with_mask_batched(input, mask, dim):
73
+ """
74
+ trunc a batched input at dim
75
+ e.g. hidden_states ([batch, seq_len, hidden_size])
76
+ attention_mask ([batch, layer, head, seq_len])
77
+ mask: [batch, seq_len]
78
+ """
79
+ assert input.shape[dim]==mask.shape[1]
80
+
81
+ if dim != 1:
82
+ input = input.transpose(1, dim)
83
+
84
+ transpose_shape = list(input.shape)
85
+ transpose_shape[1] = -1
86
+
87
+ trunc_input = input[mask].view(transpose_shape)
88
+
89
+ if dim != 1:
90
+ trunc_input = trunc_input.transpose(1, dim)
91
+
92
+ return trunc_input
utils/extend_auto_mapping.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def extend_lazy_auto_mapping(mapping, key, config_value, value):
2
+ prev_config_mapping = mapping._config_mapping
3
+ prev_model_mapping = mapping._model_mapping
4
+
5
+ prev_config_mapping[key] = config_value
6
+ prev_model_mapping[key] = value
7
+
8
+ def test_extending():
9
+ import transformers
10
+ extend_lazy_auto_mapping(transformers.models.auto.MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, 'test', transformers.models.bert.configuration_bert, transformers.models.bert.BertForSequenceClassification)
11
+
12
+ if __name__ == "__main__":
13
+ test_extending()
utils/utils_qa.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 The HuggingFace Team All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Post-processing utilities for question answering.
17
+ """
18
+ import collections
19
+ import json
20
+ import logging
21
+ import os
22
+ from typing import Optional, Tuple
23
+
24
+ import numpy as np
25
+ from tqdm.auto import tqdm
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ def postprocess_qa_predictions(
32
+ examples,
33
+ features,
34
+ predictions: Tuple[np.ndarray, np.ndarray],
35
+ version_2_with_negative: bool = False,
36
+ n_best_size: int = 20,
37
+ max_answer_length: int = 30,
38
+ null_score_diff_threshold: float = 0.0,
39
+ output_dir: Optional[str] = None,
40
+ prefix: Optional[str] = None,
41
+ log_level: Optional[int] = logging.WARNING,
42
+ ):
43
+ """
44
+ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
45
+ original contexts. This is the base postprocessing functions for models that only return start and end logits.
46
+
47
+ Args:
48
+ examples: The non-preprocessed dataset (see the main script for more information).
49
+ features: The processed dataset (see the main script for more information).
50
+ predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
51
+ The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
52
+ first dimension must match the number of elements of :obj:`features`.
53
+ version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
54
+ Whether or not the underlying dataset contains examples with no answers.
55
+ n_best_size (:obj:`int`, `optional`, defaults to 20):
56
+ The total number of n-best predictions to generate when looking for an answer.
57
+ max_answer_length (:obj:`int`, `optional`, defaults to 30):
58
+ The maximum length of an answer that can be generated. This is needed because the start and end predictions
59
+ are not conditioned on one another.
60
+ null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
61
+ The threshold used to select the null answer: if the best answer has a score that is less than the score of
62
+ the null answer minus this threshold, the null answer is selected for this example (note that the score of
63
+ the null answer for an example giving several features is the minimum of the scores for the null answer on
64
+ each feature: all features must be aligned on the fact they `want` to predict a null answer).
65
+
66
+ Only useful when :obj:`version_2_with_negative` is :obj:`True`.
67
+ output_dir (:obj:`str`, `optional`):
68
+ If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
69
+ :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
70
+ answers, are saved in `output_dir`.
71
+ prefix (:obj:`str`, `optional`):
72
+ If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
73
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
74
+ ``logging`` log level (e.g., ``logging.WARNING``)
75
+ """
76
+ assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
77
+ all_start_logits, all_end_logits = predictions
78
+
79
+ assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
80
+
81
+ # Build a map example to its corresponding features.
82
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
83
+ features_per_example = collections.defaultdict(list)
84
+ for i, feature in enumerate(features):
85
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
86
+
87
+ # The dictionaries we have to fill.
88
+ all_predictions = collections.OrderedDict()
89
+ all_nbest_json = collections.OrderedDict()
90
+ if version_2_with_negative:
91
+ scores_diff_json = collections.OrderedDict()
92
+
93
+ # Logging.
94
+ logger.setLevel(log_level)
95
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
96
+
97
+ # Let's loop over all the examples!
98
+ for example_index, example in enumerate(tqdm(examples)):
99
+ # Those are the indices of the features associated to the current example.
100
+ feature_indices = features_per_example[example_index]
101
+
102
+ min_null_prediction = None
103
+ prelim_predictions = []
104
+
105
+ # Looping through all the features associated to the current example.
106
+ for feature_index in feature_indices:
107
+ # We grab the predictions of the model for this feature.
108
+ start_logits = all_start_logits[feature_index]
109
+ end_logits = all_end_logits[feature_index]
110
+ # This is what will allow us to map some the positions in our logits to span of texts in the original
111
+ # context.
112
+ offset_mapping = features[feature_index]["offset_mapping"]
113
+ # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
114
+ # available in the current feature.
115
+ token_is_max_context = features[feature_index].get("token_is_max_context", None)
116
+
117
+ # Update minimum null prediction.
118
+ feature_null_score = start_logits[0] + end_logits[0]
119
+ if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
120
+ min_null_prediction = {
121
+ "offsets": (0, 0),
122
+ "score": feature_null_score,
123
+ "start_logit": start_logits[0],
124
+ "end_logit": end_logits[0],
125
+ }
126
+
127
+ # Go through all possibilities for the `n_best_size` greater start and end logits.
128
+ start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
129
+ end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
130
+ for start_index in start_indexes:
131
+ for end_index in end_indexes:
132
+ # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
133
+ # to part of the input_ids that are not in the context.
134
+ if (
135
+ start_index >= len(offset_mapping)
136
+ or end_index >= len(offset_mapping)
137
+ or offset_mapping[start_index] is None
138
+ or offset_mapping[end_index] is None
139
+ ):
140
+ continue
141
+ # Don't consider answers with a length that is either < 0 or > max_answer_length.
142
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
143
+ continue
144
+ # Don't consider answer that don't have the maximum context available (if such information is
145
+ # provided).
146
+ if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
147
+ continue
148
+ prelim_predictions.append(
149
+ {
150
+ "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
151
+ "score": start_logits[start_index] + end_logits[end_index],
152
+ "start_logit": start_logits[start_index],
153
+ "end_logit": end_logits[end_index],
154
+ }
155
+ )
156
+ if version_2_with_negative:
157
+ # Add the minimum null prediction
158
+ prelim_predictions.append(min_null_prediction)
159
+ null_score = min_null_prediction["score"]
160
+
161
+ # Only keep the best `n_best_size` predictions.
162
+ predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
163
+
164
+ # Add back the minimum null prediction if it was removed because of its low score.
165
+ if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
166
+ predictions.append(min_null_prediction)
167
+
168
+ # Use the offsets to gather the answer text in the original context.
169
+ context = example["context"]
170
+ for pred in predictions:
171
+ offsets = pred.pop("offsets")
172
+ pred["text"] = context[offsets[0] : offsets[1]]
173
+
174
+ # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
175
+ # failure.
176
+ if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
177
+ predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
178
+
179
+ # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
180
+ # the LogSumExp trick).
181
+ scores = np.array([pred.pop("score") for pred in predictions])
182
+ exp_scores = np.exp(scores - np.max(scores))
183
+ probs = exp_scores / exp_scores.sum()
184
+
185
+ # Include the probabilities in our predictions.
186
+ for prob, pred in zip(probs, predictions):
187
+ pred["probability"] = prob
188
+
189
+ # Pick the best prediction. If the null answer is not possible, this is easy.
190
+ if not version_2_with_negative:
191
+ all_predictions[example["id"]] = predictions[0]["text"]
192
+ else:
193
+ # Otherwise we first need to find the best non-empty prediction.
194
+ i = 0
195
+ while predictions[i]["text"] == "":
196
+ i += 1
197
+ best_non_null_pred = predictions[i]
198
+
199
+ # Then we compare to the null prediction using the threshold.
200
+ score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
201
+ scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
202
+ if score_diff > null_score_diff_threshold:
203
+ all_predictions[example["id"]] = ""
204
+ else:
205
+ all_predictions[example["id"]] = best_non_null_pred["text"]
206
+
207
+ # Make `predictions` JSON-serializable by casting np.float back to float.
208
+ all_nbest_json[example["id"]] = [
209
+ {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
210
+ for pred in predictions
211
+ ]
212
+
213
+ # If we have an output_dir, let's save all those dicts.
214
+ if output_dir is not None:
215
+ assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
216
+
217
+ prediction_file = os.path.join(
218
+ output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
219
+ )
220
+ nbest_file = os.path.join(
221
+ output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
222
+ )
223
+ if version_2_with_negative:
224
+ null_odds_file = os.path.join(
225
+ output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
226
+ )
227
+
228
+ logger.info(f"Saving predictions to {prediction_file}.")
229
+ with open(prediction_file, "w") as writer:
230
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
231
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
232
+ with open(nbest_file, "w") as writer:
233
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
234
+ if version_2_with_negative:
235
+ logger.info(f"Saving null_odds to {null_odds_file}.")
236
+ with open(null_odds_file, "w") as writer:
237
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
238
+
239
+ return all_predictions
240
+
241
+
242
+ def postprocess_qa_predictions_with_beam_search(
243
+ examples,
244
+ features,
245
+ predictions: Tuple[np.ndarray, np.ndarray],
246
+ version_2_with_negative: bool = False,
247
+ n_best_size: int = 20,
248
+ max_answer_length: int = 30,
249
+ start_n_top: int = 5,
250
+ end_n_top: int = 5,
251
+ output_dir: Optional[str] = None,
252
+ prefix: Optional[str] = None,
253
+ log_level: Optional[int] = logging.WARNING,
254
+ ):
255
+ """
256
+ Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
257
+ original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
258
+ cls token predictions.
259
+
260
+ Args:
261
+ examples: The non-preprocessed dataset (see the main script for more information).
262
+ features: The processed dataset (see the main script for more information).
263
+ predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
264
+ The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
265
+ first dimension must match the number of elements of :obj:`features`.
266
+ version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
267
+ Whether or not the underlying dataset contains examples with no answers.
268
+ n_best_size (:obj:`int`, `optional`, defaults to 20):
269
+ The total number of n-best predictions to generate when looking for an answer.
270
+ max_answer_length (:obj:`int`, `optional`, defaults to 30):
271
+ The maximum length of an answer that can be generated. This is needed because the start and end predictions
272
+ are not conditioned on one another.
273
+ start_n_top (:obj:`int`, `optional`, defaults to 5):
274
+ The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
275
+ end_n_top (:obj:`int`, `optional`, defaults to 5):
276
+ The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
277
+ output_dir (:obj:`str`, `optional`):
278
+ If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
279
+ :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
280
+ answers, are saved in `output_dir`.
281
+ prefix (:obj:`str`, `optional`):
282
+ If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
283
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
284
+ ``logging`` log level (e.g., ``logging.WARNING``)
285
+ """
286
+ assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
287
+ start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
288
+
289
+ assert len(predictions[0]) == len(
290
+ features
291
+ ), f"Got {len(predictions[0])} predicitions and {len(features)} features."
292
+
293
+ # Build a map example to its corresponding features.
294
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
295
+ features_per_example = collections.defaultdict(list)
296
+ for i, feature in enumerate(features):
297
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
298
+
299
+ # The dictionaries we have to fill.
300
+ all_predictions = collections.OrderedDict()
301
+ all_nbest_json = collections.OrderedDict()
302
+ scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
303
+
304
+ # Logging.
305
+ logger.setLevel(log_level)
306
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
307
+
308
+ # Let's loop over all the examples!
309
+ for example_index, example in enumerate(tqdm(examples)):
310
+ # Those are the indices of the features associated to the current example.
311
+ feature_indices = features_per_example[example_index]
312
+
313
+ min_null_score = None
314
+ prelim_predictions = []
315
+
316
+ # Looping through all the features associated to the current example.
317
+ for feature_index in feature_indices:
318
+ # We grab the predictions of the model for this feature.
319
+ start_log_prob = start_top_log_probs[feature_index]
320
+ start_indexes = start_top_index[feature_index]
321
+ end_log_prob = end_top_log_probs[feature_index]
322
+ end_indexes = end_top_index[feature_index]
323
+ feature_null_score = cls_logits[feature_index]
324
+ # This is what will allow us to map some the positions in our logits to span of texts in the original
325
+ # context.
326
+ offset_mapping = features[feature_index]["offset_mapping"]
327
+ # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
328
+ # available in the current feature.
329
+ token_is_max_context = features[feature_index].get("token_is_max_context", None)
330
+
331
+ # Update minimum null prediction
332
+ if min_null_score is None or feature_null_score < min_null_score:
333
+ min_null_score = feature_null_score
334
+
335
+ # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
336
+ for i in range(start_n_top):
337
+ for j in range(end_n_top):
338
+ start_index = int(start_indexes[i])
339
+ j_index = i * end_n_top + j
340
+ end_index = int(end_indexes[j_index])
341
+ # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
342
+ # p_mask but let's not take any risk)
343
+ if (
344
+ start_index >= len(offset_mapping)
345
+ or end_index >= len(offset_mapping)
346
+ or offset_mapping[start_index] is None
347
+ or offset_mapping[end_index] is None
348
+ ):
349
+ continue
350
+ # Don't consider answers with a length negative or > max_answer_length.
351
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
352
+ continue
353
+ # Don't consider answer that don't have the maximum context available (if such information is
354
+ # provided).
355
+ if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
356
+ continue
357
+ prelim_predictions.append(
358
+ {
359
+ "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
360
+ "score": start_log_prob[i] + end_log_prob[j_index],
361
+ "start_log_prob": start_log_prob[i],
362
+ "end_log_prob": end_log_prob[j_index],
363
+ }
364
+ )
365
+
366
+ # Only keep the best `n_best_size` predictions.
367
+ predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
368
+
369
+ # Use the offsets to gather the answer text in the original context.
370
+ context = example["context"]
371
+ for pred in predictions:
372
+ offsets = pred.pop("offsets")
373
+ pred["text"] = context[offsets[0] : offsets[1]]
374
+
375
+ # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
376
+ # failure.
377
+ if len(predictions) == 0:
378
+ predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
379
+
380
+ # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
381
+ # the LogSumExp trick).
382
+ scores = np.array([pred.pop("score") for pred in predictions])
383
+ exp_scores = np.exp(scores - np.max(scores))
384
+ probs = exp_scores / exp_scores.sum()
385
+
386
+ # Include the probabilities in our predictions.
387
+ for prob, pred in zip(probs, predictions):
388
+ pred["probability"] = prob
389
+
390
+ # Pick the best prediction and set the probability for the null answer.
391
+ all_predictions[example["id"]] = predictions[0]["text"]
392
+ if version_2_with_negative:
393
+ scores_diff_json[example["id"]] = float(min_null_score)
394
+
395
+ # Make `predictions` JSON-serializable by casting np.float back to float.
396
+ all_nbest_json[example["id"]] = [
397
+ {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
398
+ for pred in predictions
399
+ ]
400
+
401
+ # If we have an output_dir, let's save all those dicts.
402
+ if output_dir is not None:
403
+ assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
404
+
405
+ prediction_file = os.path.join(
406
+ output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
407
+ )
408
+ nbest_file = os.path.join(
409
+ output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
410
+ )
411
+ if version_2_with_negative:
412
+ null_odds_file = os.path.join(
413
+ output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
414
+ )
415
+
416
+ logger.info(f"Saving predictions to {prediction_file}.")
417
+ with open(prediction_file, "w") as writer:
418
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
419
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
420
+ with open(nbest_file, "w") as writer:
421
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
422
+ if version_2_with_negative:
423
+ logger.info(f"Saving null_odds to {null_odds_file}.")
424
+ with open(null_odds_file, "w") as writer:
425
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
426
+
427
+ return all_predictions, scores_diff_json