Kian Kenyon-Dean commited on
Commit
560d738
·
unverified ·
1 Parent(s): 3d9eac1

Reformat and add comments (#9)

Browse files
Files changed (8) hide show
  1. README.md +1 -1
  2. config.yaml +1 -0
  3. loss.py +13 -4
  4. mae_modules.py +1 -0
  5. mae_utils.py +8 -2
  6. masking.py +7 -2
  7. vit.py +34 -9
  8. vit_encoder.py +1 -0
README.md CHANGED
@@ -34,7 +34,7 @@ def vit_base_patch16_256(**kwargs):
34
  ```
35
 
36
  ## Provided models
37
- A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling for you: https://www.rxrx.ai/phenom
38
 
39
  We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
40
  - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
 
34
  ```
35
 
36
  ## Provided models
37
+ A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling: https://www.rxrx.ai/phenom
38
 
39
  We have partnered with Nvidia to host a publicly-available smaller and more flexible version of the MAE phenomics foundation model, called Phenom-Beta. Interested parties can access it directly through the Nvidia BioNemo API:
40
  - https://blogs.nvidia.com/blog/drug-discovery-bionemo-generative-ai/
config.yaml CHANGED
@@ -1,3 +1,4 @@
 
1
  loss:
2
  _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
3
  reduction: none
 
1
+ # © Recursion Pharmaceuticals 2024
2
  loss:
3
  _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results
4
  reduction: none
loss.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import torch.nn as nn
3
 
@@ -16,7 +17,9 @@ class FourierLoss(nn.Module):
16
  output of this loss be managed by the model under question.
17
  """
18
  super().__init__()
19
- self.loss = nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
 
 
20
  self.num_modalities = num_multimodal_modalities
21
 
22
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
@@ -34,7 +37,9 @@ class FourierLoss(nn.Module):
34
  H_W = h * w
35
 
36
  if len(input.shape) != len(target.shape) != 4:
37
- raise ValueError(f"Invalid input shape: got {input.shape} and {target.shape}.")
 
 
38
 
39
  fft_reconstructed = torch.fft.fft2(input)
40
  fft_original = torch.fft.fft2(target)
@@ -42,9 +47,13 @@ class FourierLoss(nn.Module):
42
  magnitude_reconstructed = torch.abs(fft_reconstructed)
43
  magnitude_original = torch.abs(fft_original)
44
 
45
- loss_tensor: torch.Tensor = self.loss(magnitude_reconstructed, magnitude_original)
 
 
46
 
47
- if flattened_images and not self.num_bins: # then output loss should be reshaped
 
 
48
  loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
49
 
50
  return loss_tensor
 
1
+ # © Recursion Pharmaceuticals 2024
2
  import torch
3
  import torch.nn as nn
4
 
 
17
  output of this loss be managed by the model under question.
18
  """
19
  super().__init__()
20
+ self.loss = (
21
+ nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none")
22
+ )
23
  self.num_modalities = num_multimodal_modalities
24
 
25
  def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 
37
  H_W = h * w
38
 
39
  if len(input.shape) != len(target.shape) != 4:
40
+ raise ValueError(
41
+ f"Invalid input shape: got {input.shape} and {target.shape}."
42
+ )
43
 
44
  fft_reconstructed = torch.fft.fft2(input)
45
  fft_original = torch.fft.fft2(target)
 
47
  magnitude_reconstructed = torch.abs(fft_reconstructed)
48
  magnitude_original = torch.abs(fft_original)
49
 
50
+ loss_tensor: torch.Tensor = self.loss(
51
+ magnitude_reconstructed, magnitude_original
52
+ )
53
 
54
+ if (
55
+ flattened_images and not self.num_bins
56
+ ): # then output loss should be reshaped
57
  loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C)
58
 
59
  return loss_tensor
mae_modules.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import partial
2
  from typing import Tuple, Union
3
 
 
1
+ # © Recursion Pharmaceuticals 2024
2
  from functools import partial
3
  from typing import Tuple, Union
4
 
mae_utils.py CHANGED
@@ -1,9 +1,12 @@
 
