|
|
|
import torch |
|
from os.path import join |
|
import torch.distributed as dist |
|
from .utilities import check_makedirs |
|
from collections import OrderedDict |
|
from torch.nn.parallel import DataParallel, DistributedDataParallel |
|
|
|
|
|
def step_learning_rate(base_lr, epoch, step_epoch, multiplier=0.1): |
|
lr = base_lr * (multiplier ** (epoch // step_epoch)) |
|
return lr |
|
|
|
|
|
def poly_learning_rate(base_lr, curr_iter, max_iter, power=0.9): |
|
"""poly learning rate policy""" |
|
lr = base_lr * (1 - float(curr_iter) / max_iter) ** power |
|
return lr |
|
|
|
|
|
def adjust_learning_rate(optimizer, lr): |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
|
|
def save_checkpoint(model, other_state={}, sav_path='', filename='model.pth.tar', stage=1): |
|
if isinstance(model, (DistributedDataParallel, DataParallel)): |
|
weight = model.module.state_dict() |
|
elif isinstance(model, torch.nn.Module): |
|
weight = model.state_dict() |
|
else: |
|
raise ValueError('model must be nn.Module or nn.DataParallel!') |
|
check_makedirs(sav_path) |
|
|
|
if stage == 2: |
|
for k in list(weight.keys()): |
|
if 'autoencoder' in k: |
|
weight.pop(k) |
|
|
|
other_state['state_dict'] = weight |
|
filename = join(sav_path, filename) |
|
torch.save(other_state, filename) |
|
|
|
|
|
|
|
def load_state_dict(model, state_dict, strict=True): |
|
if isinstance(model, (DistributedDataParallel, DataParallel)): |
|
model.module.load_state_dict(state_dict, strict=strict) |
|
else: |
|
model.load_state_dict(state_dict, strict=strict) |
|
|
|
|
|
def state_dict_remove_module(state_dict): |
|
new_state_dict = OrderedDict() |
|
for k, v in state_dict.items(): |
|
|
|
name = k.replace('module.', '') |
|
new_state_dict[name] = v |
|
return new_state_dict |
|
|
|
|
|
def reduce_tensor(tensor, args): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= args.world_size |
|
return rt |
|
|