vijul.shah commited on
Commit
9acc552
·
1 Parent(s): f0adec0

Input Video and Predictions as output video added

Browse files
app.py CHANGED
@@ -1,431 +1,113 @@
1
- # takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py
2
-
3
- # streamlit run app.py
4
- from io import BytesIO
5
  import os
6
  import sys
7
- import cv2
8
- import matplotlib.pyplot as plt
9
- import numpy as np
10
- import streamlit as st
11
- import torch
12
  import tempfile
13
- from PIL import Image
14
- from torchvision import models
15
- from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
16
- from torchvision import transforms
17
-
18
- from torchcam.methods import CAM
19
- from torchcam import methods as torchcam_methods
20
- from torchcam.utils import overlay_mask
21
  import os.path as osp
 
 
 
 
 
 
22
 
23
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
24
  sys.path.append(root_path)
25
 
26
- from preprocessing.dataset_creation import EyeDentityDatasetCreation
27
- from utils import get_model
28
  from registry_utils import import_registered_modules
 
 
 
 
 
 
 
 
 
 
29
 
30
  import_registered_modules()
31
- # from torchcam.methods._utils import locate_candidate_layer
32
 
33
- CAM_METHODS = [
34
- "CAM",
35
- # "GradCAM",
36
- # "GradCAMpp",
37
- # "SmoothGradCAMpp",
38
- # "ScoreCAM",
39
- # "SSCAM",
40
- # "ISCAM",
41
- # "XGradCAM",
42
- # "LayerCAM",
43
- ]
44
- TV_MODELS = [
45
- "ResNet18",
46
- "ResNet50",
47
- ]
48
  SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
49
  UPSCALE = [2, 4]
50
  UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
51
  LABEL_MAP = ["left_pupil", "right_pupil"]
52
 
53
 
54
- @torch.no_grad()
55
- def _load_model(model_configs, device="cpu"):
56
- model_path = os.path.join(root_path, model_configs["model_path"])
57
- model_configs.pop("model_path")
58
- model_dict = torch.load(model_path, map_location=device)
59
- model = get_model(model_configs=model_configs)
60
- model.load_state_dict(model_dict)
61
- model = model.to(device)
62
- model = model.eval()
63
- return model
64
-
65
-
66
- def extract_frames(video_path):
67
- vidcap = cv2.VideoCapture(video_path)
68
- frames = []
69
- success, image = vidcap.read()
70
- count = 0
71
- while success:
72
- # Convert the frame to RGB (cv2 uses BGR by default)
73
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74
- frames.append(image_rgb)
75
- success, image = vidcap.read()
76
- count += 1
77
- vidcap.release()
78
- return frames
79
-
80
-
81
- # Function to check if a file is an image
82
- def is_image(file_extension):
83
- return file_extension.lower() in ["png", "jpeg", "jpg"]
84
-
85
-
86
- # Function to check if a file is a video
87
- def is_video(file_extension):
88
- return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
89
-
90
-
91
- def resize_frame(frame, max_width, max_height):
92
- image = Image.fromarray(frame)
93
- original_size = image.size
94
-
95
- # Resize the frame similarly to the image resizing logic
96
- if original_size[0] == original_size[1] and original_size[0] >= 256:
97
- max_size = (256, 256)
98
- else:
99
- max_size = list(original_size)
100
- if original_size[0] >= 640:
101
- max_size[0] = 640
102
- elif original_size[0] < 64:
103
- max_size[0] = 64
104
- if original_size[1] >= 480:
105
- max_size[1] = 480
106
- elif original_size[1] < 32:
107
- max_size[1] = 32
108
-
109
- image.thumbnail(max_size)
110
- return image
111
-
112
-
113
  def main():
114
- # Wide mode
115
  st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
116
-
117
- # Designing the interface
118
  st.title("EyeDentify Playground")
119
- # For newline
120
- st.write("\n")
121
- # Set the columns
122
  cols = st.columns((1, 1))
123
- # cols = st.columns((1, 1, 1))
124
- cols[0].header("Input image")
125
- # cols[1].header("Raw CAM")
126
  cols[-1].header("Prediction")
127
 
128
- # Sidebar
129
- # File selection
130
  st.sidebar.title("Upload Face or Eye")
131
- # Disabling warning
132
- st.set_option("deprecation.showfileUploaderEncoding", False)
133
- # Choose your own image
134
  uploaded_file = st.sidebar.file_uploader(
135
  "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
136
  )
 
137
  if uploaded_file is not None:
138
- # Get file extension
139
  file_extension = uploaded_file.name.split(".")[-1]
140
- input_imgs = []
141
 
142
  if is_image(file_extension):
143
- input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
144
- # print("input_img before = ", input_img.size)
145
- max_size = [input_img.size[0], input_img.size[1]]
146
- cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
147
- if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
148
- max_size[0] = 256
149
- max_size[1] = 256
150
- else:
151
- if input_img.size[0] >= 640:
152
- max_size[0] = 640
153
- elif input_img.size[0] < 64:
154
- max_size[0] = 64
155
- if input_img.size[1] >= 480:
156
- max_size[1] = 480
157
- elif input_img.size[1] < 32:
158
- max_size[1] = 32
159
- input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
160
- input_imgs.append(input_img)
161
- # print("input_img after = ", input_img.size)
162
- # cols[0].image(input_img)
163
- fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
164
- # Display the input image
165
- axs0.imshow(input_imgs[0])
166
- axs0.axis("off")
167
- axs0.set_title("Input Image")
168
 
169
- # Display the plot
170
- cols[0].pyplot(fig0)
171
- cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
172
-
173
- # TODO: show the face features extracted from the image under 'input image' column
174
  elif is_video(file_extension):
175
  tfile = tempfile.NamedTemporaryFile(delete=False)
176
  tfile.write(uploaded_file.read())
177
  video_path = tfile.name
178
-
179
- # Extract frames from the video
180
- frames = extract_frames(video_path)
181
- print(f"Extracted {len(frames)} frames from the video")
182
-
183
- # Process the frames
184
- for i, frame in enumerate(frames):
185
- input_imgs.append(resize_frame(frame, 640, 480))
186
-
187
- os.remove(video_path)
188
-
189
- fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
190
- # Display the input image
191
- axs0.imshow(input_imgs[0])
192
- axs0.axis("off")
193
- axs0.set_title("Input Image")
194
-
195
- # Display the plot
196
- cols[0].pyplot(fig0)
197
- # cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
198
 
199
  st.sidebar.title("Setup")
200
-
201
- # Upscale selection
202
- upscale = "-"
203
- # upscale = st.sidebar.selectbox(
204
- # "Upscale",
205
- # ["-"] + UPSCALE,
206
- # help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling",
207
- # )
208
-
209
- # Upscale method selection
210
- if upscale != "-":
211
- upscale_method_or_model = st.sidebar.selectbox(
212
- "Upscale Method / Model",
213
- UPSCALE_METHODS + SR_METHODS,
214
- help="Select a method or model to upscale the uploaded image",
215
- )
216
- else:
217
- upscale_method_or_model = None
218
-
219
- # Pupil selection
220
  pupil_selection = st.sidebar.selectbox(
221
- "Pupil Selection",
222
- ["-"] + LABEL_MAP,
223
- help="Select left or right pupil OR keep blank for both pupil diameter estimation",
224
- )
225
-
226
- # Model selection
227
- tv_model = st.sidebar.selectbox(
228
- "Classification model",
229
- TV_MODELS,
230
- help="Supported Models for Pupil Diameter Estimation",
231
  )
