|
import logging |
|
import os |
|
import torch |
|
import torch.distributed as dist |
|
import yaml |
|
|
|
from fvcore.nn import FlopCountAnalysis |
|
from fvcore.nn import flop_count_table |
|
from fvcore.nn import flop_count_str |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
NORM_MODULES = [ |
|
torch.nn.BatchNorm1d, |
|
torch.nn.BatchNorm2d, |
|
torch.nn.BatchNorm3d, |
|
torch.nn.SyncBatchNorm, |
|
|
|
torch.nn.GroupNorm, |
|
torch.nn.InstanceNorm1d, |
|
torch.nn.InstanceNorm2d, |
|
torch.nn.InstanceNorm3d, |
|
torch.nn.LayerNorm, |
|
torch.nn.LocalResponseNorm, |
|
] |
|
|
|
def register_norm_module(cls): |
|
NORM_MODULES.append(cls) |
|
|
|
return cls |
|
|
|
|
|
def is_main_process(): |
|
rank = 0 |
|
if 'OMPI_COMM_WORLD_SIZE' in os.environ: |
|
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) |
|
|
|
return rank == 0 |
|
|
|
|
|
@torch.no_grad() |
|
def analysis_model(model, dump_input, verbose=False): |
|
model.eval() |
|
flops = FlopCountAnalysis(model, dump_input) |
|
total = flops.total() |
|
model.train() |
|
params_total = sum(p.numel() for p in model.parameters()) |
|
params_learned = sum( |
|
p.numel() for p in model.parameters() if p.requires_grad |
|
) |
|
logger.info(f"flop count table:\n {flop_count_table(flops)}") |
|
if verbose: |
|
logger.info(f"flop count str:\n {flop_count_str(flops)}") |
|
logger.info(f" Total flops: {total/1000/1000:.3f}M,") |
|
logger.info(f" Total params: {params_total/1000/1000:.3f}M,") |
|
logger.info(f" Learned params: {params_learned/1000/1000:.3f}M") |
|
|
|
return total, flop_count_table(flops), flop_count_str(flops) |
|
|
|
|
|
def load_config_dict_to_opt(opt, config_dict, splitter='.'): |
|
""" |
|
Load the key, value pairs from config_dict to opt, overriding existing values in opt |
|
if there is any. |
|
""" |
|
if not isinstance(config_dict, dict): |
|
raise TypeError("Config must be a Python dictionary") |
|
for k, v in config_dict.items(): |
|
k_parts = k.split(splitter) |
|
pointer = opt |
|
for k_part in k_parts[:-1]: |
|
if k_part not in pointer: |
|
pointer[k_part] = {} |
|
pointer = pointer[k_part] |
|
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict." |
|
ori_value = pointer.get(k_parts[-1]) |
|
pointer[k_parts[-1]] = v |
|
if ori_value: |
|
print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}") |
|
|
|
|
|
def load_opt_from_config_file(conf_file): |
|
""" |
|
Load opt from the config file. |
|
|
|
Args: |
|
conf_file: config file path |
|
|
|
Returns: |
|
dict: a dictionary of opt settings |
|
""" |
|
opt = {} |
|
with open(conf_file, encoding='utf-8') as f: |
|
config_dict = yaml.safe_load(f) |
|
load_config_dict_to_opt(opt, config_dict) |
|
|
|
return opt |
|
|
|
def cast_batch_to_dtype(batch, dtype): |
|
""" |
|
Cast the float32 tensors in a batch to a specified torch dtype. |
|
It should be called before feeding the batch to the FP16 DeepSpeed model. |
|
|
|
Args: |
|
batch (torch.tensor or container of torch.tensor): input batch |
|
Returns: |
|
return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype. |
|
""" |
|
if torch.is_tensor(batch): |
|
if torch.is_floating_point(batch): |
|
return_batch = batch.to(dtype) |
|
else: |
|
return_batch = batch |
|
elif isinstance(batch, list): |
|
return_batch = [cast_batch_to_dtype(t, dtype) for t in batch] |
|
elif isinstance(batch, tuple): |
|
return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch) |
|
elif isinstance(batch, dict): |
|
return_batch = {} |
|
for k in batch: |
|
return_batch[k] = cast_batch_to_dtype(batch[k], dtype) |
|
else: |
|
logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.") |
|
return_batch = batch |
|
|
|
return return_batch |
|
|
|
|
|
def cast_batch_to_half(batch): |
|
""" |
|
Cast the float32 tensors in a batch to float16. |
|
It should be called before feeding the batch to the FP16 DeepSpeed model. |
|
|
|
Args: |
|
batch (torch.tensor or container of torch.tensor): input batch |
|
Returns: |
|
return_batch: same type as the input batch with internal float32 tensors casted to float16 |
|
""" |
|
return cast_batch_to_dtype(batch, torch.float16) |
|
|