kevinjesse commited on
Commit
646aebf
·
1 Parent(s): bb0def0

Upload 3 files

Browse files
RefactorModels.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from transformers import AutoModel, AutoConfig
4
+
5
+ class RefactorSpanModel(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ base_model_path = 'microsoft/codebert-base'
9
+ self.base_config = AutoConfig.from_pretrained(base_model_path)
10
+ self.base_model = AutoModel.from_config(self.base_config)
11
+ self.dropout = nn.Dropout(0.5)
12
+ self.classifier = nn.Linear(768, 1)
13
+ self.start_span = nn.Linear(768, 1)
14
+
15
+ def forward(self, input_ids):
16
+ outputs = self.base_model(input_ids)
17
+ outputs_pool = self.dropout(outputs[1]) #use pooler output...
18
+ outputs_hidden = self.dropout(outputs[0])
19
+ refactor = self.classifier(outputs_pool)
20
+ span = self.start_span(outputs_hidden)
21
+ return refactor, span
22
+
23
+ class RefactorModel(nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ base_model_path = 'microsoft/codebert-base'
27
+ self.base_config = AutoConfig.from_pretrained(base_model_path)
28
+ self.base_model = AutoModel.from_config(self.base_config)
29
+ self.dropout = nn.Dropout(0.5)
30
+ self.classifier = nn.Linear(768, 1)
31
+
32
+ def forward(self, input_ids):
33
+ outputs = self.base_model(input_ids)
34
+ outputs_pool = self.dropout(outputs[1]) #use pooler output...
35
+ refactor = self.classifier(outputs_pool)
36
+ return refactor
37
+
38
+ if __name__ == "__main__":
39
+ checkpoint = 'pytorch_model_RSP.bin'
40
+ model = RefactorSpanModel()
41
+ model.load_state_dict(torch.load(checkpoint), strict=True)
42
+ #print(model.base_model.embeddings.word_embeddings.weight)
43
+
44
+ checkpoint = 'pytorch_model_RP.bin'
45
+ model = RefactorModel()
46
+ model.load_state_dict(torch.load(checkpoint), strict=True)
47
+ #print(model.base_model.embeddings.word_embeddings.weight)
pytorch_model_RP.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2bf67df05f4bed987db3343a44d1c6a95eae260e557e8bfb3e05fffaf40ccc8
3
+ size 498678473
pytorch_model_RSP.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1b310c8c8915cf8e5d3c81dd47a1acccb93455202cc5de7a7141e62441cc649
3
+ size 498682203