File size: 4,744 Bytes
9c9498f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import argparse
import torch
from PIL import Image
import json
import os
import numpy as np
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from gazelle.model import get_gazelle_model
from gazelle.model import GazeLLE
from gazelle.backbone import DinoV2Backbone
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="./data/gazefollow")
parser.add_argument("--model_name", type=str, default="gazelle_dinov2_vitl14_inout")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/gazelle_dinov2_vitl14_inout.pt")
parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args()
class GazeFollow(torch.utils.data.Dataset):
def __init__(self, path, img_transform):
self.images = json.load(open(os.path.join(path, "test_preprocessed.json"), "rb"))
self.path = path
self.transform = img_transform
def __getitem__(self, idx):
item = self.images[idx]
image = self.transform(Image.open(os.path.join(self.path, item['path'])).convert("RGB"))
height = item['height']
width = item['width']
bboxes = [head['bbox_norm'] for head in item['heads']]
gazex = [head['gazex_norm'] for head in item['heads']]
gazey = [head['gazey_norm'] for head in item['heads']]
return image, bboxes, gazex, gazey, height, width
def __len__(self):
return len(self.images)
def collate(batch):
images, bboxes, gazex, gazey, height, width = zip(*batch)
return torch.stack(images), list(bboxes), list(gazex), list(gazey), list(height), list(width)
# GazeFollow calculates AUC using original image size with GT (x,y) coordinates set to 1 and everything else as 0
# References:
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L78
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/imutils.py#L67
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/evaluation.py#L7
def gazefollow_auc(heatmap, gt_gazex, gt_gazey, height, width):
target_map = np.zeros((height, width))
for point in zip(gt_gazex, gt_gazey):
if point[0] >= 0:
x, y = map(int, [point[0]*float(width), point[1]*float(height)])
x = min(x, width - 1)
y = min(y, height - 1)
target_map[y, x] = 1
resized_heatmap = torch.nn.functional.interpolate(heatmap.unsqueeze(dim=0).unsqueeze(dim=0), (height, width), mode='bilinear').squeeze()
auc = roc_auc_score(target_map.flatten(), resized_heatmap.cpu().flatten())
return auc
# Reference: https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L81
def gazefollow_l2(heatmap, gt_gazex, gt_gazey):
argmax = heatmap.flatten().argmax().item()
pred_y, pred_x = np.unravel_index(argmax, (64, 64))
pred_x = pred_x / 64.
pred_y = pred_y / 64.
gazex = np.array(gt_gazex)
gazey = np.array(gt_gazey)
avg_l2 = np.sqrt((pred_x - gazex.mean())**2 + (pred_y - gazey.mean())**2)
all_l2s = np.sqrt((pred_x - gazex)**2 + (pred_y - gazey)**2)
min_l2 = all_l2s.min().item()
return avg_l2, min_l2
@torch.no_grad()
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on {}".format(device))
model, transform = get_gazelle_model(args.model_name)
model.load_gazelle_state_dict(torch.load(args.ckpt_path, weights_only=True))
model.to(device)
model.eval()
dataset = GazeFollow(args.data_path, transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
aucs = []
min_l2s = []
avg_l2s = []
for _, (images, bboxes, gazex, gazey, height, width) in tqdm(enumerate(dataloader), desc="Evaluating", total=len(dataloader)):
preds = model.forward({"images": images.to(device), "bboxes": bboxes})
# eval each instance (head)
for i in range(images.shape[0]): # per image
for j in range(len(bboxes[i])): # per head
auc = gazefollow_auc(preds['heatmap'][i][j], gazex[i][j], gazey[i][j], height[i], width[i])
avg_l2, min_l2 = gazefollow_l2(preds['heatmap'][i][j], gazex[i][j], gazey[i][j])
aucs.append(auc)
avg_l2s.append(avg_l2)
min_l2s.append(min_l2)
print("AUC: {}".format(np.array(aucs).mean()))
print("Avg L2: {}".format(np.array(avg_l2s).mean()))
print("Min L2: {}".format(np.array(min_l2s).mean()))
if __name__ == "__main__":
main() |