This model performs both segmentation and classification on chest radiographs (X-rays).
The model uses a tf_efficientnetv2_s
backbone with a U-Net decoder for segmentation and linear layer for classification.
For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart.
The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex.
The CheXpert (small version) and NIH Chest X-ray datasets were used to train the model.
Segmentation masks were obtained from the CheXmask dataset (paper).
The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed.
The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert.
This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset.
Validation performance as follows:
Segmentation (Dice similarity coefficient):
Right Lung: 0.957
Left Lung: 0.948
Heart: 0.943
Age Prediction:
Mean Absolute Error: 5.25 years
Classification:
View (AP, PA, lateral): 99.42% accuracy
Female: 0.999 AUC
To use the model:
import cv2
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
img = cv2.imread(..., 0)
x = model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()
with torch.inference_mode():
out = model(x.to(device))
The output is a dictionary which contains 4 keys:
mask
has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e.,out["mask"].argmax(1)
): 1 = right lung, 2 = left lung, 3 = heart.age
, in years.view
, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e.,out["view"].argmax(1)
): 0 = AP, 1 = PA, 2 = lateral.female
, binarize without["female"] >= 0.5
.
You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. You can also calculate the cardiothoracic ratio (CTR) using this function:
import numpy as np
def calculate_ctr(mask): # single mask with dims (height, width)
lungs = np.zeros_like(mask)
lungs[mask == 1] = 1
lungs[mask == 2] = 1
heart = (mask == 3).astype("int")
y, x = np.stack(np.where(lungs == 1))
lung_min = x.min()
lung_max = x.max()
y, x = np.stack(np.where(heart == 1))
heart_min = x.min()
heart_max = x.max()
lung_range = lung_max - lung_min
heart_range = heart_max - heart_min
return heart_range / lung_range
If you have pydicom
installed, you can also load a DICOM image directly:
img = model.load_image_from_dicom(path_to_dicom)
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.
- Downloads last month
- 233
Model tree for ianpan/chest-x-ray-basic
Base model
timm/tf_efficientnetv2_s.in21k_ft_in1k