Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import matplotlib.pyplot as plt | |
from PIL import Image, ImageDraw, ImageFont | |
import requests | |
from io import BytesIO | |
import numpy as np | |
# load a simple face detector | |
from retinaface import RetinaFace | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# load Gaze-LLE model | |
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout") | |
model.eval() | |
model.to(device) | |
def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): | |
if isinstance(heatmap, torch.Tensor): | |
heatmap = heatmap.detach().cpu().numpy() | |
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) | |
heatmap = plt.cm.jet(np.array(heatmap) / 255.) | |
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) | |
heatmap = Image.fromarray(heatmap).convert("RGBA") | |
heatmap.putalpha(90) | |
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) | |
if bbox is not None: | |
width, height = pil_image.size | |
xmin, ymin, xmax, ymax = bbox | |
draw = ImageDraw.Draw(overlay_image) | |
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) | |
if inout_score is not None: | |
text = f"in-frame: {inout_score:.2f}" | |
text_width = draw.textlength(text) | |
text_height = int(height * 0.01) | |
text_x = xmin * width | |
text_y = ymax * height + text_height | |
draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) | |
return overlay_image | |
def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): | |
colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] | |
overlay_image = pil_image.convert("RGBA") | |
draw = ImageDraw.Draw(overlay_image) | |
width, height = pil_image.size | |
for i in range(len(bboxes)): | |
bbox = bboxes[i] | |
xmin, ymin, xmax, ymax = bbox | |
color = colors[i % len(colors)] | |
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) | |
if inout_scores is not None: | |
inout_score = inout_scores[i] | |
text = f"in-frame: {inout_score:.2f}" | |
text_width = draw.textlength(text) | |
text_height = int(height * 0.01) | |
text_x = xmin * width | |
text_y = ymax * height + text_height | |
draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) | |
if inout_scores is not None and inout_score > inout_thresh: | |
heatmap = heatmaps[i] | |
heatmap_np = heatmap.detach().cpu().numpy() | |
max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) | |
gaze_target_x = max_index[1] / heatmap_np.shape[1] * width | |
gaze_target_y = max_index[0] / heatmap_np.shape[0] * height | |
bbox_center_x = ((xmin + xmax) / 2) * width | |
bbox_center_y = ((ymin + ymax) / 2) * height | |
draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) | |
draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) | |
return overlay_image | |
def main(image_input, progress=gr.Progress(track_tqdm=True)): | |
# load image | |
image = Image.open(image_input) | |
width, height = image.size | |
# detect faces | |
resp = RetinaFace.detect_faces(np.array(image)) | |
print(resp) | |
bboxes = [resp[key]["facial_area"] for key in resp.keys()] | |
print(bboxes) | |
# prepare gazelle input | |
img_tensor = transform(image).unsqueeze(0).to(device) | |
norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] | |
input = { | |
"images": img_tensor, # [num_images, 3, 448, 448] | |
"bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...] | |
} | |
with torch.no_grad(): | |
output = model(input) | |
img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap | |
print(img1_person1_heatmap.shape) | |
if model.inout: | |
img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction) | |
print(img1_person1_inout.item()) | |
# visualize predicted gaze heatmap for each person and gaze in/out of frame score | |
heatmap_results = [] | |
for i in range(len(bboxes)): | |
overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None) | |
heatmap_results.append(overlay_img) | |
# combined visualization with maximal gaze points for each person | |
result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5) | |
return result_gazed, heatmap_results | |
css=""" | |
div#col-container{ | |
margin: 0 auto; | |
max-width: 982px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
with gr.Column(elem_id="col-container"): | |
gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders") | |
gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!") | |
gr.HTML(""" | |
<div style="display:flex;column-gap:4px;"> | |
<a href="https://github.com/fkryan/gazelle"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
<a href="https://arxiv.org/abs/2412.09586"> | |
<img src='https://img.shields.io/badge/ArXiv-Paper-red'> | |
</a> | |
<a href="https://huggingface.co/spaces/fffiloni/Gaze-LLE?duplicate=true"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space"> | |
</a> | |
<a href="https://huggingface.co/fffiloni"> | |
<img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF"> | |
</a> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Image Input", type="filepath") | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
result = gr.Image(label="Result", columns=3) | |
heatmaps = gr.Gallery(label="Heatmap") | |
submit_button.click( | |
fn = main, | |
inputs = [input_image], | |
outputs = [result, heatmaps] | |
) | |
demo.queue().launch(show_api=False, show_error=True) |