|
import os |
|
import torch |
|
from collections import OrderedDict |
|
import glob |
|
|
|
|
|
class Saver(object): |
|
|
|
def __init__(self, args): |
|
self.args = args |
|
self.directory = os.path.join('run', args.train_dataset, args.checkname) |
|
self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) |
|
run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 |
|
|
|
self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) |
|
if not os.path.exists(self.experiment_dir): |
|
os.makedirs(self.experiment_dir) |
|
|
|
def save_checkpoint(self, state, filename='checkpoint.pth.tar'): |
|
"""Saves checkpoint to disk""" |
|
filename = os.path.join(self.experiment_dir, filename) |
|
torch.save(state, filename) |
|
|
|
def save_experiment_config(self): |
|
logfile = os.path.join(self.experiment_dir, 'parameters.txt') |
|
log_file = open(logfile, 'w') |
|
p = OrderedDict() |
|
p['train_dataset'] = self.args.train_dataset |
|
p['lr'] = self.args.lr |
|
p['epoch'] = self.args.epochs |
|
|
|
for key, val in p.items(): |
|
log_file.write(key + ':' + str(val) + '\n') |
|
log_file.close() |
|
|