|
from transformers import PretrainedConfig, AutoConfig |
|
|
|
class BERTMultiAttentionConfig(PretrainedConfig): |
|
model_type = "bert_multi_attention" |
|
|
|
|
|
def __init__( |
|
self, |
|
transformer="bert-base-uncased", |
|
hidden_size=768, |
|
num_heads=8, |
|
dropout=0.1, |
|
rnn_hidden_size=128, |
|
rnn_num_layers=2, |
|
rnn_bidirectional=True, |
|
**kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.transformer = transformer |
|
self.hidden_size = hidden_size |
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.rnn_hidden_size = rnn_hidden_size |
|
self.rnn_num_layers = rnn_num_layers |
|
self.rnn_bidirectional = rnn_bidirectional |
|
|
|
|
|
AutoConfig.register("bert_multi_attention", BERTMultiAttentionConfig) |
|
|