import torch from torch import Tensor class AllGather(torch.autograd.Function): """ all_gather with gradient back-propagation """ @staticmethod def forward(ctx, tensor_list, tensor, group, async_op): torch.distributed.all_gather( tensor_list, tensor, group=group, async_op=async_op ) return tuple(tensor_list) @staticmethod def backward(ctx, *grad_list): grad_list = list(grad_list) rank = torch.distributed.get_rank() dist_ops = [ torch.distributed.reduce(grad_list[i], i, async_op=True) for i in range(torch.distributed.get_world_size()) ] for op in dist_ops: op.wait() return None, grad_list[rank], None, None all_gather_with_grad = AllGather.apply def cos_sim(a: Tensor, b: Tensor): """ Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j. :return: Matrix with res[i][j] = cos_sim(a[i], b[j]) """ if not isinstance(a, torch.Tensor): a = torch.tensor(a) if not isinstance(b, torch.Tensor): b = torch.tensor(b) if len(a.shape) == 1: a = a.unsqueeze(0) if len(b.shape) == 1: b = b.unsqueeze(0) a_norm = torch.nn.functional.normalize(a, p=2, dim=1) b_norm = torch.nn.functional.normalize(b, p=2, dim=1) return torch.mm(a_norm, b_norm.transpose(0, 1)) def mismatched_sizes_all_gather( tensor: Tensor, group=None, async_op=False, mismatched_axis=0 ): # all_gather doesn't support tensor lists where the first dimension is mismatched. This does. assert torch.distributed.is_initialized(), "torch.distributed not initialized" world_size = torch.distributed.get_world_size() # let's get the sizes for everyone mismatched_sizes = torch.tensor( [tensor.shape[mismatched_axis]], dtype=torch.int64, device="cuda" ) sizes = [torch.zeros_like(mismatched_sizes) for _ in range(world_size)] torch.distributed.all_gather( sizes, mismatched_sizes, group=group, async_op=async_op ) sizes = torch.cat(sizes).cpu().tolist() # now pad to the max dim-0 size max_size = max(sizes) padded = torch.zeros( ( *tensor.shape[:mismatched_axis], max_size, *tensor.shape[mismatched_axis + 1 :], ), device=tensor.device, dtype=tensor.dtype, ) # selects the place where we're adding information padded_to_fill = padded.narrow(mismatched_axis, 0, tensor.shape[mismatched_axis]) padded_to_fill[...] = tensor # gather the padded tensors tensor_list = [ torch.zeros(padded.shape, device=padded.device, dtype=padded.dtype) for _ in range(world_size) ] all_gather_with_grad(tensor_list, padded, group, async_op) # trim off the padding for rank in range(world_size): # checks that the rest is 0 assert ( not tensor_list[rank] .narrow( mismatched_axis, sizes[rank], padded.shape[mismatched_axis] - sizes[rank], ) .count_nonzero() .is_nonzero() ), "This would remove non-padding information" tensor_list[rank] = tensor_list[rank].narrow(mismatched_axis, 0, sizes[rank]) return tensor_list