Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import torch | |
from train_helper import Trainer | |
def str2bool(v): | |
return v.lower() in ("yes", "true", "t", "1") | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Train') | |
parser.add_argument('--data-dir', default='data/UCF-Train-Val-Test', help='data path') | |
parser.add_argument('--dataset', default='qnrf', help='dataset name: qnrf, nwpu, sha, shb') | |
parser.add_argument('--lr', type=float, default=1e-5, | |
help='the initial learning rate') | |
parser.add_argument('--weight-decay', type=float, default=1e-4, | |
help='the weight decay') | |
parser.add_argument('--resume', default='', type=str, | |
help='the path of resume training model') | |
parser.add_argument('--max-epoch', type=int, default=1000, | |
help='max training epoch') | |
parser.add_argument('--val-epoch', type=int, default=5, | |
help='the num of steps to log training information') | |
parser.add_argument('--val-start', type=int, default=50, | |
help='the epoch start to val') | |
parser.add_argument('--batch-size', type=int, default=10, | |
help='train batch size') | |
parser.add_argument('--device', default='0', help='assign device') | |
parser.add_argument('--num-workers', type=int, default=3, | |
help='the num of training process') | |
parser.add_argument('--crop-size', type=int, default=512, | |
help='the crop size of the train image') | |
parser.add_argument('--wot', type=float, default=0.1, help='weight on OT loss') | |
parser.add_argument('--wtv', type=float, default=0.01, help='weight on TV loss') | |
parser.add_argument('--reg', type=float, default=10.0, | |
help='entropy regularization in sinkhorn') | |
parser.add_argument('--num-of-iter-in-ot', type=int, default=100, | |
help='sinkhorn iterations') | |
parser.add_argument('--norm-cood', type=int, default=0, help='whether to norm cood when computing distance') | |
args = parser.parse_args() | |
if args.dataset.lower() == 'qnrf': | |
args.crop_size = 512 | |
elif args.dataset.lower() == 'nwpu': | |
args.crop_size = 384 | |
args.val_epoch = 50 | |
elif args.dataset.lower() == 'sha': | |
args.crop_size = 256 | |
elif args.dataset.lower() == 'shb': | |
args.crop_size = 512 | |
else: | |
raise NotImplementedError | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
torch.backends.cudnn.benchmark = True | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.device.strip() # set vis gpu | |
trainer = Trainer(args) | |
trainer.setup() | |
trainer.train() | |