File size: 1,827 Bytes
646aebf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import torch.nn as nn
import torch
from transformers import AutoModel, AutoConfig
class RefactorSpanModel(nn.Module):
def __init__(self):
super().__init__()
base_model_path = 'microsoft/codebert-base'
self.base_config = AutoConfig.from_pretrained(base_model_path)
self.base_model = AutoModel.from_config(self.base_config)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(768, 1)
self.start_span = nn.Linear(768, 1)
def forward(self, input_ids):
outputs = self.base_model(input_ids)
outputs_pool = self.dropout(outputs[1]) #use pooler output...
outputs_hidden = self.dropout(outputs[0])
refactor = self.classifier(outputs_pool)
span = self.start_span(outputs_hidden)
return refactor, span
class RefactorModel(nn.Module):
def __init__(self):
super().__init__()
base_model_path = 'microsoft/codebert-base'
self.base_config = AutoConfig.from_pretrained(base_model_path)
self.base_model = AutoModel.from_config(self.base_config)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(768, 1)
def forward(self, input_ids):
outputs = self.base_model(input_ids)
outputs_pool = self.dropout(outputs[1]) #use pooler output...
refactor = self.classifier(outputs_pool)
return refactor
if __name__ == "__main__":
checkpoint = 'pytorch_model_RSP.bin'
model = RefactorSpanModel()
model.load_state_dict(torch.load(checkpoint), strict=True)
#print(model.base_model.embeddings.word_embeddings.weight)
checkpoint = 'pytorch_model_RP.bin'
model = RefactorModel()
model.load_state_dict(torch.load(checkpoint), strict=True)
#print(model.base_model.embeddings.word_embeddings.weight) |