How to load a pretrained custom model using `from_pretrained`

Hi there,

I wanted to create a custom model that includes a transformer and save it using the save_pretrained function after training for a few epochs. I would then want to load it in a different notebook using the from_pretrained function for inference.

Suppose I follow this guide and created a custom model named CustomModel with something like:

class CustomModel(PreTrainedModel):
     def __init__(self, config, transformer_model_name, n_dims=1000, n_factors=50, n_classes=10):
        super().__init__(config)
        self.embs = nn.Embedding(n_dims, n_factors)
        self.text_transformer = AutoModelForSequenceClassification.from_pretrained(
            transformer_model_name, num_labels=512)
        self.linear_layers = nn.Sequential(
            nn.Linear(n_factors+512, 256, bias=False),
            nn.LeakyReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(),
            nn.Linear(256, n_classes)
        )

Suppose I have already trained and saved the model, can I later use CustomModel.from_pretrained(model_dir) to load the trained model in a different notebook? I tried something along these lines, but got an AttributeError: 'NoneType' object has no attribute 'from_pretrained' error when I called from_pretrained.

I’m eager to figure out what I did wrong and what would be the best approach to go about this. Any suggestions would be much appreciated.

Thank you very much!

1 Like

Yes you can inherit from PreTrainedModel to inherit methods like from_pretrained, save_pretrained and push_to_hub.

Alternatively, you can leverage the PyTorchModelHubMixin class available in the huggingface_hub library. This allows you to get the same functionality:

from torch import nn
from huggingface_hub import PyTorchModelHubMixin

class CustomModel(nn.Module, PyTorchModelHubMixin):
          ...
2 Likes

Thank you, this is exactly what I was looking for!

nice,thank

The reason for this error is that the class definition is missing the config_class

class CustomModel(PreTrainedModel):
    config_class = CustomConfig

    def __init__(self, config, ...
    ...

There is an example in the guide in the original question.