import json import os import random import re import subprocess import sys import time from collections import OrderedDict from typing import Optional, Union import numpy as np import torch try: from tap import Tap except ImportError as e: print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True) time.sleep(5) raise e import dist class Args(Tap): data_path: str = '/path/to/imagenet' exp_name: str = 'text' # VAE vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' # VAR tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune' depth: int = 16 # VAR depth # VAR initialization ini: float = -1 # -1: automated model parameter initialization hd: float = 0.02 # head.w *= hd aln: float = 0.5 # the multiplier of ada_lin.w's initialization alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization # VAR optimization fp16: int = 0 # 1: using fp16, 2: bf16 tblr: float = 1e-4 # base lr tlr: float = None # lr = base lr * (bs / 256) twd: float = 0.05 # initial wd twde: float = 0 # final wd, =twde or twd tclip: float = 2. # <=0 for not using grad clip ls: float = 0.0 # label smooth bs: int = 768 # global batch size batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8 glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size() ac: int = 1 # gradient accumulation ep: int = 250 wp: float = 0 wp0: float = 0.005 # initial lr ratio at the begging of lr warm up wpe: float = 0.01 # final lr ratio at the end of training sche: str = 'lin0' # lr schedule opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work afuse: bool = True # fused adamw # other hps saln: bool = False # whether to use shared adaln anorm: bool = True # whether to use L2 normalized attention fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc. # data pn: str = '1_2_3_4_5_6_8_10_13_16' patch_size: int = 16 patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_'))) resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums) data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso hflip: bool = False # augmentation: horizontal flip workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader # progressive training pg: float = 0.0 # >0 for use progressive training during [0%, this] of training pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc pgwp: float = 0 # num of warmup epochs at each progressive stage # would be automatically set in runtime cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this] branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this] commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this] acc_mean: float = None # [automatically set; don't specify this] acc_tail: float = None # [automatically set; don't specify this] L_mean: float = None # [automatically set; don't specify this] L_tail: float = None # [automatically set; don't specify this] vacc_mean: float = None # [automatically set; don't specify this] vacc_tail: float = None # [automatically set; don't specify this] vL_mean: float = None # [automatically set; don't specify this] vL_tail: float = None # [automatically set; don't specify this] grad_norm: float = None # [automatically set; don't specify this] cur_lr: float = None # [automatically set; don't specify this] cur_wd: float = None # [automatically set; don't specify this] cur_it: str = '' # [automatically set; don't specify this] cur_ep: str = '' # [automatically set; don't specify this] remain_time: str = '' # [automatically set; don't specify this] finish_time: str = '' # [automatically set; don't specify this] # environment local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this] tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this] log_txt_path: str = '...' # [automatically set; don't specify this] last_ckpt_path: str = '...' # [automatically set; don't specify this] tf32: bool = True # whether to use TensorFloat32 device: str = 'cpu' # [automatically set; don't specify this] seed: int = None # seed def seed_everything(self, benchmark: bool): torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = benchmark if self.seed is None: torch.backends.cudnn.deterministic = False else: torch.backends.cudnn.deterministic = True seed = self.seed * dist.get_world_size() + dist.get_rank() os.environ['PYTHONHASHSEED'] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) same_seed_for_all_ranks: int = 0 # this is only for distributed sampler def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation if self.seed is None: return None g = torch.Generator() g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank()) return g local_debug: bool = 'KEVIN_LOCAL' in os.environ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ def compile_model(self, m, fast): if fast == 0 or self.local_debug: return m return torch.compile(m, mode={ 1: 'reduce-overhead', 2: 'max-autotune', 3: 'default', }[fast]) if hasattr(torch, 'compile') else m def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]: d = (OrderedDict if key_ordered else dict)() # self.as_dict() would contain methods, but we only need variables for k in self.class_variables.keys(): if k not in {'device'}: # these are not serializable d[k] = getattr(self, k) return d def load_state_dict(self, d: Union[OrderedDict, dict, str]): if isinstance(d, str): # for compatibility with old version d: dict = eval('\n'.join([l for l in d.splitlines() if ' 0: print(f'======================================================================================') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}') print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================') print(f'======================================================================================\n\n') # init torch distributed from utils import misc os.makedirs(args.local_out_dir_path, exist_ok=True) misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30) # set env args.set_tf32(args.tf32) args.seed_everything(benchmark=args.pg == 0) # update args: data loading args.device = dist.get_device() if args.pn == '256': args.pn = '1_2_3_4_5_6_8_10_13_16' elif args.pn == '512': args.pn = '1_2_3_4_6_9_13_18_24_32' elif args.pn == '1024': args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64' args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_'))) args.resos = tuple(pn * args.patch_size for pn in args.patch_nums) args.data_load_reso = max(args.resos) # update args: bs and lr bs_per_gpu = round(args.bs / args.ac / dist.get_world_size()) args.batch_size = bs_per_gpu args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size() args.workers = min(max(0, args.workers), args.batch_size) args.tlr = args.ac * args.tblr * args.glb_batch_size / 256 args.twde = args.twde or args.twd if args.wp == 0: args.wp = args.ep * 1/50 # update args: progressive training if args.pgwp == 0: args.pgwp = args.ep * 1/300 if args.pg > 0: args.sche = f'lin{args.pg:g}' # update args: paths args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt') args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth') _reg_valid_name = re.compile(r'[^\w\-+,.]') tb_name = _reg_valid_name.sub( '_', f'tb-VARd{args.depth}' f'__pn{args.pn}' f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}' ) args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name) return args