Text Classification
Safetensors
deberta-v2
celadon / modelling_deberta_multi.py
tcapelle's picture
missing files and rename
9a08e1a
raw
history blame
1.16 kB
import torch
from torch import nn, Tensor
from typing import Optional
from transformers import DebertaV2PreTrainedModel, DebertaV2Model
from .configuration_deberta_multi import MultiHeadDebertaV2Config
class MultiHeadDebertaForSequenceClassificationModel(DebertaV2PreTrainedModel):
config_class = MultiHeadDebertaV2Config
def __init__(self, config): # type: ignore
super().__init__(config)
self.deberta = DebertaV2Model(config)
self.heads = nn.ModuleList(
[nn.Linear(config.hidden_size, 4) for _ in range(config.num_heads)]
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.post_init()
def forward(
self,
input_ids: Optional["Tensor"] = None,
attention_mask: Optional["Tensor"] = None,
) -> "Tensor":
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
logits_list = [
head(self.dropout(sequence_output[:, 0, :])) for head in self.heads
]
logits = torch.stack(logits_list, dim=1)
outputs["logits"] = logits
return outputs