Spaces:
Runtime error
Runtime error
import torch | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from utils import load_config, get_train_augs, get_valid_augs, train_fn, eval_fn, SegmentationDataset | |
from model import SegmentationModel | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import DataLoader | |
# set device for training | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# load config file | |
config = load_config() | |
# load train files in dataframe | |
df = pd.read_csv(config['files']['CSV_FILE']) | |
train_df, valid_df = train_test_split(df, test_size = 0.2, random_state = 42) | |
trainset = SegmentationDataset(train_df, get_train_augs(config['model']['IMAGE_SIZE'])) | |
validset = SegmentationDataset(valid_df, get_valid_augs(config['model']['IMAGE_SIZE'])) | |
print(f"Size of Trainset : {len(trainset)}") | |
print(f"Size of Validset : {len(validset)}") | |
trainloader = DataLoader(trainset, batch_size=config['model']['BATCH_SIZE'], shuffle = True) | |
validloader = DataLoader(validset, batch_size=config['model']['BATCH_SIZE']) | |
print(f"Total n of batches in trainloader: {len(trainloader)}") | |
print(f"Total n of batches in validloader: {len(validloader)}") | |
model = SegmentationModel() | |
model.to(DEVICE) | |
optimizer = torch.optim.Adam(model.parameters(), lr = config['model']['LR']) | |
best_valid_loss = np.Inf | |
for i in tqdm(range(config['model']['EPOCHS'])): | |
train_loss = train_fn(trainloader, model, optimizer, DEVICE) | |
valid_loss = eval_fn(validloader, model, DEVICE) | |
if valid_loss < best_valid_loss: | |
torch.save(model.state_dict(), 'best_model.pt') | |
print('SAVED-MODEL') | |
best_valid_loss = valid_loss | |
print(f"Epoch: {i+1} Train Loss: {train_loss} Valid Loss: {valid_loss}") |