File size: 865 Bytes
1143c26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import PretrainedConfig, AutoConfig

class BERTMultiAttentionConfig(PretrainedConfig):
    model_type = "bert_multi_attention"
    keys_to_ignore_at_inference = ["dropout"]

    def __init__(

        self,

        transformer="bert-base-uncased",

        hidden_size=768,

        num_heads=8,

        dropout=0.1,

        rnn_hidden_size=128,

        rnn_num_layers=2,

        rnn_bidirectional=True,

        **kwargs

    ):
        super().__init__(**kwargs)
        self.transformer = transformer
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.dropout = dropout
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn_num_layers = rnn_num_layers
        self.rnn_bidirectional = rnn_bidirectional


AutoConfig.register("bert_multi_attention", BERTMultiAttentionConfig)