Spaces:
Sleeping
Sleeping
File size: 3,617 Bytes
0e73e91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
from dataclasses import dataclass
import torch
from transformers.file_utils import ModelOutput
from typing import Optional, Tuple
from transformers.modeling_outputs import MaskedLMOutput, QuestionAnsweringModelOutput
@dataclass
class BaseModelOutputWithPastAndCrossAttentionsSkim(ModelOutput):
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
attention_mask: Optional[torch.FloatTensor] = None
skim_mask: Optional[torch.FloatTensor] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentionsSkim(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
attention_mask: Optional[torch.FloatTensor] = None
skim_mask: Optional[torch.FloatTensor] = None
@dataclass
class SequenceClassifierOutputSkim(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
attention_mask: Optional[torch.FloatTensor] = None
skim_mask: Optional[torch.FloatTensor] = None
skim_loss: Optional[torch.FloatTensor] = None
classification_loss: Optional[torch.FloatTensor] = None
tokens_remained: Optional[torch.FloatTensor] = None
layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class QuestionAnsweringModelOutputSkim(QuestionAnsweringModelOutput):
attention_mask: Optional[torch.FloatTensor] = None
skim_mask: Optional[torch.FloatTensor] = None
skim_loss: Optional[torch.FloatTensor] = None
classification_loss: Optional[torch.FloatTensor] = None
tokens_remained: Optional[torch.FloatTensor] = None
layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class MaskedLMOutputSkim(MaskedLMOutput):
attention_mask: Optional[torch.FloatTensor] = None
skim_mask: Optional[torch.FloatTensor] = None
skim_loss: Optional[torch.FloatTensor] = None
classification_loss: Optional[torch.FloatTensor] = None
tokens_remained: Optional[torch.FloatTensor] = None
layer_tokens_remained: Optional[Tuple[torch.FloatTensor]] = None
def masked_softmax(vec, mask, dim=1, eps=1e-6):
mask = mask[:,None,None,:]
exps = torch.exp(vec)
masked_exps = exps * mask.float() + eps
masked_sums = masked_exps.sum(dim, keepdim=True)
return (masked_exps/masked_sums)
def convert_softmax_mask_to_digit(skim_mask):
# skim_mask [batch, from, to, seq_len]
return (skim_mask == 0).to(dtype=torch.int64).unsqueeze(1).unsqueeze(1)
def trunc_with_mask_batched(input, mask, dim):
"""
trunc a batched input at dim
e.g. hidden_states ([batch, seq_len, hidden_size])
attention_mask ([batch, layer, head, seq_len])
mask: [batch, seq_len]
"""
assert input.shape[dim]==mask.shape[1]
if dim != 1:
input = input.transpose(1, dim)
transpose_shape = list(input.shape)
transpose_shape[1] = -1
trunc_input = input[mask].view(transpose_shape)
if dim != 1:
trunc_input = trunc_input.transpose(1, dim)
return trunc_input |