RefactorBERT / RefactorModels.py
kevinjesse's picture
Upload 3 files
646aebf
raw
history blame
1.83 kB
import torch.nn as nn
import torch
from transformers import AutoModel, AutoConfig
class RefactorSpanModel(nn.Module):
def __init__(self):
super().__init__()
base_model_path = 'microsoft/codebert-base'
self.base_config = AutoConfig.from_pretrained(base_model_path)
self.base_model = AutoModel.from_config(self.base_config)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(768, 1)
self.start_span = nn.Linear(768, 1)
def forward(self, input_ids):
outputs = self.base_model(input_ids)
outputs_pool = self.dropout(outputs[1]) #use pooler output...
outputs_hidden = self.dropout(outputs[0])
refactor = self.classifier(outputs_pool)
span = self.start_span(outputs_hidden)
return refactor, span
class RefactorModel(nn.Module):
def __init__(self):
super().__init__()
base_model_path = 'microsoft/codebert-base'
self.base_config = AutoConfig.from_pretrained(base_model_path)
self.base_model = AutoModel.from_config(self.base_config)
self.dropout = nn.Dropout(0.5)
self.classifier = nn.Linear(768, 1)
def forward(self, input_ids):
outputs = self.base_model(input_ids)
outputs_pool = self.dropout(outputs[1]) #use pooler output...
refactor = self.classifier(outputs_pool)
return refactor
if __name__ == "__main__":
checkpoint = 'pytorch_model_RSP.bin'
model = RefactorSpanModel()
model.load_state_dict(torch.load(checkpoint), strict=True)
#print(model.base_model.embeddings.word_embeddings.weight)
checkpoint = 'pytorch_model_RP.bin'
model = RefactorModel()
model.load_state_dict(torch.load(checkpoint), strict=True)
#print(model.base_model.embeddings.word_embeddings.weight)