Spaces:
Runtime error
Runtime error
from pytorch_lightning.loggers import WandbLogger | |
import diffusion | |
import torch | |
import wandb | |
import pytorch_lightning as pl | |
import argparse | |
import os | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
def main(): | |
# PARSERs | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--dataset', '-d', type=str, default='mnist', | |
help='choose dataset' | |
) | |
parser.add_argument( | |
'--data_dir', '-dd', type=str, default='./data/', | |
help='model name' | |
) | |
parser.add_argument( | |
'--mode', type=str, default='ddim', | |
help='sampling mode' | |
) | |
parser.add_argument( | |
'--max_epochs', '-me', type=int, default=200, | |
help='max epoch' | |
) | |
parser.add_argument( | |
'--batch_size', '-bs', type=int, default=32, | |
help='batch size' | |
) | |
parser.add_argument( | |
'--train_ratio', '-tr', type=float, default=0.99, | |
help='batch size' | |
) | |
parser.add_argument( | |
'--timesteps', '-ts', type=int, default=1000, | |
help='max timesteps diffusion' | |
) | |
parser.add_argument( | |
'--max_batch_size', '-mbs', type=int, default=32, | |
help='max batch size' | |
) | |
parser.add_argument( | |
'--lr', '-l', type=float, default=1e-4, | |
help='learning rate' | |
) | |
parser.add_argument( | |
'--num_workers', '-nw', type=int, default=4, | |
help='number of workers' | |
) | |
parser.add_argument( | |
'--seed', '-s', type=int, default=42, | |
help='seed' | |
) | |
parser.add_argument( | |
'--name', '-n', type=str, default=None, | |
help='name of the experiment' | |
) | |
parser.add_argument( | |
'--pbar', action='store_true', | |
help='progress bar' | |
) | |
parser.add_argument( | |
'--precision', '-p', type=str, default='32', | |
help='numerical precision' | |
) | |
parser.add_argument( | |
'--sample_per_epochs', '-spe', type=int, default=25, | |
help='sample every n epochs' | |
) | |
parser.add_argument( | |
'--n_samples', '-ns', type=int, default=4, | |
help='number of workers' | |
) | |
parser.add_argument( | |
'--monitor', '-m', type=str, default='val_loss', | |
help='callbacks monitor' | |
) | |
parser.add_argument( | |
'--wandb', '-wk', type=str, default=None, | |
help='wandb API key' | |
) | |
args = parser.parse_args() | |
# SEED | |
pl.seed_everything(args.seed, workers=True) | |
# WANDB (OPTIONAL) | |
if args.wandb is not None: | |
wandb.login(key=args.wandb) # API KEY | |
name = args.name or f"diffusion-{args.max_epochs}-{args.batch_size}-{args.lr}" | |
logger = WandbLogger( | |
project="diffusion-model", | |
name=name, | |
log_model=False | |
) | |
else: | |
logger = None | |
# DATAMODULE | |
if args.dataset == "mnist": | |
DATAMODULE = diffusion.MNISTDataModule | |
img_dim = 32 | |
num_classes = 10 | |
elif args.dataset == "cifar10": | |
DATAMODULE = diffusion.CIFAR10DataModule | |
img_dim = 32 | |
num_classes = 10 | |
elif args.dataset == "celeba": | |
DATAMODULE = diffusion.CelebADataModule | |
img_dim = 64 | |
num_classes = None | |
datamodule = DATAMODULE( | |
data_dir=args.data_dir, | |
batch_size=args.batch_size, | |
num_workers=args.num_workers, | |
seed=args.seed, | |
train_ratio=args.train_ratio, | |
img_dim=img_dim | |
) | |
# MODEL | |
in_channels = 1 if args.dataset == "mnist" else 3 | |
model = diffusion.DiffusionModel( | |
lr=args.lr, | |
in_channels=in_channels, | |
sample_per_epochs=args.sample_per_epochs, | |
max_timesteps=args.timesteps, | |
dim=img_dim, | |
num_classes=num_classes, | |
n_samples=args.n_samples, | |
mode=args.mode | |
) | |
# CALLBACK | |
root_path = os.path.join(os.getcwd(), "checkpoints") | |
callback = diffusion.ModelCallback( | |
root_path=root_path, | |
ckpt_monitor=args.monitor | |
) | |
# STRATEGY | |
strategy = 'ddp_find_unused_parameters_true' if torch.cuda.is_available() else 'auto' | |
# TRAINER | |
trainer = pl.Trainer( | |
default_root_dir=root_path, | |
logger=logger, | |
callbacks=callback.get_callback(), | |
gradient_clip_val=0.5, | |
max_epochs=args.max_epochs, | |
enable_progress_bar=args.pbar, | |
deterministic=False, | |
precision=args.precision, | |
strategy=strategy, | |
accumulate_grad_batches=max(int(args.max_batch_size / args.batch_size), 1) | |
) | |
# FIT MODEL | |
trainer.fit(model=model, datamodule=datamodule) | |
if __name__ == '__main__': | |
main() | |