232
-
233
- cam_method = "CAM"
234
- # cam_method = st.sidebar.selectbox(
235
- # "CAM method",
236
- # CAM_METHODS,
237
- # help="The way your class activation map will be computed",
238
- # )
239
- # target_layer = st.sidebar.text_input(
240
- # "Target layer",
241
- # default_layer,
242
- # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
243
- # )
244
-
245
- st.sidebar.write("\n")
246
 
247
  if st.sidebar.button("Predict Diameter & Compute CAM"):
248
  if uploaded_file is None:
249
- st.sidebar.error("Please upload an image first")
250
-
251
  else:
252
  with st.spinner("Analyzing..."):
253
- model = None
254
- for input_img in input_imgs:
255
- if upscale == "-":
256
- sr_configs = None
257
- else:
258
- sr_configs = {
259
- "method": upscale_method_or_model,
260
- "params": {"upscale": upscale},
261
- }
262
- config_file = {
263
- "sr_configs": sr_configs,
264
- "feature_extraction_configs": {
265
- "blink_detection": False,
266
- "upscale": upscale,
267
- "extraction_library": "mediapipe",
268
- },
269
- }
270
-
271
- img = np.array(input_img)
272
- # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
273
- # if img.shape[0] > max_size or img.shape[1] > max_size:
274
- # img = cv2.resize(img, (max_size, max_size))
275
-
276
- ds_results = EyeDentityDatasetCreation(
277
- feature_extraction_configs=config_file["feature_extraction_configs"],
278
- sr_configs=config_file["sr_configs"],
279
- )(img)
280
- # if ds_results is not None:
281
- # print("ds_results = ", ds_results.keys())
282
-
283
- preprocess_steps = [
284
- transforms.ToTensor(),
285
- transforms.Resize(
286
- [32, 64],
287
- # interpolation=transforms.InterpolationMode.BILINEAR,
288
- interpolation=transforms.InterpolationMode.BICUBIC,
289
- antialias=True,
290
- ),
291
- ]
292
- preprocess_function = transforms.Compose(preprocess_steps)
293
-
294
- left_eye = None
295
- right_eye = None
296
-
297
- if ds_results is None:
298
- # print("type of input_img = ", type(input_img))
299
- input_img = preprocess_function(input_img)
300
- input_img = input_img.unsqueeze(0)
301
- if pupil_selection == "left_pupil":
302
- left_eye = input_img
303
- elif pupil_selection == "right_pupil":
304
- right_eye = input_img
305
- else:
306
- left_eye = input_img
307
- right_eye = input_img
308
- # print("type of left_eye = ", type(left_eye))
309
- # print("type of right_eye = ", type(right_eye))
310
- elif "eyes" in ds_results.keys():
311
- if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
312
- left_eye = ds_results["eyes"]["left_eye"]
313
- # print("type of left_eye = ", type(left_eye))
314
- left_eye = to_pil_image(left_eye).convert("RGB")
315
- # print("type of left_eye = ", type(left_eye))
316
-
317
- left_eye = preprocess_function(left_eye)
318
- # print("type of left_eye = ", type(left_eye))
319
-
320
- left_eye = left_eye.unsqueeze(0)
321
- if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
322
- right_eye = ds_results["eyes"]["right_eye"]
323
- # print("type of right_eye = ", type(right_eye))
324
- right_eye = to_pil_image(right_eye).convert("RGB")
325
- # print("type of right_eye = ", type(right_eye))
326
-
327
- right_eye = preprocess_function(right_eye)
328
- # print("type of right_eye = ", type(right_eye))
329
-
330
- right_eye = right_eye.unsqueeze(0)
331
- else:
332
- # print("type of input_img = ", type(input_img))
333
- input_img = preprocess_function(input_img)
334
- input_img = input_img.unsqueeze(0)
335
- if pupil_selection == "left_pupil":
336
- left_eye = input_img
337
- elif pupil_selection == "right_pupil":
338
- right_eye = input_img
339
- else:
340
- left_eye = input_img
341
- right_eye = input_img
342
- # print("type of left_eye = ", type(left_eye))
343
- # print("type of right_eye = ", type(right_eye))
344
-
345
- # print("left_eye = ", left_eye.shape)
346
- # print("right_eye = ", right_eye.shape)
347
-
348
- if pupil_selection == "-":
349
- selected_eyes = ["left_eye", "right_eye"]
350
- elif pupil_selection == "left_pupil":
351
- selected_eyes = ["left_eye"]
352
- elif pupil_selection == "right_pupil":
353
- selected_eyes = ["right_eye"]
354
-
355
- for eye_type in selected_eyes:
356
-
357
- if model is None:
358
- model_configs = {
359
- "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
360
- "registered_model_name": tv_model,
361
- "num_classes": 1,
362
- }
363
- registered_model_name = model_configs["registered_model_name"]
364
- model = _load_model(model_configs)
365
-
366
- if registered_model_name == "ResNet18":
367
- target_layer = model.resnet.layer4[-1].conv2
368
- elif registered_model_name == "ResNet50":
369
- target_layer = model.resnet.layer4[-1].conv3
370
- else:
371
- raise Exception(f"No target layer available for selected model: {registered_model_name}")
372
-
373
- if left_eye is not None and eye_type == "left_eye":
374
- input_img = left_eye
375
- elif right_eye is not None and eye_type == "right_eye":
376
- input_img = right_eye
377
- else:
378
- raise Exception("Wrong Data")
379
-
380
- if cam_method is not None:
381
- cam_extractor = torchcam_methods.__dict__[cam_method](
382
- model,
383
- target_layer=target_layer,
384
- fc_layer=model.resnet.fc,
385
- input_shape=input_img.shape,
386
- )
387
-
388
- # with torch.no_grad():
389
- out = model(input_img)
390
- cols[-1].markdown(
391
- f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
392
- unsafe_allow_html=True,
393
- )
394
- # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
395
-
396
- # Retrieve the CAM
397
- act_maps = cam_extractor(0, out)
398
-
399
- # Fuse the CAMs if there are several
400
- activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
401
-
402
- # Convert input image and activation map to PIL images
403
- input_image_pil = to_pil_image(input_img.squeeze(0))
404
- activation_map_pil = to_pil_image(activation_map, mode="F")
405
-
406
- # Create the overlayed CAM result
407
- result = overlay_mask(
408
- input_image_pil,
409
- activation_map_pil,
410
- alpha=0.5,
411
- )
412
-
413
- # Create a subplot with 1 row and 2 columns
414
- fig, axs = plt.subplots(1, 2, figsize=(10, 5))
415
-
416
- # Display the input image
417
- axs[0].imshow(input_image_pil)
418
- axs[0].axis("off")
419
- axs[0].set_title("Input Image")
420
-
421
- # Display the overlayed CAM result
422
- axs[1].imshow(result)
423
- axs[1].axis("off")
424
- axs[1].set_title("Overlayed CAM")
425
 
