KhaldiAbderrhmane commited on
Commit
df13d8f
·
verified ·
1 Parent(s): 71bd99f

Upload model

Browse files
Files changed (3) hide show
  1. config.json +2 -1
  2. model.py +93 -87
  3. model.safetensors +1 -1
config.json CHANGED
@@ -3,7 +3,8 @@
3
  "BERTMultiAttentionModel"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "config.BERTMultiAttentionConfig"
 
7
  },
8
  "dropout": 0.1,
9
  "hidden_size": 768,
 
3
  "BERTMultiAttentionModel"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "config.BERTMultiAttentionConfig",
7
+ "AutoModel": "model.BERTMultiAttentionModel"
8
  },
9
  "dropout": 0.1,
10
  "hidden_size": 768,
model.py CHANGED
@@ -1,87 +1,93 @@
1
- from transformers import PreTrainedModel, AutoModel
2
- import torch
3
- import torch.nn as nn
4
- import math
5
- from .config import BERTMultiAttentionConfig
6
-
7
- class MultiHeadAttention(nn.Module):
8
- def __init__(self, config):
9
- super(MultiHeadAttention, self).__init__()
10
- self.hidden_size = config.hidden_size
11
- self.num_heads = config.num_heads
12
- self.head_dim = config.hidden_size // config.num_heads
13
-
14
- self.query = nn.Linear(config.hidden_size, config.hidden_size)
15
- self.key = nn.Linear(config.hidden_size, config.hidden_size)
16
- self.value = nn.Linear(config.hidden_size, config.hidden_size)
17
- self.out = nn.Linear(config.hidden_size, config.hidden_size)
18
-
19
- self.layer_norm_q = nn.LayerNorm(config.hidden_size)
20
- self.layer_norm_k = nn.LayerNorm(config.hidden_size)
21
- self.layer_norm_v = nn.LayerNorm(config.hidden_size)
22
- self.layer_norm_out = nn.LayerNorm(config.hidden_size)
23
-
24
- self.dropout = nn.Dropout(config.dropout)
25
-
26
- def forward(self, query, key, value):
27
- batch_size = query.size(0)
28
-
29
- query = self.layer_norm_q(self.query(query))
30
- key = self.layer_norm_k(self.key(key))
31
- value = self.layer_norm_v(self.value(value))
32
-
33
- query = query.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
34
- key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
35
- value = value.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
36
-
37
- attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
38
- attention_weights = nn.Softmax(dim=-1)(attention_scores)
39
- attention_weights = self.dropout(attention_weights)
40
-
41
- attended_values = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
42
- attended_values = attended_values.view(batch_size, -1, self.hidden_size)
43
-
44
- out = self.layer_norm_out(self.out(attended_values))
45
- out = self.dropout(out)
46
-
47
- return out
48
-
49
- class BERTMultiAttentionModel(PreTrainedModel):
50
- config_class = BERTMultiAttentionConfig
51
-
52
- def __init__(self, config):
53
- super(BERTMultiAttentionModel, self).__init__(config)
54
- self.config = config
55
-
56
- self.transformer = AutoModel.from_pretrained(config.transformer)
57
- self.cross_attention = MultiHeadAttention(config)
58
- self.fc1 = nn.Linear(config.hidden_size * 2, 256)
59
- self.layer_norm_fc1 = nn.LayerNorm(256)
60
- self.dropout1 = nn.Dropout(config.dropout)
61
- self.rnn = nn.LSTM(input_size=256, hidden_size=config.rnn_hidden_size, num_layers=config.rnn_num_layers, batch_first=True, bidirectional=config.rnn_bidirectional, dropout=config.dropout)
62
- self.layer_norm_rnn = nn.LayerNorm(256)
63
- self.dropout2 = nn.Dropout(config.dropout)
64
- self.fc_proj = nn.Linear(256, 256)
65
- self.layer_norm_proj = nn.LayerNorm(256)
66
- self.dropout3 = nn.Dropout(config.dropout)
67
- self.fc_final = nn.Linear(256, 1)
68
-
69
- def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
70
- output1 = self.transformer(input_ids1, attention_mask1)[0]
71
- output2 = self.transformer(input_ids2, attention_mask2)[0]
72
- attended_output = self.cross_attention(output1, output2, output2)
73
- combined_output = torch.cat([output1, attended_output], dim=2)
74
- combined_output = torch.mean(combined_output, dim=1)
75
- combined_output = self.layer_norm_fc1(self.fc1(combined_output))
76
- combined_output = self.dropout1(torch.relu(combined_output))
77
- combined_output = combined_output.unsqueeze(1)
78
- _, (hidden_state, _) = self.rnn(combined_output)
79
- hidden_state_concat = torch.cat([hidden_state[0], hidden_state[1]], dim=-1) #-2 -1
80
- hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
81
- hidden_state_proj = self.dropout2(hidden_state_proj)
82
- final = self.fc_final(hidden_state_proj)
83
- final = self.dropout3(final)
84
- return torch.sigmoid(final)
85
-
86
-
87
- AutoModel.register(BERTMultiAttentionConfig, BERTMultiAttentionModel)
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from .config import BERTMultiAttentionConfig
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(self, config):
9
+ super(MultiHeadAttention, self).__init__()
10
+ self.hidden_size = config.hidden_size
11
+ self.num_heads = config.num_heads
12
+ self.head_dim = config.hidden_size // config.num_heads
13
+
14
+ self.query = nn.Linear(config.hidden_size, config.hidden_size)
15
+ self.key = nn.Linear(config.hidden_size, config.hidden_size)
16
+ self.value = nn.Linear(config.hidden_size, config.hidden_size)
17
+ self.out = nn.Linear(config.hidden_size, config.hidden_size)
18
+
19
+ self.layer_norm_q = nn.LayerNorm(config.hidden_size)
20
+ self.layer_norm_k = nn.LayerNorm(config.hidden_size)
21
+ self.layer_norm_v = nn.LayerNorm(config.hidden_size)
22
+ self.layer_norm_out = nn.LayerNorm(config.hidden_size)
23
+
24
+ self.dropout = nn.Dropout(config.dropout)
25
+
26
+ def forward(self, query, key, value):
27
+ batch_size = query.size(0)
28
+
29
+ query = self.layer_norm_q(self.query(query))
30
+ key = self.layer_norm_k(self.key(key))
31
+ value = self.layer_norm_v(self.value(value))
32
+
33
+ query = query.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
34
+ key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
35
+ value = value.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
36
+
37
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
38
+ attention_weights = nn.Softmax(dim=-1)(attention_scores)
39
+ attention_weights = self.dropout(attention_weights)
40
+
41
+ attended_values = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
42
+ attended_values = attended_values.view(batch_size, -1, self.hidden_size)
43
+
44
+ out = self.layer_norm_out(self.out(attended_values))
45
+ out = self.dropout(out)
46
+
47
+ return out
48
+
49
+ class BERTMultiAttentionModel(PreTrainedModel):
50
+ config_class = BERTMultiAttentionConfig
51
+
52
+ def __init__(self, config):
53
+ super(BERTMultiAttentionModel, self).__init__(config)
54
+ self.config = config
55
+
56
+ self.transformer = AutoModel.from_pretrained(config.transformer)
57
+ self.cross_attention = MultiHeadAttention(config)
58
+ self.fc1 = nn.Linear(config.hidden_size * 2, 256)
59
+ self.layer_norm_fc1 = nn.LayerNorm(256)
60
+ self.dropout1 = nn.Dropout(config.dropout)
61
+ self.rnn = nn.LSTM(input_size=256, hidden_size=config.rnn_hidden_size, num_layers=config.rnn_num_layers, batch_first=True, bidirectional=config.rnn_bidirectional, dropout=config.dropout)
62
+ self.layer_norm_rnn = nn.LayerNorm(256)
63
+ self.dropout2 = nn.Dropout(config.dropout)
64
+ self.fc_proj = nn.Linear(256, 256)
65
+ self.layer_norm_proj = nn.LayerNorm(256)
66
+ self.dropout3 = nn.Dropout(config.dropout)
67
+ self.fc_final = nn.Linear(256, 1)
68
+
69
+ def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
70
+ output1 = self.transformer(input_ids1, attention_mask=attention_mask1)[0]
71
+ output2 = self.transformer(input_ids2, attention_mask=attention_mask2)[0]
72
+
73
+ attended_output = self.cross_attention(output1, output2, output2)
74
+ combined_output = torch.cat([output1, attended_output], dim=2)
75
+ combined_output = torch.mean(combined_output, dim=1)
76
+
77
+ combined_output = self.layer_norm_fc1(self.fc1(combined_output))
78
+ combined_output = self.dropout1(torch.relu(combined_output))
79
+ combined_output = combined_output.unsqueeze(1)
80
+
81
+ _, (hidden_state, _) = self.rnn(combined_output)
82
+ hidden_state_concat = torch.cat([hidden_state[0], hidden_state[1]], dim=-1)
83
+
84
+ hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
85
+ hidden_state_proj = self.dropout2(hidden_state_proj)
86
+
87
+ final = self.fc_final(hidden_state_proj)
88
+ final = self.dropout3(final)
89
+
90
+ return torch.sigmoid(final)
91
+
92
+
93
+ AutoModel.register(BERTMultiAttentionConfig, BERTMultiAttentionModel)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ab15f6096f5cc6774861f5ead97f64985cb51f4829ddbfcfb09b3bc7db18fd19
3
  size 452438124
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:698a78c534e0418159d4ebc2779fb7bc72726b7e924c9c3447e3d7f2fc09e8bd
3
  size 452438124