KhaldiAbderrhmane commited on
Commit
3eaaf2f
·
verified ·
1 Parent(s): b04be07

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +1 -2
model.py CHANGED
@@ -53,7 +53,6 @@ class BERTMultiAttentionModel(PreTrainedModel):
53
  super(BERTMultiAttentionModel, self).__init__(config)
54
  self.config = config
55
 
56
- # Initialize the transformer model
57
  self.transformer = AutoModel.from_pretrained(config.transformer)
58
  self.cross_attention = MultiHeadAttention(config)
59
  self.fc1 = nn.Linear(config.hidden_size * 2, 256)
@@ -77,7 +76,7 @@ class BERTMultiAttentionModel(PreTrainedModel):
77
  combined_output = self.dropout1(torch.relu(combined_output))
78
  combined_output = combined_output.unsqueeze(1)
79
  _, (hidden_state, _) = self.rnn(combined_output)
80
- hidden_state_concat = torch.cat([hidden_state[0], hidden_state[1]], dim=-1)
81
  hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
82
  hidden_state_proj = self.dropout2(hidden_state_proj)
83
  final = self.fc_final(hidden_state_proj)
 
53
  super(BERTMultiAttentionModel, self).__init__(config)
54
  self.config = config
55
 
 
56
  self.transformer = AutoModel.from_pretrained(config.transformer)
57
  self.cross_attention = MultiHeadAttention(config)
58
  self.fc1 = nn.Linear(config.hidden_size * 2, 256)
 
76
  combined_output = self.dropout1(torch.relu(combined_output))
77
  combined_output = combined_output.unsqueeze(1)
78
  _, (hidden_state, _) = self.rnn(combined_output)
79
+ hidden_state_concat = torch.cat([hidden_state[-2], hidden_state[-1]], dim=-1)
80
  hidden_state_proj = self.layer_norm_proj(self.fc_proj(hidden_state_concat))
81
  hidden_state_proj = self.dropout2(hidden_state_proj)
82
  final = self.fc_final(hidden_state_proj)