KhaldiAbderrhmane
commited on
Update model.py
Browse files
model.py
CHANGED
@@ -76,7 +76,7 @@ class BERTMultiAttentionModel(PreTrainedModel):
|
|
76 |
combined_output = self.dropout1(torch.relu(combined_output))
|
77 |
combined_output = combined_output.unsqueeze(1)
|
78 |
_, (hidden_state, _) = self.rnn(combined_output)
|
79 |
-
hidden_state_concat = torch.cat([hidden_state[
|
80 |
hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
|
81 |
hidden_state_proj = self.dropout2(hidden_state_proj)
|
82 |
final = self.fc_final(hidden_state_proj)
|
|
|
76 |
combined_output = self.dropout1(torch.relu(combined_output))
|
77 |
combined_output = combined_output.unsqueeze(1)
|
78 |
_, (hidden_state, _) = self.rnn(combined_output)
|
79 |
+
hidden_state_concat = torch.cat([hidden_state[0], hidden_state[1]], dim=-1) #-2 -1
|
80 |
hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
|
81 |
hidden_state_proj = self.dropout2(hidden_state_proj)
|
82 |
final = self.fc_final(hidden_state_proj)
|