medical
AleksanderObuchowski's picture
Add files using upload-large-folder tool
5ceacbc verified
raw
history blame
4.28 kB
import logging
import os
import torch
import torch.distributed as dist
import yaml
from fvcore.nn import FlopCountAnalysis
from fvcore.nn import flop_count_table
from fvcore.nn import flop_count_str
logger = logging.getLogger(__name__)
NORM_MODULES = [
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
]
def register_norm_module(cls):
NORM_MODULES.append(cls)
return cls
def is_main_process():
rank = 0
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
return rank == 0
@torch.no_grad()
def analysis_model(model, dump_input, verbose=False):
model.eval()
flops = FlopCountAnalysis(model, dump_input)
total = flops.total()
model.train()
params_total = sum(p.numel() for p in model.parameters())
params_learned = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
logger.info(f"flop count table:\n {flop_count_table(flops)}")
if verbose:
logger.info(f"flop count str:\n {flop_count_str(flops)}")
logger.info(f" Total flops: {total/1000/1000:.3f}M,")
logger.info(f" Total params: {params_total/1000/1000:.3f}M,")
logger.info(f" Learned params: {params_learned/1000/1000:.3f}M")
return total, flop_count_table(flops), flop_count_str(flops)
def load_config_dict_to_opt(opt, config_dict, splitter='.'):
"""
Load the key, value pairs from config_dict to opt, overriding existing values in opt
if there is any.
"""
if not isinstance(config_dict, dict):
raise TypeError("Config must be a Python dictionary")
for k, v in config_dict.items():
k_parts = k.split(splitter)
pointer = opt
for k_part in k_parts[:-1]:
if k_part not in pointer:
pointer[k_part] = {}
pointer = pointer[k_part]
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
ori_value = pointer.get(k_parts[-1])
pointer[k_parts[-1]] = v
if ori_value:
print(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
def load_opt_from_config_file(conf_file):
"""
Load opt from the config file.
Args:
conf_file: config file path
Returns:
dict: a dictionary of opt settings
"""
opt = {}
with open(conf_file, encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
load_config_dict_to_opt(opt, config_dict)
return opt
def cast_batch_to_dtype(batch, dtype):
"""
Cast the float32 tensors in a batch to a specified torch dtype.
It should be called before feeding the batch to the FP16 DeepSpeed model.
Args:
batch (torch.tensor or container of torch.tensor): input batch
Returns:
return_batch: same type as the input batch with internal float32 tensors casted to the specified dtype.
"""
if torch.is_tensor(batch):
if torch.is_floating_point(batch):
return_batch = batch.to(dtype)
else:
return_batch = batch
elif isinstance(batch, list):
return_batch = [cast_batch_to_dtype(t, dtype) for t in batch]
elif isinstance(batch, tuple):
return_batch = tuple(cast_batch_to_dtype(t, dtype) for t in batch)
elif isinstance(batch, dict):
return_batch = {}
for k in batch:
return_batch[k] = cast_batch_to_dtype(batch[k], dtype)
else:
logger.debug(f"Can not cast type {type(batch)} to {dtype}. Skipping it in the batch.")
return_batch = batch
return return_batch
def cast_batch_to_half(batch):
"""
Cast the float32 tensors in a batch to float16.
It should be called before feeding the batch to the FP16 DeepSpeed model.
Args:
batch (torch.tensor or container of torch.tensor): input batch
Returns:
return_batch: same type as the input batch with internal float32 tensors casted to float16
"""
return cast_batch_to_dtype(batch, torch.float16)