Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
class AllGather(torch.autograd.Function): | |
""" | |
all_gather with gradient back-propagation | |
""" | |
def forward(ctx, tensor_list, tensor, group, async_op): | |
torch.distributed.all_gather( | |
tensor_list, tensor, group=group, async_op=async_op | |
) | |
return tuple(tensor_list) | |
def backward(ctx, *grad_list): | |
grad_list = list(grad_list) | |
rank = torch.distributed.get_rank() | |
dist_ops = [ | |
torch.distributed.reduce(grad_list[i], i, async_op=True) | |
for i in range(torch.distributed.get_world_size()) | |
] | |
for op in dist_ops: | |
op.wait() | |
return None, grad_list[rank], None, None | |
all_gather_with_grad = AllGather.apply | |
def cos_sim(a: Tensor, b: Tensor): | |
""" | |
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. | |
:return: Matrix with res[i][j] = cos_sim(a[i], b[j]) | |
""" | |
if not isinstance(a, torch.Tensor): | |
a = torch.tensor(a) | |
if not isinstance(b, torch.Tensor): | |
b = torch.tensor(b) | |
if len(a.shape) == 1: | |
a = a.unsqueeze(0) | |
if len(b.shape) == 1: | |
b = b.unsqueeze(0) | |
a_norm = torch.nn.functional.normalize(a, p=2, dim=1) | |
b_norm = torch.nn.functional.normalize(b, p=2, dim=1) | |
return torch.mm(a_norm, b_norm.transpose(0, 1)) | |
def mismatched_sizes_all_gather( | |
tensor: Tensor, group=None, async_op=False, mismatched_axis=0 | |
): | |
# all_gather doesn't support tensor lists where the first dimension is mismatched. This does. | |
assert torch.distributed.is_initialized(), "torch.distributed not initialized" | |
world_size = torch.distributed.get_world_size() | |
# let's get the sizes for everyone | |
mismatched_sizes = torch.tensor( | |
[tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda" | |
) | |
sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] | |
torch.distributed.all_gather( | |
sizes, mismatched_sizes, group=group, async_op=async_op | |
) | |
sizes = torch.cat(sizes).cpu().tolist() | |
# now pad to the max dim-0 size | |
max_size = max(sizes) | |
padded = torch.zeros( | |
( | |
*tensor.shape[:mismatched_axis], | |
max_size, | |
*tensor.shape[mismatched_axis + 1 :], | |
), | |
device=tensor.device, | |
dtype=tensor.dtype, | |
) | |
# selects the place where we're adding information | |
padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) | |
padded_to_fill[...] = tensor | |
# gather the padded tensors | |
tensor_list = [ | |
torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) | |
for _ in range(world_size) | |
] | |
all_gather_with_grad(tensor_list, padded, group, async_op) | |
# trim off the padding | |
for rank in range(world_size): | |
# checks that the rest is 0 | |
assert ( | |
not tensor_list[rank] | |
.narrow( | |
mismatched_axis, | |
sizes[rank], | |
padded.shape[mismatched_axis] - sizes[rank], | |
) | |
.count_nonzero() | |
.is_nonzero() | |
), "This would remove non-padding information" | |
tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) | |
return tensor_list | |