jamessyx commited on
Commit
bc8d78e
·
verified ·
1 Parent(s): 05994c7

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -0
README.md ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Classifier for Selecting Pathology Images
2
+
3
+
4
+
5
+ This is a ConvNext-tiny model trained on 30K annotations on if image is belongs to the pathology image or non-pathology image.
6
+
7
+ ## Usage
8
+
9
+ > #### Step1: Download model checkpoint in [convnext-pathology-classifier](https://huggingface.co/jamessyx/convnext-pathology-classifier) .
10
+
11
+
12
+
13
+ > #### Step2: Load the model
14
+
15
+ You can use the following code to load the model.
16
+
17
+ ```python
18
+ import timm ##timm version 0.9.7
19
+ import torch.nn as nn
20
+ import torch
21
+ from torchvision import transforms
22
+ from PIL import Image
23
+
24
+ class CT_SINGLE(nn.Module):
25
+ def __init__(self, model_name):
26
+ super(CT_SINGLE, self).__init__()
27
+ print(model_name)
28
+ self.model_global = timm.create_model(model_name, pretrained=False, num_classes=0)
29
+ self.fc = nn.Linear(768, 2)
30
+
31
+ def forward(self, x_global):
32
+ features_global = self.model_global(x_global)
33
+ logits = self.fc(features_global)
34
+ return logits
35
+
36
+ def load_model(checkpoint_path, model):
37
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
38
+ model.load_state_dict(checkpoint['model'])
39
+ print("Resume checkpoint %s" % checkpoint_path)
40
+
41
+ ##load the model
42
+ model = CT_SINGLE('convnext_tiny')
43
+ model_path = 'Your model path'
44
+ load_model(model_path, model)
45
+ model.eval().cuda()
46
+
47
+ ```
48
+
49
+
50
+
51
+ > ### Step3: Construct and predict your own data
52
+
53
+ In this step, you'll construct your own dataset. Use PIL to load images and employ `transforms` from torchvision for data preprocessing.
54
+
55
+ ```python
56
+ def default_loader(path):
57
+ img = Image.open(path)
58
+ return img.convert('RGB')
59
+
60
+ data_transforms = transforms.Compose([
61
+ transforms.Resize((224,224)),
62
+ transforms.ToTensor(),
63
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
64
+
65
+ def predict(img_path, model):
66
+ img = default_loader(img_path)
67
+ img = data_transforms(img)
68
+ img = img.unsqueeze(0)
69
+ img = img.cuda()
70
+ output = model(img)
71
+ _, pred = torch.topk(output, 1, dim=-1)
72
+ pred = pred.data.cpu().numpy()[:, 0]
73
+ return pred ## 0 indicates non-pathology image and 1 indicates pathology image
74
+
75
+ img_path = 'Your image path'
76
+ pred = predict(img_path, model)
77
+ print(pred)
78
+ ```
79
+