import torch import cv2 import torch.nn as nn import numpy as np from torchvision import models, transforms import time import os import copy import pickle from PIL import Image import datetime import gdown import zipfile import urllib.request from pytorch_grad_cam import GradCAMPlusPlus from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image import gradio as gr IMG_SIZE = 512 CLASSES = [ "No DR", "Mild", "Moderate", "Severe", "Proliferative DR" ] checkpoint = "./demo_checkpoint_convnext.pth" device = "cpu" if torch.cuda.is_available(): device = "cuda" model = torch.load(checkpoint, device) global_transforms = transforms.Compose([ transforms.ToPILImage(), transforms.Lambda(lambda image: image.convert('RGB')), transforms.Resize(IMG_SIZE), transforms.ToTensor(), transforms.Normalize([0.2786802, 0.2786802, 0.2786802], [0.16637428, 0.16637428, 0.16637428]) ]) def crop_image_from_gray(img,tol=7): mask = img>tol img1=img[np.ix_(mask.any(1),mask.any(0))] img2=img[np.ix_(mask.any(1),mask.any(0))] img3=img[np.ix_(mask.any(1),mask.any(0))] img = np.stack([img1,img2,img3],axis=-1) return img def circle_crop(img): height, width = img.shape x = int(width/2) y = int(height/2) r = np.amin((x,y)) circle_img = np.zeros((height, width), np.uint8) cv2.circle(circle_img, (x,y), int(r), 1, thickness=-1) img = cv2.bitwise_and(img, img, mask=circle_img) img = crop_image_from_gray(img) return img def preprocess(img): # Extract Green Channel img = img[:,:,1] #CLAHE clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) img = clahe.apply(img) # Circle crop img = circle_crop(img) # Resize img = cv2.resize(img, (IMG_SIZE,IMG_SIZE)) return img def grad_campp(img): img = np.float32(img) / 255 input_tensor = preprocess_image(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]).to(device) # Set target layers target_layers = [model.features[-1]] # GradCAM++ gradcampp = GradCAMPlusPlus(model=model, target_layers=target_layers) grayscale_gradcampp = gradcampp(input_tensor=input_tensor, targets=None , eigen_smooth=False, aug_smooth=False) grayscale_gradcampp = grayscale_gradcampp[0, :] gradcampp_image = show_cam_on_image(img, grayscale_gradcampp) return gradcampp_image def do_inference(img): img = preprocess(img) img_t = global_transforms(img) batch_t = torch.unsqueeze(img_t, 0) model.eval() gradcam_img = grad_campp(img) # We don't need gradients for test, so wrap in # no_grad to save memory with torch.no_grad(): batch_t = batch_t.to(device) # forward propagation output = model( batch_t) # get prediction probs = torch.nn.functional.softmax(output, dim=1) output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) probs = probs.cpu().numpy()[0] probs = probs[output] labels = np.array(CLASSES)[output] return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))}, gradcam_img im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', invert_colors=False, source="upload", type="numpy") title = "ConvNeXt for Diabetic Retinopathy Detection" description = "" examples = [['./examples/0_0.jpeg'],['./examples/0_1.png'], ['./examples/1_0.jpeg'],['./examples/1_1.png'], ['./examples/2_0.jpeg'],['./examples/2_1.png'], ['./examples/3_0.jpeg'],['./examples/3_1.png'], ['./examples/4_0.jpeg'],['./examples/4_1.png']] #article="
" iface = gr.Interface( do_inference, im, outputs = [ gr.outputs.Label(num_top_classes=5), gr.outputs.Image(label='Output image', type='pil')], live=False, interpretation=None, title=title, description=description, examples=examples ) #iface.test_launch() iface.launch()