Spaces:
Runtime error
Runtime error
joselobenitezg
commited on
Commit
·
5e4b3a1
1
Parent(s):
c82f96b
add files
Browse files- .gitignore +4 -0
- Learn_PyTorch_ImageSegmentation.ipynb +0 -0
- README.md +56 -1
- model.py +30 -0
- requirements.txt +3 -0
- train.py +55 -0
- utils.py +104 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
flagged
|
3 |
+
*.pt
|
4 |
+
DS_Store
|
Learn_PyTorch_ImageSegmentation.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -9,4 +9,59 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
# Pytorch Image Segmentation
|
13 |
+
|
14 |
+
## This repo contains the code for training a U-Net model for image segmentation on the Human Segmentation Dataset.
|
15 |
+
|
16 |
+
<a href="https://colab.research.google.com/github/josebenitezg/Pytorch-Image-Segmentation/blob/main/Learn_PyTorch_ImageSegmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab">
|
17 |
+
</a>
|
18 |
+
|
19 |
+
## Usage :nut_and_bolt:
|
20 |
+
|
21 |
+
1. Clone this repo
|
22 |
+
|
23 |
+
```
|
24 |
+
git clone https://github.com/josebenitezg/Pytorch-Image-Segmentation
|
25 |
+
```
|
26 |
+
|
27 |
+
2. Create a virtual enviroment
|
28 |
+
|
29 |
+
```
|
30 |
+
python -m venv env
|
31 |
+
```
|
32 |
+
|
33 |
+
3. Activate virtual enviroment
|
34 |
+
|
35 |
+
- for linux
|
36 |
+
|
37 |
+
```
|
38 |
+
source env/bin/activate
|
39 |
+
```
|
40 |
+
|
41 |
+
- for windows
|
42 |
+
|
43 |
+
```
|
44 |
+
env\Scripts\Activate.bat
|
45 |
+
```
|
46 |
+
|
47 |
+
4. Install requirements
|
48 |
+
|
49 |
+
```
|
50 |
+
pip install -r requirements.txt
|
51 |
+
```
|
52 |
+
|
53 |
+
5. Train the model
|
54 |
+
|
55 |
+
```
|
56 |
+
python train.py
|
57 |
+
```
|
58 |
+
|
59 |
+
6. Run gradio inference app
|
60 |
+
|
61 |
+
```
|
62 |
+
python gradio_inference.py
|
63 |
+
```
|
64 |
+
|
65 |
+
This repo contains dataset files to train a small model.
|
66 |
+
|
67 |
+
Dataset Credit : https://github.com/VikramShenoy97/Human-Segmentation-Datasets
|
model.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import segmentation_models_pytorch as smp
|
3 |
+
from segmentation_models_pytorch.losses import DiceLoss
|
4 |
+
|
5 |
+
ENCODER = 'timm-efficientnet-b0'
|
6 |
+
WEIGHTS = 'imagenet'
|
7 |
+
|
8 |
+
class SegmentationModel(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
super(SegmentationModel, self).__init__()
|
12 |
+
|
13 |
+
self.arc = smp.Unet(
|
14 |
+
encoder_name = ENCODER,
|
15 |
+
encoder_weights = WEIGHTS,
|
16 |
+
in_channels = 3,
|
17 |
+
classes = 1,
|
18 |
+
activation = None
|
19 |
+
)
|
20 |
+
|
21 |
+
def forward(self, images, masks = None):
|
22 |
+
|
23 |
+
logits = self.arc(images)
|
24 |
+
|
25 |
+
if masks != None:
|
26 |
+
loss1 = DiceLoss(mode='binary')(logits, masks)
|
27 |
+
loss2 = nn.BCEWithLogitsLoss()(logits, masks)
|
28 |
+
return logits, loss1 + loss2
|
29 |
+
|
30 |
+
return logits
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
albumentations==1.3.0
|
2 |
+
segmentation-models-pytorch==0.3.2
|
3 |
+
opencv-contrib-python
|
train.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from tqdm import tqdm
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
from utils import load_config, get_train_augs, get_valid_augs, train_fn, eval_fn, SegmentationDataset
|
10 |
+
from model import SegmentationModel
|
11 |
+
from sklearn.model_selection import train_test_split
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
# set device for training
|
15 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
# load config file
|
18 |
+
config = load_config()
|
19 |
+
|
20 |
+
# load train files in dataframe
|
21 |
+
df = pd.read_csv(config['files']['CSV_FILE'])
|
22 |
+
|
23 |
+
train_df, valid_df = train_test_split(df, test_size = 0.2, random_state = 42)
|
24 |
+
|
25 |
+
trainset = SegmentationDataset(train_df, get_train_augs(config['model']['IMAGE_SIZE']))
|
26 |
+
|
27 |
+
validset = SegmentationDataset(valid_df, get_valid_augs(config['model']['IMAGE_SIZE']))
|
28 |
+
|
29 |
+
print(f"Size of Trainset : {len(trainset)}")
|
30 |
+
print(f"Size of Validset : {len(validset)}")
|
31 |
+
|
32 |
+
trainloader = DataLoader(trainset, batch_size=config['model']['BATCH_SIZE'], shuffle = True)
|
33 |
+
validloader = DataLoader(validset, batch_size=config['model']['BATCH_SIZE'])
|
34 |
+
|
35 |
+
print(f"Total n of batches in trainloader: {len(trainloader)}")
|
36 |
+
print(f"Total n of batches in validloader: {len(validloader)}")
|
37 |
+
|
38 |
+
|
39 |
+
model = SegmentationModel()
|
40 |
+
model.to(DEVICE)
|
41 |
+
|
42 |
+
optimizer = torch.optim.Adam(model.parameters(), lr = config['model']['LR'])
|
43 |
+
|
44 |
+
best_valid_loss = np.Inf
|
45 |
+
|
46 |
+
for i in tqdm(range(config['model']['EPOCHS'])):
|
47 |
+
|
48 |
+
train_loss = train_fn(trainloader, model, optimizer, DEVICE)
|
49 |
+
valid_loss = eval_fn(validloader, model, DEVICE)
|
50 |
+
|
51 |
+
if valid_loss < best_valid_loss:
|
52 |
+
torch.save(model.state_dict(), 'best_model.pt')
|
53 |
+
print('SAVED-MODEL')
|
54 |
+
best_valid_loss = valid_loss
|
55 |
+
print(f"Epoch: {i+1} Train Loss: {train_loss} Valid Loss: {valid_loss}")
|
utils.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import yaml
|
5 |
+
import numpy as np
|
6 |
+
import albumentations as A
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
|
9 |
+
|
10 |
+
def get_train_augs(IMAGE_SIZE):
|
11 |
+
|
12 |
+
return A.Compose([
|
13 |
+
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
|
14 |
+
A.HorizontalFlip(p = 0.5),
|
15 |
+
A.VerticalFlip(p = 0.5)
|
16 |
+
])
|
17 |
+
|
18 |
+
def get_valid_augs(IMAGE_SIZE):
|
19 |
+
|
20 |
+
return A.Compose([
|
21 |
+
A.Resize(IMAGE_SIZE, IMAGE_SIZE),
|
22 |
+
])
|
23 |
+
|
24 |
+
def train_fn(data_loader, model, optimizer, DEVICE):
|
25 |
+
|
26 |
+
model.train()
|
27 |
+
total_loss = 0.0
|
28 |
+
|
29 |
+
for images, masks in data_loader:
|
30 |
+
|
31 |
+
images = images.to(DEVICE)
|
32 |
+
masks = masks.to(DEVICE)
|
33 |
+
|
34 |
+
optimizer.zero_grad()
|
35 |
+
logits, loss = model(images, masks)
|
36 |
+
loss.backward()
|
37 |
+
optimizer.step()
|
38 |
+
total_loss += loss.item()
|
39 |
+
|
40 |
+
return total_loss / len(data_loader)
|
41 |
+
|
42 |
+
def eval_fn(data_loader, model, DEVICE):
|
43 |
+
|
44 |
+
model.eval()
|
45 |
+
total_loss = 0.0
|
46 |
+
with torch.no_grad():
|
47 |
+
for images, masks in data_loader:
|
48 |
+
|
49 |
+
images = images.to(DEVICE)
|
50 |
+
masks = masks.to(DEVICE)
|
51 |
+
|
52 |
+
logits, loss = model(images, masks)
|
53 |
+
|
54 |
+
total_loss += loss.item()
|
55 |
+
|
56 |
+
return total_loss / len(data_loader)
|
57 |
+
|
58 |
+
def load_config():
|
59 |
+
config_file = f'config/config.yaml'
|
60 |
+
|
61 |
+
with open(config_file, 'r') as file:
|
62 |
+
config = yaml.safe_load(file)
|
63 |
+
|
64 |
+
return config
|
65 |
+
|
66 |
+
|
67 |
+
class SegmentationDataset(Dataset):
|
68 |
+
|
69 |
+
def __init__(self, df, augmentations):
|
70 |
+
|
71 |
+
self.df = df
|
72 |
+
self.augmentations = augmentations
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.df)
|
76 |
+
|
77 |
+
def __getitem__(self, idx):
|
78 |
+
|
79 |
+
row = self.df.iloc[idx]
|
80 |
+
|
81 |
+
image_path = row.images
|
82 |
+
mask_path = row.masks
|
83 |
+
|
84 |
+
image = cv2.imread(image_path)
|
85 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
86 |
+
|
87 |
+
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) #(h, w, c)
|
88 |
+
# Resize the mask to the same dimensions as the image
|
89 |
+
mask = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST) # (h, w)
|
90 |
+
mask = np.expand_dims(mask, axis = -1)
|
91 |
+
|
92 |
+
if self.augmentations:
|
93 |
+
data = self.augmentations(image = image, mask = mask)
|
94 |
+
image = data['image']
|
95 |
+
mask = data['mask']
|
96 |
+
|
97 |
+
# (h, w, c) -> (c, h, w)
|
98 |
+
image = np.transpose(image, (2,0,1)).astype(np.float32)
|
99 |
+
mask = np.transpose(mask, (2,0,1)).astype(np.float32)
|
100 |
+
|
101 |
+
image = torch.Tensor(image) / 255.0
|
102 |
+
mask = torch.round(torch.Tensor(mask) / 255.0)
|
103 |
+
|
104 |
+
return image, mask
|