Alexeym12 commited on
Commit
9fb80bf
·
1 Parent(s): f2c6c91

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +5 -4
README.md CHANGED
@@ -9,7 +9,7 @@ We introduce the model for multilabel ESG risks classification. There is 47 clas
9
 
10
  ## Usage
11
  ```python
12
-
13
  from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
14
  import torch
15
  #Mean Pooling - Take attention mask into account for correct averaging
@@ -45,10 +45,11 @@ class ESGify(MPNetPreTrainedModel):
45
  outputs = self.mpnet(input_ids=input_ids,
46
  attention_mask=attention_mask)
47
 
48
- # mean pooling dataset
49
  logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
50
- # Feed input to classifier to compute logits
51
-
 
52
  return logits
53
 
54
  model = ESGify.from_pretrained('ai-lab/ESGify')
 
9
 
10
  ## Usage
11
  ```python
12
+ from collections import OrderedDict
13
  from transformers import MPNetPreTrainedModel, MPNetModel, AutoTokenizer
14
  import torch
15
  #Mean Pooling - Take attention mask into account for correct averaging
 
45
  outputs = self.mpnet(input_ids=input_ids,
46
  attention_mask=attention_mask)
47
 
48
+ # mean pooling dataset and eed input to classifier to compute logits
49
  logits = self.classifier( mean_pooling(outputs['last_hidden_state'],attention_mask))
50
+
51
+ # apply sigmoid
52
+ logits = 1.0 / (1.0 + torch.exp(-logits))
53
  return logits
54
 
55
  model = ESGify.from_pretrained('ai-lab/ESGify')