import gradio as gr import torch import matplotlib.pyplot as plt from PIL import Image, ImageDraw, ImageFont import requests from io import BytesIO import numpy as np # load a simple face detector from retinaface import RetinaFace device = "cuda" if torch.cuda.is_available() else "cpu" # load Gaze-LLE model model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout") model.eval() model.to(device) def main(image_input, progress=gr.Progress(track_tqdm=True)): # load image image = Image.open(image_input) width, height = image.size # detect faces resp = RetinaFace.detect_faces(np.array(image)) print(resp) bboxes = [resp[key]["facial_area"] for key in resp.keys()] print(bboxes) # prepare gazelle input img_tensor = transform(image).unsqueeze(0).to(device) norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] input = { "images": img_tensor, # [num_images, 3, 448, 448] "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...] } with torch.no_grad(): output = model(input) img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap print(img1_person1_heatmap.shape) if model.inout: img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction) print(img1_person1_inout.item()) # visualize predicted gaze heatmap for each person and gaze in/out of frame score def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): if isinstance(heatmap, torch.Tensor): heatmap = heatmap.detach().cpu().numpy() heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) heatmap = plt.cm.jet(np.array(heatmap) / 255.) heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) heatmap = Image.fromarray(heatmap).convert("RGBA") heatmap.putalpha(90) overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) if bbox is not None: width, height = pil_image.size xmin, ymin, xmax, ymax = bbox draw = ImageDraw.Draw(overlay_image) draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) if inout_score is not None: text = f"in-frame: {inout_score:.2f}" text_width = draw.textlength(text) text_height = int(height * 0.01) text_x = xmin * width text_y = ymax * height + text_height draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) return overlay_image heatmap_results = [] for i in range(len(bboxes)): overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None)) heatmap_results.append(overlay_img) # combined visualization with maximal gaze points for each person def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] overlay_image = pil_image.convert("RGBA") draw = ImageDraw.Draw(overlay_image) width, height = pil_image.size for i in range(len(bboxes)): bbox = bboxes[i] xmin, ymin, xmax, ymax = bbox color = colors[i % len(colors)] draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) if inout_scores is not None: inout_score = inout_scores[i] text = f"in-frame: {inout_score:.2f}" text_width = draw.textlength(text) text_height = int(height * 0.01) text_x = xmin * width text_y = ymax * height + text_height draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) if inout_scores is not None and inout_score > inout_thresh: heatmap = heatmaps[i] heatmap_np = heatmap.detach().cpu().numpy() max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) gaze_target_x = max_index[1] / heatmap_np.shape[1] * width gaze_target_y = max_index[0] / heatmap_np.shape[0] * height bbox_center_x = ((xmin + xmax) / 2) * width bbox_center_y = ((ymin + ymax) / 2) * height draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) return overlay_image result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5) return result_gazed, heatmap_results css=""" div#col-container{ margin: 0 auto; max-width: 982px; } """ with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders") gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!") gr.HTML("""
Duplicate this Space Follow me on HF
""") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Image Input", type="filepath") submit_button = gr.Button("Submit") with gr.Column(): result = gr.Image(label="Result") heatmaps = gr.Gallery(label="Heatmap") submit_button.click( fn = main, inputs = [input_image], outputs = [result, heatmaps] ) demo.queue().launch(show_api=False, show_error=True)