pure-deberta-large / modeling_pure_deberta.py
yangwang825's picture
Upload PureDebertaForSequenceClassification
1729fca verified
import torch
import torch.nn as nn
from transformers import (
DebertaV2Model,
PreTrainedModel,
)
from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout
from typing import Union, Tuple, Optional
from transformers.modeling_outputs import SequenceClassifierOutput
from .configuration_pure_deberta import PureDebertaConfig
class PFSA(torch.nn.Module):
"""
https://openreview.net/pdf?id=isodM5jTA7h
"""
def __init__(self, input_dim, alpha=1):
super(PFSA, self).__init__()
self.input_dim = input_dim
self.alpha = alpha
def forward_one_sample(self, x):
x = x.transpose(1, 2)[..., None]
k = torch.mean(x, dim=[-1, -2], keepdim=True)
kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1]
qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1]
C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd)
A = (1 - torch.sigmoid(C_qk)) ** self.alpha
out = x * A
out = out.squeeze(dim=-1).transpose(1, 2)
return out
def forward(self, input_values, attention_mask=None):
"""
x: [B, T, F]
"""
out = []
b, t, f = input_values.shape
for x, mask in zip(input_values, attention_mask):
x = x.view(1, t, f)
x_in = x[:, :sum(mask), :]
x_out = self.forward_one_sample(x_in)
x_expanded = torch.zeros_like(x, device=x.device)
x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out
out.append(x_expanded)
out = torch.vstack(out)
out = out.view(b, t, f)
return out
class PURE(torch.nn.Module):
def __init__(
self,
in_dim,
target_rank=5,
npc=1,
center=False,
num_iters=1,
alpha=1,
do_pcr=True,
do_pfsa=True,
*args, **kwargs
):
super().__init__()
self.in_dim = in_dim
self.target_rank = target_rank
self.npc = npc
self.center = center
self.num_iters = num_iters
self.do_pcr = do_pcr
self.do_pfsa = do_pfsa
self.attention = PFSA(in_dim, alpha=alpha)
def _compute_pc(self, X, attention_mask):
"""
x: (B, T, F)
"""
pcs = []
bs, seqlen, dim = X.shape
for x, mask in zip(X, attention_mask):
x_ = x[:sum(mask), :]
q = min(self.target_rank, sum(mask))
_, _, V = torch.pca_lowrank(x_, q=q, center=self.center, niter=self.num_iters)
pc = V.transpose(0, 1)[:self.npc, :] # pc: [K, F]
pcs.append(pc)
# pcs = torch.vstack(pcs)
# pcs = pcs.view(bs, self.num_pc_to_remove, dim)
return pcs
def _remove_pc(self, X, pcs):
"""
[B, T, F], [B, ..., F]
"""
b, t, f = X.shape
out = []
for i, (x, pc) in enumerate(zip(X, pcs)):
# v = []
# for j, t in enumerate(x):
# t_ = t
# for c_ in c:
# t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1)
# v.append(t_.transpose(-1, -2))
# v = torch.vstack(v)
v = x - x @ pc.transpose(0, 1) @ pc
out.append(v[None, ...])
out = torch.vstack(out)
return out
def forward(self, input_values, attention_mask=None, *args, **kwargs):
"""
PCR -> Attention
x: (B, T, F)
"""
x = input_values
if self.do_pcr:
pc = self._compute_pc(x, attention_mask) # pc: [B, K, F]
xx = self._remove_pc(x, pc)
# xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F]
else:
xx = x
if self.do_pfsa:
xx = self.attention(xx, attention_mask)
return xx
class StatisticsPooling(torch.nn.Module):
def __init__(self, return_mean=True, return_std=True):
super().__init__()
# Small value for GaussNoise
self.eps = 1e-5
self.return_mean = return_mean
self.return_std = return_std
if not (self.return_mean or self.return_std):
raise ValueError(
"both of statistics are equal to False \n"
"consider enabling mean and/or std statistic pooling"
)
def forward(self, input_values, attention_mask=None):
"""Calculates mean and std for a batch (input tensor).
Arguments
---------
x : torch.Tensor
It represents a tensor for a mini-batch.
"""
x = input_values
if attention_mask is None:
if self.return_mean:
mean = x.mean(dim=1)
if self.return_std:
std = x.std(dim=1)
else:
mean = []
std = []
for snt_id in range(x.shape[0]):
# Avoiding padded time steps
lengths = torch.sum(attention_mask, dim=1)
relative_lengths = lengths / torch.max(lengths)
actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int()
# actual_size = int(torch.round(lengths[snt_id] * x.shape[1]))
# computing statistics
if self.return_mean:
mean.append(
torch.mean(x[snt_id, 0:actual_size, ...], dim=0)
)
if self.return_std:
std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0))
if self.return_mean:
mean = torch.stack(mean)
if self.return_std:
std = torch.stack(std)
if self.return_mean:
gnoise = self._get_gauss_noise(mean.size(), device=mean.device)
gnoise = gnoise
mean += gnoise
if self.return_std:
std = std + self.eps
# Append mean and std of the batch
if self.return_mean and self.return_std:
pooled_stats = torch.cat((mean, std), dim=1)
pooled_stats = pooled_stats.unsqueeze(1)
elif self.return_mean:
pooled_stats = mean.unsqueeze(1)
elif self.return_std:
pooled_stats = std.unsqueeze(1)
return pooled_stats
def _get_gauss_noise(self, shape_of_tensor, device="cpu"):
"""Returns a tensor of epsilon Gaussian noise.
Arguments
---------
shape_of_tensor : tensor
It represents the size of tensor for generating Gaussian noise.
"""
gnoise = torch.randn(shape_of_tensor, device=device)
gnoise -= torch.min(gnoise)
gnoise /= torch.max(gnoise)
gnoise = self.eps * ((1 - 9) * gnoise + 9)
return gnoise
class DebertaV2PreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PureDebertaConfig
base_model_prefix = "deberta"
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class PureDebertaForSequenceClassification(DebertaV2PreTrainedModel):
def __init__(
self,
config,
):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.deberta = DebertaV2Model(config)
drop_out = getattr(config, "cls_dropout", None)
drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
self.pure = PURE(
in_dim=config.hidden_size,
svd_rank=config.svd_rank,
num_pc_to_remove=config.num_pc_to_remove,
center=config.center,
num_iters=config.num_iters,
alpha=config.alpha,
disable_pcr=config.disable_pcr,
disable_pfsa=config.disable_pfsa,
disable_covariance=config.disable_covariance
)
self.mean = StatisticsPooling(return_mean=True, return_std=False)
self.dropout = StableDropout(drop_out)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.deberta(
input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
token_embeddings = outputs.last_hidden_state
token_embeddings = self.pure(token_embeddings, attention_mask)
pooled_output = self.mean(token_embeddings).squeeze(1)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = nn.MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = nn.BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)