Kian Kenyon-Dean
commited on
Reformat and add comments (#9)
Browse files- README.md +1 -1
- config.yaml +1 -0
- loss.py +13 -4
- mae_modules.py +1 -0
- mae_utils.py +8 -2
- masking.py +7 -2
- vit.py +34 -9
- vit_encoder.py +1 -0
README.md
CHANGED
@@ -34,7 +34,7 @@ def vit_base_patch16_256(**kwargs):
|
|
34 |
```
|
35 |
|
36 |
## Provided models
|
37 |
-
A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling
|
38 |
|
39 |
We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
|
40 |
- https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
|
|
|
34 |
```
|
35 |
|
36 |
## Provided models
|
37 |
+
A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling: https://www.rxrx.ai/phenom
|
38 |
|
39 |
We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
|
40 |
- https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
|
config.yaml
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
loss:
|
2 |
_target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
|
3 |
reduction: none
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
loss:
|
3 |
_target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
|
4 |
reduction: none
|
loss.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
@@ -16,7 +17,9 @@ class FourierLoss(nn.Module):
|
|
16 |
output of this loss be managed by the model under question.
|
17 |
"""
|
18 |
super().__init__()
|
19 |
-
self.loss =
|
|
|
|
|
20 |
self.num_modalities = num_multimodal_modalities
|
21 |
|
22 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
@@ -34,7 +37,9 @@ class FourierLoss(nn.Module):
|
|
34 |
H_W = h * w
|
35 |
|
36 |
if len(input.shape) != len(target.shape) != 4:
|
37 |
-
raise ValueError(
|
|
|
|
|
38 |
|
39 |
fft_reconstructed = torch.fft.fft2(input)
|
40 |
fft_original = torch.fft.fft2(target)
|
@@ -42,9 +47,13 @@ class FourierLoss(nn.Module):
|
|
42 |
magnitude_reconstructed = torch.abs(fft_reconstructed)
|
43 |
magnitude_original = torch.abs(fft_original)
|
44 |
|
45 |
-
loss_tensor: torch.Tensor = self.loss(
|
|
|
|
|
46 |
|
47 |
-
if
|
|
|
|
|
48 |
loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
|
49 |
|
50 |
return loss_tensor
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
|
|
|
17 |
output of this loss be managed by the model under question.
|
18 |
"""
|
19 |
super().__init__()
|
20 |
+
self.loss = (
|
21 |
+
nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
|
22 |
+
)
|
23 |
self.num_modalities = num_multimodal_modalities
|
24 |
|
25 |
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
37 |
H_W = h * w
|
38 |
|
39 |
if len(input.shape) != len(target.shape) != 4:
|
40 |
+
raise ValueError(
|
41 |
+
f"Invalid input shape: got {input.shape} and {target.shape}."
|
42 |
+
)
|
43 |
|
44 |
fft_reconstructed = torch.fft.fft2(input)
|
45 |
fft_original = torch.fft.fft2(target)
|
|
|
47 |
magnitude_reconstructed = torch.abs(fft_reconstructed)
|
48 |
magnitude_original = torch.abs(fft_original)
|
49 |
|
50 |
+
loss_tensor: torch.Tensor = self.loss(
|
51 |
+
magnitude_reconstructed, magnitude_original
|
52 |
+
)
|
53 |
|
54 |
+
if (
|
55 |
+
flattened_images and not self.num_bins
|
56 |
+
): # then output loss should be reshaped
|
57 |
loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
|
58 |
|
59 |
return loss_tensor
|
mae_modules.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from functools import partial
|
2 |
from typing import Tuple, Union
|
3 |
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
from functools import partial
|
3 |
from typing import Tuple, Union
|
4 |
|
mae_utils.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
1 |
import math
|
2 |
|
3 |
import torch
|
4 |
|
5 |
|
6 |
-
def flatten_images(
|
|
|
|
|
7 |
"""
|
8 |
Flattens 2D images into tokens with the same pixel values
|
9 |
|
@@ -33,7 +36,10 @@ def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool =
|
|
33 |
|
34 |
|
35 |
def unflatten_tokens(
|
36 |
-
tokens: torch.Tensor,
|
|
|
|
|
|
|
37 |
) -> torch.Tensor:
|
38 |
"""
|
39 |
Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
import math
|
3 |
|
4 |
import torch
|
5 |
|
6 |
|
7 |
+
def flatten_images(
|
8 |
+
img: torch.Tensor, patch_size: int, channel_agnostic: bool = False
|
9 |
+
) -> torch.Tensor:
|
10 |
"""
|
11 |
Flattens 2D images into tokens with the same pixel values
|
12 |
|
|
|
36 |
|
37 |
|
38 |
def unflatten_tokens(
|
39 |
+
tokens: torch.Tensor,
|
40 |
+
patch_size: int,
|
41 |
+
num_modalities: int = 1,
|
42 |
+
channel_agnostic: bool = False,
|
43 |
) -> torch.Tensor:
|
44 |
"""
|
45 |
Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
|
masking.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Tuple, Union
|
2 |
|
3 |
import torch
|
@@ -36,11 +37,15 @@ def transformer_random_masking(
|
|
36 |
|
37 |
# get masked input
|
38 |
tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
|
39 |
-
x_masked = torch.gather(
|
|
|
|
|
40 |
|
41 |
# get binary mask used for loss masking: 0 is keep, 1 is remove
|
42 |
mask = torch.ones([N, L], device=x.device)
|
43 |
mask[:, :len_keep] = 0
|
44 |
-
mask = torch.gather(
|
|
|
|
|
45 |
|
46 |
return x_masked, mask, ind_restore
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
from typing import Tuple, Union
|
3 |
|
4 |
import torch
|
|
|
37 |
|
38 |
# get masked input
|
39 |
tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
|
40 |
+
x_masked = torch.gather(
|
41 |
+
x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
|
42 |
+
)
|
43 |
|
44 |
# get binary mask used for loss masking: 0 is keep, 1 is remove
|
45 |
mask = torch.ones([N, L], device=x.device)
|
46 |
mask[:, :len_keep] = 0
|
47 |
+
mask = torch.gather(
|
48 |
+
mask, dim=1, index=ind_restore
|
49 |
+
) # unshuffle to get the binary mask
|
50 |
|
51 |
return x_masked, mask, ind_restore
|
vit.py
CHANGED
@@ -1,9 +1,14 @@
|
|
|
|
1 |
import timm.models.vision_transformer as vit
|
2 |
import torch
|
3 |
|
4 |
|
5 |
def generate_2d_sincos_pos_embeddings(
|
6 |
-
embedding_dim: int,
|
|
|
|
|
|
|
|
|
7 |
) -> torch.nn.Parameter:
|
8 |
"""
|
9 |
Generate 2Dimensional sin/cosine positional embeddings
|
@@ -30,16 +35,25 @@ def generate_2d_sincos_pos_embeddings(
|
|
30 |
"""
|
31 |
|
32 |
linear_positions = torch.arange(length, dtype=torch.float32)
|
33 |
-
height_mesh, width_mesh = torch.meshgrid(
|
|
|
|
|
34 |
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
|
35 |
-
positional_weights =
|
|
|
|
|
36 |
positional_weights = 1.0 / (scale**positional_weights)
|
37 |
|
38 |
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
|
39 |
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
|
40 |
|
41 |
positional_encoding = torch.cat(
|
42 |
-
[
|
|
|
|
|
|
|
|
|
|
|
43 |
dim=1,
|
44 |
)[None, :, :]
|
45 |
|
@@ -73,11 +87,15 @@ class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
|
|
73 |
bias=bias,
|
74 |
)
|
75 |
# channel-agnostic MAE has a single projection for all chans
|
76 |
-
self.proj = torch.nn.Conv2d(
|
|
|
|
|
77 |
|
78 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
79 |
in_chans = x.shape[1]
|
80 |
-
x = torch.stack(
|
|
|
|
|
81 |
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
|
82 |
return x
|
83 |
|
@@ -106,7 +124,9 @@ class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
|
|
106 |
return self.pos_drop(x) # type: ignore[no-any-return]
|
107 |
|
108 |
|
109 |
-
def channel_agnostic_vit(
|
|
|
|
|
110 |
# replace patch embedding with channel-agnostic version
|
111 |
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
|
112 |
img_size=vit_backbone.patch_embed.img_size[0],
|
@@ -145,9 +165,14 @@ def sincos_positional_encoding_vit(
|
|
145 |
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
|
146 |
"""
|
147 |
# length: number of tokens along height or width of image after patching (assuming square)
|
148 |
-
length =
|
|
|
|
|
149 |
pos_embeddings = generate_2d_sincos_pos_embeddings(
|
150 |
-
vit_backbone.embed_dim,
|
|
|
|
|
|
|
151 |
)
|
152 |
# note, if the model had weight_init == 'skip', this might get overwritten
|
153 |
vit_backbone.pos_embed = pos_embeddings
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
import timm.models.vision_transformer as vit
|
3 |
import torch
|
4 |
|
5 |
|
6 |
def generate_2d_sincos_pos_embeddings(
|
7 |
+
embedding_dim: int,
|
8 |
+
length: int,
|
9 |
+
scale: float = 10000.0,
|
10 |
+
use_class_token: bool = True,
|
11 |
+
num_modality: int = 1,
|
12 |
) -> torch.nn.Parameter:
|
13 |
"""
|
14 |
Generate 2Dimensional sin/cosine positional embeddings
|
|
|
35 |
"""
|
36 |
|
37 |
linear_positions = torch.arange(length, dtype=torch.float32)
|
38 |
+
height_mesh, width_mesh = torch.meshgrid(
|
39 |
+
linear_positions, linear_positions, indexing="ij"
|
40 |
+
)
|
41 |
positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
|
42 |
+
positional_weights = (
|
43 |
+
torch.arange(positional_dim, dtype=torch.float32) / positional_dim
|
44 |
+
)
|
45 |
positional_weights = 1.0 / (scale**positional_weights)
|
46 |
|
47 |
height_weights = torch.outer(height_mesh.flatten(), positional_weights)
|
48 |
width_weights = torch.outer(width_mesh.flatten(), positional_weights)
|
49 |
|
50 |
positional_encoding = torch.cat(
|
51 |
+
[
|
52 |
+
torch.sin(height_weights),
|
53 |
+
torch.cos(height_weights),
|
54 |
+
torch.sin(width_weights),
|
55 |
+
torch.cos(width_weights),
|
56 |
+
],
|
57 |
dim=1,
|
58 |
)[None, :, :]
|
59 |
|
|
|
87 |
bias=bias,
|
88 |
)
|
89 |
# channel-agnostic MAE has a single projection for all chans
|
90 |
+
self.proj = torch.nn.Conv2d(
|
91 |
+
1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
92 |
+
)
|
93 |
|
94 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
in_chans = x.shape[1]
|
96 |
+
x = torch.stack(
|
97 |
+
[self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
|
98 |
+
) # single project for all chans
|
99 |
x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
|
100 |
return x
|
101 |
|
|
|
124 |
return self.pos_drop(x) # type: ignore[no-any-return]
|
125 |
|
126 |
|
127 |
+
def channel_agnostic_vit(
|
128 |
+
vit_backbone: vit.VisionTransformer, max_in_chans: int
|
129 |
+
) -> vit.VisionTransformer:
|
130 |
# replace patch embedding with channel-agnostic version
|
131 |
vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
|
132 |
img_size=vit_backbone.patch_embed.img_size[0],
|
|
|
165 |
the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
|
166 |
"""
|
167 |
# length: number of tokens along height or width of image after patching (assuming square)
|
168 |
+
length = (
|
169 |
+
vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
|
170 |
+
)
|
171 |
pos_embeddings = generate_2d_sincos_pos_embeddings(
|
172 |
+
vit_backbone.embed_dim,
|
173 |
+
length=length,
|
174 |
+
scale=scale,
|
175 |
+
use_class_token=vit_backbone.cls_token is not None,
|
176 |
)
|
177 |
# note, if the model had weight_init == 'skip', this might get overwritten
|
178 |
vit_backbone.pos_embed = pos_embeddings
|
vit_encoder.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from typing import Dict
|
2 |
|
3 |
import timm.models.vision_transformer as vit
|
|
|
1 |
+
# © Recursion Pharmaceuticals 2024
|
2 |
from typing import Dict
|
3 |
|
4 |
import timm.models.vision_transformer as vit
|