Gaze-LLE / ORIGINAL_README.md
fffiloni's picture
Migrated from GitHub
9c9498f verified
|
raw
history blame
8.55 kB
# Gaze-LLE
<div style="text-align:center;">
<img src="./assets/the_office.png" height="100"/>
<img src="./assets/MLB_1.gif" height="100"/>
<img src="./assets/succession.png" height="100"/>
<img src="./assets/CBS_2.gif" height="100"/>
</div>
[Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders](https://arxiv.org/abs/2412.09586) \
[Fiona Ryan](https://fkryan.github.io/), Ajay Bati, [Sangmin Lee](https://sites.google.com/view/sangmin-lee), [Daniel Bolya](https://dbolya.github.io/), [Judy Hoffman](https://faculty.cc.gatech.edu/~judy/)\*, [James M. Rehg](https://rehg.org/)\*
This is the official implementation for Gaze-LLE, a transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!
<div style="text-align:center;">
<img src="./assets/gazelle_arch.png" height="200"/>
</div>
## Installation
Clone this repo, then create the virtual environment.
```
conda env create -f environment.yml
conda activate gazelle
pip install -e .
```
If your system supports it, consider installing [xformers](https://github.com/facebookresearch/xformers) to speed up attention computation.
```
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118
```
## Pretrained Models
We provide the following pretrained models for download.
| Name | Backbone type | Backbone name | Training data | Checkpoint |
| ---- | ------------- | ------------- |-------------- | ---------- |
| ```gazelle_dinov2_vitb14``` | DINOv2 ViT-B | ```dinov2_vitb14```| GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14.pt) |
| ```gazelle_dinov2_vitl14``` | DINOv2 ViT-L | ```dinov2_vitl14``` | GazeFollow | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14.pt) |
| ```gazelle_dinov2_vitb14_inout``` | DINOv2 ViT-B | ```dinov2_vitb14``` | Gazefollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitb14_inout.pt) |
| ```gazelle_large_vitl14_inout``` | DINOv2-ViT-L | ```dinov2_vitl14``` | GazeFollow -> VideoAttentionTarget | [Download](https://github.com/fkryan/gazelle/releases/download/v1.0.0/gazelle_dinov2_vitl14_inout.pt) |
Note that our Gaze-LLE checkpoints contain only the gaze decoder weights - the DINOv2 backbone weights are downloaded from ```facebookresearch/dinov2``` on PyTorch Hub when the Gaze-LLE model is created in our code.
The GazeFollow-trained models output a spatial heatmap of gaze locations over the scene with values in range ```[0,1]```, where 1 represents the highest probability of the location being a gaze target. The models that are additionally finetuned on VideoAttentionTarget also predict a in/out of frame gaze score in range ```[0,1]``` where 1 represents the person's gaze target being in the frame.
### PyTorch Hub
The models are also available on PyTorch Hub for easy use without installing from source.
```
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14')
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14')
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitb14_inout')
model, transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
```
## Usage
### Colab Demo Notebook
Check out our [Demo Notebook](https://colab.research.google.com/drive/1TSoyFvNs1-au9kjOZN_fo5ebdzngSPDq?usp=sharing) on Google Colab for how to detect gaze for all people in an image.
### Gaze Prediction
Gaze-LLE is set up for multi-person inference (e.g. for a single image, GazeLLE encodes the scene only once and then uses the features to predict the gaze of multiple people in the image). The input is a batch of image tensors and a list of bounding boxes for each image representing the heads of the people to predict gaze for in each image. The bounding boxes are tuples of form ```(xmin, ymin, xmax, ymax)``` and are in ```[0,1]``` normalized image coordinates. Below we show how to perform inference for a single person in a single image.
```
from PIL import Image
import torch
from gazelle.model import get_gazelle_model
model, transform = get_gazelle_model("gazelle_dinov2_vitl14_inout")
model.load_gazelle_state_dict(torch.load("/path/to/checkpoint.pt", weights_only=True))
model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
image = Image.open("path/to/image.png").convert("RGB")
input = {
"images": transform(image).unsqueeze(dim=0).to(device), # tensor of shape [1, 3, 448, 448]
"bboxes": [[(0.1, 0.2, 0.5, 0.7)]] # list of lists of bbox tuples
}
with torch.no_grad():
output = model(input)
predicted_heatmap = output["heatmap"][0][0] # access prediction for first person in first image. Tensor of size [64, 64]
predicted_inout = output["inout"][0][0] # in/out of frame score (1 = in frame) (output["inout"] will be None for non-inout models)
```
We empirically find that Gaze-LLE is effective without a bounding box input for scenes with just one person. However, providing a bounding box can improve results, and is necessary for scenes with multiple people to specify which person's gaze to estimate. To inference without a bounding box, use None in place of a bounding box tuple in the bbox list (e.g. ```input["bboxes"] = [[None]]``` in the example above).
We also provide a function to visualize the predicted heatmap for an image.
```
import matplotlib.pyplot as plt
from gazelle.utils import visualize_heatmap
viz = visualize_heatmap(image, predicted_heatmap)
plt.imshow(viz)
plt.show()
```
## Evaluate
We provide evaluation scripts for GazeFollow and VideoAttentionTarget below to reproduce our results from our checkpoints.
### GazeFollow
Download the GazeFollow dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset). We provide a preprocessing script ```data_prep/preprocess_gazefollow.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as
```
python data_prep/preprocess_gazefollow.py --data_path /path/to/gazefollow/data_new
```
Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation.
```
python scripts/eval_gazefollow.py
--data_path /path/to/gazefollow/data_new \
--model_name gazelle_dinov2_vitl14 \
--ckpt_path /path/to/checkpoint.pt \
--batch_size 128
```
### VideoAttentionTarget
Download the VideoAttentionTarget dataset [here](https://github.com/ejcgt/attention-target-detection?tab=readme-ov-file#dataset-1). We provide a preprocessing script ```data_prep/preprocess_vat.py```, which preprocesses and compiles the annotations into a JSON file for each split within the dataset folder. Run the preprocessing script as
```
python data_prep/preprocess_gazefollow.py --data_path /path/to/videoattentiontarget
```
Download the pretrained model checkpoints above and use ```--model_name``` and ```ckpt_path``` to specify the model type and checkpoint for evaluation.
```
python scripts/eval_vat.py
--data_path /path/to/videoattentiontarget \
--model_name gazelle_dinov2_vitl14_inout \
--ckpt_path /path/to/checkpoint.pt \
--batch_size 64
```
## Citation
```
@article{ryan2024gazelle,
author = {Ryan, Fiona and Bati, Ajay and Lee, Sangmin and Bolya, Daniel and Hoffman, Judy and Rehg, James M},
title = {Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders},
journal = {arXiv preprint arXiv:2412.09586},
year = {2024},
}
```
## References
- Our models are built on top of pretrained DINOv2 models from PyTorch Hub ([Github repo](https://github.com/facebookresearch/dinov2)).
- Our GazeFollow and VideoAttentionTarget preprocessing code is based on [Detecting Attended Targets in Video](https://github.com/ejcgt/attention-target-detection).
- We use [PyTorch Image Models (timm)](https://github.com/huggingface/pytorch-image-models) for our transformer implementation.
- We use [xFormers](https://github.com/facebookresearch/xformers) for efficient multi-head attention.