Spaces:
Build error
Build error
from __future__ import annotations | |
import argparse | |
import datetime | |
import glob | |
import inspect | |
import os | |
import sys | |
from inspect import Parameter | |
import imageio | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torchvision | |
from einops import rearrange | |
from matplotlib import pyplot as plt | |
from natsort import natsorted | |
from omegaconf import OmegaConf | |
from packaging import version | |
from PIL import Image | |
from pytorch_lightning import seed_everything | |
from pytorch_lightning.callbacks import Callback | |
from pytorch_lightning.trainer import Trainer | |
from pytorch_lightning.utilities import rank_zero_only | |
from safetensors.torch import load_file as load_safetensors | |
from .vwm.util import instantiate_from_config, isheatmap | |
MULTINODE_HACKS = True | |
def default_trainer_args(): | |
argspec = dict(inspect.signature(Trainer.__init__).parameters) | |
argspec.pop("self") | |
default_args = { | |
param: argspec[param].default | |
for param in argspec | |
if argspec[param] != Parameter.empty | |
} | |
return default_args | |
def get_parser(**parser_kwargs): | |
def str2bool(v): | |
if isinstance(v, bool): | |
return v | |
if v.lower() in ("yes", "true", "t", "y", "1"): | |
return True | |
elif v.lower() in ("no", "false", "f", "n", "0"): | |
return False | |
else: | |
raise argparse.ArgumentTypeError("Boolean value expected") | |
parser = argparse.ArgumentParser(**parser_kwargs) | |
parser.add_argument( | |
"-n", | |
"--name", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="postfix for logdir" | |
) | |
parser.add_argument( | |
"--no_date", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="if True, skip date generation for logdir and only use naming via opt.base or opt.name (+ opt.postfix, optionally)" | |
) | |
parser.add_argument( | |
"-r", | |
"--resume", | |
type=str, | |
const=True, | |
default="", | |
nargs="?", | |
help="resume from logdir or checkpoint in logdir" | |
) | |
parser.add_argument( | |
"-b", | |
"--base", | |
nargs="*", | |
metavar="base_config.yaml", | |
help="paths to base configs. " | |
"Loaded from left-to-right. " | |
"Parameters can be overwritten or added with command-line options of the form `--key value`", | |
default=list() | |
) | |
parser.add_argument( | |
"-t", | |
"--train", | |
type=str2bool, | |
const=True, | |
default=True, | |
nargs="?", | |
help="train" | |
) | |
parser.add_argument( | |
"--no_test", | |
type=str2bool, | |
const=True, | |
default=True, | |
nargs="?", | |
help="disable test" | |
) | |
parser.add_argument( | |
"-p", | |
"--project", | |
help="name of new or path to existing project" | |
) | |
parser.add_argument( | |
"-d", | |
"--debug", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="enable post-mortem debugging" | |
) | |
parser.add_argument( | |
"-s", | |
"--seed", | |
type=int, | |
default=23, | |
help="seed for seed_everything" | |
) | |
parser.add_argument( | |
"-f", | |
"--postfix", | |
type=str, | |
default="", | |
help="post-postfix for default name" | |
) | |
parser.add_argument( | |
"-l", | |
"--logdir", | |
type=str, | |
default="logs", | |
help="directory for logging data" | |
) | |
parser.add_argument( | |
"--scale_lr", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="scale base-lr by ngpu * batch_size * n_accumulate" | |
) | |
parser.add_argument( | |
"--legacy_naming", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="name run based on config file name if true, else by whole path" | |
) | |
parser.add_argument( | |
"--enable_tf32", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="enables the TensorFloat32 format both for matmuls and cuDNN for pytorch 1.12" | |
) | |
parser.add_argument( | |
"--no_base_name", | |
type=str2bool, | |
nargs="?", | |
const=True, | |
default=False, | |
help="no config name" | |
) | |
if version.parse(pl.__version__) >= version.parse("2.0.0"): | |
parser.add_argument( | |
"--resume_from_checkpoint", | |
type=str, | |
default=None, | |
help="single checkpoint file to resume from" | |
) | |
parser.add_argument( | |
"--n_devices", | |
type=int, | |
default=8, | |
help="number of gpus in training" | |
) | |
parser.add_argument( | |
"--finetune", | |
type=str, | |
default="ckpts/pytorch_model.bin", | |
help="path to checkpoint to finetune from" | |
) | |
default_args = default_trainer_args() | |
for key in default_args: | |
parser.add_argument("--" + key, default=default_args[key]) | |
return parser | |
def get_checkpoint_name(logdir): | |
ckpt = os.path.join(logdir, "checkpoints", "last**.ckpt") | |
ckpt = natsorted(glob.glob(ckpt)) | |
print("Available last checkpoints:", ckpt) | |
if len(ckpt) > 1: | |
print("Got most recent checkpoint") | |
ckpt = sorted(ckpt, key=lambda x: os.path.getmtime(x))[-1] | |
print(f"Most recent ckpt is {ckpt}") | |
with open(os.path.join(logdir, "most_recent_ckpt.txt"), "w") as f: | |
f.write(ckpt + "\n") | |
try: | |
version = int(ckpt.split("/")[-1].split("-v")[-1].split(".")[0]) | |
except Exception as e: | |
# version confusion but not bad | |
print(e) | |
version = 1 | |
# version = last_version + 1 | |
else: | |
# in this case, we only have one "last.ckpt" | |
ckpt = ckpt[0] | |
version = 1 | |
melk_ckpt_name = f"last-v{version}.ckpt" | |
print(f"Current melk ckpt name: {melk_ckpt_name}") | |
return ckpt, melk_ckpt_name | |
def save_img_seq_to_video(out_path, img_seq, fps): | |
# img_seq: np array | |
writer = imageio.get_writer(out_path, fps=fps) | |
for img in img_seq: | |
writer.append_data(img) | |
writer.close() | |
class SetupCallback(Callback): | |
def __init__( | |
self, | |
resume, | |
now, | |
logdir, | |
ckptdir, | |
cfgdir, | |
config, | |
lightning_config, | |
debug, | |
ckpt_name=None | |
): | |
super().__init__() | |
self.resume = resume | |
self.now = now | |
self.logdir = logdir | |
self.ckptdir = ckptdir | |
self.cfgdir = cfgdir | |
self.config = config | |
self.lightning_config = lightning_config | |
self.debug = debug | |
self.ckpt_name = ckpt_name | |
def on_exception(self, trainer: pl.Trainer, pl_module, exception): | |
if not self.debug and trainer.global_rank == 0: | |
# print("Summoning checkpoint") | |
# if self.ckpt_name is None: | |
# ckpt_path = os.path.join(self.ckptdir, "last.ckpt") | |
# else: | |
# ckpt_path = os.path.join(self.ckptdir, self.ckpt_name) | |
# trainer.save_checkpoint(ckpt_path) | |
print("Exiting") | |
def on_fit_start(self, trainer, pl_module): | |
if trainer.global_rank == 0: | |
# create logdirs and save configs | |
os.makedirs(self.logdir, exist_ok=True) | |
os.makedirs(self.ckptdir, exist_ok=True) | |
os.makedirs(self.cfgdir, exist_ok=True) | |
if "callbacks" in self.lightning_config: | |
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: | |
os.makedirs( | |
os.path.join(self.ckptdir, "trainstep_checkpoints"), | |
exist_ok=True | |
) | |
print("Project config") | |
print(OmegaConf.to_yaml(self.config)) | |
if MULTINODE_HACKS: | |
import time | |
time.sleep(5) | |
OmegaConf.save( | |
self.config, | |
os.path.join(self.cfgdir, f"{self.now}-project.yaml") | |
) | |
print("Lightning config") | |
print(OmegaConf.to_yaml(self.lightning_config)) | |
OmegaConf.save( | |
OmegaConf.create({"lightning": self.lightning_config}), | |
os.path.join(self.cfgdir, f"{self.now}-lightning.yaml") | |
) | |
else: | |
# ModelCheckpoint callback created log directory, remove it | |
if not MULTINODE_HACKS and not self.resume and os.path.exists(self.logdir): | |
dst, name = os.path.split(self.logdir) | |
dst = os.path.join(dst, "child_runs", name) | |
os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
try: | |
os.rename(self.logdir, dst) | |
except FileNotFoundError: | |
pass | |
class ImageLogger(Callback): | |
def __init__( | |
self, | |
batch_frequency, | |
clamp=True, | |
increase_log_steps=True, | |
rescale=True, | |
disabled=False, | |
log_on_batch_idx=False, | |
log_first_step=False, | |
log_images_kwargs=None, | |
log_before_first_step=False, | |
enable_autocast=True, | |
num_frames=25 | |
): | |
super().__init__() | |
self.enable_autocast = enable_autocast | |
self.rescale = rescale | |
self.batch_freq = batch_frequency | |
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] | |
if not increase_log_steps: | |
self.log_steps = [self.batch_freq] | |
self.clamp = clamp | |
self.disabled = disabled | |
self.log_on_batch_idx = log_on_batch_idx | |
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else dict() | |
self.log_first_step = log_first_step | |
self.log_before_first_step = log_before_first_step | |
self.num_frames = num_frames | |
def log_local( | |
self, | |
save_dir, | |
split, | |
images, | |
global_step, | |
current_epoch, | |
batch_idx | |
): | |
root = os.path.join(save_dir, "images", split) | |
for log_type in images: | |
if isheatmap(images[log_type]): | |
_fig, ax = plt.subplots() | |
ax = ax.matshow( | |
images[log_type].cpu().numpy(), cmap="hot", interpolation="lanczos" | |
) | |
plt.colorbar(ax) | |
plt.axis("off") | |
filename = f"{log_type}_epoch{current_epoch:03}_batch{batch_idx:06}_step{global_step:06}.png" | |
os.makedirs(root, exist_ok=True) | |
path = os.path.join(root, log_type, filename) | |
plt.savefig(path) | |
plt.close() | |
elif "mp4" in log_type: | |
dir_path = os.path.join(root, log_type) | |
os.makedirs(dir_path, exist_ok=True) | |
img_seq = images[log_type] | |
if self.rescale: | |
img_seq = (img_seq + 1.0) / 2.0 | |
img_seq = rearrange(img_seq, "(b t) c h w -> b t h w c", t=self.num_frames) | |
B, _T = img_seq.shape[:2] | |
for b_i in range(B): | |
cur_img_seq = img_seq[b_i].numpy() # [t h w c] | |
cur_img_seq = (cur_img_seq * 255).astype(np.uint8) # [t h w c] | |
filename = f"{log_type}_epoch{current_epoch:02}_batch{batch_idx:04}_step{global_step:06}.mp4" | |
save_img_seq_to_video(os.path.join(root, log_type, filename), cur_img_seq, fps=10) | |
else: | |
grid = torchvision.utils.make_grid(images[log_type], nrow=int(images[log_type].shape[0] ** 0.5)) | |
if self.rescale: | |
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w | |
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1) | |
grid = grid.numpy() | |
grid = (grid * 255).astype(np.uint8) | |
filename = f"{log_type}_epoch{current_epoch:02}_batch{batch_idx:04}_step{global_step:06}.png" | |
dir_path = os.path.join(root, log_type) | |
os.makedirs(dir_path, exist_ok=True) | |
path = os.path.join(dir_path, filename) | |
img = Image.fromarray(grid) | |
img.save(path) | |
def log_img(self, pl_module, batch, batch_idx, split="train"): | |
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step | |
if ( | |
self.check_frequency(check_idx) | |
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0 | |
and callable(pl_module.log_images) | |
) or split == "test": | |
is_train = pl_module.training | |
if is_train: | |
pl_module.eval() | |
gpu_autocast_kwargs = { | |
"enabled": self.enable_autocast, # torch.is_autocast_enabled(), | |
"dtype": torch.get_autocast_gpu_dtype(), | |
"cache_enabled": torch.is_autocast_cache_enabled() | |
} | |
with torch.no_grad(), torch.cuda.amp.autocast(**gpu_autocast_kwargs): | |
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) | |
for log_type in images: | |
if isinstance(images[log_type], torch.Tensor): | |
images[log_type] = images[log_type].detach().float().cpu() | |
if self.clamp and not isheatmap(images[log_type]): | |
images[log_type] = torch.clamp(images[log_type], -1.0, 1.0) | |
self.log_local( | |
pl_module.logger.save_dir, | |
split, | |
images, | |
pl_module.global_step, | |
pl_module.current_epoch, | |
batch_idx | |
) | |
if is_train: | |
pl_module.train() | |
def check_frequency(self, check_idx): | |
if (check_idx % self.batch_freq == 0 or check_idx in self.log_steps) and (check_idx > 0 or self.log_first_step): | |
try: | |
self.log_steps.pop(0) | |
except IndexError as e: | |
print(e) | |
pass | |
return True | |
else: | |
return False | |
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): | |
self.log_img(pl_module, batch, batch_idx, split="train") | |
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): | |
if self.log_before_first_step and pl_module.global_step == 0: | |
print(f"{self.__class__.__name__}: logging before training") | |
self.log_img(pl_module, batch, batch_idx, split="train") | |
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs): | |
if not self.disabled and pl_module.global_step > 0: | |
self.log_img(pl_module, batch, batch_idx, split="val") | |
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | |
self.log_img(pl_module, batch, batch_idx, split="test") | |
if __name__ == "__main__": | |
# custom parser to specify config files, train, test and debug mode, postfix, resume | |
# `--key value` arguments are interpreted as arguments to the trainer | |
# `nested.key=value` arguments are interpreted as config parameters | |
# configs are merged from left-to-right followed by command line parameters | |
# model: | |
# base_learning_rate: float | |
# target: path to lightning module | |
# params: | |
# key: value | |
# data: | |
# target: train.DataModuleFromConfig | |
# params: | |
# batch_size: int | |
# wrap: bool | |
# train: | |
# target: path to train dataset | |
# params: | |
# key: value | |
# validation: | |
# target: path to validation dataset | |
# params: | |
# key: value | |
# test: | |
# target: path to test dataset | |
# params: | |
# key: value | |
# lightning: (optional, has sane defaults and can be specified on cmd line) | |
# trainer: | |
# additional arguments to trainer | |
# logger: | |
# logger to instantiate | |
# modelcheckpoint: | |
# modelcheckpoint to instantiate | |
# callbacks: | |
# callback1: | |
# target: importpath | |
# params: | |
# key: value | |
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | |
# add cwd for convenience and to make classes in this file available when | |
# running as `python train.py` | |
# (in particular `train.DataModuleFromConfig`) | |
sys.path.append(os.getcwd()) | |
parser = get_parser() | |
opt, unknown = parser.parse_known_args() | |
if opt.name and opt.resume: | |
raise ValueError( | |
"-n/--name and -r/--resume cannot be specified both. " | |
"If you want to resume training in a new log folder, " | |
"use -n/--name in combination with --resume_from_checkpoint" | |
) | |
melk_ckpt_name = None | |
name = None | |
if opt.resume: | |
if not os.path.exists(opt.resume): | |
raise ValueError(f"Cannot find {opt.resume}") | |
if os.path.isfile(opt.resume): | |
paths = opt.resume.split("/") | |
# idx = len(paths)-paths[::-1].index("logs")+1 | |
# logdir = "/".join(paths[:idx]) | |
logdir = "/".join(paths[:-2]) | |
ckpt = opt.resume | |
_, melk_ckpt_name = get_checkpoint_name(logdir) | |
else: | |
assert os.path.isdir(opt.resume), opt.resume | |
logdir = opt.resume.rstrip("/") | |
ckpt, melk_ckpt_name = get_checkpoint_name(logdir) | |
print("#" * 100) | |
print(f"Resuming from checkpoint `{ckpt}`") | |
print("#" * 100) | |
opt.resume_from_checkpoint = ckpt | |
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml"))) | |
opt.base = base_configs + opt.base | |
_tmp = logdir.split("/") | |
nowname = _tmp[-1] | |
else: | |
if opt.name: | |
name = "_" + opt.name | |
elif opt.base: | |
if opt.no_base_name: | |
name = "" | |
else: | |
if opt.legacy_naming: | |
cfg_fname = os.path.split(opt.base[0])[-1] | |
cfg_name = os.path.splitext(cfg_fname)[0] | |
else: | |
assert "configs" in os.path.split(opt.base[0])[0], os.path.split( | |
opt.base[0] | |
)[0] | |
cfg_path = os.path.split(opt.base[0])[0].split(os.sep)[ | |
os.path.split(opt.base[0])[0].split(os.sep).index("configs") | |
+ 1: | |
] # cut away the first one (we assert all configs are in "configs") | |
cfg_name = os.path.splitext(os.path.split(opt.base[0])[-1])[0] | |
cfg_name = "-".join(cfg_path) + f"-{cfg_name}" | |
name = "_" + cfg_name | |
else: | |
name = "" | |
if opt.no_date: | |
nowname = name + opt.postfix | |
if nowname.startswith("_"): | |
nowname = nowname[1:] | |
else: | |
nowname = now + name + opt.postfix | |
logdir = os.path.join(opt.logdir, nowname) | |
ckptdir = os.path.join(logdir, "checkpoints") | |
cfgdir = os.path.join(logdir, "configs") | |
seed_everything(opt.seed, workers=True) | |
# move before model init, in case a torch.compile(...) is called somewhere | |
if opt.enable_tf32: | |
# pt_version = version.parse(torch.__version__) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
print(f"Enabling TF32 for PyTorch {torch.__version__}") | |
else: | |
print(f"Using default TF32 settings for PyTorch {torch.__version__}:") | |
print(f"torch.backends.cuda.matmul.allow_tf32={torch.backends.cuda.matmul.allow_tf32}") | |
print(f"torch.backends.cudnn.allow_tf32={torch.backends.cudnn.allow_tf32}") | |
try: | |
# init and save configs | |
configs = [OmegaConf.load(cfg) for cfg in opt.base] | |
cli = OmegaConf.from_dotlist(unknown) | |
config = OmegaConf.merge(*configs, cli) | |
lightning_config = config.pop("lightning", OmegaConf.create()) | |
# merge trainer cli with config | |
trainer_config = lightning_config.get("trainer", OmegaConf.create()) | |
# default to gpu | |
trainer_config["accelerator"] = "gpu" | |
standard_args = default_trainer_args() | |
for k in standard_args: | |
if getattr(opt, k) != standard_args[k]: | |
trainer_config[k] = getattr(opt, k) | |
n_devices = getattr(opt, "n_devices", None) | |
if n_devices is not None: | |
assert isinstance(n_devices, int) and n_devices > 0 | |
devices = [str(i) for i in range(n_devices)] | |
trainer_config["devices"] = ",".join(devices) + "," | |
else: | |
assert "devices" in trainer_config, "Must specify either n_devices or devices" | |
ckpt_resume_path = opt.resume_from_checkpoint | |
if "devices" not in trainer_config and trainer_config["accelerator"] != "gpu": | |
del trainer_config["accelerator"] | |
cpu = True | |
else: | |
gpuinfo = trainer_config["devices"] | |
print(f"Running on GPUs {gpuinfo}") | |
cpu = False | |
trainer_opt = argparse.Namespace(**trainer_config) | |
lightning_config.trainer = trainer_config | |
# model | |
model = instantiate_from_config(config.model) | |
# use pretrained model | |
if not opt.resume or opt.finetune: | |
if not opt.finetune or not os.path.exists(opt.finetune): | |
default_ckpt = "ckpts/svd_xt.safetensors" | |
print(f"Loading pretrained model from {default_ckpt}") | |
svd = load_safetensors(default_ckpt) | |
for k in list(svd.keys()): | |
if "time_embed" in k: # duplicate a new timestep embedding from the pretrained weights | |
svd[k.replace("time_embed", "cond_time_stack_embed")] = svd[k] | |
else: | |
ckpt_path = opt.finetune | |
print(f"Loading pretrained model from {ckpt_path}") | |
if ckpt_path.endswith("ckpt"): | |
svd = torch.load(ckpt_path, map_location="cpu")["state_dict"] | |
elif ckpt_path.endswith("bin"): # for deepspeed merged checkpoints | |
svd = torch.load(ckpt_path, map_location="cpu") | |
for k in list(svd.keys()): # remove the prefix | |
if "_forward_module" in k: | |
svd[k.replace("_forward_module.", "")] = svd[k] | |
del svd[k] | |
elif ckpt_path.endswith("safetensors"): | |
svd = load_safetensors(ckpt_path) | |
else: | |
raise NotImplementedError | |
missing, unexpected = model.load_state_dict(svd, strict=False) | |
# avoid empty weights when resuming from EMA weights | |
for miss_k in missing: | |
ema_name = miss_k.replace(".", "").replace("modeldiffusion_model", "model_ema.diffusion_model") | |
svd[miss_k] = svd[ema_name] | |
print("Fill", miss_k, "with", ema_name) | |
missing, unexpected = model.load_state_dict(svd, strict=False) | |
if len(missing) > 0: | |
if not opt.finetune or not os.path.exists(opt.finetune): | |
model.reinit_ema() | |
missing = [model_key for model_key in missing if "model_ema" not in model_key] | |
# print(f"Missing keys: {missing}") | |
print(f"Missing keys: {missing}") | |
# if len(unexpected) > 0: | |
# print(f"Unexpected keys: {unexpected}") | |
print(f"Unexpected keys: {unexpected}") | |
# trainer and callbacks | |
trainer_kwargs = dict() | |
# default logger configs | |
default_logger_cfgs = { | |
"csv": { | |
"target": "pytorch_lightning.loggers.CSVLogger", | |
"params": { | |
"name": "testtube", # hack for sbord fanatics | |
"save_dir": logdir | |
} | |
} | |
} | |
default_logger_cfg = default_logger_cfgs["csv"] | |
if "logger" in lightning_config: | |
logger_cfg = lightning_config.logger | |
else: | |
logger_cfg = OmegaConf.create() | |
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg) | |
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) | |
# use TrainResult/EvalResult(checkpoint_on=metric) to specify which metric is used to determine best models | |
default_modelckpt_cfg = { | |
"target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
"params": { | |
"dirpath": ckptdir, | |
"filename": "{epoch:02}", | |
"verbose": True, | |
"save_last": True, | |
"save_top_k": -1 | |
} | |
} | |
# if hasattr(model, "monitor"): | |
# print(f"Monitoring {model.monitor} as checkpoint metric") | |
# default_modelckpt_cfg["params"]["monitor"] = model.monitor | |
# default_modelckpt_cfg["params"]["save_top_k"] = 3 | |
if "modelcheckpoint" in lightning_config: | |
modelckpt_cfg = lightning_config.modelcheckpoint | |
else: | |
modelckpt_cfg = OmegaConf.create() | |
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg) | |
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}") | |
# default to ddp if not further specified | |
default_strategy_config = {"target": "pytorch_lightning.strategies.DDPStrategy"} | |
if "strategy" in lightning_config: | |
strategy_cfg = lightning_config.strategy | |
else: | |
strategy_cfg = OmegaConf.create() | |
default_strategy_config["params"] = { | |
"find_unused_parameters": True | |
} | |
strategy_cfg = OmegaConf.merge(default_strategy_config, strategy_cfg) | |
print( | |
f"strategy config: \n ++++++++++++++ \n {strategy_cfg} \n ++++++++++++++ " | |
) | |
trainer_kwargs["strategy"] = instantiate_from_config(strategy_cfg) | |
# add callback which sets up log directory | |
default_callbacks_cfg = { | |
"setup_callback": { | |
"target": "train.SetupCallback", | |
"params": { | |
"resume": opt.resume, | |
"now": now, | |
"logdir": logdir, | |
"ckptdir": ckptdir, | |
"cfgdir": cfgdir, | |
"config": config, | |
"lightning_config": lightning_config, | |
"debug": opt.debug, | |
"ckpt_name": melk_ckpt_name | |
} | |
}, | |
"image_logger": { | |
"target": "train.ImageLogger", | |
"params": { | |
"batch_frequency": 1000, | |
"clamp": True | |
} | |
}, | |
"learning_rate_logger": { | |
"target": "pytorch_lightning.callbacks.LearningRateMonitor", | |
"params": { | |
"logging_interval": "step" | |
} | |
} | |
} | |
if version.parse(pl.__version__) >= version.parse("1.4.0"): | |
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg}) | |
if "callbacks" in lightning_config: | |
callbacks_cfg = lightning_config.callbacks | |
else: | |
callbacks_cfg = OmegaConf.create() | |
# if "metrics_over_trainsteps_checkpoint" in callbacks_cfg: | |
# print( | |
# "WARNING: saving checkpoints every n train steps without deleting, this might require some free space" | |
# ) | |
# default_metrics_over_trainsteps_ckpt_dict = { | |
# "metrics_over_trainsteps_checkpoint": { | |
# "target": "pytorch_lightning.callbacks.ModelCheckpoint", | |
# "params": { | |
# "dirpath": os.path.join(ckptdir, "trainstep_checkpoints"), | |
# "filename": "{epoch:06}-{step:09}", | |
# "verbose": True, | |
# "save_top_k": -1, | |
# "every_n_train_steps": 10000, | |
# "save_weights_only": True | |
# } | |
# } | |
# } | |
# default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict) | |
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) | |
if "ignore_keys_callback" in callbacks_cfg and ckpt_resume_path is not None: | |
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = ckpt_resume_path | |
elif "ignore_keys_callback" in callbacks_cfg: | |
del callbacks_cfg["ignore_keys_callback"] | |
trainer_kwargs["callbacks"] = [ | |
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg | |
] | |
if "plugins" not in trainer_kwargs: | |
trainer_kwargs["plugins"] = list() | |
# cmd line trainer args (which are in trainer_opt) have always priority over | |
# config-trainer-args (which are in trainer_kwargs) | |
trainer_opt = vars(trainer_opt) | |
trainer_kwargs = { | |
key: val for key, val in trainer_kwargs.items() if key not in trainer_opt | |
} | |
trainer = Trainer(**trainer_opt, **trainer_kwargs) | |
trainer.logdir = logdir | |
# data | |
data = instantiate_from_config(config.data) | |
# calling these ourselves should not be necessary, but it is | |
# lightning still takes care of proper multiprocessing though | |
data.prepare_data() | |
# data.setup() | |
print("#### Data #####") | |
try: | |
for k in data.datasets: | |
print( | |
f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}" | |
) | |
except: | |
print("Datasets not yet initialized") | |
# configure learning rate | |
if "batch_size" in config.data.params: | |
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate | |
else: | |
bs, base_lr = ( | |
config.data.params.train.loader.batch_size, | |
config.model.base_learning_rate | |
) | |
if cpu: | |
ngpu = 1 | |
else: | |
ngpu = len(lightning_config.trainer.devices.strip(",").split(",")) | |
if "accumulate_grad_batches" in lightning_config.trainer: | |
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches | |
else: | |
accumulate_grad_batches = 1 | |
print(f"accumulate_grad_batches = {accumulate_grad_batches}") | |
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches | |
if opt.scale_lr: | |
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr | |
print( | |
"Setting learning rate to " | |
f"{model.learning_rate:.2e} = {accumulate_grad_batches} (accumulate_grad_batches) * {ngpu} (num_gpus) * {bs} (batch_size) * {base_lr:.2e} (base_lr)" | |
) | |
else: | |
model.learning_rate = base_lr | |
print("++++ NOT USING LR SCALING ++++") | |
print(f"Setting learning rate to {model.learning_rate:.2e}") | |
# allow checkpointing via USR1 | |
def melk(*args, **kwargs): | |
# run all checkpoint hooks | |
if trainer.global_rank == 0: | |
# print("Summoning checkpoint") | |
# if melk_ckpt_name is None: | |
# ckpt_path = os.path.join(ckptdir, "last.ckpt") | |
# else: | |
# ckpt_path = os.path.join(ckptdir, melk_ckpt_name) | |
# trainer.save_checkpoint(ckpt_path) | |
print("Exiting") | |
def divein(*args, **kwargs): | |
if trainer.global_rank == 0: | |
import pudb | |
pudb.set_trace() | |
import signal | |
signal.signal(signal.SIGUSR1, melk) | |
signal.signal(signal.SIGUSR2, divein) | |
# run | |
if opt.train: | |
trainer.fit(model, data, ckpt_path=ckpt_resume_path) | |
if not opt.no_test and not trainer.interrupted: | |
trainer.test(model, data) | |
except RuntimeError as error: | |
# if MULTINODE_HACKS: | |
# import datetime | |
# import os | |
# import socket | |
# | |
# import requests | |
# | |
# device = os.environ.get("CUDA_VISIBLE_DEVICES", "?") | |
# hostname = socket.gethostname() | |
# ts = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") | |
# resp = requests.get("http://169.254.169.254/latest/meta-data/instance-id") | |
# print( | |
# f"ERROR at {ts} " | |
# f"on {hostname}/{resp.text} (CUDA_VISIBLE_DEVICES={device}): {type(err).__name__}: {err}", | |
# flush=True | |
# ) | |
raise error | |
except Exception: | |
if opt.debug and trainer.global_rank == 0: | |
try: | |
import pudb as debugger | |
except ImportError: | |
import pdb as debugger | |
debugger.post_mortem() | |
raise | |
finally: | |
# move newly created debug project to debug_runs | |
if opt.debug and not opt.resume and trainer.global_rank == 0: | |
dst, name = os.path.split(logdir) | |
dst = os.path.join(dst, "debug_runs", name) | |
os.makedirs(os.path.split(dst)[0], exist_ok=True) | |
os.rename(logdir, dst) | |
# if trainer.global_rank == 0: | |
# print(trainer.profiler.summary()) | |