Gaze-LLE / scripts /eval_gazefollow.py
fffiloni's picture
Migrated from GitHub
9c9498f verified
raw
history blame
4.74 kB
import argparse
import torch
from PIL import Image
import json
import os
import numpy as np
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from gazelle.model import get_gazelle_model
from gazelle.model import GazeLLE
from gazelle.backbone import DinoV2Backbone
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, default="./data/gazefollow")
parser.add_argument("--model_name", type=str, default="gazelle_dinov2_vitl14_inout")
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/gazelle_dinov2_vitl14_inout.pt")
parser.add_argument("--batch_size", type=int, default=128)
args = parser.parse_args()
class GazeFollow(torch.utils.data.Dataset):
def __init__(self, path, img_transform):
self.images = json.load(open(os.path.join(path, "test_preprocessed.json"), "rb"))
self.path = path
self.transform = img_transform
def __getitem__(self, idx):
item = self.images[idx]
image = self.transform(Image.open(os.path.join(self.path, item['path'])).convert("RGB"))
height = item['height']
width = item['width']
bboxes = [head['bbox_norm'] for head in item['heads']]
gazex = [head['gazex_norm'] for head in item['heads']]
gazey = [head['gazey_norm'] for head in item['heads']]
return image, bboxes, gazex, gazey, height, width
def __len__(self):
return len(self.images)
def collate(batch):
images, bboxes, gazex, gazey, height, width = zip(*batch)
return torch.stack(images), list(bboxes), list(gazex), list(gazey), list(height), list(width)
# GazeFollow calculates AUC using original image size with GT (x,y) coordinates set to 1 and everything else as 0
# References:
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L78
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/imutils.py#L67
# https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/utils/evaluation.py#L7
def gazefollow_auc(heatmap, gt_gazex, gt_gazey, height, width):
target_map = np.zeros((height, width))
for point in zip(gt_gazex, gt_gazey):
if point[0] >= 0:
x, y = map(int, [point[0]*float(width), point[1]*float(height)])
x = min(x, width - 1)
y = min(y, height - 1)
target_map[y, x] = 1
resized_heatmap = torch.nn.functional.interpolate(heatmap.unsqueeze(dim=0).unsqueeze(dim=0), (height, width), mode='bilinear').squeeze()
auc = roc_auc_score(target_map.flatten(), resized_heatmap.cpu().flatten())
return auc
# Reference: https://github.com/ejcgt/attention-target-detection/blob/acd264a3c9e6002b71244dea8c1873e5c5818500/eval_on_gazefollow.py#L81
def gazefollow_l2(heatmap, gt_gazex, gt_gazey):
argmax = heatmap.flatten().argmax().item()
pred_y, pred_x = np.unravel_index(argmax, (64, 64))
pred_x = pred_x / 64.
pred_y = pred_y / 64.
gazex = np.array(gt_gazex)
gazey = np.array(gt_gazey)
avg_l2 = np.sqrt((pred_x - gazex.mean())**2 + (pred_y - gazey.mean())**2)
all_l2s = np.sqrt((pred_x - gazex)**2 + (pred_y - gazey)**2)
min_l2 = all_l2s.min().item()
return avg_l2, min_l2
@torch.no_grad()
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Running on {}".format(device))
model, transform = get_gazelle_model(args.model_name)
model.load_gazelle_state_dict(torch.load(args.ckpt_path, weights_only=True))
model.to(device)
model.eval()
dataset = GazeFollow(args.data_path, transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate)
aucs = []
min_l2s = []
avg_l2s = []
for _, (images, bboxes, gazex, gazey, height, width) in tqdm(enumerate(dataloader), desc="Evaluating", total=len(dataloader)):
preds = model.forward({"images": images.to(device), "bboxes": bboxes})
# eval each instance (head)
for i in range(images.shape[0]): # per image
for j in range(len(bboxes[i])): # per head
auc = gazefollow_auc(preds['heatmap'][i][j], gazex[i][j], gazey[i][j], height[i], width[i])
avg_l2, min_l2 = gazefollow_l2(preds['heatmap'][i][j], gazex[i][j], gazey[i][j])
aucs.append(auc)
avg_l2s.append(avg_l2)
min_l2s.append(min_l2)
print("AUC: {}".format(np.array(aucs).mean()))
print("Avg L2: {}".format(np.array(avg_l2s).mean()))
print("Min L2: {}".format(np.array(min_l2s).mean()))
if __name__ == "__main__":
main()