File size: 5,440 Bytes
270d2eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr

import torch
import matplotlib.pyplot as plt 
from PIL import Image, ImageDraw, ImageFont 
import requests 
from io import BytesIO 
import numpy as np 

# load a simple face detector 
from retinaface import RetinaFace 

device = "cuda" if torch.cuda.is_available() else "cpu"

# load Gaze-LLE model
model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout")
model.eval()
model.to(device)

def main(image_input):
    # load image
    image = Image.open(image_input)
    width, height = image.size

    # detect faces
    resp = RetinaFace.detect_faces(np.array(image))
    print(resp)
    bboxes = [resp[key]["facial_area"] for key in resp.keys()]
    print(bboxes)

    # prepare gazelle input
    img_tensor = transform(image).unsqueeze(0).to(device)
    norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height]) for bbox in bboxes]]

    input = {
        "images": img_tensor, # [num_images, 3, 448, 448]
        "bboxes": norm_bboxes # [[img1_bbox1, img1_bbox2...], [img2_bbox1, img2_bbox2]...]
    }

    with torch.no_grad():
        output = model(input)

    img1_person1_heatmap = output['heatmap'][0][0] # [64, 64] heatmap
    print(img1_person1_heatmap.shape)
    if model.inout:
        img1_person1_inout = output['inout'][0][0] # gaze in frame score (if model supports inout prediction)
        print(img1_person1_inout.item())

    # visualize predicted gaze heatmap for each person and gaze in/out of frame score

    def visualize_heatmap(pil_image, heatmap, bbox=None, inout_score=None):
        if isinstance(heatmap, torch.Tensor):
            heatmap = heatmap.detach().cpu().numpy()
        heatmap = Image.fromarray((heatmap * 255).astype(np.uint8)).resize(pil_image.size, Image.Resampling.BILINEAR)
        heatmap = plt.cm.jet(np.array(heatmap) / 255.)
        heatmap = (heatmap[:, :, :3] * 255).astype(np.uint8)
        heatmap = Image.fromarray(heatmap).convert("RGBA")
        heatmap.putalpha(90)
        overlay_image = Image.alpha_composite(pil_image.convert("RGBA"), heatmap)

        if bbox is not None:
            width, height = pil_image.size
            xmin, ymin, xmax, ymax = bbox
            draw = ImageDraw.Draw(overlay_image)
            draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline="lime", width=int(min(width, height) * 0.01))

            if inout_score is not None:
                text = f"in-frame: {inout_score:.2f}"
                text_width = draw.textlength(text)
                text_height = int(height * 0.01)
                text_x = xmin * width
                text_y = ymax * height + text_height
                draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
        return overlay_image


    # combined visualization with maximal gaze points for each person

    def visualize_all(pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
        colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']
        overlay_image = pil_image.convert("RGBA")
        draw = ImageDraw.Draw(overlay_image)
        width, height = pil_image.size

        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            color = colors[i % len(colors)]
            draw.rectangle([xmin * width, ymin * height, xmax * width, ymax * height], outline=color, width=int(min(width, height) * 0.01))

            if inout_scores is not None:
                inout_score = inout_scores[i]
                text = f"in-frame: {inout_score:.2f}"
                text_width = draw.textlength(text)
                text_height = int(height * 0.01)
                text_x = xmin * width
                text_y = ymax * height + text_height
                draw.text((text_x, text_y), text, fill=color, font=ImageFont.load_default(size=int(min(width, height) * 0.05)))

            if inout_scores is not None and inout_score > inout_thresh:
                heatmap = heatmaps[i]
                heatmap_np = heatmap.detach().cpu().numpy()
                max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
                gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                bbox_center_x = ((xmin + xmax) / 2) * width
                bbox_center_y = ((ymin + ymax) / 2) * height

                draw.ellipse([(gaze_target_x-5, gaze_target_y-5), (gaze_target_x+5, gaze_target_y+5)], fill=color, width=int(0.005*min(width, height)))
                draw.line([(bbox_center_x, bbox_center_y), (gaze_target_x, gaze_target_y)], fill=color, width=int(0.005*min(width, height)))

        return overlay_image

    result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)

    return result_gazed


with gr.Blocks() as demo: 
    with gr.Column():
        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Image Input", type="filepath")
                submit_button = gr.Button("Submit")
            with gr.Column():
                result = gr.Image(label="Result")

    submit_button.click(
        fn = main,
        inputs = [input_image],
        outputs = [result]
    )
demo.queue().launch(show_api=False, show_error=True)