|
import os |
|
import typing |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DataParallel |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
from ..data.datasets import ResumableDistributedSampler as DistributedSampler |
|
from ..data.datasets import ResumableSequentialSampler as SequentialSampler |
|
|
|
|
|
class Accelerator: |
|
"""This class is used to prepare models and dataloaders for |
|
usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to |
|
prepare the respective objects. In the case of models, they are moved to |
|
the appropriate GPU and SyncBatchNorm is applied to them. In the case of |
|
dataloaders, a sampler is created and the dataloader is initialized with |
|
that sampler. |
|
|
|
If the world size is 1, prepare_model and prepare_dataloader are |
|
no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the |
|
script was launched without ``torchrun``, and ``DataParallel`` |
|
will be used instead of ``DistributedDataParallel`` (not recommended), if |
|
the world size (number of GPUs) is greater than 1. |
|
|
|
Parameters |
|
---------- |
|
amp : bool, optional |
|
Whether or not to enable automatic mixed precision, by default False |
|
""" |
|
|
|
def __init__(self, amp: bool = False): |
|
local_rank = os.getenv("LOCAL_RANK", None) |
|
self.world_size = torch.cuda.device_count() |
|
|
|
self.use_ddp = self.world_size > 1 and local_rank is not None |
|
self.use_dp = self.world_size > 1 and local_rank is None |
|
self.device = "cpu" if self.world_size == 0 else "cuda" |
|
|
|
if self.use_ddp: |
|
local_rank = int(local_rank) |
|
dist.init_process_group( |
|
"nccl", |
|
init_method="env://", |
|
world_size=self.world_size, |
|
rank=local_rank, |
|
) |
|
|
|
self.local_rank = 0 if local_rank is None else local_rank |
|
self.amp = amp |
|
|
|
class DummyScaler: |
|
def __init__(self): |
|
pass |
|
|
|
def step(self, optimizer): |
|
optimizer.step() |
|
|
|
def scale(self, loss): |
|
return loss |
|
|
|
def unscale_(self, optimizer): |
|
return optimizer |
|
|
|
def update(self): |
|
pass |
|
|
|
self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() |
|
self.device_ctx = ( |
|
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None |
|
) |
|
|
|
def __enter__(self): |
|
if self.device_ctx is not None: |
|
self.device_ctx.__enter__() |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
if self.device_ctx is not None: |
|
self.device_ctx.__exit__(exc_type, exc_value, traceback) |
|
|
|
def prepare_model(self, model: torch.nn.Module, **kwargs): |
|
"""Prepares model for DDP or DP. The model is moved to |
|
the device of the correct rank. |
|
|
|
Parameters |
|
---------- |
|
model : torch.nn.Module |
|
Model that is converted for DDP or DP. |
|
|
|
Returns |
|
------- |
|
torch.nn.Module |
|
Wrapped model, or original model if DDP and DP are turned off. |
|
""" |
|
model = model.to(self.device) |
|
if self.use_ddp: |
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) |
|
model = DistributedDataParallel( |
|
model, device_ids=[self.local_rank], **kwargs |
|
) |
|
elif self.use_dp: |
|
model = DataParallel(model, **kwargs) |
|
return model |
|
|
|
|
|
def autocast(self, *args, **kwargs): |
|
"""Context manager for autocasting. Arguments |
|
go to ``torch.cuda.amp.autocast``. |
|
""" |
|
return torch.cuda.amp.autocast(self.amp, *args, **kwargs) |
|
|
|
def backward(self, loss: torch.Tensor): |
|
"""Backwards pass, after scaling the loss if ``amp`` is |
|
enabled. |
|
|
|
Parameters |
|
---------- |
|
loss : torch.Tensor |
|
Loss value. |
|
""" |
|
self.scaler.scale(loss).backward() |
|
|
|
def step(self, optimizer: torch.optim.Optimizer): |
|
"""Steps the optimizer, using a ``scaler`` if ``amp`` is |
|
enabled. |
|
|
|
Parameters |
|
---------- |
|
optimizer : torch.optim.Optimizer |
|
Optimizer to step forward. |
|
""" |
|
self.scaler.step(optimizer) |
|
|
|
def update(self): |
|
"""Updates the scale factor.""" |
|
self.scaler.update() |
|
|
|
def prepare_dataloader( |
|
self, dataset: typing.Iterable, start_idx: int = None, **kwargs |
|
): |
|
"""Wraps a dataset with a DataLoader, using the correct sampler if DDP is |
|
enabled. |
|
|
|
Parameters |
|
---------- |
|
dataset : typing.Iterable |
|
Dataset to build Dataloader around. |
|
start_idx : int, optional |
|
Start index of sampler, useful if resuming from some epoch, |
|
by default None |
|
|
|
Returns |
|
------- |
|
_type_ |
|
_description_ |
|
""" |
|
|
|
if self.use_ddp: |
|
sampler = DistributedSampler( |
|
dataset, |
|
start_idx, |
|
num_replicas=self.world_size, |
|
rank=self.local_rank, |
|
) |
|
if "num_workers" in kwargs: |
|
kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) |
|
kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) |
|
else: |
|
sampler = SequentialSampler(dataset, start_idx) |
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) |
|
return dataloader |
|
|
|
@staticmethod |
|
def unwrap(model): |
|
"""Unwraps the model if it was wrapped in DDP or DP, otherwise |
|
just returns the model. Use this to unwrap the model returned by |
|
:py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. |
|
""" |
|
if hasattr(model, "module"): |
|
return model.module |
|
return model |
|
|