KhaldiAbderrhmane commited on
Commit
1143c26
·
verified ·
1 Parent(s): 56bfb42

Upload 3 files

Browse files
Files changed (3) hide show
  1. __init__.py +2 -0
  2. config.py +28 -0
  3. model.py +94 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from similarity.config import BERTMultiAttentionConfig
2
+ from similarity.model import BERTMultiAttentionModel, MultiHeadAttention
config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig, AutoConfig
2
+
3
+ class BERTMultiAttentionConfig(PretrainedConfig):
4
+ model_type = "bert_multi_attention"
5
+ keys_to_ignore_at_inference = ["dropout"]
6
+
7
+ def __init__(
8
+ self,
9
+ transformer="bert-base-uncased",
10
+ hidden_size=768,
11
+ num_heads=8,
12
+ dropout=0.1,
13
+ rnn_hidden_size=128,
14
+ rnn_num_layers=2,
15
+ rnn_bidirectional=True,
16
+ **kwargs
17
+ ):
18
+ super().__init__(**kwargs)
19
+ self.transformer = transformer
20
+ self.hidden_size = hidden_size
21
+ self.num_heads = num_heads
22
+ self.dropout = dropout
23
+ self.rnn_hidden_size = rnn_hidden_size
24
+ self.rnn_num_layers = rnn_num_layers
25
+ self.rnn_bidirectional = rnn_bidirectional
26
+
27
+
28
+ AutoConfig.register("bert_multi_attention", BERTMultiAttentionConfig)
model.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from .config import BERTMultiAttentionConfig
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(self, config):
9
+ super(MultiHeadAttention, self).__init__()
10
+ self.hidden_size = config.hidden_size
11
+ self.num_heads = config.num_heads
12
+ self.head_dim = config.hidden_size // config.num_heads
13
+
14
+ self.query = nn.Linear(config.hidden_size, config.hidden_size)
15
+ self.key = nn.Linear(config.hidden_size, config.hidden_size)
16
+ self.value = nn.Linear(config.hidden_size, config.hidden_size)
17
+ self.out = nn.Linear(config.hidden_size, config.hidden_size)
18
+
19
+ self.layer_norm_q = nn.LayerNorm(config.hidden_size)
20
+ self.layer_norm_k = nn.LayerNorm(config.hidden_size)
21
+ self.layer_norm_v = nn.LayerNorm(config.hidden_size)
22
+ self.layer_norm_out = nn.LayerNorm(config.hidden_size)
23
+
24
+ self.dropout = nn.Dropout(config.dropout)
25
+
26
+ def forward(self, query, key, value):
27
+ batch_size = query.size(0)
28
+
29
+ query = self.layer_norm_q(self.query(query))
30
+ key = self.layer_norm_k(self.key(key))
31
+ value = self.layer_norm_v(self.value(value))
32
+
33
+ query = query.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
34
+ key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
35
+ value = value.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
36
+
37
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
38
+ attention_weights = nn.Softmax(dim=-1)(attention_scores)
39
+ attention_weights = self.dropout(attention_weights)
40
+
41
+ attended_values = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
42
+ attended_values = attended_values.view(batch_size, -1, self.hidden_size)
43
+
44
+ out = self.layer_norm_out(self.out(attended_values))
45
+ out = self.dropout(out)
46
+
47
+ return out
48
+
49
+ class BERTMultiAttentionModel(PreTrainedModel):
50
+ config_class = BERTMultiAttentionConfig
51
+
52
+ def __init__(self, config):
53
+ super(BERTMultiAttentionModel, self).__init__(config)
54
+ self.config = config
55
+
56
+ # Initialize the transformer model
57
+ self.transformer = AutoModel.from_pretrained(config.transformer)
58
+ self.cross_attention = MultiHeadAttention(config)
59
+ self.fc1 = nn.Linear(config.hidden_size * 2, 256)
60
+ self.layer_norm_fc1 = nn.LayerNorm(256)
61
+ self.dropout1 = nn.Dropout(config.dropout)
62
+ self.rnn = nn.LSTM(input_size=256, hidden_size=config.rnn_hidden_size, num_layers=config.rnn_num_layers, batch_first=True, bidirectional=config.rnn_bidirectional, dropout=config.dropout)
63
+ self.layer_norm_rnn = nn.LayerNorm(256)
64
+ self.dropout2 = nn.Dropout(config.dropout)
65
+ self.fc_proj = nn.Linear(256, 256)
66
+ self.layer_norm_proj = nn.LayerNorm(256)
67
+ self.dropout3 = nn.Dropout(config.dropout)
68
+ self.fc_final = nn.Linear(256, 1)
69
+
70
+ def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
71
+ output1 = self.transformer(input_ids1, attention_mask=attention_mask1)[0]
72
+ output2 = self.transformer(input_ids2, attention_mask=attention_mask2)[0]
73
+
74
+ attended_output = self.cross_attention(output1, output2, output2)
75
+ combined_output = torch.cat([output1, attended_output], dim=2)
76
+ combined_output = torch.mean(combined_output, dim=1)
77
+
78
+ combined_output = self.layer_norm_fc1(self.fc1(combined_output))
79
+ combined_output = self.dropout1(torch.relu(combined_output))
80
+ combined_output = combined_output.unsqueeze(1)
81
+
82
+ _, (hidden_state, _) = self.rnn(combined_output)
83
+ hidden_state_concat = torch.cat([hidden_state[-2], hidden_state[-1]], dim=-1)
84
+
85
+ hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
86
+ hidden_state_proj = self.dropout2(hidden_state_proj)
87
+
88
+ final = self.fc_final(hidden_state_proj)
89
+ final = self.dropout3(final)
90
+
91
+ return torch.sigmoid(final)
92
+
93
+
94
+ AutoModel.register(BERTMultiAttentionConfig, BERTMultiAttentionModel)