KhaldiAbderrhmane commited on
Commit
b04be07
·
verified ·
1 Parent(s): f043c7c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +88 -94
model.py CHANGED
@@ -1,94 +1,88 @@
1
- from transformers import PreTrainedModel, AutoModel
2
- import torch
3
- import torch.nn as nn
4
- import math
5
- from .config import BERTMultiAttentionConfig
6
-
7
- class MultiHeadAttention(nn.Module):
8
- def __init__(self, config):
9
- super(MultiHeadAttention, self).__init__()
10
- self.hidden_size = config.hidden_size
11
- self.num_heads = config.num_heads
12
- self.head_dim = config.hidden_size // config.num_heads
13
-
14
- self.query = nn.Linear(config.hidden_size, config.hidden_size)
15
- self.key = nn.Linear(config.hidden_size, config.hidden_size)
16
- self.value = nn.Linear(config.hidden_size, config.hidden_size)
17
- self.out = nn.Linear(config.hidden_size, config.hidden_size)
18
-
19
- self.layer_norm_q = nn.LayerNorm(config.hidden_size)
20
- self.layer_norm_k = nn.LayerNorm(config.hidden_size)
21
- self.layer_norm_v = nn.LayerNorm(config.hidden_size)
22
- self.layer_norm_out = nn.LayerNorm(config.hidden_size)
23
-
24
- self.dropout = nn.Dropout(config.dropout)
25
-
26
- def forward(self, query, key, value):
27
- batch_size = query.size(0)
28
-
29
- query = self.layer_norm_q(self.query(query))
30
- key = self.layer_norm_k(self.key(key))
31
- value = self.layer_norm_v(self.value(value))
32
-
33
- query = query.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
34
- key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
35
- value = value.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
36
-
37
- attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
38
- attention_weights = nn.Softmax(dim=-1)(attention_scores)
39
- attention_weights = self.dropout(attention_weights)
40
-
41
- attended_values = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
42
- attended_values = attended_values.view(batch_size, -1, self.hidden_size)
43
-
44
- out = self.layer_norm_out(self.out(attended_values))
45
- out = self.dropout(out)
46
-
47
- return out
48
-
49
- class BERTMultiAttentionModel(PreTrainedModel):
50
- config_class = BERTMultiAttentionConfig
51
-
52
- def __init__(self, config):
53
- super(BERTMultiAttentionModel, self).__init__(config)
54
- self.config = config
55
-
56
- # Initialize the transformer model
57
- self.transformer = AutoModel.from_pretrained(config.transformer)
58
- self.cross_attention = MultiHeadAttention(config)
59
- self.fc1 = nn.Linear(config.hidden_size * 2, 256)
60
- self.layer_norm_fc1 = nn.LayerNorm(256)
61
- self.dropout1 = nn.Dropout(config.dropout)
62
- self.rnn = nn.LSTM(input_size=256, hidden_size=config.rnn_hidden_size, num_layers=config.rnn_num_layers, batch_first=True, bidirectional=config.rnn_bidirectional, dropout=config.dropout)
63
- self.layer_norm_rnn = nn.LayerNorm(256)
64
- self.dropout2 = nn.Dropout(config.dropout)
65
- self.fc_proj = nn.Linear(256, 256)
66
- self.layer_norm_proj = nn.LayerNorm(256)
67
- self.dropout3 = nn.Dropout(config.dropout)
68
- self.fc_final = nn.Linear(256, 1)
69
-
70
- def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
71
- output1 = self.transformer(input_ids1, attention_mask=attention_mask1)[0]
72
- output2 = self.transformer(input_ids2, attention_mask=attention_mask2)[0]
73
-
74
- attended_output = self.cross_attention(output1, output2, output2)
75
- combined_output = torch.cat([output1, attended_output], dim=2)
76
- combined_output = torch.mean(combined_output, dim=1)
77
-
78
- combined_output = self.layer_norm_fc1(self.fc1(combined_output))
79
- combined_output = self.dropout1(torch.relu(combined_output))
80
- combined_output = combined_output.unsqueeze(1)
81
-
82
- _, (hidden_state, _) = self.rnn(combined_output)
83
- hidden_state_concat = torch.cat([hidden_state[-2], hidden_state[-1]], dim=-1)
84
-
85
- hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
86
- hidden_state_proj = self.dropout2(hidden_state_proj)
87
-
88
- final = self.fc_final(hidden_state_proj)
89
- final = self.dropout3(final)
90
-
91
- return torch.sigmoid(final)
92
-
93
-
94
- AutoModel.register(BERTMultiAttentionConfig, BERTMultiAttentionModel)
 
