SengTak's picture
added necesary files-1
55890ea verified
raw
history blame
3.35 kB
import torch
from torch import Tensor
class AllGather(torch.autograd.Function):
"""
all_gather with gradient back-propagation
"""
@staticmethod
def forward(ctx, tensor_list, tensor, group, async_op):
torch.distributed.all_gather(
tensor_list, tensor, group=group, async_op=async_op
)
return tuple(tensor_list)
@staticmethod
def backward(ctx, *grad_list):
grad_list = list(grad_list)
rank = torch.distributed.get_rank()
dist_ops = [
torch.distributed.reduce(grad_list[i], i, async_op=True)
for i in range(torch.distributed.get_world_size())
]
for op in dist_ops:
op.wait()
return None, grad_list[rank], None, None
all_gather_with_grad = AllGather.apply
def cos_sim(a: Tensor, b: Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
if len(a.shape) == 1:
a = a.unsqueeze(0)
if len(b.shape) == 1:
b = b.unsqueeze(0)
a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))
def mismatched_sizes_all_gather(
tensor: Tensor, group=None, async_op=False, mismatched_axis=0
):
# all_gather doesn't support tensor lists where the first dimension is mismatched. This does.
assert torch.distributed.is_initialized(), "torch.distributed not initialized"
world_size = torch.distributed.get_world_size()
# let's get the sizes for everyone
mismatched_sizes = torch.tensor(
[tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda"
)
sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)]
torch.distributed.all_gather(
sizes, mismatched_sizes, group=group, async_op=async_op
)
sizes = torch.cat(sizes).cpu().tolist()
# now pad to the max dim-0 size
max_size = max(sizes)
padded = torch.zeros(
(
*tensor.shape[:mismatched_axis],
max_size,
*tensor.shape[mismatched_axis + 1 :],
),
device=tensor.device,
dtype=tensor.dtype,
)
# selects the place where we're adding information
padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis])
padded_to_fill[...] = tensor
# gather the padded tensors
tensor_list = [
torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype)
for _ in range(world_size)
]
all_gather_with_grad(tensor_list, padded, group, async_op)
# trim off the padding
for rank in range(world_size):
# checks that the rest is 0
assert (
not tensor_list[rank]
.narrow(
mismatched_axis,
sizes[rank],
padded.shape[mismatched_axis] - sizes[rank],
)
.count_nonzero()
.is_nonzero()
), "This would remove non-padding information"
tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank])
return tensor_list