426
- # Display the plot
427
- cols[-1].pyplot(fig)
428
- cols[-1].text(f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
 
431
  if __name__ == "__main__":
 
 
 
 
 
1
  import os
2
  import sys
 
 
 
 
 
3
  import tempfile
 
 
 
 
 
 
 
 
4
  import os.path as osp
5
+ from PIL import Image
6
+ from io import BytesIO
7
+ import numpy as np
8
+ import streamlit as st
9
+ from PIL import ImageOps
10
+ from matplotlib import pyplot as plt
11
 
12
  root_path = osp.abspath(osp.join(__file__, osp.pardir))
13
  sys.path.append(root_path)
14
 
 
 
15
  from registry_utils import import_registered_modules
16
+ from app_utils import (
17
+ extract_frames,
18
+ is_image,
19
+ is_video,
20
+ display_results,
21
+ overlay_text_on_frame,
22
+ process_frames,
23
+ process_video,
24
+ resize_frame,
25
+ )
26
 
27
  import_registered_modules()
 
28
 
29
+ CAM_METHODS = ["CAM"]
30
+ TV_MODELS = ["ResNet18", "ResNet50"]
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
32
  UPSCALE = [2, 4]
33
  UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
34
  LABEL_MAP = ["left_pupil", "right_pupil"]
35
 
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def main():
 
38
  st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
 
 
39
  st.title("EyeDentify Playground")
 
 
 
40
  cols = st.columns((1, 1))
41
+ cols[0].header("Input")
 
 
42
  cols[-1].header("Prediction")
43
 
 
 
44
  st.sidebar.title("Upload Face or Eye")
 
 
 
45
  uploaded_file = st.sidebar.file_uploader(
46
  "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
47
  )
48
+
49
  if uploaded_file is not None:
 
50
  file_extension = uploaded_file.name.split(".")[-1]
 
51
 
52
  if is_image(file_extension):
53
+ input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB")
54
+ # NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image.
55
+ input_img = ImageOps.exif_transpose(input_img)
56
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
57
+ input_img = resize_frame(input_img, max_width=640, max_height=480)
58
+ cols[0].image(input_img, use_column_width=True)
59
+ input_img.save("out.jpg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
 
 
 
 
 
61
  elif is_video(file_extension):
62
  tfile = tempfile.NamedTemporaryFile(delete=False)
63
  tfile.write(uploaded_file.read())
64
  video_path = tfile.name
65
+ video_frames = extract_frames(video_path)
66
+ cols[0].video(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  st.sidebar.title("Setup")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  pupil_selection = st.sidebar.selectbox(
70
+ "Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation"
 
 
 
 
 
 
 
 
 
71
  )
72
+ tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models")
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  if st.sidebar.button("Predict Diameter & Compute CAM"):
75
  if uploaded_file is None:
76
+ st.sidebar.error("Please upload an image or video")
 
77
  else:
78
  with st.spinner("Analyzing..."):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ if is_image(file_extension):
81
+ input_frames, output_frames, predicted_diameters, face_frames = process_frames(
82
+ [input_img], tv_model, pupil_selection, cam_method=CAM_METHODS[-1]
83
+ )
84
+ for ff in face_frames:
85
+ if ff["has_face"]:
86
+ cols[1].image(face_frames[0]["img"], use_column_width=True)
87
+
88
+ input_frames_keys = input_frames.keys()
89
+ video_cols = cols[1].columns(len(input_frames_keys))
90
+ for i, eye_type in enumerate(input_frames_keys):
91
+ video_cols[i].image(input_frames[eye_type][-1], use_column_width=True)
92
+
93
+ output_frames_keys = output_frames.keys()
94
+ fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5))
95
+ for i, eye_type in enumerate(output_frames_keys):
96
+ height, width, c = output_frames[eye_type][0].shape
97
+ video_cols[i].image(output_frames[eye_type][-1], use_column_width=True)
98
+
99
+ frame = np.zeros((height, width, c), dtype=np.uint8)
100
+ text = f"{predicted_diameters[eye_type][0]:.2f}"
101
+ frame = overlay_text_on_frame(frame, text)
102
+ video_cols[i].image(frame, use_column_width=True)
103
+
104
+ elif is_video(file_extension):
105
+ output_video_path = f"{root_path}/tmp.webm"
106
+ process_video(
107
+ cols, video_frames, tv_model, pupil_selection, output_video_path, cam_method=CAM_METHODS[-1]
108
+ )
109
+
110
+ os.remove(video_path)
111
 
112
 
113
  if __name__ == "__main__":
app_old.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py
2
+
3
+ # streamlit run app.py
4
+ from io import BytesIO
5
+ import os
6
+ import sys
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import streamlit as st
11
+ import torch
12
+ import tempfile
13
+ from PIL import Image
14
+ from torchvision import models
15
+ from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
16
+ from torchvision import transforms
17
+
18
+ from torchcam.methods import CAM
19
+ from torchcam import methods as torchcam_methods
20
+ from torchcam.utils import overlay_mask
21
+ import os.path as osp
22
+
23
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
24
+ sys.path.append(root_path)
25
+
26
+ from preprocessing.dataset_creation import EyeDentityDatasetCreation
27
+ from utils import get_model
28
+ from registry_utils import import_registered_modules
29
+
30
+ import_registered_modules()
31
+ # from torchcam.methods._utils import locate_candidate_layer
32
+
33
+ CAM_METHODS = [
34
+ "CAM",
35
+ # "GradCAM",
36
+ # "GradCAMpp",
37
+ # "SmoothGradCAMpp",
38
+ # "ScoreCAM",
39
+ # "SSCAM",
40
+ # "ISCAM",
41
+ # "XGradCAM",
42
+ # "LayerCAM",
43
+ ]
44
+ TV_MODELS = [
45
+ "ResNet18",
46
+ "ResNet50",
47
+ ]
48
+ SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"]
49
+ UPSCALE = [2, 4]
50
+ UPSCALE_METHODS = ["BILINEAR", "BICUBIC"]
51
+ LABEL_MAP = ["left_pupil", "right_pupil"]
52
+
53
+
54
+ @torch.no_grad()
55
+ def _load_model(model_configs, device="cpu"):
56
+ model_path = os.path.join(root_path, model_configs["model_path"])
57
+ model_configs.pop("model_path")
58
+ model_dict = torch.load(model_path, map_location=device)
59
+ model = get_model(model_configs=model_configs)
60
+ model.load_state_dict(model_dict)
61
+ model = model.to(device)
62
+ model = model.eval()
63
+ return model
64
+
65
+
66
+ def extract_frames(video_path):
67
+ vidcap = cv2.VideoCapture(video_path)
68
+ frames = []
69
+ success, image = vidcap.read()
70
+ count = 0
71
+ while success:
72
+ # Convert the frame to RGB (cv2 uses BGR by default)
73
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
74
+ frames.append(image_rgb)
75
+ success, image = vidcap.read()
76
+ count += 1
77
+ vidcap.release()
78
+ return frames
79
+
80
+
81
+ # Function to check if a file is an image
82
+ def is_image(file_extension):
83
+ return file_extension.lower() in ["png", "jpeg", "jpg"]
84
+
85
+
86
+ # Function to check if a file is a video
87
+ def is_video(file_extension):
88
+ return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
89
+
90
+
91
+ def resize_frame(frame, max_width, max_height):
92
+ image = Image.fromarray(frame)
93
+ original_size = image.size
94
+
95
+ # Resize the frame similarly to the image resizing logic
96
+ if original_size[0] == original_size[1] and original_size[0] >= 256:
97
+ max_size = (256, 256)
98
+ else:
99
+ max_size = list(original_size)
100
+ if original_size[0] >= 640:
101
+ max_size[0] = 640
102
+ elif original_size[0] < 64:
103
+ max_size[0] = 64
104
+ if original_size[1] >= 480:
105
+ max_size[1] = 480
106
+ elif original_size[1] < 32:
107
+ max_size[1] = 32
108
+
109
+ image.thumbnail(max_size)
110
+ return image
111
+
112
+
113
+ def main():
114
+ # Wide mode
115
+ st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
116
+
117
+ # Designing the interface
118
+ st.title("EyeDentify Playground")
119
+ # For newline
120
+ st.write("\n")
121
+ # Set the columns
122
+ cols = st.columns((1, 1))
123
+ # cols = st.columns((1, 1, 1))
124
+ cols[0].header("Input image")
125
+ # cols[1].header("Raw CAM")
126
+ cols[-1].header("Prediction")
127
+
128
+ # Sidebar
129
+ # File selection
130
+ st.sidebar.title("Upload Face or Eye")
131
+ # Disabling warning
132
+ st.set_option("deprecation.showfileUploaderEncoding", False)
133
+ # Choose your own image
134
+ uploaded_file = st.sidebar.file_uploader(
135
+ "Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"]
136
+ )
137
+ if uploaded_file is not None:
138
+ # Get file extension
139
+ file_extension = uploaded_file.name.split(".")[-1]
140
+ input_imgs = []
141
+
142
+ if is_image(file_extension):
143
+ input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
144
+ # print("input_img before = ", input_img.size)
145
+ max_size = [input_img.size[0], input_img.size[1]]
146
+ cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}")
147
+ if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256:
148
+ max_size[0] = 256
149
+ max_size[1] = 256
150
+ else:
151
+ if input_img.size[0] >= 640:
152
+ max_size[0] = 640
153
+ elif input_img.size[0] < 64:
154
+ max_size[0] = 64
155
+ if input_img.size[1] >= 480:
156
+ max_size[1] = 480
157
+ elif input_img.size[1] < 32:
158
+ max_size[1] = 32
159
+ input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling
160
+ input_imgs.append(input_img)
161
+ # print("input_img after = ", input_img.size)
162
+ # cols[0].image(input_img)
163
+ fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
164
+ # Display the input image
165
+ axs0.imshow(input_imgs[0])
166
+ axs0.axis("off")
167
+ axs0.set_title("Input Image")
168
+
169
+ # Display the plot
170
+ cols[0].pyplot(fig0)
171
+ cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
172
+
173
+ # TODO: show the face features extracted from the image under 'input image' column
174
+ elif is_video(file_extension):
175
+ tfile = tempfile.NamedTemporaryFile(delete=False)
176
+ tfile.write(uploaded_file.read())
177
+ video_path = tfile.name
178
+
179
+ # Extract frames from the video
180
+ frames = extract_frames(video_path)
181
+ print(f"Extracted {len(frames)} frames from the video")
182
+
183
+ # Process the frames
184
+ for i, frame in enumerate(frames):
185
+ input_imgs.append(resize_frame(frame, 640, 480))
186
+
187
+ os.remove(video_path)
188
+
189
+ fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10))
190
+ # Display the input image
191
+ axs0.imshow(input_imgs[0])
192
+ axs0.axis("off")
193
+ axs0.set_title("Input Image")
194
+
195
+ # Display the plot
196
+ cols[0].pyplot(fig0)
197
+ # cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}")
198
+
199
+ st.sidebar.title("Setup")
200
+
201
+ # Upscale selection
202
+ upscale = "-"
203
+ # upscale = st.sidebar.selectbox(
204
+ # "Upscale",
205
+ # ["-"] + UPSCALE,
206
+ # help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling",
207
+ # )
208
+
209
+ # Upscale method selection
210
+ if upscale != "-":
211
+ upscale_method_or_model = st.sidebar.selectbox(
212
+ "Upscale Method / Model",
213
+ UPSCALE_METHODS + SR_METHODS,
214
+ help="Select a method or model to upscale the uploaded image",
215
+ )
216
+ else:
217
+ upscale_method_or_model = None
218
+
219
+ # Pupil selection
220
+ pupil_selection = st.sidebar.selectbox(
221
+ "Pupil Selection",
222
+ ["-"] + LABEL_MAP,
223
+ help="Select left or right pupil OR keep blank for both pupil diameter estimation",
224
+ )
225
+
226
+ # Model selection
227
+ tv_model = st.sidebar.selectbox(
228
+ "Classification model",
229
+ TV_MODELS,
230
+ help="Supported Models for Pupil Diameter Estimation",
231
+ )
232
+
233
+ cam_method = "CAM"
234
+ # cam_method = st.sidebar.selectbox(
235
+ # "CAM method",
236
+ # CAM_METHODS,
237
+ # help="The way your class activation map will be computed",
238
+ # )
239
+ # target_layer = st.sidebar.text_input(
240
+ # "Target layer",
241
+ # default_layer,
242
+ # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
243
+ # )
244
+
245
+ st.sidebar.write("\n")
246
+
247
+ if st.sidebar.button("Predict Diameter & Compute CAM"):
248
+ if uploaded_file is None:
249
+ st.sidebar.error("Please upload an image first")
250
+
251
+ else:
252
+ with st.spinner("Analyzing..."):
253
+ model = None
254
+ for input_img in input_imgs:
255
+ if upscale == "-":
256
+ sr_configs = None
257
+ else:
258
+ sr_configs = {
259
+ "method": upscale_method_or_model,
260
+ "params": {"upscale": upscale},
261
+ }
262
+ config_file = {
263
+ "sr_configs": sr_configs,
264
+ "feature_extraction_configs": {
265
+ "blink_detection": False,
266
+ "upscale": upscale,
267
+ "extraction_library": "mediapipe",
268
+ },
269
+ }
270
+
271
+ img = np.array(input_img)
272
+ # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
273
+ # if img.shape[0] > max_size or img.shape[1] > max_size:
274
+ # img = cv2.resize(img, (max_size, max_size))
275
+
276
+ ds_results = EyeDentityDatasetCreation(
277
+ feature_extraction_configs=config_file["feature_extraction_configs"],
278
+ sr_configs=config_file["sr_configs"],
279
+ )(img)
280
+
281
+ # if ds_results is not None:
282
+ # print("ds_results = ", ds_results.keys())
283
+ # NOTE:
284
+ # ds_results.keys() contains ===> 'full_imgs', 'faces', 'eyes', 'blinks', 'iris'
285
+
286
+ preprocess_steps = [
287
+ transforms.ToTensor(),
288
+ transforms.Resize(
289
+ [32, 64],
290
+ interpolation=transforms.InterpolationMode.BICUBIC,
291
+ antialias=True,
292
+ ),
293
+ ]
294
+ preprocess_function = transforms.Compose(preprocess_steps)
295
+
296
+ left_eye = None
297
+ right_eye = None
298
+
299
+ if ds_results is None:
300
+ # print("type of input_img = ", type(input_img))
301
+ input_img = preprocess_function(input_img)
302
+ input_img = input_img.unsqueeze(0)
303
+ if pupil_selection == "left_pupil":
304
+ left_eye = input_img
305
+ elif pupil_selection == "right_pupil":
306
+ right_eye = input_img
307
+ else:
308
+ left_eye = input_img
309
+ right_eye = input_img
310
+ # print("type of left_eye = ", type(left_eye))
311
+ # print("type of right_eye = ", type(right_eye))
312
+ elif "eyes" in ds_results.keys():
313
+ if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
314
+ left_eye = ds_results["eyes"]["left_eye"]
315
+ # print("type of left_eye = ", type(left_eye))
316
+ left_eye = to_pil_image(left_eye).convert("RGB")
317
+ # print("type of left_eye = ", type(left_eye))
318
+
319
+ left_eye = preprocess_function(left_eye)
320
+ # print("type of left_eye = ", type(left_eye))
321
+
322
+ left_eye = left_eye.unsqueeze(0)
323
+ if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
324
+ right_eye = ds_results["eyes"]["right_eye"]
325
+ # print("type of right_eye = ", type(right_eye))
326
+ right_eye = to_pil_image(right_eye).convert("RGB")
327
+ # print("type of right_eye = ", type(right_eye))
328
+
329
+ right_eye = preprocess_function(right_eye)
330
+ # print("type of right_eye = ", type(right_eye))
331
+
332
+ right_eye = right_eye.unsqueeze(0)
333
+ else:
334
+ # print("type of input_img = ", type(input_img))
335
+ input_img = preprocess_function(input_img)
336
+ input_img = input_img.unsqueeze(0)
337
+ if pupil_selection == "left_pupil":
338
+ left_eye = input_img
339
+ elif pupil_selection == "right_pupil":
340
+ right_eye = input_img
341
+ else:
342
+ left_eye = input_img
343
+ right_eye = input_img
344
+ # print("type of left_eye = ", type(left_eye))
345
+ # print("type of right_eye = ", type(right_eye))
346
+
347
+ # print("left_eye = ", left_eye.shape)
348
+ # print("right_eye = ", right_eye.shape)
349
+
350
+ if pupil_selection == "-":
351
+ selected_eyes = ["left_eye", "right_eye"]
352
+ elif pupil_selection == "left_pupil":
353
+ selected_eyes = ["left_eye"]
354
+ elif pupil_selection == "right_pupil":
355
+ selected_eyes = ["right_eye"]
356
+
357
+ for eye_type in selected_eyes:
358
+
359
+ if model is None:
360
+ model_configs = {
361
+ "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
362
+ "registered_model_name": tv_model,
363
+ "num_classes": 1,
364
+ }
365
+ registered_model_name = model_configs["registered_model_name"]
366
+ model = _load_model(model_configs)
367
+
368
+ if registered_model_name == "ResNet18":
369
+ target_layer = model.resnet.layer4[-1].conv2
370
+ elif registered_model_name == "ResNet50":
371
+ target_layer = model.resnet.layer4[-1].conv3
372
+ else:
373
+ raise Exception(f"No target layer available for selected model: {registered_model_name}")
374
+
375
+ if left_eye is not None and eye_type == "left_eye":
376
+ input_img = left_eye
377
+ elif right_eye is not None and eye_type == "right_eye":
378
+ input_img = right_eye
379
+ else:
380
+ raise Exception("Wrong Data")
381
+
382
+ if cam_method is not None:
383
+ cam_extractor = torchcam_methods.__dict__[cam_method](
384
+ model,
385
+ target_layer=target_layer,
386
+ fc_layer=model.resnet.fc,
387
+ input_shape=input_img.shape,
388
+ )
389
+
390
+ # with torch.no_grad():
391
+ out = model(input_img)
392
+ cols[-1].markdown(
393
+ f"<h3>Predicted Pupil Diameter: {out[0].item():.2f} mm</h3>",
394
+ unsafe_allow_html=True,
395
+ )
396
+ # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}")
397
+
398
+ # Retrieve the CAM
399
+ act_maps = cam_extractor(0, out)
400
+
401
+ # Fuse the CAMs if there are several
402
+ activation_map = act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps)
403
+
404
+ # Convert input image and activation map to PIL images
405
+ input_image_pil = to_pil_image(input_img.squeeze(0))
406
+ activation_map_pil = to_pil_image(activation_map, mode="F")
407
+
408
+ # Create the overlayed CAM result
409
+ result = overlay_mask(
410
+ input_image_pil,
411
+ activation_map_pil,
412
+ alpha=0.5,
413
+ )
414
+
415
+ # Create a subplot with 1 row and 2 columns
416
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
417
+
418
+ # Display the input image
419
+ axs[0].imshow(input_image_pil)
420
+ axs[0].axis("off")
421
+ axs[0].set_title("Input Image")
422
+
423
+ # Display the overlayed CAM result
424
+ axs[1].imshow(result)
425
+ axs[1].axis("off")
426
+ axs[1].set_title("Overlayed CAM")
427
+
428
+ # Display the plot
429
+ cols[-1].pyplot(fig)
430
+ cols[-1].text(f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}")
431
+
432
+
433
+ if __name__ == "__main__":
434
+ main()
app_utils.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ import os
4
+ import sys
5
+ import cv2
6
+ from matplotlib import pyplot as plt
7
+ import numpy as np
8
+ import streamlit as st
9
+ import torch
10
+ import tempfile
11
+ from PIL import Image
12
+ from torchvision.transforms.functional import to_pil_image
13
+ from torchvision import transforms
14
+
15
+ from torchcam.methods import CAM
16
+ from torchcam import methods as torchcam_methods
17
+ from torchcam.utils import overlay_mask
18
+ import os.path as osp
19
+
20
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
21
+ sys.path.append(root_path)
22
+
23
+ from preprocessing.dataset_creation import EyeDentityDatasetCreation
24
+ from utils import get_model
25
+
26
+
27
+ @torch.no_grad()
28
+ def load_model(model_configs, device="cpu"):
29
+ """Loads the pre-trained model."""
30
+ model_path = os.path.join(root_path, model_configs["model_path"])
31
+ model_dict = torch.load(model_path, map_location=device)
32
+ model = get_model(model_configs=model_configs)
33
+ model.load_state_dict(model_dict)
34
+ model = model.to(device).eval()
35
+ return model
36
+
37
+
38
+ def extract_frames(video_path):
39
+ """Extracts frames from a video file."""
40
+ vidcap = cv2.VideoCapture(video_path)
41
+ frames = []
42
+ success, image = vidcap.read()
43
+ while success:
44
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
45
+ frames.append(image_rgb)
46
+ success, image = vidcap.read()
47
+ vidcap.release()
48
+ return frames
49
+
50
+
51
+ def resize_frame(image, max_width=640, max_height=480):
52
+ if not isinstance(image, Image.Image):
53
+ image = Image.fromarray(image)
54
+ original_size = image.size
55
+
56
+ # Resize the frame similarly to the image resizing logic
57
+ if original_size[0] == original_size[1] and original_size[0] >= 256:
58
+ max_size = (256, 256)
59
+ else:
60
+ max_size = list(original_size)
61
+ if original_size[0] >= max_width:
62
+ max_size[0] = max_width
63
+ elif original_size[0] < 64:
64
+ max_size[0] = 64
65
+ if original_size[1] >= max_height:
66
+ max_size[1] = max_height
67
+ elif original_size[1] < 32:
68
+ max_size[1] = 32
69
+
70
+ image.thumbnail(max_size)
71
+ # image = image.resize(max_size)
72
+ return image
73
+
74
+
75
+ def is_image(file_extension):
76
+ """Checks if the file is an image."""
77
+ return file_extension.lower() in ["png", "jpeg", "jpg"]
78
+
79
+
80
+ def is_video(file_extension):
81
+ """Checks if the file is a video."""
82
+ return file_extension.lower() in ["mp4", "avi", "mov", "mkv", "webm"]
83
+
84
+
85
+ def display_results(input_image, cam_frame, pupil_diameter, cols):
86
+ """Displays the input image and overlayed CAM result."""
87
+ fig, axs = plt.subplots(1, 2, figsize=(10, 5))
88
+ axs[0].imshow(input_image)
89
+ axs[0].axis("off")
90
+ axs[0].set_title("Input Image")
91
+ axs[1].imshow(cam_frame)
92
+ axs[1].axis("off")
93
+ axs[1].set_title("Overlayed CAM")
94
+ cols[-1].pyplot(fig)
95
+ cols[-1].text(f"Pupil Diameter: {pupil_diameter:.2f} mm")
96
+
97
+
98
+ def preprocess_image(input_img, max_size=(256, 256)):
99
+ """Resizes and preprocesses an image."""
100
+ input_img.thumbnail(max_size)
101
+ preprocess_steps = [
102
+ transforms.ToTensor(),
103
+ transforms.Resize([32, 64], interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
104
+ ]
105
+ return transforms.Compose(preprocess_steps)(input_img).unsqueeze(0)
106
+
107
+
108
+ def overlay_text_on_frame(frame, text, position=(16, 20)):
109
+ """Write text on the image frame using OpenCV."""
110
+ return cv2.putText(frame, text, position, cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, cv2.LINE_AA)
111
+
112
+
113
+ def process_frames(input_imgs, tv_model, pupil_selection, cam_method):
114
+ upscale = "-"
115
+ upscale_method_or_model = "-"
116
+ if upscale == "-":
117
+ sr_configs = None
118
+ else:
119
+ sr_configs = {
120
+ "method": upscale_method_or_model,
121
+ "params": {"upscale": upscale},
122
+ }
123
+ config_file = {
124
+ "sr_configs": sr_configs,
125
+ "feature_extraction_configs": {
126
+ "blink_detection": False,
127
+ "upscale": upscale,
128
+ "extraction_library": "mediapipe",
129
+ },
130
+ }
131
+ left_pupil_model = None
132
+ right_pupil_model = None
133
+ face_frames = []
134
+ output_frames = {}
135
+ input_frames = {}
136
+ predicted_diameters = {}
137
+
138
+ if pupil_selection == "both":
139
+ selected_eyes = ["left_eye", "right_eye"]
140
+
141
+ elif pupil_selection == "left_pupil":
142
+ selected_eyes = ["left_eye"]
143
+
144
+ elif pupil_selection == "right_pupil":
145
+ selected_eyes = ["right_eye"]
146
+
147
+ for eye_type in selected_eyes:
148
+ model_configs = {
149
+ "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt",
150
+ "registered_model_name": tv_model,
151
+ "num_classes": 1,
152
+ }
153
+ if eye_type == "left_eye":
154
+ left_pupil_model = load_model(model_configs)
155
+ left_pupil_cam_extractor = None
156
+ output_frames[eye_type] = []
157
+ input_frames[eye_type] = []
158
+ predicted_diameters[eye_type] = []
159
+ else:
160
+ right_pupil_model = load_model(model_configs)
161
+ right_pupil_cam_extractor = None
162
+ output_frames[eye_type] = []
163
+ input_frames[eye_type] = []
164
+ predicted_diameters[eye_type] = []
165
+
166
+ ds_creation = EyeDentityDatasetCreation(
167
+ feature_extraction_configs=config_file["feature_extraction_configs"],
168
+ sr_configs=config_file["sr_configs"],
169
+ )
170
+
171
+ preprocess_steps = [
172
+ transforms.ToTensor(),
173
+ transforms.Resize(
174
+ [32, 64],
175
+ interpolation=transforms.InterpolationMode.BICUBIC,
176
+ antialias=True,
177
+ ),
178
+ ]
179
+ preprocess_function = transforms.Compose(preprocess_steps)
180
+
181
+ for input_img in input_imgs:
182
+
183
+ img = np.array(input_img)
184
+ ds_results = ds_creation(img)
185
+
186
+ left_eye = None
187
+ right_eye = None
188
+ blinked = False
189
+
190
+ if ds_results is not None and "face" in ds_results:
191
+ face_img = to_pil_image(ds_results["face"])
192
+ has_face = True
193
+ else:
194
+ face_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8))
195
+ has_face = False
196
+ face_frames.append({"has_face": has_face, "img": face_img})
197
+
198
+ if ds_results is not None and "eyes" in ds_results.keys():
199
+ blinked = ds_results["eyes"]["blinked"]
200
+ if not blinked:
201
+ if "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None:
202
+ left_eye = ds_results["eyes"]["left_eye"]
203
+ left_eye = to_pil_image(left_eye).convert("RGB")
204
+ left_eye = preprocess_function(left_eye)
205
+ left_eye = left_eye.unsqueeze(0)
206
+ if "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None:
207
+ right_eye = ds_results["eyes"]["right_eye"]
208
+ right_eye = to_pil_image(right_eye).convert("RGB")
209
+ right_eye = preprocess_function(right_eye)
210
+ right_eye = right_eye.unsqueeze(0)
211
+ else:
212
+ input_img = preprocess_function(input_img)
213
+ input_img = input_img.unsqueeze(0)
214
+ if pupil_selection == "left_pupil":
215
+ left_eye = input_img
216
+ elif pupil_selection == "right_pupil":
217
+ right_eye = input_img
218
+ else:
219
+ left_eye = input_img
220
+ right_eye = input_img
221
+
222
+ for eye_type in selected_eyes:
223
+ if left_eye is not None and eye_type == "left_eye":
224
+ if left_pupil_cam_extractor is None:
225
+ if tv_model == "ResNet18":
226
+ target_layer = left_pupil_model.resnet.layer4[-1].conv2
227
+ elif tv_model == "ResNet50":
228
+ target_layer = left_pupil_model.resnet.layer4[-1].conv3
229
+ else:
230
+ raise Exception(f"No target layer available for selected model: {tv_model}")
231
+ left_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
232
+ left_pupil_model,
233
+ target_layer=target_layer,
234
+ fc_layer=left_pupil_model.resnet.fc,
235
+ input_shape=left_eye.shape,
236
+ )
237
+ output = left_pupil_model(left_eye)
238
+ predicted_diameter = output[0].item()
239
+ act_maps = left_pupil_cam_extractor(0, output)
240
+ activation_map = act_maps[0] if len(act_maps) == 1 else left_pupil_cam_extractor.fuse_cams(act_maps)
241
+ input_image_pil = to_pil_image(left_eye.squeeze(0))
242
+ elif right_eye is not None and eye_type == "right_eye":
243
+ if right_pupil_cam_extractor is None:
244
+ if tv_model == "ResNet18":
245
+ target_layer = right_pupil_model.resnet.layer4[-1].conv2
246
+ elif tv_model == "ResNet50":
247
+ target_layer = right_pupil_model.resnet.layer4[-1].conv3
248
+ else:
249
+ raise Exception(f"No target layer available for selected model: {tv_model}")
250
+ right_pupil_cam_extractor = torchcam_methods.__dict__[cam_method](
251
+ right_pupil_model,
252
+ target_layer=target_layer,
253
+ fc_layer=right_pupil_model.resnet.fc,
254
+ input_shape=right_eye.shape,
255
+ )
256
+ output = right_pupil_model(right_eye)
257
+ predicted_diameter = output[0].item()
258
+ act_maps = right_pupil_cam_extractor(0, output)
259
+ activation_map = act_maps[0] if len(act_maps) == 1 else right_pupil_cam_extractor.fuse_cams(act_maps)
260
+ input_image_pil = to_pil_image(right_eye.squeeze(0))
261
+
262
+ if blinked:
263
+ zeros_img = to_pil_image(np.zeros((256, 256, 3), dtype=np.uint8))
264
+ input_image_pil = zeros_img
265
+ result = zeros_img
266
+ predicted_diameter = 0
267
+ else:
268
+ # Create CAM overlay
269
+ activation_map_pil = to_pil_image(activation_map, mode="F")
270
+ result = overlay_mask(input_image_pil, activation_map_pil, alpha=0.5)
271
+
272
+ # Add frame and predicted diameter to lists
273
+ input_frames[eye_type].append(np.array(input_image_pil))
274
+ output_frames[eye_type].append(np.array(result))
275
+ predicted_diameters[eye_type].append(predicted_diameter)
276
+
277
+ return input_frames, output_frames, predicted_diameters, face_frames
278
+
279
+
280
+ # Function to display video with autoplay and loop
281
+ def display_video_with_autoplay(video_col, video_path):
282
+ video_html = f"""
283
+ <video width="100%" height="auto" autoplay loop muted>
284
+ <source src="data:video/mp4;base64,{video_path}" type="video/mp4">
285
+ </video>
286
+ """
287
+ video_col.markdown(video_html, unsafe_allow_html=True)
288
+
289
+
290
+ def get_codec_and_extension(file_format):
291
+ """Return codec and file extension based on the format."""
292
+ if file_format == "mp4":
293
+ return "H264", ".mp4"
294
+ elif file_format == "avi":
295
+ return "MJPG", ".avi"
296
+ elif file_format == "webm":
297
+ return "VP80", ".webm"
298
+ else:
299
+ return "MJPG", ".avi"
300
+
301
+
302
+ def process_video(cols, video_frames, tv_model, pupil_selection, output_path, cam_method):
303
+
304
+ resized_frames = []
305
+ for i, frame in enumerate(video_frames):
306
+ input_img = resize_frame(frame, max_width=640, max_height=480)
307
+ # input_img = Image.fromarray(input_img)
308
+ resized_frames.append(input_img)
309
+
310
+ input_frames, output_frames, predicted_diameters, face_frames = process_frames(
311
+ resized_frames, tv_model, pupil_selection, cam_method
312
+ )
313
+
314
+ file_format = output_path.split(".")[-1]
315
+ codec, extension = get_codec_and_extension(file_format)
316
+
317
+ video_cols = cols[1].columns(len(input_frames.keys()))
318
+
319
+ for i, eye_type in enumerate(input_frames.keys()):
320
+ in_frames = input_frames[eye_type]
321
+ height, width, _ = in_frames[0].shape
322
+ fourcc = cv2.VideoWriter_fourcc(*codec)
323
+ fps = 10.0
324
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
325
+ for frame in in_frames:
326
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
327
+ out.release()
328
+
329
+ with open(output_path, "rb") as video_file:
330
+ video_bytes = video_file.read()
331
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
332
+ display_video_with_autoplay(video_cols[i], video_base64)
333
+
334
+ os.remove(output_path)
335
+
336
+ for i, eye_type in enumerate(output_frames.keys()):
337
+ out_frames = output_frames[eye_type]
338
+ height, width, _ = out_frames[0].shape
339
+ fourcc = cv2.VideoWriter_fourcc(*codec)
340
+ fps = 10.0
341
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
342
+ for j, frame in enumerate(out_frames):
343
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
344
+ out.release()
345
+
346
+ with open(output_path, "rb") as video_file:
347
+ video_bytes = video_file.read()
348
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
349
+ display_video_with_autoplay(video_cols[i], video_base64)
350
+
351
+ os.remove(output_path)
352
+
353
+ for i, eye_type in enumerate(output_frames.keys()):
354
+
355
+ out_frames = output_frames[eye_type]
356
+ height, width, _ = out_frames[0].shape
357
+ fourcc = cv2.VideoWriter_fourcc(*codec)
358
+ fps = 10.0
359
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
360
+
361
+ for diameter in predicted_diameters[eye_type]:
362
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
363
+ text = f"{diameter:.2f}"
364
+ frame = overlay_text_on_frame(frame, text)
365
+ out.write(frame)
366
+ out.release()
367
+
368
+ with open(output_path, "rb") as video_file:
369
+ video_bytes = video_file.read()
370
+ video_base64 = base64.b64encode(video_bytes).decode("utf-8")
371
+ display_video_with_autoplay(video_cols[i], video_base64)
372
+ os.remove(output_path)
373
+
374
+ return predicted_diameters
config.yml CHANGED
@@ -2,8 +2,9 @@ seed: 42
2
 
