kadirnar's picture
Upload 98 files
e7d5680 verified
import random
from typing import Optional
import numpy as np
import torch
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import ProcessGroupMesh
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
DP_AXIS, SP_AXIS = 0, 1
class ZeroSeqParallelPlugin(LowLevelZeroPlugin):
def __init__(
self,
sp_size: int = 1,
stage: int = 2,
precision: str = "fp16",
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0.0,
norm_type: float = 2.0,
reduce_bucket_size_in_m: int = 12,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__(
stage=stage,
precision=precision,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
max_norm=max_norm,
norm_type=norm_type,
reduce_bucket_size_in_m=reduce_bucket_size_in_m,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
master_weights=master_weights,
verbose=verbose,
)
self.sp_size = sp_size
assert self.world_size % sp_size == 0, "world_size must be divisible by sp_size"
self.dp_size = self.world_size // sp_size
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.sp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.sp_group = self.pg_mesh.get_group_along_axis(SP_AXIS)
self.dp_rank = self.pg_mesh.coordinate(DP_AXIS)
self.sp_rank = self.pg_mesh.coordinate(SP_AXIS)
def __del__(self):
"""Destroy the prcess groups in ProcessGroupMesh"""
self.pg_mesh.destroy_mesh_process_groups()
def prepare_dataloader(
self,
dataset,
batch_size,
shuffle=False,
seed=1024,
drop_last=False,
pin_memory=False,
num_workers=0,
distributed_sampler_cls=None,
**kwargs,
):
_kwargs = kwargs.copy()
distributed_sampler_cls = distributed_sampler_cls or DistributedSampler
sampler = distributed_sampler_cls(dataset, num_replicas=self.dp_size, rank=self.dp_rank, shuffle=shuffle)
# Deterministic dataloader
def seed_worker(worker_id):
worker_seed = seed
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
random.seed(worker_seed)
return DataLoader(
dataset,
batch_size=batch_size,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs,
)