import torch import torch.nn as nn from transformers import ( DebertaV2Model, PreTrainedModel, ) from transformers.models.deberta_v2.modeling_deberta_v2 import StableDropout from typing import Union, Tuple, Optional from transformers.modeling_outputs import SequenceClassifierOutput from .configuration_pure_deberta import PureDebertaConfig class PFSA(torch.nn.Module): """ https://openreview.net/pdf?id=isodM5jTA7h """ def __init__(self, input_dim, alpha=1): super(PFSA, self).__init__() self.input_dim = input_dim self.alpha = alpha def forward_one_sample(self, x): x = x.transpose(1, 2)[..., None] k = torch.mean(x, dim=[-1, -2], keepdim=True) kd = torch.sqrt((k - k.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, 1, 1] qd = torch.sqrt((x - x.mean(dim=1, keepdim=True)).pow(2).sum(dim=1, keepdim=True)) # [B, 1, T, 1] C_qk = (((x - x.mean(dim=1, keepdim=True)) * (k - k.mean(dim=1, keepdim=True))).sum(dim=1, keepdim=True)) / (qd * kd) A = (1 - torch.sigmoid(C_qk)) ** self.alpha out = x * A out = out.squeeze(dim=-1).transpose(1, 2) return out def forward(self, input_values, attention_mask=None): """ x: [B, T, F] """ out = [] b, t, f = input_values.shape for x, mask in zip(input_values, attention_mask): x = x.view(1, t, f) x_in = x[:, :sum(mask), :] x_out = self.forward_one_sample(x_in) x_expanded = torch.zeros_like(x, device=x.device) x_expanded[:, :x_out.shape[-2], :x_out.shape[-1]] = x_out out.append(x_expanded) out = torch.vstack(out) out = out.view(b, t, f) return out class PURE(torch.nn.Module): def __init__( self, in_dim, target_rank=5, npc=1, center=False, num_iters=1, alpha=1, do_pcr=True, do_pfsa=True, *args, **kwargs ): super().__init__() self.in_dim = in_dim self.target_rank = target_rank self.npc = npc self.center = center self.num_iters = num_iters self.do_pcr = do_pcr self.do_pfsa = do_pfsa self.attention = PFSA(in_dim, alpha=alpha) def _compute_pc(self, X, attention_mask): """ x: (B, T, F) """ pcs = [] bs, seqlen, dim = X.shape for x, mask in zip(X, attention_mask): x_ = x[:sum(mask), :] q = min(self.target_rank, sum(mask)) _, _, V = torch.pca_lowrank(x_, q=q, center=self.center, niter=self.num_iters) pc = V.transpose(0, 1)[:self.npc, :] # pc: [K, F] pcs.append(pc) # pcs = torch.vstack(pcs) # pcs = pcs.view(bs, self.num_pc_to_remove, dim) return pcs def _remove_pc(self, X, pcs): """ [B, T, F], [B, ..., F] """ b, t, f = X.shape out = [] for i, (x, pc) in enumerate(zip(X, pcs)): # v = [] # for j, t in enumerate(x): # t_ = t # for c_ in c: # t_ = t_.view(f, 1) - c_.view(f, 1) @ c_.view(1, f) @ t.view(f, 1) # v.append(t_.transpose(-1, -2)) # v = torch.vstack(v) v = x - x @ pc.transpose(0, 1) @ pc out.append(v[None, ...]) out = torch.vstack(out) return out def forward(self, input_values, attention_mask=None, *args, **kwargs): """ PCR -> Attention x: (B, T, F) """ x = input_values if self.do_pcr: pc = self._compute_pc(x, attention_mask) # pc: [B, K, F] xx = self._remove_pc(x, pc) # xx = xt - xt @ pc.transpose(1, 2) @ pc # [B, T, F] * [B, F, K] * [B, K, F] = [B, T, F] else: xx = x if self.do_pfsa: xx = self.attention(xx, attention_mask) return xx class StatisticsPooling(torch.nn.Module): def __init__(self, return_mean=True, return_std=True): super().__init__() # Small value for GaussNoise self.eps = 1e-5 self.return_mean = return_mean self.return_std = return_std if not (self.return_mean or self.return_std): raise ValueError( "both of statistics are equal to False \n" "consider enabling mean and/or std statistic pooling" ) def forward(self, input_values, attention_mask=None): """Calculates mean and std for a batch (input tensor). Arguments --------- x : torch.Tensor It represents a tensor for a mini-batch. """ x = input_values if attention_mask is None: if self.return_mean: mean = x.mean(dim=1) if self.return_std: std = x.std(dim=1) else: mean = [] std = [] for snt_id in range(x.shape[0]): # Avoiding padded time steps lengths = torch.sum(attention_mask, dim=1) relative_lengths = lengths / torch.max(lengths) actual_size = torch.round(relative_lengths[snt_id] * x.shape[1]).int() # actual_size = int(torch.round(lengths[snt_id] * x.shape[1])) # computing statistics if self.return_mean: mean.append( torch.mean(x[snt_id, 0:actual_size, ...], dim=0) ) if self.return_std: std.append(torch.std(x[snt_id, 0:actual_size, ...], dim=0)) if self.return_mean: mean = torch.stack(mean) if self.return_std: std = torch.stack(std) if self.return_mean: gnoise = self._get_gauss_noise(mean.size(), device=mean.device) gnoise = gnoise mean += gnoise if self.return_std: std = std + self.eps # Append mean and std of the batch if self.return_mean and self.return_std: pooled_stats = torch.cat((mean, std), dim=1) pooled_stats = pooled_stats.unsqueeze(1) elif self.return_mean: pooled_stats = mean.unsqueeze(1) elif self.return_std: pooled_stats = std.unsqueeze(1) return pooled_stats def _get_gauss_noise(self, shape_of_tensor, device="cpu"): """Returns a tensor of epsilon Gaussian noise. Arguments --------- shape_of_tensor : tensor It represents the size of tensor for generating Gaussian noise. """ gnoise = torch.randn(shape_of_tensor, device=device) gnoise -= torch.min(gnoise) gnoise /= torch.max(gnoise) gnoise = self.eps * ((1 - 9) * gnoise + 9) return gnoise class DebertaV2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ config_class = PureDebertaConfig base_model_prefix = "deberta" _keys_to_ignore_on_load_unexpected = ["position_embeddings"] supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, nn.Linear): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class PureDebertaForSequenceClassification(DebertaV2PreTrainedModel): def __init__( self, config, ): super().__init__(config) self.num_labels = config.num_labels self.config = config self.deberta = DebertaV2Model(config) drop_out = getattr(config, "cls_dropout", None) drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out self.pure = PURE( in_dim=config.hidden_size, svd_rank=config.svd_rank, num_pc_to_remove=config.num_pc_to_remove, center=config.center, num_iters=config.num_iters, alpha=config.alpha, disable_pcr=config.disable_pcr, disable_pfsa=config.disable_pfsa, disable_covariance=config.disable_covariance ) self.mean = StatisticsPooling(return_mean=True, return_std=False) self.dropout = StableDropout(drop_out) self.classifier = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.deberta( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) token_embeddings = outputs.last_hidden_state token_embeddings = self.pure(token_embeddings, attention_mask) pooled_output = self.mean(token_embeddings).squeeze(1) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )