Spaces:
Sleeping
Sleeping
import cv2 | |
import gradio as gr | |
import json | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from importlib import import_module | |
from pytorch_grad_cam import GradCAM | |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
from skimage.exposure import match_histograms | |
from skp.utils import load_model_from_config, load_kfold_ensemble_as_list | |
class ModelForGradCAM(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
def forward(self, x): | |
return self.model({"x": x})["logits1"] | |
def convert_bone_age_to_string(bone_age: float): | |
# bone_age in months | |
years = round(bone_age // 12) | |
months = bone_age - (years * 12) | |
months = round(months) | |
if months == 12: | |
years += 1 | |
months = 0 | |
if years == 0: | |
str_output = f"{months} months" if months != 1 else "1 month" | |
else: | |
if months == 0: | |
str_output = f"{years} years" if years != 1 else "1 year" | |
else: | |
str_output = ( | |
f"{years} years, {months} months" | |
if months != 1 | |
else f"{years} years, 1 month" | |
) | |
return str_output | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
cfg_crop = import_module("skp.configs.boneage.cfg_crop_simple_resize").cfg | |
crop_model = load_model_from_config( | |
cfg_crop, weights_path="crop.pt", device=device, eval_mode=True | |
) | |
cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg | |
cfg.backbone = "convnextv2_tiny" | |
model_list = load_kfold_ensemble_as_list( | |
cfg, [f"net{i}.pt" for i in range(3)], device=device, eval_mode=True | |
) | |
ref_img = rearrange(cv2.imread("ref_img.png", 0), "h w -> h w 1 ") | |
with open("greulich_and_pyle_ages.json", "r") as f: | |
greulich_and_pyle_ages = json.load(f)["bone_ages"] | |
greulich_and_pyle_ages = {k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()} | |
model_grad_cam = ModelForGradCAM(model_list[0]) | |
target_layers = [model_grad_cam.model.backbone.stages[-1]] | |
def predict_bone_age(Radiograph, Sex, Heatmap): | |
x0 = rearrange(Radiograph, "h w -> h w 1") | |
x = cfg_crop.val_transforms(image=x0)["image"] | |
x = torch.from_numpy(x) | |
x = rearrange(x, "h w c -> 1 c h w") | |
# crop | |
with torch.inference_mode(): | |
box = crop_model({"x": x.to(device).float()}, return_loss=False)["logits"][ | |
0 | |
].cpu() | |
box[[0, 2]] = box[[0, 2]] * x0.shape[1] | |
box[[1, 3]] = box[[1, 3]] * x0.shape[0] | |
box = box.numpy().astype("int") | |
x, y, w, h = box | |
x0 = x0[y : y + h, x : x + w] | |
# histogram matching | |
x0 = match_histograms(x0, ref_img) | |
x = cfg.val_transforms(image=x0)["image"] | |
# create image channel for female/male | |
ch = np.zeros_like(x) | |
if Sex: # 0- male, 1- female | |
ch[...] = 255 | |
x = np.concatenate([x, ch], axis=-1) | |
x = torch.from_numpy(x) | |
x = rearrange(x, "h w c -> 1 c h w") | |
with torch.inference_mode(): | |
bone_age = [] | |
for each_model in model_list: | |
pred = each_model({"x": x.to(device).float()}, return_loss=False)[ | |
"logits1" | |
][0].cpu() | |
pred = (pred.softmax(0) * torch.arange(240)).sum().numpy() | |
bone_age.append(pred) | |
bone_age = np.mean(bone_age) | |
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"] | |
diffs_gp = np.abs(bone_age - gp_ages) | |
diffs_gp = np.argsort(diffs_gp) | |
closest1 = gp_ages[diffs_gp[0]] | |
closest2 = gp_ages[diffs_gp[1]] | |
bone_age_str = convert_bone_age_to_string(bone_age) | |
closest1 = convert_bone_age_to_string(closest1) | |
closest2 = convert_bone_age_to_string(closest2) | |
if Heatmap: | |
targets = [ClassifierOutputTarget(round(bone_age))] | |
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam: | |
grayscale_cam = cam(input_tensor=x.to(device).float(), targets=targets, eigen_smooth=True) | |
heatmap = cv2.applyColorMap((grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET) | |
image = cv2.cvtColor(x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB) | |
image_weight = 0.6 | |
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image | |
grad_cam_image = grad_cam_image.astype("uint8") | |
else: | |
# if no heatmap desired, just show image | |
grad_cam_image = cv2.cvtColor(x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB) | |
return f"Predicted bone age: {bone_age_str}\n\nThe closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}", grad_cam_image | |
image = gr.Image(image_mode="L") | |
sex = gr.Radio(["Male", "Female"], type="index") | |
generate_heatmap = gr.Radio(["No", "Yes"], type="index") | |
textbox = gr.Textbox(show_label=True, label="Result") | |
grad_cam_image = gr.Image(image_mode="RGB", label="Heatmap / Image") | |
with gr.Blocks() as demo: | |
gr.Markdown( | |
""" | |
# Deep Learning Model for Pediatric Bone Age | |
This model predicts the bone age from a single frontal view hand radiograph. | |
The model was trained on the publicly available | |
[RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset. | |
The model achieves a mean absolute error of 4.26 months on the original test set comprising 200 multi-annotated hand radiographs, | |
which is competitive with [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge. | |
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on | |
to make its prediction. However, this takes extra computation and will increase the runtime. | |
This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes | |
any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected | |
health information, as this demonstration is not compliant with patient privacy laws. | |
Created by: Ian Pan, <https://ianpan.me> | |
Last updated: December 15, 2024 | |
""" | |
) | |
gr.Interface( | |
fn=predict_bone_age, | |
inputs=[image, sex, generate_heatmap], | |
outputs=[textbox, grad_cam_image], | |
examples=[ | |
["examples/2639.png", "Female", "Yes"], | |
["examples/10043.png", "Female", "No"], | |
["examples/8888.png", "Female", "Yes"], | |
], | |
cache_examples=False | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |