from transformers import PreTrainedModel | |
import torch.nn as nn | |
from .configuration_simple_model import SimpleNNConfig | |
# Define the model class | |
class SimpleNN(PreTrainedModel): | |
config_class = SimpleNNConfig | |
def __init__(self, config): | |
super().__init__(config) | |
self.dense = nn.Linear(config.input_size, config.num_classes) | |
def forward(self, x): | |
return self.dense(x) |