3
  feature_extraction_configs:
4
  blink_detection: true
 
5
  extraction_library: "mediapipe"
6
- show_features: ['full_imgs', 'faces', 'eyes', 'blinks', 'iris']
7
 
8
  model_configs:
9
  models_path: "pre_trained_models"
 
2
 
3
  feature_extraction_configs:
4
  blink_detection: true
5
+ upscale: 1
6
  extraction_library: "mediapipe"
7
+ show_features: ['faces', 'eyes', 'blinks']
8
 
9
  model_configs:
10
  models_path: "pre_trained_models"
feature_extraction/extractor_mediapipe.py CHANGED
@@ -18,9 +18,7 @@ class ExtractorMediaPipe:
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  # ========== Face Extraction ==========
21
- self.face_detector = mp.solutions.face_detection.FaceDetection(
22
- model_selection=0, min_detection_confidence=0.5
23
- )
24
  self.face_mesh = mp.solutions.face_mesh.FaceMesh(
25
  max_num_faces=1,
26
  static_image_mode=True,
@@ -169,19 +167,11 @@ class ExtractorMediaPipe:
169
  left_eye_landmark3 = landmarks[left_indices[12]]
170
  left_eye_landmark4 = landmarks[left_indices[4]]
171
 
172
- right_eye_horizontal_distance = self.euclideanDistance(
173
- right_eye_landmark1, right_eye_landmark2
174
- )
175
- right_eye_vertical_distance = self.euclideanDistance(
176
- right_eye_landmark3, right_eye_landmark4
177
- )
178
 
179
- left_eye_vertical_distance = self.euclideanDistance(
180
- left_eye_landmark3, left_eye_landmark4
181
- )
182
- left_eye_horizontal_distance = self.euclideanDistance(
183
- left_eye_landmark1, left_eye_landmark2
184
- )
185
 
186
  right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
187
  left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
@@ -192,10 +182,7 @@ class ExtractorMediaPipe:
192
 
193
  def extract_eyes_regions(self, image, landmarks, eye_indices):
194
  h, w, _ = image.shape
195
- points = [
196
- (int(landmarks[idx].x * w), int(landmarks[idx].y * h))
197
- for idx in eye_indices
198
- ]
199
 
200
  x_min = min([p[0] for p in points])
201
  x_max = max([p[0] for p in points])
@@ -261,21 +248,14 @@ class ExtractorMediaPipe:
261
 
262
  if blink_detection:
263
  mesh_coordinates = self.landmarksDetection(image, results, False)
264
- eyes_ratio = self.blinkRatio(
265
- mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE
266
- )
267
- if (
268
- eyes_ratio > self.blink_lower_thresh
269
- and eyes_ratio <= self.blink_upper_thresh
270
- ):
271
  # print(
272
  # "I think person blinked. eyes_ratio = ",
273
  # eyes_ratio,
274
  # "Confirming with ViT model...",
275
  # )
276
- blinked = self.blink_detection_model(
277
- left_eye=left_eye, right_eye=right_eye
278
- )
279
  # if blinked:
280
  # print("Yes, person blinked. Confirmed by model")
281
  # else:
@@ -298,9 +278,7 @@ class ExtractorMediaPipe:
298
  iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
299
 
300
  # Perform adaptive thresholding
301
- _, iris_img_mask = cv2.threshold(
302
- iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU
303
- )
304
 
305
  # Invert the mask
306
  segmented_mask = cv2.bitwise_not(iris_img_mask)
@@ -335,9 +313,7 @@ class ExtractorMediaPipe:
335
 
336
  cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
337
 
338
- left_iris_segmented_data = self.segment_iris(
339
- cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB)
340
- )
341
 
