|
import torch.nn as nn |
|
import torch |
|
from transformers import AutoModel, AutoConfig |
|
|
|
class RefactorSpanModel(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
base_model_path = 'microsoft/codebert-base' |
|
self.base_config = AutoConfig.from_pretrained(base_model_path) |
|
self.base_model = AutoModel.from_config(self.base_config) |
|
self.dropout = nn.Dropout(0.5) |
|
self.classifier = nn.Linear(768, 1) |
|
self.start_span = nn.Linear(768, 1) |
|
|
|
def forward(self, input_ids): |
|
outputs = self.base_model(input_ids) |
|
outputs_pool = self.dropout(outputs[1]) |
|
outputs_hidden = self.dropout(outputs[0]) |
|
refactor = self.classifier(outputs_pool) |
|
span = self.start_span(outputs_hidden) |
|
return refactor, span |
|
|
|
class RefactorModel(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
base_model_path = 'microsoft/codebert-base' |
|
self.base_config = AutoConfig.from_pretrained(base_model_path) |
|
self.base_model = AutoModel.from_config(self.base_config) |
|
self.dropout = nn.Dropout(0.5) |
|
self.classifier = nn.Linear(768, 1) |
|
|
|
def forward(self, input_ids): |
|
outputs = self.base_model(input_ids) |
|
outputs_pool = self.dropout(outputs[1]) |
|
refactor = self.classifier(outputs_pool) |
|
return refactor |
|
|
|
if __name__ == "__main__": |
|
checkpoint = 'pytorch_model_RSP.bin' |
|
model = RefactorSpanModel() |
|
model.load_state_dict(torch.load(checkpoint), strict=True) |
|
|
|
|
|
checkpoint = 'pytorch_model_RP.bin' |
|
model = RefactorModel() |
|
model.load_state_dict(torch.load(checkpoint), strict=True) |
|
|