MultiTalk-Code / base /base_model.py
ameerazam08's picture
Upload folder using huggingface_hub
6931c7b verified
raw
history blame contribute delete
858 Bytes
import torch.nn as nn
import numpy as np
class BaseModel(nn.Module):
"""
Base class for all models
"""
def __init__(self):
super(BaseModel, self).__init__()
# self.logger = logging.getLogger(self.__class__.__name__)
def forward(self, *x):
"""
Forward pass logic
:return: Model output
"""
raise NotImplementedError
def summary(self, logger, writer):
"""
Model summary
"""
model_parameters = filter(lambda p: p.requires_grad, self.parameters())
params = sum([np.prod(p.size()) for p in model_parameters]) / 1e6 # Unit is Mega
logger.info(self)
logger.info('===>Trainable parameters: %.3f M' % params)
if writer is not None:
writer.add_text('Model Summary', 'Trainable parameters: %.3f M' % params)