yangwang825 commited on
Commit
1729fca
·
verified ·
1 Parent(s): c700a08

Upload PureDebertaForSequenceClassification

Browse files
Files changed (3) hide show
  1. config.json +7 -1
  2. model.safetensors +3 -0
  3. modeling_pure_deberta.py +337 -0
config.json CHANGED
@@ -1,8 +1,13 @@
1
  {
 
2
  "alpha": 1,
 
 
 
3
  "attention_probs_dropout_prob": 0.1,
4
  "auto_map": {
5
- "AutoConfig": "configuration_pure_deberta.PureDebertaConfig"
 
6
  },
7
  "center": false,
8
  "disable_covariance": true,
@@ -35,6 +40,7 @@
35
  "relative_attention": true,
36
  "share_att_key": true,
37
  "svd_rank": 5,
 
38
  "transformers_version": "4.46.2",
39
  "type_vocab_size": 0,
40
  "vocab_size": 128100
 
1
  {
2
+ "_name_or_path": "microsoft/deberta-v3-large",
3
  "alpha": 1,
4
+ "architectures": [
5
+ "PureDebertaForSequenceClassification"
6
+ ],
7
  "attention_probs_dropout_prob": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "configuration_pure_deberta.PureDebertaConfig",
10
+ "AutoModelForSequenceClassification": "modeling_pure_deberta.PureDebertaForSequenceClassification"
11
  },
12
  "center": false,
13
  "disable_covariance": true,
 
40
  "relative_attention": true,
41
  "share_att_key": true,
42
  "svd_rank": 5,
43
+ "torch_dtype": "float32",
44
  "transformers_version": "4.46.2",
45
  "type_vocab_size": 0,
46
  "vocab_size": 128100
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a506f94bfac95d49b2c7729c2ce3956233a9963c6e777dfbd17c4b314331d56b
3
+ size 1736105856
modeling_pure_deberta.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import (
4
+ DebertaV2Model,
5
+ PreTrainedModel,
6
+ )
7
+ from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout
8
+ from typing import Union, Tuple, Optional
9
+ from transformers.modeling_outputs import SequenceClassifierOutput
10
+
11
+ from .configuration_pure_deberta import PureDebertaConfig
12
+
13
+
14
+ class PFSA(torch.nn.Module):
15
+ """
16
+ https://openreview.net/pdf?id=isodM5jTA7h
17
+ """
18
+ def __init__(self, input_dim, alpha=1):
19
+ super(PFSA, self).__init__()
20
+ self.input_dim = input_dim
21
+ self.alpha = alpha
22
+
23
+ def forward_one_sample(self, x):
24
+ x = x.transpose(1, 2)[..., None]
25
+ k = torch.mean(x, dim=[-1, -2], keepdim=True)
26
+ kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
27
+ qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
28
+ C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
29
+ A = (1 - torch.sigmoid(C_qk)) ** self.alpha
30
+ out = x * A
31
+ out = out.squeeze(dim=-1).transpose(1, 2)
32
+ return out
33
+
34
+ def forward(self, input_values, attention_mask=None):
35
+ """
36
+ x: [B, T, F]
37
+ """
38
+ out = []
39
+ b, t, f = input_values.shape
40
+ for x, mask in zip(input_values, attention_mask):
41
+ x = x.view(1, t, f)
42
+ x_in = x[:, :sum(mask), :]
43
+ x_out = self.forward_one_sample(x_in)
44
+ x_expanded = torch.zeros_like(x, device=x.device)
45
+ x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
46
+ out.append(x_expanded)
47
+ out = torch.vstack(out)
48
+ out = out.view(b, t, f)
49
+ return out
50
+
51
+
52
+ class PURE(torch.nn.Module):
53
+
54
+ def __init__(
55
+ self,
56
+ in_dim,
57
+ target_rank=5,
58
+ npc=1,
59
+ center=False,
60
+ num_iters=1,
61
+ alpha=1,
62
+ do_pcr=True,
63
+ do_pfsa=True,
64
+ *args, **kwargs
65
+ ):
66
+ super().__init__()
67
+ self.in_dim = in_dim
68
+ self.target_rank = target_rank
69
+ self.npc = npc
70
+ self.center = center
71
+ self.num_iters = num_iters
72
+ self.do_pcr = do_pcr
73
+ self.do_pfsa = do_pfsa
74
+ self.attention = PFSA(in_dim, alpha=alpha)
75
+
76
+ def _compute_pc(self, X, attention_mask):
77
+ """
78
+ x: (B, T, F)
79
+ """
80
+ pcs = []
81
+ bs, seqlen, dim = X.shape
82
+ for x, mask in zip(X, attention_mask):
83
+ x_ = x[:sum(mask), :]
84
+ q = min(self.target_rank, sum(mask))
85
+ _, _, V = torch.pca_lowrank(x_, q=q, center=self.center, niter=self.num_iters)
86
+ pc = V.transpose(0, 1)[:self.npc, :] # pc: [K, F]
87
+ pcs.append(pc)
88
+ # pcs = torch.vstack(pcs)
89
+ # pcs = pcs.view(bs, self.num_pc_to_remove, dim)
90
+ return pcs
91
+
92
+ def _remove_pc(self, X, pcs):
93
+ """
94
+ [B, T, F], [B, ..., F]
95
+ """
96
+ b, t, f = X.shape
97
+ out = []
98
+ for i, (x, pc) in enumerate(zip(X, pcs)):
99
+ # v = []
100
+ # for j, t in enumerate(x):
101
+ # t_ = t
102
+ # for c_ in c:
103
+ # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
104
+ # v.append(t_.transpose(-1, -2))
105
+ # v = torch.vstack(v)
106
+ v = x - x @ pc.transpose(0, 1) @ pc
107
+ out.append(v[None, ...])
108
+ out = torch.vstack(out)
109
+ return out
110
+
111
+ def forward(self, input_values, attention_mask=None, *args, **kwargs):
112
+ """
113
+ PCR -> Attention
114
+ x: (B, T, F)
115
+ """
116
+ x = input_values
117
+ if self.do_pcr:
118
+ pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
119
+ xx = self._remove_pc(x, pc)
120
+ # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
121
+ else:
122
+ xx = x
123
+ if self.do_pfsa:
124
+ xx = self.attention(xx, attention_mask)
125
+ return xx
126
+
127
+
128
+ class StatisticsPooling(torch.nn.Module):
129
+
130
+ def __init__(self, return_mean=True, return_std=True):
131
+ super().__init__()
132
+
133
+ # Small value for GaussNoise
134
+ self.eps = 1e-5
135
+ self.return_mean = return_mean
136
+ self.return_std = return_std
137
+ if not (self.return_mean or self.return_std):
138
+ raise ValueError(
139
+ "both of statistics are equal to False \n"
140
+ "consider enabling mean and/or std statistic pooling"
141
+ )
142
+
143
+ def forward(self, input_values, attention_mask=None):
144
+ """Calculates mean and std for a batch (input tensor).
145
+
146
+ Arguments
147
+ ---------
148
+ x : torch.Tensor
149
+ It represents a tensor for a mini-batch.
150
+ """
151
+ x = input_values
152
+ if attention_mask is None:
153
+ if self.return_mean:
154
+ mean = x.mean(dim=1)
155
+ if self.return_std:
156
+ std = x.std(dim=1)
157
+ else:
158
+ mean = []
159
+ std = []
160
+ for snt_id in range(x.shape[0]):
161
+ # Avoiding padded time steps
162
+ lengths = torch.sum(attention_mask, dim=1)
163
+ relative_lengths = lengths / torch.max(lengths)
164
+ actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
165
+ # actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
166
+
167
+ # computing statistics
168
+ if self.return_mean:
169
+ mean.append(
170
+ torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
171
+ )
172
+ if self.return_std:
173
+ std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
174
+ if self.return_mean:
175
+ mean = torch.stack(mean)
176
+ if self.return_std:
177
+ std = torch.stack(std)
178
+
179
+ if self.return_mean:
180
+ gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
181
+ gnoise = gnoise
182
+ mean += gnoise
183
+ if self.return_std:
184
+ std = std + self.eps
185
+
186
+ # Append mean and std of the batch
187
+ if self.return_mean and self.return_std:
188
+ pooled_stats = torch.cat((mean, std), dim=1)
189
+ pooled_stats = pooled_stats.unsqueeze(1)
190
+ elif self.return_mean:
191
+ pooled_stats = mean.unsqueeze(1)
192
+ elif self.return_std:
193
+ pooled_stats = std.unsqueeze(1)
194
+
195
+ return pooled_stats
196
+
197
+ def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
198
+ """Returns a tensor of epsilon Gaussian noise.
199
+
200
+ Arguments
201
+ ---------
202
+ shape_of_tensor : tensor
203
+ It represents the size of tensor for generating Gaussian noise.
204
+ """
205
+ gnoise = torch.randn(shape_of_tensor, device=device)
206
+ gnoise -= torch.min(gnoise)
207
+ gnoise /= torch.max(gnoise)
208
+ gnoise = self.eps * ((1 - 9) * gnoise + 9)
209
+
210
+ return gnoise
211
+
212
+
213
+ class DebertaV2PreTrainedModel(PreTrainedModel):
214
+ """
215
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
216
+ models.
217
+ """
218
+
219
+ config_class = PureDebertaConfig
220
+ base_model_prefix = "deberta"
221
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
222
+ supports_gradient_checkpointing = True
223
+
224
+ def _init_weights(self, module):
225
+ """Initialize the weights."""
226
+ if isinstance(module, nn.Linear):
227
+ # Slightly different from the TF version which uses truncated_normal for initialization
228
+ # cf https://github.com/pytorch/pytorch/pull/5617
229
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
230
+ if module.bias is not None:
231
+ module.bias.data.zero_()
232
+ elif isinstance(module, nn.Embedding):
233
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
234
+ if module.padding_idx is not None:
235
+ module.weight.data[module.padding_idx].zero_()
236
+
237
+
238
+ class PureDebertaForSequenceClassification(DebertaV2PreTrainedModel):
239
+
240
+ def __init__(
241
+ self,
242
+ config,
243
+ ):
244
+ super().__init__(config)
245
+ self.num_labels = config.num_labels
246
+ self.config = config
247
+
248
+ self.deberta = DebertaV2Model(config)
249
+ drop_out = getattr(config, "cls_dropout", None)
250
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
251
+ self.pure = PURE(
252
+ in_dim=config.hidden_size,
253
+ svd_rank=config.svd_rank,
254
+ num_pc_to_remove=config.num_pc_to_remove,
255
+ center=config.center,
256
+ num_iters=config.num_iters,
257
+ alpha=config.alpha,
258
+ disable_pcr=config.disable_pcr,
259
+ disable_pfsa=config.disable_pfsa,
260
+ disable_covariance=config.disable_covariance
261
+ )
262
+ self.mean = StatisticsPooling(return_mean=True, return_std=False)
263
+ self.dropout = StableDropout(drop_out)
264
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
265
+
266
+ # Initialize weights and apply final processing
267
+ self.post_init()
268
+
269
+ def forward(
270
+ self,
271
+ input_ids: Optional[torch.Tensor] = None,
272
+ attention_mask: Optional[torch.Tensor] = None,
273
+ token_type_ids: Optional[torch.Tensor] = None,
274
+ position_ids: Optional[torch.Tensor] = None,
275
+ inputs_embeds: Optional[torch.Tensor] = None,
276
+ labels: Optional[torch.Tensor] = None,
277
+ output_attentions: Optional[bool] = None,
278
+ output_hidden_states: Optional[bool] = None,
279
+ return_dict: Optional[bool] = None,
280
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
281
+ r"""
282
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
283
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
284
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
285
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
286
+ """
287
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
288
+
289
+ outputs = self.deberta(
290
+ input_ids,
291
+ token_type_ids=token_type_ids,
292
+ attention_mask=attention_mask,
293
+ position_ids=position_ids,
294
+ inputs_embeds=inputs_embeds,
295
+ output_attentions=output_attentions,
296
+ output_hidden_states=output_hidden_states,
297
+ return_dict=return_dict,
298
+ )
299
+
300
+ token_embeddings = outputs.last_hidden_state
301
+ token_embeddings = self.pure(token_embeddings, attention_mask)
302
+ pooled_output = self.mean(token_embeddings).squeeze(1)
303
+ pooled_output = self.dropout(pooled_output)
304
+ logits = self.classifier(pooled_output)
305
+
306
+ loss = None
307
+ if labels is not None:
308
+ if self.config.problem_type is None:
309
+ if self.num_labels == 1:
310
+ self.config.problem_type = "regression"
311
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
312
+ self.config.problem_type = "single_label_classification"
313
+ else:
314
+ self.config.problem_type = "multi_label_classification"
315
+
316
+ if self.config.problem_type == "regression":
317
+ loss_fct = nn.MSELoss()
318
+ if self.num_labels == 1:
319
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
320
+ else:
321
+ loss = loss_fct(logits, labels)
322
+ elif self.config.problem_type == "single_label_classification":
323
+ loss_fct = nn.CrossEntropyLoss()
324
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
325
+ elif self.config.problem_type == "multi_label_classification":
326
+ loss_fct = nn.BCEWithLogitsLoss()
327
+ loss = loss_fct(logits, labels)
328
+ if not return_dict:
329
+ output = (logits,) + outputs[2:]
330
+ return ((loss,) + output) if loss is not None else output
331
+
332
+ return SequenceClassifierOutput(
333
+ loss=loss,
334
+ logits=logits,
335
+ hidden_states=outputs.hidden_states,
336
+ attentions=outputs.attentions,
337
+ )