import torch.nn as nn import numpy as np class BaseModel(nn.Module): """ Base class for all models """ def __init__(self): super(BaseModel, self).__init__() # self.logger = logging.getLogger(self.__class__.__name__) def forward(self, *x): """ Forward pass logic :return: Model output """ raise NotImplementedError def summary(self, logger, writer): """ Model summary """ model_parameters = filter(lambda p: p.requires_grad, self.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) / 1e6 # Unit is Mega logger.info(self) logger.info('===>Trainable parameters: %.3f M' % params) if writer is not None: writer.add_text('Model Summary', 'Trainable parameters: %.3f M' % params)