This model performs both segmentation and classification on chest radiographs (X-rays). The model uses a tf_efficientnetv2_s backbone with a U-Net decoder for segmentation and linear layer for classification. For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart. The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex. The CheXpert (small version) and NIH Chest X-ray datasets were used to train the model. Segmentation masks were obtained from the CheXmask dataset (paper). The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed. The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert. This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset.

Validation performance as follows:

Segmentation (Dice similarity coefficient):
  Right Lung: 0.957
   Left Lung: 0.948
       Heart: 0.943

Age Prediction:
  Mean Absolute Error: 5.25 years

Classification:
  View (AP, PA, lateral): 99.42% accuracy
  Female: 0.999 AUC

To use the model:

import cv2
import torch
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
img = cv2.imread(..., 0)
x = model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()

with torch.inference_mode():
  out = model(x.to(device))

The output is a dictionary which contains 4 keys:

  • mask has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., out["mask"].argmax(1)): 1 = right lung, 2 = left lung, 3 = heart.
  • age, in years.
  • view, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e., out["view"].argmax(1)): 0 = AP, 1 = PA, 2 = lateral.
  • female, binarize with out["female"] >= 0.5.

You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. You can also calculate the cardiothoracic ratio (CTR) using this function:

import numpy as np

def calculate_ctr(mask): # single mask with dims (height, width)
    lungs = np.zeros_like(mask)
    lungs[mask == 1] = 1
    lungs[mask == 2] = 1
    heart = (mask == 3).astype("int")
    y, x = np.stack(np.where(lungs == 1))
    lung_min = x.min()
    lung_max = x.max()
    y, x = np.stack(np.where(heart == 1))    
    heart_min = x.min()
    heart_max = x.max()
    lung_range = lung_max - lung_min
    heart_range = heart_max - heart_min
    return heart_range / lung_range

If you have pydicom installed, you can also load a DICOM image directly:

img = model.load_image_from_dicom(path_to_dicom)

This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.

Downloads last month
233
Safetensors
Model size
22.2M params
Tensor type
F32
ยท
Inference Examples
Inference API (serverless) does not yet support model repos that contain custom code.

Model tree for ianpan/chest-x-ray-basic

Finetuned
(2)
this model

Space using ianpan/chest-x-ray-basic 1