import gradio as gr |
import torch |
import matplotlib.pyplot as plt |
from PIL import Image, ImageDraw, ImageFont |
import requests |
from io import BytesIO |
import numpy as np |
from retinaface import RetinaFace |
device = "cuda" if torch.cuda.is_available() else "cpu" |
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout") |
model.eval() |
model.to(device) |
def main(image_input): |
image = Image.open(image_input) |
width, height = image.size |
resp = RetinaFace.detect_faces(np.array(image)) |
print(resp) |
bboxes = [resp[key]["facial_area"] for key in resp.keys()] |
print(bboxes) |
img_tensor = transform(image).unsqueeze(0).to(device) |
norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] |
input = { |
"images": img_tensor, |
"bboxes": norm_bboxes |
} |
with torch.no_grad(): |
output = model(input) |
img1_person1_heatmap = output['heatmap'][0][0] |
print(img1_person1_heatmap.shape) |
if model.inout: |
img1_person1_inout = output['inout'][0][0] |
print(img1_person1_inout.item()) |
def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): |
if isinstance(heatmap, torch.Tensor): |
heatmap = heatmap.detach().cpu().numpy() |
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) |
heatmap = plt.cm.jet(np.array(heatmap) / 255.) |
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) |
heatmap = Image.fromarray(heatmap).convert("RGBA") |
heatmap.putalpha(90) |
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) |
if bbox is not None: |
width, height = pil_image.size |
xmin, ymin, xmax, ymax = bbox |
draw = ImageDraw.Draw(overlay_image) |
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) |
if inout_score is not None: |
text = f"in-frame: {inout_score:.2f}" |
text_width = draw.textlength(text) |
text_height = int(height * 0.01) |
text_x = xmin * width |
text_y = ymax * height + text_height |
draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) |
return overlay_image |
def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): |
colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] |
overlay_image = pil_image.convert("RGBA") |
draw = ImageDraw.Draw(overlay_image) |
width, height = pil_image.size |
for i in range(len(bboxes)): |
bbox = bboxes[i] |
xmin, ymin, xmax, ymax = bbox |
color = colors[i % len(colors)] |
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) |
if inout_scores is not None: |
inout_score = inout_scores[i] |
text = f"in-frame: {inout_score:.2f}" |
text_width = draw.textlength(text) |
text_height = int(height * 0.01) |
text_x = xmin * width |
text_y = ymax * height + text_height |
draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) |
if inout_scores is not None and inout_score > inout_thresh: |
heatmap = heatmaps[i] |
heatmap_np = heatmap.detach().cpu().numpy() |
max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) |
gaze_target_x = max_index[1] / heatmap_np.shape[1] * width |
gaze_target_y = max_index[0] / heatmap_np.shape[0] * height |
bbox_center_x = ((xmin + xmax) / 2) * width |
bbox_center_y = ((ymin + ymax) / 2) * height |
draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) |
draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) |
return overlay_image |
result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5) |
return result_gazed |
with gr.Blocks() as demo: |
with gr.Column(): |
with gr.Row(): |
with gr.Column(): |
input_image = gr.Image(label="Image Input", type="filepath") |
submit_button = gr.Button("Submit") |
with gr.Column(): |
result = gr.Image(label="Result") |
submit_button.click( |
fn = main, |
inputs = [input_image], |
outputs = [result] |
) |
demo.queue().launch(show_api=False, show_error=True) |