|
import gradio as gr |
|
|
|
import torch |
|
import matplotlib.pyplot as plt |
|
from PIL import Image, ImageDraw, ImageFont |
|
import requests |
|
from io import BytesIO |
|
import numpy as np |
|
|
|
|
|
from retinaface import RetinaFace |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout") |
|
model.eval() |
|
model.to(device) |
|
|
|
def main(image_input): |
|
|
|
image = Image.open(image_input) |
|
width, height = image.size |
|
|
|
|
|
resp = RetinaFace.detect_faces(np.array(image)) |
|
print(resp) |
|
bboxes = [resp[key]["facial_area"] for key in resp.keys()] |
|
print(bboxes) |
|
|
|
|
|
img_tensor = transform(image).unsqueeze(0).to(device) |
|
norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]] |
|
|
|
input = { |
|
"images": img_tensor, |
|
"bboxes": norm_bboxes |
|
} |
|
|
|
with torch.no_grad(): |
|
output = model(input) |
|
|
|
img1_person1_heatmap = output['heatmap'][0][0] |
|
print(img1_person1_heatmap.shape) |
|
if model.inout: |
|
img1_person1_inout = output['inout'][0][0] |
|
print(img1_person1_inout.item()) |
|
|
|
|
|
|
|
def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None): |
|
if isinstance(heatmap, torch.Tensor): |
|
heatmap = heatmap.detach().cpu().numpy() |
|
heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR) |
|
heatmap = plt.cm.jet(np.array(heatmap) / 255.) |
|
heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8) |
|
heatmap = Image.fromarray(heatmap).convert("RGBA") |
|
heatmap.putalpha(90) |
|
overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap) |
|
|
|
if bbox is not None: |
|
width, height = pil_image.size |
|
xmin, ymin, xmax, ymax = bbox |
|
draw = ImageDraw.Draw(overlay_image) |
|
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01)) |
|
|
|
if inout_score is not None: |
|
text = f"in-frame: {inout_score:.2f}" |
|
text_width = draw.textlength(text) |
|
text_height = int(height * 0.01) |
|
text_x = xmin * width |
|
text_y = ymax * height + text_height |
|
draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05))) |
|
return overlay_image |
|
|
|
|
|
|
|
|
|
def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5): |
|
colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow'] |
|
overlay_image = pil_image.convert("RGBA") |
|
draw = ImageDraw.Draw(overlay_image) |
|
width, height = pil_image.size |
|
|
|
for i in range(len(bboxes)): |
|
bbox = bboxes[i] |
|
xmin, ymin, xmax, ymax = bbox |
|
color = colors[i % len(colors)] |
|
draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01)) |
|
|
|
if inout_scores is not None: |
|
inout_score = inout_scores[i] |
|
text = f"in-frame: {inout_score:.2f}" |
|
text_width = draw.textlength(text) |
|
text_height = int(height * 0.01) |
|
text_x = xmin * width |
|
text_y = ymax * height + text_height |
|
draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05))) |
|
|
|
if inout_scores is not None and inout_score > inout_thresh: |
|
heatmap = heatmaps[i] |
|
heatmap_np = heatmap.detach().cpu().numpy() |
|
max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape) |
|
gaze_target_x = max_index[1] / heatmap_np.shape[1] * width |
|
gaze_target_y = max_index[0] / heatmap_np.shape[0] * height |
|
bbox_center_x = ((xmin + xmax) / 2) * width |
|
bbox_center_y = ((ymin + ymax) / 2) * height |
|
|
|
draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height))) |
|
draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height))) |
|
|
|
return overlay_image |
|
|
|
result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5) |
|
|
|
return result_gazed |
|
|
|
|
|
with gr.Blocks() as demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(label="Image Input", type="filepath") |
|
submit_button = gr.Button("Submit") |
|
with gr.Column(): |
|
result = gr.Image(label="Result") |
|
|
|
submit_button.click( |
|
fn = main, |
|
inputs = [input_image], |
|
outputs = [result] |
|
) |
|
demo.queue().launch(show_api=False, show_error=True) |