fffiloni commited on
Commit
33c5278
·
verified ·
1 Parent(s): 50964d7

add credentials + heatmaps outputs

Browse files
Files changed (1) hide show
  1. app.py +36 -7
app.py CHANGED
@@ -17,7 +17,7 @@ model, transform = torch.hub.load("fkryan/gazelle", "gazelle_dinov2_vitl14_inout
17
  model.eval()
18
  model.to(device)
19
 
20
- def main(image_input):
21
  # load image
22
  image = Image.open(image_input)
23
  width, height = image.size
@@ -73,6 +73,10 @@ def main(image_input):
73
  draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
74
  return overlay_image
75
 
 
 
 
 
76
 
77
  # combined visualization with maximal gaze points for each person
78
 
@@ -113,21 +117,46 @@ def main(image_input):
113
 
114
  result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)
115
 
116
- return result_gazed
117
-
118
-
119
- with gr.Blocks() as demo:
120
- with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  with gr.Row():
122
  with gr.Column():
123
  input_image = gr.Image(label="Image Input", type="filepath")
124
  submit_button = gr.Button("Submit")
125
  with gr.Column():
126
  result = gr.Image(label="Result")
 
127
 
128
  submit_button.click(
129
  fn = main,
130
  inputs = [input_image],
131
- outputs = [result]
132
  )
133
  demo.queue().launch(show_api=False, show_error=True)
 
17
  model.eval()
18
  model.to(device)
19
 
20
+ def main(image_input, progress=gr.Progress(track_tqdm=True)):
21
  # load image
22
  image = Image.open(image_input)
23
  width, height = image.size
 
73
  draw.text((text_x, text_y), text, fill="lime", font=ImageFont.load_default(size=int(min(width, height) * 0.05)))
74
  return overlay_image
75
 
76
+ heatmap_results = []
77
+ for i in range(len(bboxes)):
78
+ overlay_img = visualize_heatmap(image, output['heatmap'][0][i], norm_bboxes[0][i], inout_score=output['inout'][0][i] if output['inout'] is not None else None))
79
+ heatmap_results.append(overlay_img)
80
 
81
  # combined visualization with maximal gaze points for each person
82
 
 
117
 
118
  result_gazed = visualize_all(image, output['heatmap'][0], norm_bboxes[0], output['inout'][0] if output['inout'] is not None else None, inout_thresh=0.5)
119
 
120
+ return result_gazed, heatmap_results
121
+
122
+ css="""
123
+ div#col-container{
124
+ margin: 0 auto;
125
+ max-width: 982px;
126
+ }
127
+ """
128
+
129
+ with gr.Blocks(css=css) as demo:
130
+ with gr.Column(elem_id="col-container"):
131
+ gr.Markdown("# Gaze-LLE: Gaze Target Estimation via Large-Scale Learned Encoders")
132
+ gr.Markdown("A transformer approach for estimating gaze targets that leverages the power of pretrained visual foundation models. Gaze-LLE provides a streamlined gaze architecture that learns only a lightweight gaze decoder on top of a frozen, pretrained visual encoder (DINOv2). Gaze-LLE learns 1-2 orders of magnitude fewer parameters than prior works and doesn't require any extra input modalities like depth and pose!")
133
+ gr.HTML("""
134
+ <div style="display:flex;column-gap:4px;">
135
+ <a href="https://github.com/fkryan/gazelle">
136
+ <img src='https://img.shields.io/badge/GitHub-Repo-blue'>
137
+ </a>
138
+ <a href="https://arxiv.org/abs/2412.09586">
139
+ <img src='https://img.shields.io/badge/ArXiv-Paper-red'>
140
+ </a>
141
+ <a href="https://huggingface.co/spaces/fffiloni/Gaze-LLE?duplicate=true">
142
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-sm.svg" alt="Duplicate this Space">
143
+ </a>
144
+ <a href="https://huggingface.co/fffiloni">
145
+ <img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/follow-me-on-HF-sm-dark.svg" alt="Follow me on HF">
146
+ </a>
147
+ </div>
148
+ """)
149
  with gr.Row():
150
  with gr.Column():
151
  input_image = gr.Image(label="Image Input", type="filepath")
152
  submit_button = gr.Button("Submit")
153
  with gr.Column():
154
  result = gr.Image(label="Result")
155
+ heatmaps = gr.Gallery(label="Heatmap")
156
 
157
  submit_button.click(
158
  fn = main,
159
  inputs = [input_image],
160
+ outputs = [result, heatmaps]
161
  )
162
  demo.queue().launch(show_api=False, show_error=True)