1
+ from transformers import PreTrainedModel, AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+ from .config import BERTMultiAttentionConfig
6
+
7
+ class MultiHeadAttention(nn.Module):
8
+ def __init__(self, config):
9
+ super(MultiHeadAttention, self).__init__()
10
+ self.hidden_size = config.hidden_size
11
+ self.num_heads = config.num_heads
12
+ self.head_dim = config.hidden_size // config.num_heads
13
+
14
+ self.query = nn.Linear(config.hidden_size, config.hidden_size)
15
+ self.key = nn.Linear(config.hidden_size, config.hidden_size)
16
+ self.value = nn.Linear(config.hidden_size, config.hidden_size)
17
+ self.out = nn.Linear(config.hidden_size, config.hidden_size)
18
+
19
+ self.layer_norm_q = nn.LayerNorm(config.hidden_size)
20
+ self.layer_norm_k = nn.LayerNorm(config.hidden_size)
21
+ self.layer_norm_v = nn.LayerNorm(config.hidden_size)
22
+ self.layer_norm_out = nn.LayerNorm(config.hidden_size)
23
+
24
+ self.dropout = nn.Dropout(config.dropout)
25
+
26
+ def forward(self, query, key, value):
27
+ batch_size = query.size(0)
28
+
29
+ query = self.layer_norm_q(self.query(query))
30
+ key = self.layer_norm_k(self.key(key))
31
+ value = self.layer_norm_v(self.value(value))
32
+
33
+ query = query.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
34
+ key = key.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
35
+ value = value.view(batch_size, -1, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
36
+
37
+ attention_scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
38
+ attention_weights = nn.Softmax(dim=-1)(attention_scores)
39
+ attention_weights = self.dropout(attention_weights)
40
+
41
+ attended_values = torch.matmul(attention_weights, value).permute(0, 2, 1, 3).contiguous()
42
+ attended_values = attended_values.view(batch_size, -1, self.hidden_size)
43
+
44
+ out = self.layer_norm_out(self.out(attended_values))
45
+ out = self.dropout(out)
46
+
47
+ return out
48
+
49
+ class BERTMultiAttentionModel(PreTrainedModel):
50
+ config_class = BERTMultiAttentionConfig
51
+
52
+ def __init__(self, config):
53
+ super(BERTMultiAttentionModel, self).__init__(config)
54
+ self.config = config
55
+
56
+ # Initialize the transformer model
57
+ self.transformer = AutoModel.from_pretrained(config.transformer)
58
+ self.cross_attention = MultiHeadAttention(config)
59
+ self.fc1 = nn.Linear(config.hidden_size * 2, 256)
60
+ self.layer_norm_fc1 = nn.LayerNorm(256)
61
+ self.dropout1 = nn.Dropout(config.dropout)
62
+ self.rnn = nn.LSTM(input_size=256, hidden_size=config.rnn_hidden_size, num_layers=config.rnn_num_layers, batch_first=True, bidirectional=config.rnn_bidirectional, dropout=config.dropout)
63
+ self.layer_norm_rnn = nn.LayerNorm(256)
64
+ self.dropout2 = nn.Dropout(config.dropout)
65
+ self.fc_proj = nn.Linear(256, 256)
66
+ self.layer_norm_proj = nn.LayerNorm(256)
67
+ self.dropout3 = nn.Dropout(config.dropout)
68
+ self.fc_final = nn.Linear(256, 1)
69
+
70
+ def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
71
+ output1 = self.transformer(input_ids1, attention_mask1)[0]
72
+ output2 = self.transformer(input_ids2, attention_mask2)[0]
73
+ attended_output = self.cross_attention(output1, output2, output2)
74
+ combined_output = torch.cat([output1, attended_output], dim=2)
75
+ combined_output = torch.mean(combined_output, dim=1)
76
+ combined_output = self.layer_norm_fc1(self.fc1(combined_output))
77
+ combined_output = self.dropout1(torch.relu(combined_output))
78
+ combined_output = combined_output.unsqueeze(1)
79
+ _, (hidden_state, _) = self.rnn(combined_output)
80
+ hidden_state_concat = torch.cat([hidden_state[0], hidden_state[1]], dim=-1)
81
+ hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
82
+ hidden_state_proj = self.dropout2(hidden_state_proj)
83
+ final = self.fc_final(hidden_state_proj)
84
+ final = self.dropout3(final)
85
+ return torch.sigmoid(final)
86
+
87
+
88
+ AutoModel.register(BERTMultiAttentionConfig, BERTMultiAttentionModel)