import torch.nn as nn import torch from transformers import AutoModel, AutoConfig class RefactorSpanModel(nn.Module): def __init__(self): super().__init__() base_model_path = 'microsoft/codebert-base' self.base_config = AutoConfig.from_pretrained(base_model_path) self.base_model = AutoModel.from_config(self.base_config) self.dropout = nn.Dropout(0.5) self.classifier = nn.Linear(768, 1) self.start_span = nn.Linear(768, 1) def forward(self, input_ids): outputs = self.base_model(input_ids) outputs_pool = self.dropout(outputs[1]) #use pooler output... outputs_hidden = self.dropout(outputs[0]) refactor = self.classifier(outputs_pool) span = self.start_span(outputs_hidden) return refactor, span class RefactorModel(nn.Module): def __init__(self): super().__init__() base_model_path = 'microsoft/codebert-base' self.base_config = AutoConfig.from_pretrained(base_model_path) self.base_model = AutoModel.from_config(self.base_config) self.dropout = nn.Dropout(0.5) self.classifier = nn.Linear(768, 1) def forward(self, input_ids): outputs = self.base_model(input_ids) outputs_pool = self.dropout(outputs[1]) #use pooler output... refactor = self.classifier(outputs_pool) return refactor if __name__ == "__main__": checkpoint = 'pytorch_model_RSP.bin' model = RefactorSpanModel() model.load_state_dict(torch.load(checkpoint), strict=True) #print(model.base_model.embeddings.word_embeddings.weight) checkpoint = 'pytorch_model_RP.bin' model = RefactorModel() model.load_state_dict(torch.load(checkpoint), strict=True) #print(model.base_model.embeddings.word_embeddings.weight)