Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import os | |
import numpy as np | |
import datasets.crowd as crowd | |
from models import vgg19 | |
parser = argparse.ArgumentParser(description='Test ') | |
parser.add_argument('--device', default='0', help='assign device') | |
parser.add_argument('--crop-size', type=int, default=512, | |
help='the crop size of the train image') | |
parser.add_argument('--model-path', type=str, default='pretrained_models/model_qnrf.pth', | |
help='saved model path') | |
parser.add_argument('--data-path', type=str, | |
default='data/QNRF-Train-Val-Test', | |
help='saved model path') | |
parser.add_argument('--dataset', type=str, default='qnrf', | |
help='dataset name: qnrf, nwpu, sha, shb') | |
parser.add_argument('--pred-density-map-path', type=str, default='', | |
help='save predicted density maps when pred-density-map-path is not empty.') | |
args = parser.parse_args() | |
os.environ['CUDA_VISIBLE_DEVICES'] = args.device # set vis gpu | |
device = torch.device('cuda') | |
model_path = args.model_path | |
crop_size = args.crop_size | |
data_path = args.data_path | |
if args.dataset.lower() == 'qnrf': | |
dataset = crowd.Crowd_qnrf(os.path.join(data_path, 'test'), crop_size, 8, method='val') | |
elif args.dataset.lower() == 'nwpu': | |
dataset = crowd.Crowd_nwpu(os.path.join(data_path, 'val'), crop_size, 8, method='val') | |
elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb': | |
dataset = crowd.Crowd_sh(os.path.join(data_path, 'test_data'), crop_size, 8, method='val') | |
else: | |
raise NotImplementedError | |
dataloader = torch.utils.data.DataLoader(dataset, 1, shuffle=False, | |
num_workers=1, pin_memory=True) | |
if args.pred_density_map_path: | |
import cv2 | |
if not os.path.exists(args.pred_density_map_path): | |
os.makedirs(args.pred_density_map_path) | |
model = vgg19() | |
model.to(device) | |
model.load_state_dict(torch.load(model_path, device)) | |
model.eval() | |
image_errs = [] | |
for inputs, count, name in dataloader: | |
inputs = inputs.to(device) | |
assert inputs.size(0) == 1, 'the batch size should equal to 1' | |
with torch.set_grad_enabled(False): | |
outputs, _ = model(inputs) | |
img_err = count[0].item() - torch.sum(outputs).item() | |
print(name, img_err, count[0].item(), torch.sum(outputs).item()) | |
image_errs.append(img_err) | |
if args.pred_density_map_path: | |
vis_img = outputs[0, 0].cpu().numpy() | |
# normalize density map values from 0 to 1, then map it to 0-255. | |
vis_img = (vis_img - vis_img.min()) / (vis_img.max() - vis_img.min() + 1e-5) | |
vis_img = (vis_img * 255).astype(np.uint8) | |
vis_img = cv2.applyColorMap(vis_img, cv2.COLORMAP_JET) | |
cv2.imwrite(os.path.join(args.pred_density_map_path, str(name[0]) + '.png'), vis_img) | |
image_errs = np.array(image_errs) | |
mse = np.sqrt(np.mean(np.square(image_errs))) | |
mae = np.mean(np.abs(image_errs)) | |
print('{}: mae {}, mse {}\n'.format(model_path, mae, mse)) | |