File size: 5,402 Bytes
85e3d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import os
import pandas as pd
from tqdm import tqdm
from encode import rle_encode, list_to_string

def dice_score(y_p, y_t, smooth=1e-6):
    y_p = y_p[:, :, 2:-2, 2:-2]
    y_p = F.softmax(y_p, dim=1)
    y_p = torch.argmax(y_p, dim=1, keepdim=True)
    i = torch.sum(y_p * y_t, dim=(2, 3))
    u = torch.sum(y_p, dim=(2, 3)) + torch.sum(y_t, dim=(2, 3))
    score = (2 * i + smooth)/(u + smooth)
    return torch.mean(score)

def ce_loss(y_p, y_t):
    y_p = y_p[:, :, 2:-2, 2:-2]
    y_t = y_t.squeeze(dim=1)
    weight = torch.Tensor([0.57, 4.17]).to(y_t.device)
    criterion = nn.CrossEntropyLoss(weight)
    loss = criterion(y_p, y_t)
    return loss

def false_color(band11, band14, band15):
    def normalize(band, bounds):
        return (band - bounds[0]) / (bounds[1] - bounds[0])    
    _T11_BOUNDS = (243, 303)
    _CLOUD_TOP_TDIFF_BOUNDS = (-4, 5)
    _TDIFF_BOUNDS = (-4, 2)
    r = normalize(band15 - band14, _TDIFF_BOUNDS)
    g = normalize(band14 - band11, _CLOUD_TOP_TDIFF_BOUNDS)
    b = normalize(band14, _T11_BOUNDS)
    return np.clip(np.stack([r, g, b], axis=2), 0, 1)

class ICRGWDataset(Dataset):
    def __init__(self, tar_path, ids, padding_size):
        self.tar_path = tar_path
        self.ids = ids
        self.padding_size = padding_size
    def __len__(self):
        return len(self.ids)
    def __getitem__(self, idx):
        N_TIMES_BEFORE = 4
        sample_path = f"{self.tar_path}/{self.ids[idx]}"
        band11 = np.load(f"{sample_path}/band_11.npy")[..., N_TIMES_BEFORE]
        band14 = np.load(f"{sample_path}/band_14.npy")[..., N_TIMES_BEFORE]
        band15 = np.load(f"{sample_path}/band_15.npy")[..., N_TIMES_BEFORE]
        image = false_color(band11, band14, band15)
        image = torch.Tensor(image)
        image = image.permute(2, 0, 1)
        padding_size = self.padding_size
        image = F.pad(image, (padding_size, padding_size, padding_size, padding_size), mode='reflect')
        try:
            label = np.load(f"{sample_path}/human_pixel_masks.npy")
            label = torch.Tensor(label).to(torch.int64)
            label = label.permute(2, 0, 1)
        except FileNotFoundError:
            # label does not exist
            label = torch.zeros((1, image.shape[1], image.shape[2]))
        return image, label


if __name__ == "__main__":

    data_path = "./train" 
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    image_ids = os.listdir(data_path)
        
    ids_train, ids_valid = train_test_split(image_ids, test_size=0.1, random_state=42)
    print(f"TrainSize: {len(ids_train)}, ValidSize: {len(ids_valid)}")

    batch_size = 8
    epochs = 1
    lr = 1e-5
    
    train_dataset = ICRGWDataset(data_path, ids_train, 2)
    valid_dataset = ICRGWDataset(data_path, ids_valid, 2)
    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=1)
    valid_dataloader = DataLoader(valid_dataset, 1, shuffle=None, num_workers=1)

    # Define model
    model = nn.Conv2d(3, 2, 1) # TODO: replace with your model
    model = model.to(device)
    model.train()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    # train model
    bst_dice = 0
    for epoch in range(epochs):
        model.train()
        bar = tqdm(train_dataloader)
        tot_loss = 0
        tot_score = 0
        count = 0
        for X, y in bar:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            loss = ce_loss(pred, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            tot_loss += loss.item()
            tot_score += dice_score(pred, y)
            count += 1
            bar.set_postfix(TrainLoss=f'{tot_loss/count:.4f}', TrainDice=f'{tot_score/count:.4f}')

        model.eval()
        bar = tqdm(valid_dataloader)
        tot_score = 0
        count = 0
        for X, y in bar:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            tot_score += dice_score(pred, y)
            count += 1
            bar.set_postfix(ValidDice=f'{tot_score/count:.4f}')

        if tot_score/count > bst_dice:
            bst_dice = tot_score/count
            torch.save(model.state_dict(), 'u-net.pth')
            print("current model saved!")

    # evaluate model on validation set and print results
    model.eval()
    tot_score = 0
    for X, y in valid_dataloader:
        X = X.to(device)
        y = y.to(device)
        pred = model(X)
        tot_score += dice_score(pred, y)
    print(f"Validation Dice Score: {tot_score/len(valid_dataloader)}")


    # save predictions on test set to csv file suitable for submission
    submission = pd.read_csv('sample_submission.csv', index_col='record_id')

    test_dataset = ICRGWDataset("test/", os.listdir('test'), 2)
    for idx, (X, y) in enumerate(test_dataset):
        X = X.to(device)
        pred = model(X.unsqueeze(0))[:, :, 2:-2, 2:-2] # remove padding
        pred = torch.argmax(pred, dim=1)[0]
        pred = pred.detach().cpu().numpy()
        submission.loc[int(test_dataset.ids[idx]), 'encoded_pixels'] = list_to_string(rle_encode(pred))
    submission.to_csv('submission.csv')