recursionaut commited on
Commit
49f08fd
·
1 Parent(s): e635625

huggingface compatible changes

Browse files
MODELCARD.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Phenom CA-MAE-S/16
7
+
8
+ Channel-agnostic image encoding model designed for microscopy image featurization.
9
+ The model uses a vision transformer backbone with channelwise cross-attention over patch tokens to create contextualized representations separately for each channel.
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ This model is a [channel-agnostic masked autoencoder](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html) trained to reconstruct microscopy images over three datasets:
17
+ 1. RxRx3
18
+ 2. JUMP-CP overexpression
19
+ 3. JUMP-CP gene-knockouts
20
+
21
+ - **Developed, funded, and shared by:** Recursion
22
+ - **Model type:** Vision transformer CA-MAE
23
+ - **Image modality:** Optimized for microscopy images from the CellPainting assay
24
+ - **License:**
25
+
26
+
27
+ ### Model Sources
28
+
29
+ - **Repository:** [https://github.com/recursionpharma/maes_microscopy](https://github.com/recursionpharma/maes_microscopy)
30
+ - **Paper:** [Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology](https://openaccess.thecvf.com/content/CVPR2024/html/Kraus_Masked_Autoencoders_for_Microscopy_are_Scalable_Learners_of_Cellular_Biology_CVPR_2024_paper.html)
31
+
32
+
33
+ ## Uses
34
+
35
+ NOTE: model embeddings tend to extract features only after using standard batch correction post-processing techniques. **We recommend**, at a *minimum*, after inferencing the model over your images, to do the standard `PCA-CenterScale` pattern or better yet Typical Variation Normalization:
36
+
37
+ 1. Fit a PCA kernel on all the *control images* (or all images if no controls) from across all experimental batches (e.g. the plates of wells from your assay),
38
+ 2. Transform all the embeddings with that PCA kernel,
39
+ 3. For each experimental batch, fit a separate StandardScaler on the transformed embeddings of the controls from step 2, then transform the rest of the embeddings from that batch with that StandardScaler.
40
+
41
+ ### Direct Use
42
+
43
+ - Create biologically useful embeddings of microscopy images
44
+ - Create contextualized embeddings of each channel of a microscopy image (set `return_channelwise_embeddings=True`)
45
+ - Leverage the full MAE encoder + decoder to predict new channels / stains for images without all 6 CellPainting channels
46
+
47
+ ### Downstream Use
48
+
49
+ - A determined ML expert could fine-tune the encoder for downstream tasks such as classification
50
+
51
+ ### Out-of-Scope Use
52
+
53
+ - Unlikely to be especially performant on brightfield microscopy images
54
+ - Out-of-domain medical images, such as H&E (maybe it would be a decent baseline though)
55
+
56
+ ## Bias, Risks, and Limitations
57
+
58
+ - Primary limitation is that the embeddings tend to be more useful at scale. For example, if you only have 1 plate of microscopy images, the embeddings might underperform compared to a supervised bespoke model.
59
+
60
+ ## How to Get Started with the Model
61
+
62
+ You should be able to successfully run the below tests, which demonstrate how to use the model at inference time.
63
+
64
+ ```python
65
+ import pytest
66
+ import torch
67
+
68
+ from huggingface_mae import MAEModel
69
+
70
+ huggingface_phenombeta_model_dir = "."
71
+ # huggingface_modelpath = "recursionpharma/test-pb-model"
72
+
73
+
74
+ @pytest.fixture
75
+ def huggingface_model():
76
+ # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
77
+ # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
78
+ huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
79
+ huggingface_model.eval()
80
+ return huggingface_model
81
+
82
+
83
+ @pytest.mark.parametrize("C", [1, 4, 6, 11])
84
+ @pytest.mark.parametrize("return_channelwise_embeddings", [True, False])
85
+ def test_model_predict(huggingface_model, C, return_channelwise_embeddings):
86
+ example_input_array = torch.randint(
87
+ low=0,
88
+ high=255,
89
+ size=(2, C, 256, 256),
90
+ dtype=torch.uint8,
91
+ device=huggingface_model.device,
92
+ )
93
+ huggingface_model.return_channelwise_embeddings = return_channelwise_embeddings
94
+ embeddings = huggingface_model.predict(example_input_array)
95
+ expected_output_dim = 384 * C if return_channelwise_embeddings else 384
96
+ assert embeddings.shape == (2, expected_output_dim)
97
+ ```
98
+
99
+
100
+ ## Training, evaluation and testing details
101
+
102
+ See paper linked above for details on model training and evaluation. Primary hyperparameters are included in the repo linked above.
103
+
104
+
105
+ ## Environmental Impact
106
+
107
+ - **Hardware Type:** Nvidia H100 Hopper nodes
108
+ - **Hours used:** 400
109
+ - **Cloud Provider:** private cloud
110
+ - **Carbon Emitted:** 138.24 kg co2 (roughly the equivalent of one car driving from Toronto to Montreal)
111
+
112
+ **BibTeX:**
113
+
114
+ ```TeX
115
+ @inproceedings{kraus2024masked,
116
+ title={Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology},
117
+ author={Kraus, Oren and Kenyon-Dean, Kian and Saberian, Saber and Fallah, Maryam and McLean, Peter and Leung, Jess and Sharma, Vasudev and Khan, Ayla and Balakrishnan, Jia and Celik, Safiye and others},
118
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
119
+ pages={11757--11768},
120
+ year={2024}
121
+ }
122
+ ```
123
+
124
+ ## Model Card Contact
125
+
126
+ - Kian Kenyon-Dean: [email protected]
127
+ - Oren Kraus: [email protected]
128
+ - Or, email: [email protected]
models/phenom_beta_huggingface/config.json → config.json RENAMED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "maes_microscopy_project"
7
+ version = "0.1.0"
8
+ authors = [
9
+ {name = "kian-kd", email = "[email protected]"},
10
+ {name = "Laksh47", email = "[email protected]"},
11
+ ]
12
+ requires-python = ">=3.10.4"
13
+
14
+ dependencies = [
15
+ "huggingface-hub",
16
+ "timm",
17
+ "torch>=2.3",
18
+ "torchmetrics",
19
+ "torchvision",
20
+ "tqdm",
21
+ "transformers",
22
+ "xformers",
23
+ "zarr",
24
+ "hydra-core",
25
+ "pytorch-lightning>=2.1",
26
+ "matplotlib",
27
+ "scikit-image",
28
+ "ipykernel",
29
+ "isort",
30
+ "ruff",
31
+ "pytest",
32
+ ]
33
+
34
+ [tool.setuptools]
35
+ py-modules = []
requirements.in DELETED
@@ -1,17 +0,0 @@
1
- huggingface-hub
2
- timm
3
- torch>=2.3
4
- torchmetrics
5
- torchvision
6
- tqdm
7
- transformers
8
- xformers
9
- zarr
10
- hydra-core
11
- pytorch-lightning>=2.1
12
- matplotlib
13
- scikit-image
14
- ipykernel
15
- isort
16
- ruff
17
- pytest
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,326 +0,0 @@
1
- #
2
- # This file is autogenerated by pip-compile with Python 3.10
3
- # by the following command:
4
- #
5
- # pip-compile --no-emit-index-url --output-file=requirements.txt requirements.in
6
- #
7
- --trusted-host pypi.ngc.nvidia.com
8
-
9
- aiohappyeyeballs==2.4.3
10
- # via aiohttp
11
- aiohttp==3.10.10
12
- # via fsspec
13
- aiosignal==1.3.1
14
- # via aiohttp
15
- antlr4-python3-runtime==4.9.3
16
- # via
17
- # hydra-core
18
- # omegaconf
19
- asciitree==0.3.3
20
- # via zarr
21
- asttokens==2.4.1
22
- # via stack-data
23
- async-timeout==4.0.3
24
- # via aiohttp
25
- attrs==24.2.0
26
- # via aiohttp
27
- certifi==2024.8.30
28
- # via requests
29
- charset-normalizer==3.4.0
30
- # via requests
31
- comm==0.2.2
32
- # via ipykernel
33
- contourpy==1.3.0
34
- # via matplotlib
35
- cycler==0.12.1
36
- # via matplotlib
37
- debugpy==1.8.7
38
- # via ipykernel
39
- decorator==5.1.1
40
- # via ipython
41
- exceptiongroup==1.2.2
42
- # via
43
- # ipython
44
- # pytest
45
- executing==2.1.0
46
- # via stack-data
47
- fasteners==0.19
48
- # via zarr
49
- filelock==3.16.1
50
- # via
51
- # huggingface-hub
52
- # torch
53
- # transformers
54
- # triton
55
- fonttools==4.54.1
56
- # via matplotlib
57
- frozenlist==1.5.0
58
- # via
59
- # aiohttp
60
- # aiosignal
61
- fsspec[http]==2024.10.0
62
- # via
63
- # huggingface-hub
64
- # pytorch-lightning
65
- # torch
66
- huggingface-hub==0.26.2
67
- # via
68
- # -r requirements.in
69
- # timm
70
- # tokenizers
71
- # transformers
72
- hydra-core==1.3.2
73
- # via -r requirements.in
74
- idna==3.10
75
- # via
76
- # requests
77
- # yarl
78
- imageio==2.36.0
79
- # via scikit-image
80
- iniconfig==2.0.0
81
- # via pytest
82
- ipykernel==6.29.5
83
- # via -r requirements.in
84
- ipython==8.29.0
85
- # via ipykernel
86
- isort==5.13.2
87
- # via -r requirements.in
88
- jedi==0.19.1
89
- # via ipython
90
- jinja2==3.1.4
91
- # via torch
92
- jupyter-client==8.6.3
93
- # via ipykernel
94
- jupyter-core==5.7.2
95
- # via
96
- # ipykernel
97
- # jupyter-client
98
- kiwisolver==1.4.7
99
- # via matplotlib
100
- lazy-loader==0.4
101
- # via scikit-image
102
- lightning-utilities==0.11.8
103
- # via
104
- # pytorch-lightning
105
- # torchmetrics
106
- markupsafe==3.0.2
107
- # via jinja2
108
- matplotlib==3.9.2
109
- # via -r requirements.in
110
- matplotlib-inline==0.1.7
111
- # via
112
- # ipykernel
113
- # ipython
114
- mpmath==1.3.0
115
- # via sympy
116
- multidict==6.1.0
117
- # via
118
- # aiohttp
119
- # yarl
120
- nest-asyncio==1.6.0
121
- # via ipykernel
122
- networkx==3.2.1
123
- # via
124
- # scikit-image
125
- # torch
126
- numcodecs==0.12.1
127
- # via zarr
128
- numpy==1.26.4
129
- # via
130
- # contourpy
131
- # imageio
132
- # matplotlib
133
- # numcodecs
134
- # scikit-image
135
- # scipy
136
- # tifffile
137
- # torchmetrics
138
- # torchvision
139
- # transformers
140
- # xformers
141
- # zarr
142
- nvidia-cublas-cu12==12.4.5.8
143
- # via
144
- # nvidia-cudnn-cu12
145
- # nvidia-cusolver-cu12
146
- # torch
147
- nvidia-cuda-cupti-cu12==12.4.127
148
- # via torch
149
- nvidia-cuda-nvrtc-cu12==12.4.127
150
- # via torch
151
- nvidia-cuda-runtime-cu12==12.4.127
152
- # via torch
153
- nvidia-cudnn-cu12==9.1.0.70
154
- # via torch
155
- nvidia-cufft-cu12==11.2.1.3
156
- # via torch
157
- nvidia-curand-cu12==10.3.5.147
158
- # via torch
159
- nvidia-cusolver-cu12==11.6.1.9
160
- # via torch
161
- nvidia-cusparse-cu12==12.3.1.170
162
- # via
163
- # nvidia-cusolver-cu12
164
- # torch
165
- nvidia-nccl-cu12==2.21.5
166
- # via torch
167
- nvidia-nvjitlink-cu12==12.4.127
168
- # via
169
- # nvidia-cusolver-cu12
170
- # nvidia-cusparse-cu12
171
- # torch
172
- nvidia-nvtx-cu12==12.4.127
173
- # via torch
174
- omegaconf==2.3.0
175
- # via hydra-core
176
- packaging==24.1
177
- # via
178
- # huggingface-hub
179
- # hydra-core
180
- # ipykernel
181
- # lazy-loader
182
- # lightning-utilities
183
- # matplotlib
184
- # pytest
185
- # pytorch-lightning
186
- # scikit-image
187
- # torchmetrics
188
- # transformers
189
- parso==0.8.4
190
- # via jedi
191
- pexpect==4.9.0
192
- # via ipython
193
- pillow==11.0.0
194
- # via
195
- # imageio
196
- # matplotlib
197
- # scikit-image
198
- # torchvision
199
- platformdirs==4.3.6
200
- # via jupyter-core
201
- pluggy==1.5.0
202
- # via pytest
203
- prompt-toolkit==3.0.48
204
- # via ipython
205
- propcache==0.2.0
206
- # via yarl
207
- psutil==6.1.0
208
- # via ipykernel
209
- ptyprocess==0.7.0
210
- # via pexpect
211
- pure-eval==0.2.3
212
- # via stack-data
213
- pygments==2.18.0
214
- # via ipython
215
- pyparsing==3.2.0
216
- # via matplotlib
217
- pytest==8.3.3
218
- # via -r requirements.in
219
- python-dateutil==2.9.0.post0
220
- # via
221
- # jupyter-client
222
- # matplotlib
223
- pytorch-lightning==2.4.0
224
- # via -r requirements.in
225
- pyyaml==6.0.2
226
- # via
227
- # huggingface-hub
228
- # omegaconf
229
- # pytorch-lightning
230
- # timm
231
- # transformers
232
- pyzmq==26.2.0
233
- # via
234
- # ipykernel
235
- # jupyter-client
236
- regex==2024.9.11
237
- # via transformers
238
- requests==2.32.3
239
- # via
240
- # huggingface-hub
241
- # transformers
242
- ruff==0.7.2
243
- # via -r requirements.in
244
- safetensors==0.4.5
245
- # via
246
- # timm
247
- # transformers
248
- scikit-image==0.24.0
249
- # via -r requirements.in
250
- scipy==1.13.1
251
- # via scikit-image
252
- six==1.16.0
253
- # via
254
- # asttokens
255
- # python-dateutil
256
- stack-data==0.6.3
257
- # via ipython
258
- sympy==1.13.1
259
- # via torch
260
- tifffile==2024.8.30
261
- # via scikit-image
262
- timm==1.0.11
263
- # via -r requirements.in
264
- tokenizers==0.20.2
265
- # via transformers
266
- tomli==2.0.2
267
- # via pytest
268
- torch==2.5.1
269
- # via
270
- # -r requirements.in
271
- # pytorch-lightning
272
- # timm
273
- # torchmetrics
274
- # torchvision
275
- # xformers
276
- torchmetrics==1.5.1
277
- # via
278
- # -r requirements.in
279
- # pytorch-lightning
280
- torchvision==0.20.1
281
- # via
282
- # -r requirements.in
283
- # timm
284
- tornado==6.4.1
285
- # via
286
- # ipykernel
287
- # jupyter-client
288
- tqdm==4.66.6
289
- # via
290
- # -r requirements.in
291
- # huggingface-hub
292
- # pytorch-lightning
293
- # transformers
294
- traitlets==5.14.3
295
- # via
296
- # comm
297
- # ipykernel
298
- # ipython
299
- # jupyter-client
300
- # jupyter-core
301
- # matplotlib-inline
302
- transformers==4.46.1
303
- # via -r requirements.in
304
- triton==3.1.0
305
- # via torch
306
- typing-extensions==4.12.2
307
- # via
308
- # huggingface-hub
309
- # ipython
310
- # lightning-utilities
311
- # multidict
312
- # pytorch-lightning
313
- # torch
314
- urllib3==2.2.3
315
- # via requests
316
- wcwidth==0.2.13
317
- # via prompt-toolkit
318
- xformers==0.0.28.post3
319
- # via -r requirements.in
320
- yarl==1.17.1
321
- # via aiohttp
322
- zarr==2.18.2
323
- # via -r requirements.in
324
-
325
- # The following packages are considered to be unsafe in a requirements file:
326
- # setuptools
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test_huggingface_mae.py CHANGED
@@ -3,14 +3,14 @@ import torch
3
 
4
  from huggingface_mae import MAEModel
5
 
6
- huggingface_phenombeta_model_dir = "models/phenom_beta_huggingface"
7
  # huggingface_modelpath = "recursionpharma/test-pb-model"
8
 
9
 
10
  @pytest.fixture
11
  def huggingface_model():
12
  # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
13
- # huggingface-cli download recursionpharma/test-pb-model --local-dir=models/phenom_beta_huggingface
14
  huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
15
  huggingface_model.eval()
16
  return huggingface_model
 
3
 
4
  from huggingface_mae import MAEModel
5
 
6
+ huggingface_phenombeta_model_dir = "."
7
  # huggingface_modelpath = "recursionpharma/test-pb-model"
8
 
9
 
10
  @pytest.fixture
11
  def huggingface_model():
12
  # Make sure you have the model/config downloaded from https://huggingface.co/recursionpharma/test-pb-model to this directory
13
+ # huggingface-cli download recursionpharma/test-pb-model --local-dir=.
14
  huggingface_model = MAEModel.from_pretrained(huggingface_phenombeta_model_dir)
15
  huggingface_model.eval()
16
  return huggingface_model