import argparse import os import torch from train_helper import Trainer def str2bool(v): return v.lower() in ("yes", "true", "t", "1") def parse_args(): parser = argparse.ArgumentParser(description='Train') parser.add_argument('--data-dir', default='data/UCF-Train-Val-Test', help='data path') parser.add_argument('--dataset', default='qnrf', help='dataset name: qnrf, nwpu, sha, shb') parser.add_argument('--lr', type=float, default=1e-5, help='the initial learning rate') parser.add_argument('--weight-decay', type=float, default=1e-4, help='the weight decay') parser.add_argument('--resume', default='', type=str, help='the path of resume training model') parser.add_argument('--max-epoch', type=int, default=1000, help='max training epoch') parser.add_argument('--val-epoch', type=int, default=5, help='the num of steps to log training information') parser.add_argument('--val-start', type=int, default=50, help='the epoch start to val') parser.add_argument('--batch-size', type=int, default=10, help='train batch size') parser.add_argument('--device', default='0', help='assign device') parser.add_argument('--num-workers', type=int, default=3, help='the num of training process') parser.add_argument('--crop-size', type=int, default=512, help='the crop size of the train image') parser.add_argument('--wot', type=float, default=0.1, help='weight on OT loss') parser.add_argument('--wtv', type=float, default=0.01, help='weight on TV loss') parser.add_argument('--reg', type=float, default=10.0, help='entropy regularization in sinkhorn') parser.add_argument('--num-of-iter-in-ot', type=int, default=100, help='sinkhorn iterations') parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance') args = parser.parse_args() if args.dataset.lower() == 'qnrf': args.crop_size = 512 elif args.dataset.lower() == 'nwpu': args.crop_size = 384 args.val_epoch = 50 elif args.dataset.lower() == 'sha': args.crop_size = 256 elif args.dataset.lower() == 'shb': args.crop_size = 512 else: raise NotImplementedError return args if __name__ == '__main__': args = parse_args() torch.backends.cudnn.benchmark = True os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() # set vis gpu trainer = Trainer(args) trainer.setup() trainer.train()