File size: 1,671 Bytes
6ded986 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# © Recursion Pharmaceuticals 2024
from typing import Tuple, Union
import torch
def transformer_random_masking(
x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Random mask patches per sample
Parameters
----------
x : token tensor (N, L, D)
mask_ratio: float - ratio of image to mask
constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks
Returns
-------
x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D)
mask : binary mask indicated masked tokens (1 where masked) (N, L)
ind_restore : locations of masked tokens, needed for decoder
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio))
# use random noise to generate batch based random masks
if constant_noise is not None:
noise = constant_noise
else:
noise = torch.rand(N, L, device=x.device)
shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index
ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index
# get masked input
tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
x_masked = torch.gather(
x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
)
# get binary mask used for loss masking: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
mask = torch.gather(
mask, dim=1, index=ind_restore
) # unshuffle to get the binary mask
return x_masked, mask, ind_restore
|