File size: 7,138 Bytes
2bb0b78
 
 
7657632
e30f1e3
fc2d6be
132eb74
fc2d6be
7657632
2bb0b78
132eb74
2bb0b78
132eb74
2bb0b78
 
 
 
 
 
132eb74
 
 
 
 
 
2bb0b78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc2d6be
 
7657632
 
 
 
4c834bf
 
 
 
 
 
 
 
 
 
 
fc2d6be
 
 
 
 
 
 
 
 
 
7657632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09f1543
 
2fe95cd
 
 
7657632
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e30f1e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b15b19e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fe95cd
 
 
b15b19e
2fe95cd
 
 
b15b19e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fe95cd
 
 
b15b19e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
utility helpers for distributed checks
"""
import os
import pickle  # nosec
from contextlib import contextmanager
from datetime import timedelta

import torch
import torch.distributed as dist
from accelerate import PartialState

distributed_state = None  # pylint: disable=invalid-name


def is_distributed():
    """
    Check if distributed training is initialized.
    """
    global distributed_state  # pylint: disable=global-statement
    if not distributed_state:
        timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
        distributed_state = PartialState(timeout=timedelta(seconds=timeout))

    return distributed_state.use_distributed and distributed_state.initialized


def barrier():
    """
    Acts as a barrier to wait for all processes. This ensures that all processes
    reach the barrier before proceeding further.
    """
    if is_distributed():
        dist.barrier()


def is_main_process():
    """
    Check if the current process is the main process.
    If not in distributed mode, always return True.
    """
    if not is_distributed():
        return True
    return dist.get_rank() == 0


def get_world_size():
    return int(os.getenv("WORLD_SIZE", "1"))


@contextmanager
def zero_only():
    """
    Context manager that only runs the enclosed block on the main rank.
    """
    if is_main_process():
        yield
    else:
        yield None


@contextmanager
def zero_first(is_main):
    """
    runs the wrapped context so that rank 0 runs first before other ranks
    """
    if not is_main:  # other ranks wait first
        barrier()
    yield
    if is_main:  # then rank 0 waits after it has run the context
        barrier()


def gather_scalar_from_all_ranks(fn, world_size=1):  # pylint: disable=invalid-name
    """
    Run a callable 'fn' on all ranks and gather the results on the specified rank.

    Args:
    - fn (callable): A function that computes the value. This should not have any side effects.
    - rank (int, optional): The rank that gathers the values. Default is 0.
    - world_size (int, optional): Total number of processes in the current distributed setup.

    Returns:
    - A list of computed values from all ranks if on the gathering rank, otherwise None.
    """
    value_scalar = fn()
    if not is_distributed():
        return [value_scalar]
    value_tensor = torch.tensor(
        value_scalar, device=torch.cuda.current_device()
    ).float()

    if not is_main_process():
        dist.gather(value_tensor, dst=0)
    else:
        gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
        dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

        # Convert tensors back to their original type (int or float)
        gathered_values = []
        for tensor in gathered_tensors:
            if tensor == tensor.int():
                gathered_values.append(int(tensor.item()))
            else:
                gathered_values.append(float(tensor.item()))
        return gathered_values
    return None


def broadcast_dict(vals: dict):
    if not is_distributed():
        return vals

    if is_main_process():
        data_byte = pickle.dumps(vals)
        data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
        data_size = torch.IntTensor([len(data_byte)]).to("cuda")
    else:
        data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
        data_size = torch.IntTensor([0]).to("cuda")

    dist.broadcast(data_size, 0)
    if not is_main_process():
        # resize
        data_tensor = data_tensor.new_empty([data_size.item()])

    dist.broadcast(data_tensor, 0)

    if not is_main_process():
        data_list = data_tensor.cpu().tolist()
        data_byte = bytes(data_list[: data_size.item()])
        vals = pickle.loads(data_byte)  # nosec

    return vals


def compute_and_broadcast(fn):  # pylint: disable=invalid-name
    """
    Compute a value using the function 'fn' only on the specified rank (default is 0).
    The value is then broadcasted to all other ranks.

    Args:
    - fn (callable): A function that computes the value. This should not have any side effects.
    - rank (int, optional): The rank that computes the value. Default is 0.

    Returns:
    - The computed value (int or float).
    """
    if is_main_process():
        value_scalar = fn()
        value_tensor = torch.tensor(
            value_scalar, device=torch.cuda.current_device()
        ).float()
    else:
        value_tensor = torch.tensor(
            0.0, device=torch.cuda.current_device()
        )  # Placeholder tensor

    # Broadcast the tensor to all processes.
    barrier()
    dist.broadcast(value_tensor, src=0)

    # Convert the tensor back to its original type (int or float)
    if value_tensor == value_tensor.int():
        return int(value_tensor.item())
    return float(value_tensor.item())


def gather_from_all_ranks(fn, world_size=1):  # pylint: disable=invalid-name
    """
    Run a callable 'fn' on all ranks and gather the results on the specified rank.

    Args:
    - fn (callable): A function that computes the value. This should not have any side effects.
    - rank (int, optional): The rank that gathers the values. Default is 0.
    - world_size (int, optional): Total number of processes in the current distributed setup.

    Returns:
    - A list of computed values from all ranks if on the gathering rank, otherwise None.
    """
    value_scalar = fn()
    value_tensor = torch.tensor(
        value_scalar, device=torch.cuda.current_device()
    ).float()

    # Placeholder tensor for gathering results
    if is_main_process():
        gathered_tensors = [torch.zeros_like(value_tensor) for _ in range(world_size)]
    else:
        gathered_tensors = None

    dist.gather(value_tensor, gather_list=gathered_tensors, dst=0)

    if is_main_process():
        # Convert tensors back to their original type (int or float)
        gathered_values = []
        for tensor in gathered_tensors:
            if tensor == tensor.int():
                gathered_values.append(int(tensor.item()))
            else:
                gathered_values.append(float(tensor.item()))
        return gathered_values
    return None


def reduce_and_broadcast(fn1, fn2):
    """
    Run a callable 'fn1' on all ranks, gather the results, reduce them using 'fn2',
    and then broadcast the reduced result to all ranks.

    Args:
    - fn1 (callable): A function that computes the value on each rank.
    - fn2 (callable): A reduction function that takes a list of values and returns a single value.
    - world_size (int, optional): Total number of processes in the current distributed setup.

    Returns:
    - The reduced and broadcasted value.
    """

    # Gather values from all ranks using fn1
    if not is_distributed():
        return fn2([fn1()])

    gathered_values = gather_from_all_ranks(fn1, world_size=dist.get_world_size())

    # Use compute_and_broadcast to compute the reduced value on the main process
    # and then broadcast it to all ranks
    return compute_and_broadcast(lambda: fn2(gathered_values))