Spaces:
Runtime error
Runtime error
File size: 1,536 Bytes
2bd24e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import os
def adjust_learning_rate(optimizer, epoch, initial_lr=0.001, decay_epoch=10):
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
lr = max(initial_lr * (0.1 ** (epoch // decay_epoch)), 1e-6)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
class Save_Handle(object):
"""handle the number of """
def __init__(self, max_num):
self.save_list = []
self.max_num = max_num
def append(self, save_path):
if len(self.save_list) < self.max_num:
self.save_list.append(save_path)
else:
remove_path = self.save_list[0]
del self.save_list[0]
self.save_list.append(save_path)
if os.path.exists(remove_path):
os.remove(remove_path)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = 1.0 * self.sum / self.count
def get_avg(self):
return self.avg
def get_count(self):
return self.count
def set_trainable(model, requires_grad):
for param in model.parameters():
param.requires_grad = requires_grad
def get_num_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad) |