KhaldiAbderrhmane
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -53,7 +53,6 @@ class BERTMultiAttentionModel(PreTrainedModel):
|
|
53 |
super(BERTMultiAttentionModel, self).__init__(config)
|
54 |
self.config = config
|
55 |
|
56 |
-
# Initialize the transformer model
|
57 |
self.transformer = AutoModel.from_pretrained(config.transformer)
|
58 |
self.cross_attention = MultiHeadAttention(config)
|
59 |
self.fc1 = nn.Linear(config.hidden_size * 2, 256)
|
@@ -77,7 +76,7 @@ class BERTMultiAttentionModel(PreTrainedModel):
|
|
77 |
combined_output = self.dropout1(torch.relu(combined_output))
|
78 |
combined_output = combined_output.unsqueeze(1)
|
79 |
_, (hidden_state, _) = self.rnn(combined_output)
|
80 |
-
hidden_state_concat = torch.cat([hidden_state[
|
81 |
hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
|
82 |
hidden_state_proj = self.dropout2(hidden_state_proj)
|
83 |
final = self.fc_final(hidden_state_proj)
|
|
|
53 |
super(BERTMultiAttentionModel, self).__init__(config)
|
54 |
self.config = config
|
55 |
|
|
|
56 |
self.transformer = AutoModel.from_pretrained(config.transformer)
|
57 |
self.cross_attention = MultiHeadAttention(config)
|
58 |
self.fc1 = nn.Linear(config.hidden_size * 2, 256)
|
|
|
76 |
combined_output = self.dropout1(torch.relu(combined_output))
|
77 |
combined_output = combined_output.unsqueeze(1)
|
78 |
_, (hidden_state, _) = self.rnn(combined_output)
|
79 |
+
hidden_state_concat = torch.cat([hidden_state[-2], hidden_state[-1]], dim=-1)
|
80 |
hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
|
81 |
hidden_state_proj = self.dropout2(hidden_state_proj)
|
82 |
final = self.fc_final(hidden_state_proj)
|