1
  import math
2
 
3
  import torch
4
 
5
 
6
- def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool = False) -> torch.Tensor:
 
 
7
  """
8
  Flattens 2D images into tokens with the same pixel values
9
 
@@ -33,7 +36,10 @@ def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool =
33
 
34
 
35
  def unflatten_tokens(
36
- tokens: torch.Tensor, patch_size: int, num_modalities: int = 1, channel_agnostic: bool = False
 
 
 
37
  ) -> torch.Tensor:
38
  """
39
  Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
 
1
+ # © Recursion Pharmaceuticals 2024
2
  import math
3
 
4
  import torch
5
 
6
 
7
+ def flatten_images(
8
+ img: torch.Tensor, patch_size: int, channel_agnostic: bool = False
9
+ ) -> torch.Tensor:
10
  """
11
  Flattens 2D images into tokens with the same pixel values
12
 
 
36
 
37
 
38
  def unflatten_tokens(
39
+ tokens: torch.Tensor,
40
+ patch_size: int,
41
+ num_modalities: int = 1,
42
+ channel_agnostic: bool = False,
43
  ) -> torch.Tensor:
44
  """
45
  Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values
masking.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Tuple, Union
2
 
3
  import torch
@@ -36,11 +37,15 @@ def transformer_random_masking(
36
 
37
  # get masked input
38
  tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
39
- x_masked = torch.gather(x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D))
 
 
40
 
41
  # get binary mask used for loss masking: 0 is keep, 1 is remove
42
  mask = torch.ones([N, L], device=x.device)
43
  mask[:, :len_keep] = 0
44
- mask = torch.gather(mask, dim=1, index=ind_restore) # unshuffle to get the binary mask
 
 
45
 
46
  return x_masked, mask, ind_restore
 
1
+ # © Recursion Pharmaceuticals 2024
2
  from typing import Tuple, Union
3
 
4
  import torch
 
37
 
38
  # get masked input
39
  tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices
40
+ x_masked = torch.gather(
41
+ x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)
42
+ )
43
 
44
  # get binary mask used for loss masking: 0 is keep, 1 is remove
45
  mask = torch.ones([N, L], device=x.device)
46
  mask[:, :len_keep] = 0
47
+ mask = torch.gather(
48
+ mask, dim=1, index=ind_restore
49
+ ) # unshuffle to get the binary mask
50
 
51
  return x_masked, mask, ind_restore
vit.py CHANGED
@@ -1,9 +1,14 @@
 
1
  import timm.models.vision_transformer as vit
2
  import torch
3
 
4
 
5
  def generate_2d_sincos_pos_embeddings(
6
- embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1
 
 
 
 
7
  ) -> torch.nn.Parameter:
8
  """
9
  Generate 2Dimensional sin/cosine positional embeddings
@@ -30,16 +35,25 @@ def generate_2d_sincos_pos_embeddings(
30
  """
31
 
32
  linear_positions = torch.arange(length, dtype=torch.float32)
33
- height_mesh, width_mesh = torch.meshgrid(linear_positions, linear_positions, indexing="ij")
 
 
34
  positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
35
- positional_weights = torch.arange(positional_dim, dtype=torch.float32) / positional_dim
 
 
36
  positional_weights = 1.0 / (scale**positional_weights)
37
 
38
  height_weights = torch.outer(height_mesh.flatten(), positional_weights)
39
  width_weights = torch.outer(width_mesh.flatten(), positional_weights)
40
 
41
  positional_encoding = torch.cat(
42
- [torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights)],
 
 
 
 
 
43
  dim=1,
44
  )[None, :, :]
45
 
@@ -73,11 +87,15 @@ class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc]
73
  bias=bias,
74
  )
75
  # channel-agnostic MAE has a single projection for all chans
76
- self.proj = torch.nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
 
 
77
 
78
  def forward(self, x: torch.Tensor) -> torch.Tensor:
79
  in_chans = x.shape[1]
80
- x = torch.stack([self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2) # single project for all chans
 
 
81
  x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
82
  return x
83
 
@@ -106,7 +124,9 @@ class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc]
106
  return self.pos_drop(x) # type: ignore[no-any-return]
