simple_nn_model / simpleModel.py
yasinamp's picture
Upload SimpleNN
8b17fbc verified
raw
history blame contribute delete
403 Bytes
from transformers import PreTrainedModel
import torch.nn as nn
from .configuration_simple_model import SimpleNNConfig
# Define the model class
class SimpleNN(PreTrainedModel):
config_class = SimpleNNConfig
def __init__(self, config):
super().__init__(config)
self.dense = nn.Linear(config.input_size, config.num_classes)
def forward(self, x):
return self.dense(x)