|
from torch.utils.data import Dataset |
|
|
|
class CustomDataset(Dataset): |
|
def __init__(self, data) -> None: |
|
super().__init__() |
|
self.data = data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, index): |
|
|
|
d = self.data[index] |
|
return d |
|
|
|
class EarlyStopping(): |
|
def __init__(self, tolerance=10, min_delta=0): |
|
|
|
self.tolerance = tolerance |
|
self.min_delta = min_delta |
|
self.counter = 0 |
|
self.early_stop = False |
|
|
|
def __call__(self, train_loss, min_loss): |
|
if (train_loss-min_loss) > self.min_delta: |
|
self.counter +=1 |
|
if self.counter >= self.tolerance: |
|
self.early_stop = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|