File size: 7,138 Bytes
2bb0b78 7657632 e30f1e3 fc2d6be 132eb74 fc2d6be 7657632 2bb0b78 132eb74 2bb0b78 132eb74 2bb0b78 132eb74 2bb0b78 fc2d6be 7657632 4c834bf fc2d6be 7657632 09f1543 2fe95cd 7657632 e30f1e3 b15b19e 2fe95cd b15b19e 2fe95cd b15b19e 2fe95cd b15b19e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
"""
utility helpers for distributed checks
"""
import os
import pickle # nosec
from contextlib import contextmanager
from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
distributed_state = None # pylint: disable=invalid-name
def is_distributed():
"""
Check if distributed training is initialized.
"""
global distributed_state # pylint: disable=global-statement
if not distributed_state:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))
return distributed_state.use_distributed and distributed_state.initialized
def barrier():
"""
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
if is_distributed():
dist.barrier()
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
return dist.get_rank() == 0
def get_world_size():
return int(os.getenv("WORLD_SIZE", "1"))
@contextmanager
def zero_only():
"""
Context manager that only runs the enclosed block on the main rank.
"""
if is_main_process():
yield
else:
yield None
@contextmanager
def zero_first(is_main):
"""
runs the wrapped context so that rank 0 runs first before other ranks
"""
if not is_main: # other ranks wait first
barrier()
yield
if is_main: # then rank 0 waits after it has run the context
barrier()
def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
if not is_main_process():
dist.gather(value_tensor, dst=0)
else:
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None
def broadcast_dict(vals: dict):
if not is_distributed():
return vals
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
dist.broadcast(data_size, 0)
if not is_main_process():
# resize
data_tensor = data_tensor.new_empty([data_size.item()])
dist.broadcast(data_tensor, 0)
if not is_main_process():
data_list = data_tensor.cpu().tolist()
data_byte = bytes(data_list[: data_size.item()])
vals = pickle.loads(data_byte) # nosec
return vals
def compute_and_broadcast(fn): # pylint: disable=invalid-name
"""
Compute a value using the function 'fn' only on the specified rank (default is 0).
The value is then broadcasted to all other ranks.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that computes the value. Default is 0.
Returns:
- The computed value (int or float).
"""
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
else:
value_tensor = torch.tensor(
0.0, device=torch.cuda.current_device()
) # Placeholder tensor
# Broadcast the tensor to all processes.
barrier()
dist.broadcast(value_tensor, src=0)
# Convert the tensor back to its original type (int or float)
if value_tensor == value_tensor.int():
return int(value_tensor.item())
return float(value_tensor.item())
def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
Run a callable 'fn' on all ranks and gather the results on the specified rank.
Args:
- fn (callable): A function that computes the value. This should not have any side effects.
- rank (int, optional): The rank that gathers the values. Default is 0.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- A list of computed values from all ranks if on the gathering rank, otherwise None.
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=torch.cuda.current_device()
).float()
# Placeholder tensor for gathering results
if is_main_process():
gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
else:
gathered_tensors = None
dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)
if is_main_process():
# Convert tensors back to their original type (int or float)
gathered_values = []
for tensor in gathered_tensors:
if tensor == tensor.int():
gathered_values.append(int(tensor.item()))
else:
gathered_values.append(float(tensor.item()))
return gathered_values
return None
def reduce_and_broadcast(fn1, fn2):
"""
Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
and then broadcast the reduced result to all ranks.
Args:
- fn1 (callable): A function that computes the value on each rank.
- fn2 (callable): A reduction function that takes a list of values and returns a single value.
- world_size (int, optional): Total number of processes in the current distributed setup.
Returns:
- The reduced and broadcasted value.
"""
# Gather values from all ranks using fn1
if not is_distributed():
return fn2([fn1()])
gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())
# Use compute_and_broadcast to compute the reduced value on the main process
# and then broadcast it to all ranks
return compute_and_broadcast(lambda: fn2(gathered_values))
|