Spaces:
Running
Running
import os | |
import sys | |
import tempfile | |
import os.path as osp | |
from PIL import Image | |
from io import BytesIO | |
import numpy as np | |
import streamlit as st | |
from PIL import ImageOps | |
from matplotlib import pyplot as plt | |
root_path = osp.abspath(osp.join(__file__, osp.pardir)) | |
sys.path.append(root_path) | |
from registry_utils import import_registered_modules | |
from app_utils import ( | |
extract_frames, | |
is_image, | |
is_video, | |
display_results, | |
overlay_text_on_frame, | |
process_frames, | |
process_video, | |
resize_frame, | |
) | |
import_registered_modules() | |
CAM_METHODS = ["CAM"] | |
TV_MODELS = ["ResNet18", "ResNet50"] | |
SR_METHODS = ["GFPGAN", "CodeFormer", "RealESRGAN", "SRResNet", "HAT"] | |
UPSCALE = [2, 4] | |
UPSCALE_METHODS = ["BILINEAR", "BICUBIC"] | |
LABEL_MAP = ["left_pupil", "right_pupil"] | |
def main(): | |
st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide") | |
st.title("EyeDentify Playground") | |
cols = st.columns((1, 1)) | |
cols[0].header("Input") | |
cols[-1].header("Prediction") | |
st.sidebar.title("Upload Face or Eye") | |
uploaded_file = st.sidebar.file_uploader( | |
"Upload Image or Video", type=["png", "jpeg", "jpg", "mp4", "avi", "mov", "mkv", "webm"] | |
) | |
if uploaded_file is not None: | |
file_extension = uploaded_file.name.split(".")[-1] | |
if is_image(file_extension): | |
input_img = Image.open(BytesIO(uploaded_file.read())).convert("RGB") | |
# NOTE: images taken with phone camera has an EXIF data field which often rotates images taken with the phone in a tilted position. PIL has a utility function that removes this data and ‘uprights’ the image. | |
input_img = ImageOps.exif_transpose(input_img) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
input_img = resize_frame(input_img, max_width=640, max_height=480) | |
cols[0].image(input_img, use_column_width=True) | |
st.session_state.total_frames = 1 | |
elif is_video(file_extension): | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
video_path = tfile.name | |
video_frames = extract_frames(video_path) | |
cols[0].video(video_path) | |
st.session_state.total_frames = len(video_frames) | |
st.session_state.current_frame = 0 | |
st.session_state.frame_placeholder = cols[0].empty() | |
txt = f"<p style='font-size:20px;'> Number of Frames Processed: <strong>{st.session_state.current_frame} / {st.session_state.total_frames}</strong> </p>" | |
st.session_state.frame_placeholder.markdown(txt, unsafe_allow_html=True) | |
st.sidebar.title("Setup") | |
pupil_selection = st.sidebar.selectbox( | |
"Pupil Selection", ["both"] + LABEL_MAP, help="Select left or right pupil OR both for diameter estimation" | |
) | |
tv_model = st.sidebar.selectbox("Classification model", ["ResNet18", "ResNet50"], help="Supported Models") | |
if st.sidebar.button("Predict Diameter & Compute CAM"): | |
if uploaded_file is None: | |
st.sidebar.error("Please upload an image or video") | |
else: | |
with st.spinner("Analyzing..."): | |
if is_image(file_extension): | |
input_frames, output_frames, predicted_diameters, face_frames = process_frames( | |
cols, | |
[input_img], | |
tv_model, | |
pupil_selection, | |
cam_method=CAM_METHODS[-1], | |
output_path=None, | |
codec=None, | |
) | |
# for ff in face_frames: | |
# if ff["has_face"]: | |
# cols[1].image(face_frames[0]["img"], use_column_width=True) | |
input_frames_keys = input_frames.keys() | |
video_cols = cols[1].columns(len(input_frames_keys)) | |
for i, eye_type in enumerate(input_frames_keys): | |
video_cols[i].image(input_frames[eye_type][-1], use_column_width=True) | |
output_frames_keys = output_frames.keys() | |
fig, axs = plt.subplots(1, len(output_frames_keys), figsize=(10, 5)) | |
for i, eye_type in enumerate(output_frames_keys): | |
height, width, c = output_frames[eye_type][0].shape | |
video_cols[i].image(output_frames[eye_type][-1], use_column_width=True) | |
frame = np.zeros((height, width, c), dtype=np.uint8) | |
text = f"{predicted_diameters[eye_type][0]:.2f}" | |
frame = overlay_text_on_frame(frame, text) | |
video_cols[i].image(frame, use_column_width=True) | |
elif is_video(file_extension): | |
output_video_path = f"{root_path}/tmp.webm" | |
process_video( | |
cols, video_frames, tv_model, pupil_selection, output_video_path, cam_method=CAM_METHODS[-1] | |
) | |
os.remove(video_path) | |
if __name__ == "__main__": | |
main() | |