|
""" |
|
utility helpers for distributed checks |
|
""" |
|
import os |
|
import pickle |
|
from contextlib import contextmanager |
|
from datetime import timedelta |
|
|
|
import torch |
|
import torch.distributed as dist |
|
from accelerate import PartialState |
|
|
|
distributed_state = None |
|
|
|
|
|
def is_distributed(): |
|
""" |
|
Check if distributed training is initialized. |
|
""" |
|
global distributed_state |
|
if not distributed_state: |
|
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800)) |
|
distributed_state = PartialState(timeout=timedelta(seconds=timeout)) |
|
|
|
return distributed_state.use_distributed and distributed_state.initialized |
|
|
|
|
|
def barrier(): |
|
""" |
|
Acts as a barrier to wait for all processes. This ensures that all processes |
|
reach the barrier before proceeding further. |
|
""" |
|
if is_distributed(): |
|
dist.barrier() |
|
|
|
|
|
def is_main_process(): |
|
""" |
|
Check if the current process is the main process. |
|
If not in distributed mode, always return True. |
|
""" |
|
if not is_distributed(): |
|
return True |
|
return dist.get_rank() == 0 |
|
|
|
|
|
def get_world_size(): |
|
return int(os.getenv("WORLD_SIZE", "1")) |
|
|
|
|
|
@contextmanager |
|
def zero_only(): |
|
""" |
|
Context manager that only runs the enclosed block on the main rank. |
|
""" |
|
if is_main_process(): |
|
yield |
|
else: |
|
yield None |
|
|
|
|
|
@contextmanager |
|
def zero_first(is_main): |
|
""" |
|
runs the wrapped context so that rank 0 runs first before other ranks |
|
""" |
|
if not is_main: |
|
barrier() |
|
yield |
|
if is_main: |
|
barrier() |
|
|
|
|
|
def gather_scalar_from_all_ranks(fn, world_size=1): |
|
""" |
|
Run a callable 'fn' on all ranks and gather the results on the specified rank. |
|
|
|
Args: |
|
- fn (callable): A function that computes the value. This should not have any side effects. |
|
- rank (int, optional): The rank that gathers the values. Default is 0. |
|
- world_size (int, optional): Total number of processes in the current distributed setup. |
|
|
|
Returns: |
|
- A list of computed values from all ranks if on the gathering rank, otherwise None. |
|
""" |
|
value_scalar = fn() |
|
if not is_distributed(): |
|
return [value_scalar] |
|
value_tensor = torch.tensor( |
|
value_scalar, device=torch.cuda.current_device() |
|
).float() |
|
|
|
if not is_main_process(): |
|
dist.gather(value_tensor, dst=0) |
|
else: |
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] |
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) |
|
|
|
|
|
gathered_values = [] |
|
for tensor in gathered_tensors: |
|
if tensor == tensor.int(): |
|
gathered_values.append(int(tensor.item())) |
|
else: |
|
gathered_values.append(float(tensor.item())) |
|
return gathered_values |
|
return None |
|
|
|
|
|
def broadcast_dict(vals: dict): |
|
if not is_distributed(): |
|
return vals |
|
|
|
if is_main_process(): |
|
data_byte = pickle.dumps(vals) |
|
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda") |
|
data_size = torch.IntTensor([len(data_byte)]).to("cuda") |
|
else: |
|
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda") |
|
data_size = torch.IntTensor([0]).to("cuda") |
|
|
|
dist.broadcast(data_size, 0) |
|
if not is_main_process(): |
|
|
|
data_tensor = data_tensor.new_empty([data_size.item()]) |
|
|
|
dist.broadcast(data_tensor, 0) |
|
|
|
if not is_main_process(): |
|
data_list = data_tensor.cpu().tolist() |
|
data_byte = bytes(data_list[: data_size.item()]) |
|
vals = pickle.loads(data_byte) |
|
|
|
return vals |
|
|
|
|
|
def compute_and_broadcast(fn): |
|
""" |
|
Compute a value using the function 'fn' only on the specified rank (default is 0). |
|
The value is then broadcasted to all other ranks. |
|
|
|
Args: |
|
- fn (callable): A function that computes the value. This should not have any side effects. |
|
- rank (int, optional): The rank that computes the value. Default is 0. |
|
|
|
Returns: |
|
- The computed value (int or float). |
|
""" |
|
if is_main_process(): |
|
value_scalar = fn() |
|
value_tensor = torch.tensor( |
|
value_scalar, device=torch.cuda.current_device() |
|
).float() |
|
else: |
|
value_tensor = torch.tensor( |
|
0.0, device=torch.cuda.current_device() |
|
) |
|
|
|
|
|
barrier() |
|
dist.broadcast(value_tensor, src=0) |
|
|
|
|
|
if value_tensor == value_tensor.int(): |
|
return int(value_tensor.item()) |
|
return float(value_tensor.item()) |
|
|
|
|
|
def gather_from_all_ranks(fn, world_size=1): |
|
""" |
|
Run a callable 'fn' on all ranks and gather the results on the specified rank. |
|
|
|
Args: |
|
- fn (callable): A function that computes the value. This should not have any side effects. |
|
- rank (int, optional): The rank that gathers the values. Default is 0. |
|
- world_size (int, optional): Total number of processes in the current distributed setup. |
|
|
|
Returns: |
|
- A list of computed values from all ranks if on the gathering rank, otherwise None. |
|
""" |
|
value_scalar = fn() |
|
value_tensor = torch.tensor( |
|
value_scalar, device=torch.cuda.current_device() |
|
).float() |
|
|
|
|
|
if is_main_process(): |
|
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)] |
|
else: |
|
gathered_tensors = None |
|
|
|
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0) |
|
|
|
if is_main_process(): |
|
|
|
gathered_values = [] |
|
for tensor in gathered_tensors: |
|
if tensor == tensor.int(): |
|
gathered_values.append(int(tensor.item())) |
|
else: |
|
gathered_values.append(float(tensor.item())) |
|
return gathered_values |
|
return None |
|
|
|
|
|
def reduce_and_broadcast(fn1, fn2): |
|
""" |
|
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2', |
|
and then broadcast the reduced result to all ranks. |
|
|
|
Args: |
|
- fn1 (callable): A function that computes the value on each rank. |
|
- fn2 (callable): A reduction function that takes a list of values and returns a single value. |
|
- world_size (int, optional): Total number of processes in the current distributed setup. |
|
|
|
Returns: |
|
- The reduced and broadcasted value. |
|
""" |
|
|
|
|
|
if not is_distributed(): |
|
return fn2([fn1()]) |
|
|
|
gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size()) |
|
|
|
|
|
|
|
return compute_and_broadcast(lambda: fn2(gathered_values)) |
|
|