File size: 865 Bytes
1143c26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from transformers import PretrainedConfig, AutoConfig
class BERTMultiAttentionConfig(PretrainedConfig):
model_type = "bert_multi_attention"
keys_to_ignore_at_inference = ["dropout"]
def __init__(
self,
transformer="bert-base-uncased",
hidden_size=768,
num_heads=8,
dropout=0.1,
rnn_hidden_size=128,
rnn_num_layers=2,
rnn_bidirectional=True,
**kwargs
):
super().__init__(**kwargs)
self.transformer = transformer
self.hidden_size = hidden_size
self.num_heads = num_heads
self.dropout = dropout
self.rnn_hidden_size = rnn_hidden_size
self.rnn_num_layers = rnn_num_layers
self.rnn_bidirectional = rnn_bidirectional
AutoConfig.register("bert_multi_attention", BERTMultiAttentionConfig)
|