Spaces:
Runtime error
Runtime error
import os | |
import time | |
import torch | |
import torch.nn as nn | |
from torch import optim | |
from torch.utils.data import DataLoader | |
from torch.utils.data.dataloader import default_collate | |
import numpy as np | |
from datetime import datetime | |
from datasets.crowd import Crowd_qnrf, Crowd_nwpu, Crowd_sh | |
from models import vgg19 | |
from losses.ot_loss import OT_Loss | |
from utils.pytorch_utils import Save_Handle, AverageMeter | |
import utils.log_utils as log_utils | |
def train_collate(batch): | |
transposed_batch = list(zip(*batch)) | |
images = torch.stack(transposed_batch[0], 0) | |
points = transposed_batch[1] # the number of points is not fixed, keep it as a list of tensor | |
gt_discretes = torch.stack(transposed_batch[2], 0) | |
return images, points, gt_discretes | |
class Trainer(object): | |
def __init__(self, args): | |
self.args = args | |
def setup(self): | |
args = self.args | |
sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format( | |
args.crop_size, args.wot, args.wtv, args.reg, args.num_of_iter_in_ot, args.norm_cood) | |
self.save_dir = os.path.join('ckpts', sub_dir) | |
if not os.path.exists(self.save_dir): | |
os.makedirs(self.save_dir) | |
time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S') | |
self.logger = log_utils.get_logger(os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str))) | |
log_utils.print_config(vars(args), self.logger) | |
if torch.cuda.is_available(): | |
self.device = torch.device("cuda") | |
self.device_count = torch.cuda.device_count() | |
assert self.device_count == 1 | |
self.logger.info('using {} gpus'.format(self.device_count)) | |
else: | |
raise Exception("gpu is not available") | |
downsample_ratio = 8 | |
if args.dataset.lower() == 'qnrf': | |
self.datasets = {x: Crowd_qnrf(os.path.join(args.data_dir, x), | |
args.crop_size, downsample_ratio, x) for x in ['train', 'val']} | |
elif args.dataset.lower() == 'nwpu': | |
self.datasets = {x: Crowd_nwpu(os.path.join(args.data_dir, x), | |
args.crop_size, downsample_ratio, x) for x in ['train', 'val']} | |
elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb': | |
self.datasets = {'train': Crowd_sh(os.path.join(args.data_dir, 'train_data'), | |
args.crop_size, downsample_ratio, 'train'), | |
'val': Crowd_sh(os.path.join(args.data_dir, 'test_data'), | |
args.crop_size, downsample_ratio, 'val'), | |
} | |
else: | |
raise NotImplementedError | |
self.dataloaders = {x: DataLoader(self.datasets[x], | |
collate_fn=(train_collate | |
if x == 'train' else default_collate), | |
batch_size=(args.batch_size | |
if x == 'train' else 1), | |
shuffle=(True if x == 'train' else False), | |
num_workers=args.num_workers * self.device_count, | |
pin_memory=(True if x == 'train' else False)) | |
for x in ['train', 'val']} | |
self.model = vgg19() | |
self.model.to(self.device) | |
self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
self.start_epoch = 0 | |
if args.resume: | |
self.logger.info('loading pretrained model from ' + args.resume) | |
suf = args.resume.rsplit('.', 1)[-1] | |
if suf == 'tar': | |
checkpoint = torch.load(args.resume, self.device) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
self.start_epoch = checkpoint['epoch'] + 1 | |
elif suf == 'pth': | |
self.model.load_state_dict(torch.load(args.resume, self.device)) | |
else: | |
self.logger.info('random initialization') | |
self.ot_loss = OT_Loss(args.crop_size, downsample_ratio, args.norm_cood, self.device, args.num_of_iter_in_ot, | |
args.reg) | |
self.tv_loss = nn.L1Loss(reduction='none').to(self.device) | |
self.mse = nn.MSELoss().to(self.device) | |
self.mae = nn.L1Loss().to(self.device) | |
self.save_list = Save_Handle(max_num=1) | |
self.best_mae = np.inf | |
self.best_mse = np.inf | |
self.best_count = 0 | |
def train(self): | |
"""training process""" | |
args = self.args | |
for epoch in range(self.start_epoch, args.max_epoch + 1): | |
self.logger.info('-' * 5 + 'Epoch {}/{}'.format(epoch, args.max_epoch) + '-' * 5) | |
self.epoch = epoch | |
self.train_eopch() | |
if epoch % args.val_epoch == 0 and epoch >= args.val_start: | |
self.val_epoch() | |
def train_eopch(self): | |
epoch_ot_loss = AverageMeter() | |
epoch_ot_obj_value = AverageMeter() | |
epoch_wd = AverageMeter() | |
epoch_count_loss = AverageMeter() | |
epoch_tv_loss = AverageMeter() | |
epoch_loss = AverageMeter() | |
epoch_mae = AverageMeter() | |
epoch_mse = AverageMeter() | |
epoch_start = time.time() | |
self.model.train() # Set model to training mode | |
for step, (inputs, points, gt_discrete) in enumerate(self.dataloaders['train']): | |
inputs = inputs.to(self.device) | |
gd_count = np.array([len(p) for p in points], dtype=np.float32) | |
points = [p.to(self.device) for p in points] | |
gt_discrete = gt_discrete.to(self.device) | |
N = inputs.size(0) | |
with torch.set_grad_enabled(True): | |
outputs, outputs_normed = self.model(inputs) | |
# Compute OT loss. | |
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs, points) | |
ot_loss = ot_loss * self.args.wot | |
ot_obj_value = ot_obj_value * self.args.wot | |
epoch_ot_loss.update(ot_loss.item(), N) | |
epoch_ot_obj_value.update(ot_obj_value.item(), N) | |
epoch_wd.update(wd, N) | |
# Compute counting loss. | |
count_loss = self.mae(outputs.sum(1).sum(1).sum(1), | |
torch.from_numpy(gd_count).float().to(self.device)) | |
epoch_count_loss.update(count_loss.item(), N) | |
# Compute TV loss. | |
gd_count_tensor = torch.from_numpy(gd_count).float().to(self.device).unsqueeze(1).unsqueeze( | |
2).unsqueeze(3) | |
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6) | |
tv_loss = (self.tv_loss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum( | |
1) * torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.wtv | |
epoch_tv_loss.update(tv_loss.item(), N) | |
loss = ot_loss + count_loss + tv_loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
pred_count = torch.sum(outputs.view(N, -1), dim=1).detach().cpu().numpy() | |
pred_err = pred_count - gd_count | |
epoch_loss.update(loss.item(), N) | |
epoch_mse.update(np.mean(pred_err * pred_err), N) | |
epoch_mae.update(np.mean(abs(pred_err)), N) | |
self.logger.info( | |
'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, ' | |
'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec' | |
.format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(), epoch_wd.get_avg(), | |
epoch_ot_obj_value.get_avg(), epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(), | |
np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(), | |
time.time() - epoch_start)) | |
model_state_dic = self.model.state_dict() | |
save_path = os.path.join(self.save_dir, '{}_ckpt.tar'.format(self.epoch)) | |
torch.save({ | |
'epoch': self.epoch, | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'model_state_dict': model_state_dic | |
}, save_path) | |
self.save_list.append(save_path) | |
def val_epoch(self): | |
args = self.args | |
epoch_start = time.time() | |
self.model.eval() # Set model to evaluate mode | |
epoch_res = [] | |
for inputs, count, name in self.dataloaders['val']: | |
inputs = inputs.to(self.device) | |
assert inputs.size(0) == 1, 'the batch size should equal to 1 in validation mode' | |
with torch.set_grad_enabled(False): | |
outputs, _ = self.model(inputs) | |
res = count[0].item() - torch.sum(outputs).item() | |
epoch_res.append(res) | |
epoch_res = np.array(epoch_res) | |
mse = np.sqrt(np.mean(np.square(epoch_res))) | |
mae = np.mean(np.abs(epoch_res)) | |
self.logger.info('Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec' | |
.format(self.epoch, mse, mae, time.time() - epoch_start)) | |
model_state_dic = self.model.state_dict() | |
if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae): | |
self.best_mse = mse | |
self.best_mae = mae | |
self.logger.info("save best mse {:.2f} mae {:.2f} model epoch {}".format(self.best_mse, | |
self.best_mae, | |
self.epoch)) | |
torch.save(model_state_dic, os.path.join(self.save_dir, 'best_model_{}.pth'.format(self.best_count))) | |
self.best_count += 1 | |