|
# Classifier for Selecting Pathology Images |
|
|
|
|
|
|
|
This is a ConvNext-tiny model trained on 30K annotations on if image is belongs to the pathology image or non-pathology image. |
|
|
|
## Usage |
|
|
|
> #### Step1: Download model checkpoint in [convnext-pathology-classifier](https://huggingface.co/jamessyx/convnext-pathology-classifier) . |
|
|
|
|
|
|
|
> #### Step2: Load the model |
|
|
|
You can use the following code to load the model. |
|
|
|
```python |
|
import timm ##timm version 0.9.7 |
|
import torch.nn as nn |
|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
|
|
class CT_SINGLE(nn.Module): |
|
def __init__(self, model_name): |
|
super(CT_SINGLE, self).__init__() |
|
print(model_name) |
|
self.model_global = timm.create_model(model_name, pretrained=False, num_classes=0) |
|
self.fc = nn.Linear(768, 2) |
|
|
|
def forward(self, x_global): |
|
features_global = self.model_global(x_global) |
|
logits = self.fc(features_global) |
|
return logits |
|
|
|
def load_model(checkpoint_path, model): |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
model.load_state_dict(checkpoint['model']) |
|
print("Resume checkpoint %s" % checkpoint_path) |
|
|
|
##load the model |
|
model = CT_SINGLE('convnext_tiny') |
|
model_path = 'Your model path' |
|
load_model(model_path, model) |
|
model.eval().cuda() |
|
|
|
``` |
|
|
|
|
|
|
|
> ### Step3: Construct and predict your own data |
|
|
|
In this step, you'll construct your own dataset. Use PIL to load images and employ `transforms` from torchvision for data preprocessing. |
|
|
|
```python |
|
def default_loader(path): |
|
img = Image.open(path) |
|
return img.convert('RGB') |
|
|
|
data_transforms = transforms.Compose([ |
|
transforms.Resize((224,224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
|
|
|
def predict(img_path, model): |
|
img = default_loader(img_path) |
|
img = data_transforms(img) |
|
img = img.unsqueeze(0) |
|
img = img.cuda() |
|
output = model(img) |
|
_, pred = torch.topk(output, 1, dim=-1) |
|
pred = pred.data.cpu().numpy()[:, 0] |
|
return pred ## 0 indicates non-pathology image and 1 indicates pathology image |
|
|
|
img_path = 'Your image path' |
|
pred = predict(img_path, model) |
|
print(pred) |
|
``` |
|
|
|
|