MultiTalk-Code / base /baseTrainer.py
ameerazam08's picture
Upload folder using huggingface_hub
6931c7b verified
#!/usr/bin/env python
import torch
from os.path import join
import torch.distributed as dist
from .utilities import check_makedirs
from collections import OrderedDict
from torch.nn.parallel import DataParallel, DistributedDataParallel
def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1):
lr = base_lr * (multiplier ** (epoch // step_epoch))
return lr
def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9):
"""poly learning rate policy"""
lr = base_lr * (1 - float(curr_iter) / max_iter) ** power
return lr
def adjust_learning_rate(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def save_checkpoint(model, other_state={}, sav_path='', filename='model.pth.tar', stage=1):
if isinstance(model, (DistributedDataParallel, DataParallel)):
weight = model.module.state_dict()
elif isinstance(model, torch.nn.Module):
weight = model.state_dict()
else:
raise ValueError('model must be nn.Module or nn.DataParallel!')
check_makedirs(sav_path)
if stage == 2: # remove vqvae part
for k in list(weight.keys()):
if 'autoencoder' in k:
weight.pop(k)
other_state['state_dict'] = weight
filename = join(sav_path, filename)
torch.save(other_state, filename)
def load_state_dict(model, state_dict, strict=True):
if isinstance(model, (DistributedDataParallel, DataParallel)):
model.module.load_state_dict(state_dict, strict=strict)
else:
model.load_state_dict(state_dict, strict=strict)
def state_dict_remove_module(state_dict):
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# name = k[7:] # remove 'module.' of dataparallel
name = k.replace('module.', '')
new_state_dict[name] = v
return new_state_dict
def reduce_tensor(tensor, args):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= args.world_size
return rt