|
import cv2 |
|
import gradio as gr |
|
import json |
|
import numpy as np |
|
import spaces |
|
import torch |
|
import torch.nn as nn |
|
|
|
from einops import rearrange |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from skimage.exposure import match_histograms |
|
from transformers import AutoModel |
|
|
|
|
|
class ModelForGradCAM(nn.Module): |
|
def __init__(self, model, female): |
|
super().__init__() |
|
self.model = model |
|
self.female = female |
|
|
|
def forward(self, x): |
|
return self.model(x, self.female, return_logits=True) |
|
|
|
|
|
def convert_bone_age_to_string(bone_age: float): |
|
|
|
years = round(bone_age // 12) |
|
months = bone_age - (years * 12) |
|
months = round(months) |
|
|
|
if months == 12: |
|
years += 1 |
|
months = 0 |
|
|
|
if years == 0: |
|
str_output = f"{months} months" if months != 1 else "1 month" |
|
else: |
|
if months == 0: |
|
str_output = f"{years} years" if years != 1 else "1 year" |
|
else: |
|
str_output = ( |
|
f"{years} years, {months} months" |
|
if months != 1 |
|
else f"{years} years, 1 month" |
|
) |
|
return str_output |
|
|
|
|
|
@spaces.GPU |
|
def predict_bone_age(Radiograph, Sex, Heatmap): |
|
x = crop_model.preprocess(Radiograph) |
|
x = torch.from_numpy(x).float().to(device) |
|
x = rearrange(x, "h w -> 1 1 h w") |
|
|
|
img_shape = torch.tensor([Radiograph.shape[:2]]).to(device) |
|
with torch.inference_mode(): |
|
box = crop_model(x, img_shape=img_shape).to("cpu").numpy() |
|
x, y, w, h = box[0] |
|
cropped = Radiograph[y : y + h, x : x + w] |
|
|
|
x = match_histograms(cropped, ref_img) |
|
|
|
x = model.preprocess(x) |
|
x = torch.from_numpy(x).float().to(device) |
|
x = rearrange(x, "h w -> 1 1 h w") |
|
female = torch.tensor([Sex]).to(device) |
|
with torch.inference_mode(): |
|
bone_age = model(x, female)[0].item() |
|
|
|
|
|
|
|
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"] |
|
diffs_gp = np.abs(bone_age - gp_ages) |
|
diffs_gp = np.argsort(diffs_gp) |
|
closest1 = gp_ages[diffs_gp[0]] |
|
closest2 = gp_ages[diffs_gp[1]] |
|
|
|
bone_age_str = convert_bone_age_to_string(bone_age) |
|
closest1 = convert_bone_age_to_string(closest1) |
|
closest2 = convert_bone_age_to_string(closest2) |
|
|
|
if Heatmap: |
|
|
|
|
|
|
|
|
|
model_grad_cam = ModelForGradCAM(model.net1, female) |
|
target_layers = [model_grad_cam.model.backbone.stages[-1]] |
|
targets = [ClassifierOutputTarget(round(bone_age))] |
|
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam: |
|
grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True) |
|
|
|
heatmap = cv2.applyColorMap( |
|
(grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET |
|
) |
|
image = cv2.cvtColor( |
|
x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB |
|
) |
|
image_weight = 0.6 |
|
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image |
|
grad_cam_image = grad_cam_image |
|
else: |
|
|
|
grad_cam_image = cv2.cvtColor(x[0, 0].to("cpu").numpy(), cv2.COLOR_GRAY2RGB) |
|
|
|
return ( |
|
bone_age_str, |
|
f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}", |
|
grad_cam_image.astype("uint8"), |
|
) |
|
|
|
|
|
image = gr.Image(image_mode="L") |
|
sex = gr.Radio(["Male", "Female"], type="index") |
|
generate_heatmap = gr.Radio(["No", "Yes"], type="index") |
|
label = gr.Label(show_label=False) |
|
textbox = gr.Textbox(show_label=False) |
|
grad_cam_image = gr.Image(image_mode="RGB", label="Heatmap / Image") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Deep Learning Model for Pediatric Bone Age |
|
|
|
This model predicts the bone age from a single frontal view hand radiograph. Read more about the model here: |
|
<https://huggingface.co/ianpan/bone-age> |
|
|
|
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on |
|
to make its prediction. However, this takes extra computation and will increase the runtime. |
|
|
|
This model is for demonstration purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes |
|
any and all responsibility regarding their own use of this model and its outputs. Do NOT upload any images containing protected |
|
health information, as this demonstration is not compliant with patient privacy laws. |
|
|
|
Created by: Ian Pan, <https://ianpan.me> |
|
|
|
Last updated: December 16, 2024 |
|
""" |
|
) |
|
gr.Interface( |
|
fn=predict_bone_age, |
|
inputs=[image, sex, generate_heatmap], |
|
outputs=[label, textbox, grad_cam_image], |
|
examples=[ |
|
["examples/2639.png", "Female", "Yes"], |
|
["examples/10043.png", "Female", "No"], |
|
["examples/8888.png", "Female", "Yes"], |
|
], |
|
cache_examples=True, |
|
) |
|
|
|
if __name__ == "__main__": |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device `{device}` ...") |
|
|
|
crop_model = AutoModel.from_pretrained( |
|
"ianpan/bone-age-crop", trust_remote_code=True |
|
) |
|
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True) |
|
|
|
crop_model, model = crop_model.eval().to(device), model.eval().to(device) |
|
|
|
ref_img = cv2.imread("ref_img.png", 0) |
|
|
|
with open("greulich_and_pyle_ages.json", "r") as f: |
|
greulich_and_pyle_ages = json.load(f)["bone_ages"] |
|
|
|
greulich_and_pyle_ages = { |
|
k: np.asarray(v) for k, v in greulich_and_pyle_ages.items() |
|
} |
|
|
|
demo.launch(share=True) |
|
|