"""Benchmarking and measurement utilities""" import functools import pynvml import torch from pynvml.nvml import NVMLError def check_cuda_device(default_value): """ wraps a function and returns the default value instead of running the wrapped function if cuda isn't available or the device is auto :param default_value: :return: """ def deco(func): @functools.wraps(func) def wrapper(*args, **kwargs): device = kwargs.get("device", args[0] if args else None) if ( device is None or not torch.cuda.is_available() or device == "auto" or torch.device(device).type == "cpu" or torch.device(device).type == "meta" ): return default_value return func(*args, **kwargs) return wrapper return deco @check_cuda_device(0.0) def gpu_memory_usage(device=0): return torch.cuda.memory_allocated(device) / 1024.0**3 @check_cuda_device((0.0, 0.0, 0.0)) def gpu_memory_usage_all(device=0): usage = torch.cuda.memory_allocated(device) / 1024.0**3 reserved = torch.cuda.memory_reserved(device) / 1024.0**3 smi = gpu_memory_usage_smi(device) return usage, reserved - usage, max(0, smi - reserved) def mps_memory_usage_all(): usage = torch.mps.current_allocated_memory() / 1024.0**3 reserved = torch.mps.driver_allocated_memory() / 1024.0**3 return usage, reserved - usage, 0 @check_cuda_device(0.0) def gpu_memory_usage_smi(device=0): if isinstance(device, torch.device): device = device.index if isinstance(device, str) and device.startswith("cuda:"): device = int(device[5:]) try: pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(device) info = pynvml.nvmlDeviceGetMemoryInfo(handle) return info.used / 1024.0**3 except NVMLError: return 0.0 def log_gpu_memory_usage(log, msg, device): if torch.backends.mps.is_available(): usage, cache, misc = mps_memory_usage_all() else: usage, cache, misc = gpu_memory_usage_all(device) extras = [] if cache > 0: extras.append(f"+{cache:.03f}GB cache") if misc > 0: extras.append(f"+{misc:.03f}GB misc") log.info( f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2 ) return usage, cache, misc