342
  # Crop the right iris to be exactly 16*upscaled x 16*upscaled
343
  r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
@@ -347,9 +323,7 @@ class ExtractorMediaPipe:
347
 
348
  cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
349
 
350
- right_iris_segmented_data = self.segment_iris(
351
- cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB)
352
- )
353
 
354
  return {
355
  "left_iris": {
 
18
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
  # ========== Face Extraction ==========
21
+ self.face_detector = mp.solutions.face_detection.FaceDetection(model_selection=0, min_detection_confidence=0.5)
 
 
22
  self.face_mesh = mp.solutions.face_mesh.FaceMesh(
23
  max_num_faces=1,
24
  static_image_mode=True,
 
167
  left_eye_landmark3 = landmarks[left_indices[12]]
168
  left_eye_landmark4 = landmarks[left_indices[4]]
169
 
170
+ right_eye_horizontal_distance = self.euclideanDistance(right_eye_landmark1, right_eye_landmark2)
171
+ right_eye_vertical_distance = self.euclideanDistance(right_eye_landmark3, right_eye_landmark4)
 
 
 
 
172
 
173
+ left_eye_vertical_distance = self.euclideanDistance(left_eye_landmark3, left_eye_landmark4)
174
+ left_eye_horizontal_distance = self.euclideanDistance(left_eye_landmark1, left_eye_landmark2)
 
 
 
 
175
 
176
  right_eye_ratio = right_eye_vertical_distance / right_eye_horizontal_distance
177
  left_eye_ratio = left_eye_vertical_distance / left_eye_horizontal_distance
 
182
 
183
  def extract_eyes_regions(self, image, landmarks, eye_indices):
184
  h, w, _ = image.shape
185
+ points = [(int(landmarks[idx].x * w), int(landmarks[idx].y * h)) for idx in eye_indices]
 
 
 
186
 
187
  x_min = min([p[0] for p in points])
188
  x_max = max([p[0] for p in points])
 
248
 
249
  if blink_detection:
250
  mesh_coordinates = self.landmarksDetection(image, results, False)
251
+ eyes_ratio = self.blinkRatio(mesh_coordinates, self.RIGHT_EYE, self.LEFT_EYE)
252
+ if eyes_ratio > self.blink_lower_thresh and eyes_ratio <= self.blink_upper_thresh:
 
 
 
 
 
253
  # print(
254
  # "I think person blinked. eyes_ratio = ",
255
  # eyes_ratio,
256
  # "Confirming with ViT model...",
257
  # )
258
+ blinked = self.blink_detection_model(left_eye=left_eye, right_eye=right_eye)
 
 
259
  # if blinked:
260
  # print("Yes, person blinked. Confirmed by model")
261
  # else:
 
278
  iris_img_blur = cv2.GaussianBlur(iris_img_gray, (5, 5), 0)
279
 
280
  # Perform adaptive thresholding
281
+ _, iris_img_mask = cv2.threshold(iris_img_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
 
 
282
 
283
  # Invert the mask
284
  segmented_mask = cv2.bitwise_not(iris_img_mask)
 
313
 
314
  cropped_left_iris = image[l_y1:l_y2, l_x1:l_x2]
315
 
316
+ left_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_left_iris, cv2.COLOR_BGR2RGB))
 
 
317
 
318
  # Crop the right iris to be exactly 16*upscaled x 16*upscaled
319
  r_x1 = max(int(r_cx) - (8 * self.upscale), 0)
 
323
 
324
  cropped_right_iris = image[r_y1:r_y2, r_x1:r_x2]
325
 
326
+ right_iris_segmented_data = self.segment_iris(cv2.cvtColor(cropped_right_iris, cv2.COLOR_BGR2RGB))
 
 
327
 
328
  return {
329
  "left_iris": {
feature_extraction/features_extractor.py CHANGED
@@ -14,9 +14,7 @@ warnings.filterwarnings("ignore")
14
 
15
  class FeaturesExtractor:
16
 
17
- def __init__(
18
- self, extraction_library="mediapipe", blink_detection=False, upscale=1
19
- ):
20
  self.upscale = upscale
21
  self.blink_detection = blink_detection
22
  self.extraction_library = extraction_library
 
14
 
15
  class FeaturesExtractor:
16
 
17
+ def __init__(self, extraction_library="mediapipe", blink_detection=False, upscale=1):
 
 
18
  self.upscale = upscale
19
  self.blink_detection = blink_detection
20
  self.extraction_library = extraction_library
image.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+
4
+ # Load the original face image
5
+ face_image = cv2.imread("path_to_face_image.jpg")
6
+
7
+ # Suppose CAM_left and CAM_right are the CAM results for the eyes (each 32x64)
8
+ CAM_left = cv2.imread("path_to_CAM_left.jpg") # or generated by your model
9
+ CAM_right = cv2.imread("path_to_CAM_right.jpg") # or generated by your model
10
+
11
+ # Example bounding boxes for the left and right eye
12
+ left_eye_bbox = (x_left, y_left, width_left, height_left)
13
+ right_eye_bbox = (x_right, y_right, width_right, height_right)
14
+
15
+ # Resize CAM images if needed (they should be 32x64, but resize to match bbox size)
16
+ CAM_left_resized = cv2.resize(CAM_left, (width_left, height_left))
17
+ CAM_right_resized = cv2.resize(CAM_right, (width_right, height_right))
18
+
19
+ # Create a copy of the face image to overlay the CAM results
20
+ face_with_CAM = face_image.copy()
21
+
22
+ # Overlay left eye CAM
23
+ face_with_CAM[y_left : y_left + height_left, x_left : x_left + width_left] = CAM_left_resized
24
+
25
+ # Overlay right eye CAM
26
+ face_with_CAM[y_right : y_right + height_right, x_right : x_right + width_right] = CAM_right_resized
27
+
28
+ # Save or display the result
29
+ cv2.imwrite("face_with_CAM_overlay.jpg", face_with_CAM)
30
+ cv2.imshow("Face with CAM Overlay", face_with_CAM)
31
+ cv2.waitKey(0)
32
+ cv2.destroyAllWindows()
video.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+
4
+ # Load the video
5
+ video_path = "path_to_video.mp4"
6
+ cap = cv2.VideoCapture(video_path)
7
+
8
+ # Video properties
9
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
10
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
11
+ fps = cap.get(cv2.CAP_PROP_FPS)
12
+
13
+ # Create a VideoWriter object for the output video
14
+ out = cv2.VideoWriter("output_with_CAM.mp4", cv2.VideoWriter_fourcc(*"mp4v"), fps, (frame_width, frame_height))
15
+
16
+ # Process each frame
17
+ while True:
18
+ ret, frame = cap.read()
19
+ if not ret:
20
+ break # End of the video
21
+
22
+ # Detect landmarks for left and right eye bounding boxes (example)
23
+ left_eye_bbox = (x_left, y_left, width_left, height_left)
24
+ right_eye_bbox = (x_right, y_right, width_right, height_right)
25
+
26
+ # Crop the eyes
27
+ left_eye = frame[y_left : y_left + height_left, x_left : x_left + width_left]
28
+ right_eye = frame[y_right : y_right + height_right, x_right : x_right + width_right]
29
+
30
+ # Generate CAMs for left and right eyes
31
+ CAM_left = generate_CAM(left_eye) # Use your model here
32
+ CAM_right = generate_CAM(right_eye) # Use your model here
33
+
34
+ # Resize CAMs if necessary
35
+ CAM_left_resized = cv2.resize(CAM_left, (width_left, height_left))
36
+ CAM_right_resized = cv2.resize(CAM_right, (width_right, height_right))
37
+
38
+ # Overlay the CAMs onto the original frame
39
+ frame[y_left : y_left + height_left, x_left : x_left + width_left] = CAM_left_resized
40
+ frame[y_right : y_right + height_right, x_right : x_right + width_right] = CAM_right_resized
41
+
42
+ # Write the processed frame to the output video
43
+ out.write(frame)
44
+
45
+ # Release resources
46
+ cap.release()
47
+ out.release()
48
+ cv2.destroyAllWindows()