107
 
108
 
109
- def channel_agnostic_vit(vit_backbone: vit.VisionTransformer, max_in_chans: int) -> vit.VisionTransformer:
 
 
110
  # replace patch embedding with channel-agnostic version
111
  vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
112
  img_size=vit_backbone.patch_embed.img_size[0],
@@ -145,9 +165,14 @@ def sincos_positional_encoding_vit(
145
  the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
146
  """
147
  # length: number of tokens along height or width of image after patching (assuming square)
148
- length = vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
 
 
149
  pos_embeddings = generate_2d_sincos_pos_embeddings(
150
- vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None
 
 
 
151
  )
152
  # note, if the model had weight_init == 'skip', this might get overwritten
153
  vit_backbone.pos_embed = pos_embeddings
 
1
+ # © Recursion Pharmaceuticals 2024
2
  import timm.models.vision_transformer as vit
3
  import torch
4
 
5
 
6
  def generate_2d_sincos_pos_embeddings(
7
+ embedding_dim: int,
8
+ length: int,
9
+ scale: float = 10000.0,
10
+ use_class_token: bool = True,
11
+ num_modality: int = 1,
12
  ) -> torch.nn.Parameter:
13
  """
14
  Generate 2Dimensional sin/cosine positional embeddings
 
35
  """
36
 
37
  linear_positions = torch.arange(length, dtype=torch.float32)
38
+ height_mesh, width_mesh = torch.meshgrid(
39
+ linear_positions, linear_positions, indexing="ij"
40
+ )
41
  positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings
42
+ positional_weights = (
43
+ torch.arange(positional_dim, dtype=torch.float32) / positional_dim
44
+ )
45
  positional_weights = 1.0 / (scale**positional_weights)
46
 
47
  height_weights = torch.outer(height_mesh.flatten(), positional_weights)
48
  width_weights = torch.outer(width_mesh.flatten(), positional_weights)
49
 
50
  positional_encoding = torch.cat(
51
+ [
52
+ torch.sin(height_weights),
53
+ torch.cos(height_weights),
54
+ torch.sin(width_weights),
55
+ torch.cos(width_weights),
56
+ ],
57
  dim=1,
58
  )[None, :, :]
59
 
 
87
  bias=bias,
88
  )
89
  # channel-agnostic MAE has a single projection for all chans
90
+ self.proj = torch.nn.Conv2d(
91
+ 1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
92
+ )
93
 
94
  def forward(self, x: torch.Tensor) -> torch.Tensor:
95
  in_chans = x.shape[1]
96
+ x = torch.stack(
97
+ [self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2
98
+ ) # single project for all chans
99
  x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC
100
  return x
101
 
 
124
  return self.pos_drop(x) # type: ignore[no-any-return]
125
 
126
 
127
+ def channel_agnostic_vit(
128
+ vit_backbone: vit.VisionTransformer, max_in_chans: int
129
+ ) -> vit.VisionTransformer:
130
  # replace patch embedding with channel-agnostic version
131
  vit_backbone.patch_embed = ChannelAgnosticPatchEmbed(
132
  img_size=vit_backbone.patch_embed.img_size[0],
 
165
  the same ViT but with fixed no-grad positional encodings to add to vit patch encodings
166
  """
167
  # length: number of tokens along height or width of image after patching (assuming square)
168
+ length = (
169
+ vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0]
170
+ )
171
  pos_embeddings = generate_2d_sincos_pos_embeddings(
172
+ vit_backbone.embed_dim,
173
+ length=length,
174
+ scale=scale,
175
+ use_class_token=vit_backbone.cls_token is not None,
176
  )
177
  # note, if the model had weight_init == 'skip', this might get overwritten
178
  vit_backbone.pos_embed = pos_embeddings
vit_encoder.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from typing import Dict
2
 
3
  import timm.models.vision_transformer as vit
 
1
+ # © Recursion Pharmaceuticals 2024
2
  from typing import Dict
3
 
4
  import timm.models.vision_transformer as vit