Spaces:
Running
Running
File size: 13,530 Bytes
64bf706 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 |
import json
import os
import random
import re
import subprocess
import sys
import time
from collections import OrderedDict
from typing import Optional, Union
import numpy as np
import torch
try:
from tap import Tap
except ImportError as e:
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
time.sleep(5)
raise e
import dist
class Args(Tap):
data_path: str = '/path/to/imagenet'
exp_name: str = 'text'
# VAE
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
# VAR
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
depth: int = 16 # VAR depth
# VAR initialization
ini: float = -1 # -1: automated model parameter initialization
hd: float = 0.02 # head.w *= hd
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
# VAR optimization
fp16: int = 0 # 1: using fp16, 2: bf16
tblr: float = 1e-4 # base lr
tlr: float = None # lr = base lr * (bs / 256)
twd: float = 0.05 # initial wd
twde: float = 0 # final wd, =twde or twd
tclip: float = 2. # <=0 for not using grad clip
ls: float = 0.0 # label smooth
bs: int = 768 # global batch size
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
ac: int = 1 # gradient accumulation
ep: int = 250
wp: float = 0
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
wpe: float = 0.01 # final lr ratio at the end of training
sche: str = 'lin0' # lr schedule
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
afuse: bool = True # fused adamw
# other hps
saln: bool = False # whether to use shared adaln
anorm: bool = True # whether to use L2 normalized attention
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
# data
pn: str = '1_2_3_4_5_6_8_10_13_16'
patch_size: int = 16
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
hflip: bool = False # augmentation: horizontal flip
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
# progressive training
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
pgwp: float = 0 # num of warmup epochs at each progressive stage
# would be automatically set in runtime
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
acc_mean: float = None # [automatically set; don't specify this]
acc_tail: float = None # [automatically set; don't specify this]
L_mean: float = None # [automatically set; don't specify this]
L_tail: float = None # [automatically set; don't specify this]
vacc_mean: float = None # [automatically set; don't specify this]
vacc_tail: float = None # [automatically set; don't specify this]
vL_mean: float = None # [automatically set; don't specify this]
vL_tail: float = None # [automatically set; don't specify this]
grad_norm: float = None # [automatically set; don't specify this]
cur_lr: float = None # [automatically set; don't specify this]
cur_wd: float = None # [automatically set; don't specify this]
cur_it: str = '' # [automatically set; don't specify this]
cur_ep: str = '' # [automatically set; don't specify this]
remain_time: str = '' # [automatically set; don't specify this]
finish_time: str = '' # [automatically set; don't specify this]
# environment
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
log_txt_path: str = '...' # [automatically set; don't specify this]
last_ckpt_path: str = '...' # [automatically set; don't specify this]
tf32: bool = True # whether to use TensorFloat32
device: str = 'cpu' # [automatically set; don't specify this]
seed: int = None # seed
def seed_everything(self, benchmark: bool):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = benchmark
if self.seed is None:
torch.backends.cudnn.deterministic = False
else:
torch.backends.cudnn.deterministic = True
seed = self.seed * dist.get_world_size() + dist.get_rank()
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
if self.seed is None:
return None
g = torch.Generator()
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
return g
local_debug: bool = 'KEVIN_LOCAL' in os.environ
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
def compile_model(self, m, fast):
if fast == 0 or self.local_debug:
return m
return torch.compile(m, mode={
1: 'reduce-overhead',
2: 'max-autotune',
3: 'default',
}[fast]) if hasattr(torch, 'compile') else m
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
d = (OrderedDict if key_ordered else dict)()
# self.as_dict() would contain methods, but we only need variables
for k in self.class_variables.keys():
if k not in {'device'}: # these are not serializable
d[k] = getattr(self, k)
return d
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
if isinstance(d, str): # for compatibility with old version
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
for k in d.keys():
try:
setattr(self, k, d[k])
except Exception as e:
print(f'k={k}, v={d[k]}')
raise e
@staticmethod
def set_tf32(tf32: bool):
if torch.cuda.is_available():
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
if hasattr(torch, 'set_float32_matmul_precision'):
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
def dump_log(self):
if not dist.is_local_master():
return
if '1/' in self.cur_ep: # first time to dump log
with open(self.log_txt_path, 'w') as fp:
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
fp.write('\n')
log_dict = {}
for k, v in {
'it': self.cur_it, 'ep': self.cur_ep,
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
'remain_time': self.remain_time, 'finish_time': self.finish_time,
}.items():
if hasattr(v, 'item'): v = v.item()
log_dict[k] = v
with open(self.log_txt_path, 'a') as fp:
fp.write(f'{log_dict}\n')
def __str__(self):
s = []
for k in self.class_variables.keys():
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
s.append(f' {k:20s}: {getattr(self, k)}')
s = '\n'.join(s)
return f'{{\n{s}\n}}\n'
def init_dist_and_get_args():
for i in range(len(sys.argv)):
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
del sys.argv[i]
break
args = Args(explicit_bool=True).parse_args(known_only=True)
if args.local_debug:
args.pn = '1_2_3'
args.seed = 1
args.aln = 1e-2
args.alng = 1e-5
args.saln = False
args.afuse = False
args.pg = 0.8
args.pg0 = 1
else:
if args.data_path == '/path/to/imagenet':
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
# warn args.extra_args
if len(args.extra_args) > 0:
print(f'======================================================================================')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
print(f'======================================================================================\n\n')
# init torch distributed
from utils import misc
os.makedirs(args.local_out_dir_path, exist_ok=True)
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
# set env
args.set_tf32(args.tf32)
args.seed_everything(benchmark=args.pg == 0)
# update args: data loading
args.device = dist.get_device()
if args.pn == '256':
args.pn = '1_2_3_4_5_6_8_10_13_16'
elif args.pn == '512':
args.pn = '1_2_3_4_6_9_13_18_24_32'
elif args.pn == '1024':
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
args.data_load_reso = max(args.resos)
# update args: bs and lr
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
args.batch_size = bs_per_gpu
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
args.workers = min(max(0, args.workers), args.batch_size)
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
args.twde = args.twde or args.twd
if args.wp == 0:
args.wp = args.ep * 1/50
# update args: progressive training
if args.pgwp == 0:
args.pgwp = args.ep * 1/300
if args.pg > 0:
args.sche = f'lin{args.pg:g}'
# update args: paths
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
_reg_valid_name = re.compile(r'[^\w\-+,.]')
tb_name = _reg_valid_name.sub(
'_',
f'tb-VARd{args.depth}'
f'__pn{args.pn}'
f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
)
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
return args
|