consistency
Browse files
model.py
CHANGED
@@ -23,7 +23,7 @@ class BertClassifier(nn.Module):
|
|
23 |
output = self.bert(input_ids, attention_mask=attention_mask)
|
24 |
logits = self.classifier(output.pooler_output)
|
25 |
loss = None
|
26 |
-
if labels
|
27 |
loss_fct = nn.CrossEntropyLoss()
|
28 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
29 |
return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=output.hidden_states,attentions=output.attentions)
|
|
|
23 |
output = self.bert(input_ids, attention_mask=attention_mask)
|
24 |
logits = self.classifier(output.pooler_output)
|
25 |
loss = None
|
26 |
+
if labels:
|
27 |
loss_fct = nn.CrossEntropyLoss()
|
28 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
29 |
return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=output.hidden_states,attentions=output.attentions)
|