from transformers import PretrainedConfig, AutoConfig class BERTMultiAttentionConfig(PretrainedConfig): model_type = "bert_multi_attention" keys_to_ignore_at_inference = ["dropout"] def __init__( self, transformer="bert-base-uncased", hidden_size=768, num_heads=8, dropout=0.1, rnn_hidden_size=128, rnn_num_layers=2, rnn_bidirectional=True, **kwargs ): super().__init__(**kwargs) self.transformer = transformer self.hidden_size = hidden_size self.num_heads = num_heads self.dropout = dropout self.rnn_hidden_size = rnn_hidden_size self.rnn_num_layers = rnn_num_layers self.rnn_bidirectional = rnn_bidirectional AutoConfig.register("bert_multi_attention", BERTMultiAttentionConfig)