Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -11,30 +11,19 @@ Load any ESM2 models into a FastEsm model to dramatically speed up training and
|
|
11 |
Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
|
12 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
13 |
|
14 |
-
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
|
21 |
-
# Synthyra/ESM2-35M
|
22 |
-
'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
|
23 |
-
# Synthyra/ESM2-150M
|
24 |
-
'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
|
25 |
-
# Synthyra/ESM2-650M
|
26 |
-
'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
|
27 |
-
# Synthyra/ESM2-3B
|
28 |
-
'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
|
29 |
-
}
|
30 |
-
```
|
31 |
|
32 |
### For working with embeddings
|
33 |
```python
|
34 |
import torch
|
35 |
from transformers import AutoModel, AutoTokenizer
|
36 |
|
37 |
-
model_path = 'Synthyra/
|
38 |
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
|
39 |
tokenizer = model.tokenizer
|
40 |
|
@@ -70,6 +59,7 @@ with torch.no_grad():
|
|
70 |
print(attentions[-1].shape) # (2, 20, 11, 11)
|
71 |
```
|
72 |
|
|
|
73 |
## Embed entire datasets with no new code
|
74 |
To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
|
75 |
```python
|
@@ -97,6 +87,24 @@ _ = model.embed_dataset(
|
|
97 |
)
|
98 |
```
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
### Citation
|
102 |
If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
|
@@ -109,4 +117,4 @@ If you use any of this implementation or work please cite it (as well as the [ES
|
|
109 |
doi = { 10.57967/hf/3729 },
|
110 |
publisher = { Hugging Face }
|
111 |
}
|
112 |
-
```
|
|
|
11 |
Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
|
12 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
13 |
|
14 |
+
# FastESM2-650
|
15 |
|
16 |
+
## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
|
17 |
+
To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
|
18 |
+
|
19 |
+
## Use with 🤗 transformers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
### For working with embeddings
|
22 |
```python
|
23 |
import torch
|
24 |
from transformers import AutoModel, AutoTokenizer
|
25 |
|
26 |
+
model_path = 'Synthyra/FastESM2_650'
|
27 |
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
|
28 |
tokenizer = model.tokenizer
|
29 |
|
|
|
59 |
print(attentions[-1].shape) # (2, 20, 11, 11)
|
60 |
```
|
61 |
|
62 |
+
|
63 |
## Embed entire datasets with no new code
|
64 |
To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
|
65 |
```python
|
|
|
87 |
)
|
88 |
```
|
89 |
|
90 |
+
## Model probes
|
91 |
+
We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
|
92 |
+
|
93 |
+
The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
|
94 |
+
![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
|
95 |
+
|
96 |
+
## Comparison of half precisions
|
97 |
+
Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
|
98 |
+
|
99 |
+
When summing the MSE of 1000 sequences vs. the fp32 weights:
|
100 |
+
|
101 |
+
Average MSE for FP16: 0.00000140
|
102 |
+
|
103 |
+
Average MSE for BF16: 0.00004125
|
104 |
+
|
105 |
+
### Inference speed
|
106 |
+
We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
107 |
+
![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
|
108 |
|
109 |
### Citation
|
110 |
If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
|
|
|
117 |
doi = { 10.57967/hf/3729 },
|
118 |
publisher = { Hugging Face }
|
119 |
}
|
120 |
+
```
|