# takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py # streamlit run app.py from io import BytesIO import os import sys import cv2 import matplotlib.pyplot as plt import numpy as np import streamlit as st import torch from PIL import Image from torchvision import models from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor from torchvision import transforms from torchcam.methods import CAM from torchcam import methods as torchcam_methods from torchcam.utils import overlay_mask import os.path as osp root_path = osp.abspath(osp.join(__file__, osp.pardir)) sys.path.append(root_path) from preprocessing.dataset_creation import EyeDentityDatasetCreation from utils import get_model from registry_utils import import_registered_modules import_registered_modules() # from torchcam.methods._utils import locate_candidate_layer CAM_METHODS = [ "CAM", # "GradCAM", # "GradCAMpp", # "SmoothGradCAMpp", # "ScoreCAM", # "SSCAM", # "ISCAM", # "XGradCAM", # "LayerCAM", ] TV_MODELS = [ "ResNet18", "ResNet50", ] SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"] UPSCALE = [2, 4] UPSCALE_METHODS = ["BILINEAR", "BICUBIC"] LABEL_MAP = ["left_pupil", "right_pupil"] @torch.no_grad() def _load_model(model_configs, device="cpu"): model_path = os.path.join(root_path, model_configs["model_path"]) model_configs.pop("model_path") model_dict = torch.load(model_path, map_location=device) model = get_model(model_configs=model_configs) model.load_state_dict(model_dict) model = model.to(device) model = model.eval() return model def main(): # Wide mode st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") # Designing the interface st.title("EyeDentify Playground") # For newline st.write("\n") # Set the columns cols = st.columns((1, 1)) # cols = st.columns((1, 1, 1)) cols[0].header("Input image") # cols[1].header("Raw CAM") cols[-1].header("Prediction") # Sidebar # File selection st.sidebar.title("Upload Face or Eye") # Disabling warning st.set_option("deprecation.showfileUploaderEncoding", False) # Choose your own image uploaded_file = st.sidebar.file_uploader( "Upload Image", type=["png", "jpeg", "jpg"] ) if uploaded_file is not None: input_img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB") # print("input_img before = ", input_img.size) max_size = [input_img.size[0], input_img.size[1]] cols[0].text(f"Input Image: {max_size[0]} x {max_size[1]}") if input_img.size[0] == input_img.size[1] and input_img.size[0] >= 256: max_size[0] = 256 max_size[1] = 256 else: if input_img.size[0] >= 640: max_size[0] = 640 elif input_img.size[0] < 64: max_size[0] = 64 if input_img.size[1] >= 480: max_size[1] = 480 elif input_img.size[1] < 32: max_size[1] = 32 input_img.thumbnail((max_size[0], max_size[1])) # Bicubic resampling # print("input_img after = ", input_img.size) # cols[0].image(input_img) fig0, axs0 = plt.subplots(1, 1, figsize=(10, 10)) # Display the input image axs0.imshow(input_img) axs0.axis("off") axs0.set_title("Input Image") # Display the plot cols[0].pyplot(fig0) cols[0].text(f"Input Image Resized: {max_size[0]} x {max_size[1]}") st.sidebar.title("Setup") # Upscale selection upscale = "-" # upscale = st.sidebar.selectbox( # "Upscale", # ["-"] + UPSCALE, # help="Upscale the uploaded image 2 or 4 times. Keep blank for no upscaling", # ) # Upscale method selection if upscale != "-": upscale_method_or_model = st.sidebar.selectbox( "Upscale Method / Model", UPSCALE_METHODS + SR_METHODS, help="Select a method or model to upscale the uploaded image", ) else: upscale_method_or_model = None # Pupil selection pupil_selection = st.sidebar.selectbox( "Pupil Selection", ["-"] + LABEL_MAP, help="Select left or right pupil OR keep blank for both pupil diameter estimation", ) # Model selection tv_model = st.sidebar.selectbox( "Classification model", TV_MODELS, help="Supported Models for Pupil Diameter Estimation", ) cam_method = "CAM" # cam_method = st.sidebar.selectbox( # "CAM method", # CAM_METHODS, # help="The way your class activation map will be computed", # ) # target_layer = st.sidebar.text_input( # "Target layer", # default_layer, # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")', # ) st.sidebar.write("\n") if st.sidebar.button("Predict Diameter & Compute CAM"): if uploaded_file is None: st.sidebar.error("Please upload an image first") else: with st.spinner("Analyzing..."): if upscale == "-": sr_configs = None else: sr_configs = { "method": upscale_method_or_model, "params": {"upscale": upscale}, } config_file = { "sr_configs": sr_configs, "feature_extraction_configs": { "blink_detection": False, "upscale": upscale, "extraction_library": "mediapipe", }, } img = np.array(input_img) # img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) # if img.shape[0] > max_size or img.shape[1] > max_size: # img = cv2.resize(img, (max_size, max_size)) ds_results = EyeDentityDatasetCreation( feature_extraction_configs=config_file[ "feature_extraction_configs" ], sr_configs=config_file["sr_configs"], )(img) # if ds_results is not None: # print("ds_results = ", ds_results.keys()) preprocess_steps = [ transforms.ToTensor(), transforms.Resize( [32, 64], # interpolation=transforms.InterpolationMode.BILINEAR, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True, ), ] preprocess_function = transforms.Compose(preprocess_steps) left_eye = None right_eye = None if ds_results is None: # print("type of input_img = ", type(input_img)) input_img = preprocess_function(input_img) input_img = input_img.unsqueeze(0) if pupil_selection == "left_pupil": left_eye = input_img elif pupil_selection == "right_pupil": right_eye = input_img else: left_eye = input_img right_eye = input_img # print("type of left_eye = ", type(left_eye)) # print("type of right_eye = ", type(right_eye)) elif "eyes" in ds_results.keys(): if ( "left_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["left_eye"] is not None ): left_eye = ds_results["eyes"]["left_eye"] # print("type of left_eye = ", type(left_eye)) left_eye = to_pil_image(left_eye).convert("RGB") # print("type of left_eye = ", type(left_eye)) left_eye = preprocess_function(left_eye) # print("type of left_eye = ", type(left_eye)) left_eye = left_eye.unsqueeze(0) if ( "right_eye" in ds_results["eyes"].keys() and ds_results["eyes"]["right_eye"] is not None ): right_eye = ds_results["eyes"]["right_eye"] # print("type of right_eye = ", type(right_eye)) right_eye = to_pil_image(right_eye).convert("RGB") # print("type of right_eye = ", type(right_eye)) right_eye = preprocess_function(right_eye) # print("type of right_eye = ", type(right_eye)) right_eye = right_eye.unsqueeze(0) else: # print("type of input_img = ", type(input_img)) input_img = preprocess_function(input_img) input_img = input_img.unsqueeze(0) if pupil_selection == "left_pupil": left_eye = input_img elif pupil_selection == "right_pupil": right_eye = input_img else: left_eye = input_img right_eye = input_img # print("type of left_eye = ", type(left_eye)) # print("type of right_eye = ", type(right_eye)) # print("left_eye = ", left_eye.shape) # print("right_eye = ", right_eye.shape) if pupil_selection == "-": selected_eyes = ["left_eye", "right_eye"] elif pupil_selection == "left_pupil": selected_eyes = ["left_eye"] elif pupil_selection == "right_pupil": selected_eyes = ["right_eye"] for eye_type in selected_eyes: model_configs = { "model_path": root_path + f"/pre_trained_models/{tv_model}/{eye_type}.pt", "registered_model_name": tv_model, "num_classes": 1, } registered_model_name = model_configs["registered_model_name"] model = _load_model(model_configs) if registered_model_name == "ResNet18": target_layer = model.resnet.layer4[-1].conv2 elif registered_model_name == "ResNet50": target_layer = model.resnet.layer4[-1].conv3 else: raise Exception( f"No target layer available for selected model: {registered_model_name}" ) if left_eye is not None and eye_type == "left_eye": input_img = left_eye elif right_eye is not None and eye_type == "right_eye": input_img = right_eye else: raise Exception("Wrong Data") if cam_method is not None: cam_extractor = torchcam_methods.__dict__[cam_method]( model, target_layer=target_layer, fc_layer=model.resnet.fc, input_shape=input_img.shape, ) # with torch.no_grad(): out = model(input_img) cols[-1].markdown( f"

Predicted Pupil Diameter: {out[0].item():.2f} mm

", unsafe_allow_html=True, ) # cols[-1].text(f"Predicted Pupil Diameter: {out[0].item():.2f}") # Retrieve the CAM act_maps = cam_extractor(0, out) # Fuse the CAMs if there are several activation_map = ( act_maps[0] if len(act_maps) == 1 else cam_extractor.fuse_cams(act_maps) ) # Convert input image and activation map to PIL images input_image_pil = to_pil_image(input_img.squeeze(0)) activation_map_pil = to_pil_image(activation_map, mode="F") # Create the overlayed CAM result result = overlay_mask( input_image_pil, activation_map_pil, alpha=0.5, ) # Create a subplot with 1 row and 2 columns fig, axs = plt.subplots(1, 2, figsize=(10, 5)) # Display the input image axs[0].imshow(input_image_pil) axs[0].axis("off") axs[0].set_title("Input Image") # Display the overlayed CAM result axs[1].imshow(result) axs[1].axis("off") axs[1].set_title("Overlayed CAM") # Display the plot cols[-1].pyplot(fig) cols[-1].text( f"eye image size: {input_img.shape[-1]} x {input_img.shape[-2]}" ) if __name__ == "__main__": main()