EdwardoSunny commited on
Commit
85ab89d
·
1 Parent(s): 470480a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.md +14 -0
  2. LICENSE_Lavis.md +14 -0
  3. README.md +110 -13
  4. app.py +125 -62
  5. assets/distribution.png +0 -0
  6. configs/evaluation.yaml +25 -0
  7. configs/minigpt4.yaml +35 -0
  8. configs/train_instruction_tuning.yaml +54 -0
  9. configs/train_modality_alignment.yaml +60 -0
  10. dataset.json +0 -0
  11. dataset/README.md +15 -0
  12. demo.sh +5 -0
  13. demo_esm.py +101 -0
  14. deprecate/inference.py +129 -0
  15. environment.yml +63 -0
  16. esm/__init__.py +12 -0
  17. esm/__pycache__/__init__.cpython-310.pyc +0 -0
  18. esm/__pycache__/axial_attention.cpython-310.pyc +0 -0
  19. esm/__pycache__/constants.cpython-310.pyc +0 -0
  20. esm/__pycache__/data.cpython-310.pyc +0 -0
  21. esm/__pycache__/modules.cpython-310.pyc +0 -0
  22. esm/__pycache__/multihead_attention.cpython-310.pyc +0 -0
  23. esm/__pycache__/pretrained.cpython-310.pyc +0 -0
  24. esm/__pycache__/rotary_embedding.cpython-310.pyc +0 -0
  25. esm/__pycache__/version.cpython-310.pyc +0 -0
  26. esm/axial_attention.py +239 -0
  27. esm/constants.py +10 -0
  28. esm/data.py +493 -0
  29. esm/esmfold/v1/__init__.py +0 -0
  30. esm/esmfold/v1/categorical_mixture.py +43 -0
  31. esm/esmfold/v1/esmfold.py +364 -0
  32. esm/esmfold/v1/misc.py +309 -0
  33. esm/esmfold/v1/pretrained.py +181 -0
  34. esm/esmfold/v1/tri_self_attn_block.py +160 -0
  35. esm/esmfold/v1/trunk.py +243 -0
  36. esm/inverse_folding/__init__.py +11 -0
  37. esm/inverse_folding/features.py +356 -0
  38. esm/inverse_folding/gvp_encoder.py +56 -0
  39. esm/inverse_folding/gvp_modules.py +475 -0
  40. esm/inverse_folding/gvp_transformer.py +144 -0
  41. esm/inverse_folding/gvp_transformer_encoder.py +189 -0
  42. esm/inverse_folding/gvp_utils.py +68 -0
  43. esm/inverse_folding/multichain_util.py +152 -0
  44. esm/inverse_folding/transformer_decoder.py +228 -0
  45. esm/inverse_folding/transformer_layer.py +304 -0
  46. esm/inverse_folding/util.py +323 -0
  47. esm/model/__init__.py +1 -0
  48. esm/model/__pycache__/__init__.cpython-310.pyc +0 -0
  49. esm/model/__pycache__/esm1.cpython-310.pyc +0 -0
  50. esm/model/__pycache__/esm2.cpython-310.pyc +0 -0
LICENSE.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright 2023 Deyao Zhu
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
LICENSE_Lavis.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,13 +1,110 @@
1
- ---
2
- title: ProteinGPT Llama3
3
- emoji: 💬
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.36.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ProteinChat: Towards Enabling ChatGPT-Like Capabilities on Protein 3D Structures
2
+
3
+ This repository holds the code and data of ProteinChat: Towards Enabling ChatGPT-Like Capabilities on Protein 3D Structures.
4
+
5
+ ## Technical report is available [here](https://www.techrxiv.org/articles/preprint/ProteinChat_Towards_Achieving_ChatGPT-Like_Functionalities_on_Protein_3D_Structures/23120606)
6
+
7
+ ## Examples
8
+
9
+ ![Eg1](fig/protein-eg.png)
10
+
11
+
12
+ ## Introduction
13
+ - In this work, we make an initial attempt towards enabling ChatGPT-like capabilities on protein 3D structures, by developing a prototype system ProteinChat.
14
+ - ProteinChat works in a similar way as ChatGPT. Users upload a protein 3D structure and ask various questions about this protein. ProteinChat will answer these questions in a multi-turn, interactive manner.
15
+ - The ProteinChat system consists of a protein 3D structure encoder (based on [ESM inverse folding](https://github.com/facebookresearch/esm/tree/main/examples/inverse_folding)), a large language model (LLM), and an adaptor. The protein encoder takes a protein 3D structure as input and learns a representation for this protein. The adaptor transforms the protein representation produced by the protein encoder into another representation that is acceptable to the LLM. The LLM takes the representation transformed by the adaptor and users' questions about this protein as inputs and generates answers. All these components are trained end-to-end.
16
+ - To train ProteinChat, we collected instruction tuning datasets which contain 143508 proteins and 143508 instructions.
17
+
18
+
19
+ ![overview](fig/proteinchat_overview.png)
20
+
21
+ ## Datasets
22
+
23
+ The dataset contains 143508 proteins (represented using 3D structures) with 143508 instructions.
24
+ The instruction set are available at [this link](https://drive.google.com/file/d/1iMgPyiIzpvXdKiNsXnRKn2YpmP92Xyub/view?usp=share_link).
25
+ The processed protein files (83G in total) are available at [this link](https://drive.google.com/file/d/1AeJW5BY5C-d8mKJjAULTax6WA4hzWS0N/view?usp=share_link).
26
+ The data is curated from the [Protein Data Bank](https://www.rcsb.org/). More details can be found [here](data/README.md).
27
+
28
+ ## Getting Started
29
+ ### Installation
30
+ These instructions largely follow those in MiniGPT-4.
31
+
32
+ **1. Prepare the code and the environment**
33
+
34
+ Git clone our repository, creating a python environment and ativate it via the following command
35
+
36
+ ```bash
37
+ git clone https://github.com/UCSD-AI4H/proteinchat
38
+ cd proteinchat
39
+ conda env create -f environment.yml
40
+ conda activate proteinchat
41
+ pip install einops
42
+ ```
43
+
44
+ Verify the installation of `torch` and `torchvision` is successful by running `python -c "import torchvision; print(torchvision.__version__)"`. If it outputs the version number without any warnings or errors, then you are good to go. __If it outputs any warnings or errors__, try to uninstall `torch` by `conda uninstall pytorch torchvision torchaudio cudatoolkit` and then reinstall them following [here](https://pytorch.org/get-started/previous-versions/#v1121). You need to find the correct command according to the CUDA version your GPU driver supports (check `nvidia-smi`).
45
+
46
+ **2. Prepare the pretrained Vicuna weights**
47
+
48
+ The current version of ProteinChat is built on the v0 versoin of Vicuna-13B.
49
+ Please refer to our instruction [here](PrepareVicuna.md)
50
+ to prepare the Vicuna weights.
51
+ The final weights would be in a single folder in a structure similar to the following:
52
+
53
+ ```
54
+ vicuna_weights
55
+ ├── config.json
56
+ ├── generation_config.json
57
+ ├── pytorch_model.bin.index.json
58
+ ├── pytorch_model-00001-of-00003.bin
59
+ ...
60
+ ```
61
+
62
+ Then, set the path to the vicuna weight in the model config file
63
+ [here](minigpt4/configs/models/minigpt4.yaml#L16) at Line 16.
64
+
65
+ ### Training
66
+ **You need roughly 45 GB GPU memory for the training.**
67
+
68
+ The training configuration file is [configs/train_instruction_tuning.yaml](configs/train_instruction_tuning.yaml). In addition, you may want to change the number of epochs and other hyper-parameters there, such as `max_epoch`, `init_lr`, `min_lr`,`warmup_steps`, `batch_size_train`. Please adjust `iters_per_epoch` so that `iters_per_epoch` * `batch_size_train` = your training set size. Due to the GPU consumption, we set `batch_size_train=1`.
69
+
70
+ Start training on LLaMA model with protein dataset by running [finetune.sh](finetune.sh) `bash finetune.sh`.
71
+
72
+ **It takes around 24 GB GPU memory for the demo.**
73
+
74
+ Find the checkpoint you save in the training process above, which is located under the folder `minigpt4/output/minigpt4_stage2_esm/` by default. Copy it to the folder `ckpt` by running `cp minigpt4/output/minigpt4_stage2_esm/.../checkpoint_xxx.pth`, and modify the `ckpt` entry in [configs/evaluation.yaml](configs/evaluation.yaml) to the location of your checkpoint.
75
+
76
+ Now we launch the `demo.py` in our original environment. Then, start the demo [demo.sh](demo.sh) on your local machine by running `bash demo.sh`. Then, open the URL created by the demo and try it out!
77
+
78
+
79
+ ## Acknowledgement
80
+
81
+ + [ProteinChat](https://github.com/UCSD-AI4H/proteinchat)
82
+ + [MiniGPT-4](https://minigpt-4.github.io/)
83
+ + [Lavis](https://github.com/salesforce/LAVIS)
84
+ + [Vicuna](https://github.com/lm-sys/FastChat)
85
+ + [ESM-IF1](https://github.com/facebookresearch/esm/tree/main/examples/inverse_folding)
86
+
87
+
88
+
89
+ ## License
90
+ This repository is under [BSD 3-Clause License](LICENSE.md).
91
+ Many codes are based on [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) with BSD 3-Clause License [here](LICENSE_MiniGPT4.md), which is based on [Lavis](https://github.com/salesforce/LAVIS) with
92
+ BSD 3-Clause License [here](LICENSE_Lavis.md).
93
+
94
+
95
+ ## Disclaimer
96
+
97
+ This is a prototype system that has not been systematically and comprehensively validated by biologists yet. Please use with caution.
98
+
99
+ Trained models and demo websites will be released after we thoroughly validate the system with biologists.
100
+
101
+
102
+ ## Citation
103
+
104
+ If you're using ProteinChat in your research or applications, please cite using this BibTeX:
105
+ ```bibtex
106
+ @article{guo2023proteinchat,
107
+ title={ProteinChat: Towards Enabling ChatGPT-Like Capabilities on Protein 3D Structures},
108
+ author={Guo, Han and Huo, Mingjia and Xie, Pengtao},
109
+ year={2023}
110
+ }
app.py CHANGED
@@ -1,63 +1,126 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import argparse
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+ from minigpt4.common.config import Config
9
+ from minigpt4.common.dist_utils import get_rank
10
+ from minigpt4.common.registry import registry
11
+ from minigpt4.conversation.conversation_esm import Chat, CONV_VISION
12
+ import esm
13
+
14
+ # ProteinGPT Initialization Function
15
+ def initialize_chat(args):
16
+ cfg = Config(args)
17
+ model_config = cfg.model_cfg
18
+ model_config.device_8bit = 0
19
+ model_cls = registry.get_model_class(model_config.arch)
20
+ model = model_cls.from_config(model_config).to('cpu')
21
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
22
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
23
+ chat = Chat(model, vis_processor, device='cpu')
24
+ return chat
25
+
26
+ # Gradio Reset Function
27
+ def gradio_reset(chat_state, img_list):
28
+ if chat_state is not None:
29
+ chat_state.messages = []
30
+ if img_list is not None:
31
+ img_list = []
32
+ return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your protein structure and sequence first', interactive=False), gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
33
+
34
+ # Upload Function
35
+ def upload_protein(structure, sequence, text_input, chat_state):
36
+ # Check if structure and sequence files are valid
37
+ if structure is None or not structure.endswith(".pt"):
38
+ return (None, None, None, gr.update(placeholder="Invalid structure file, must be a .pt file.", interactive=True), chat_state, None)
39
+ if sequence is None or not sequence.endswith(".pt"):
40
+ return (None, None, None, gr.update(placeholder="Invalid sequence file, must be a .pt file.", interactive=True), chat_state, None)
41
+
42
+ # Load protein structure and sequence
43
+ pdb_embedding = torch.load(structure, map_location=torch.device('cpu'))
44
+ sample_pdb = pdb_embedding.to('cpu')
45
+
46
+ seq_embedding = torch.load(sequence, map_location=torch.device('cpu'))
47
+ sample_seq = seq_embedding.to('cpu')
48
+
49
+ # Initialize the conversation state
50
+ chat_state = CONV_VISION.copy()
51
+ img_list = []
52
+
53
+ # Upload protein data
54
+ llm_message = chat.upload_protein(sample_pdb, sample_seq, chat_state, img_list)
55
+
56
+ # Return the required outputs
57
+ return (gr.update(interactive=False), # Disable structure file input
58
+ gr.update(interactive=False), # Disable sequence file input
59
+ gr.update(interactive=True, placeholder='Type and press Enter'), # Enable the text input box
60
+ gr.update(value="Start Chatting", interactive=False), # Update upload button state
61
+ chat_state, # Return the conversation state
62
+ img_list) # Return the list of images (if any)
63
+ # Ask Function
64
+ def gradio_ask(user_message, chatbot, chat_state):
65
+ if len(user_message) == 0:
66
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
67
+ chat.ask(user_message, chat_state)
68
+ chatbot = chatbot + [[user_message, None]]
69
+ return '', chatbot, chat_state
70
+
71
+ # Answer Function
72
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
73
+ img_list = [mat.half() for mat in img_list]
74
+ llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=300, num_beams=num_beams, temperature=temperature, max_length=2000)[0]
75
+ chatbot[-1][1] = llm_message
76
+ return chatbot, chat_state, img_list
77
+
78
+ # Command-line Argument Parsing
79
+ def parse_args():
80
+ parser = argparse.ArgumentParser(description="Demo")
81
+ parser.add_argument("--cfg-path", help="path to configuration file.", default='configs/evaluation.yaml')
82
+ parser.add_argument(
83
+ "--options",
84
+ nargs="+",
85
+ help="override some settings in the used config, the key-value pair "
86
+ "in xxx=yyy format will be merged into config file (deprecate), "
87
+ "change to --cfg-options instead.",
88
+ )
89
+ args = parser.parse_args()
90
+ return args
91
+
92
+ # Demo Gradio Interface
93
+ title = """<h1 align="center">Demo of ProteinGPT</h1>"""
94
+ description = """<h3>Upload your protein sequence and structure and start chatting with your protein!</h3>"""
95
+ article = """<div style='display:flex; gap: 0.25rem; '><a href='https://huggingface.co/AI-BIO/ProteinGPT-Llama3'><img src='https://img.shields.io/badge/Project-Page-Green'></a><a href='https://github.com'><img src='https://img.shields.io/badge/Github-Code-blue'></a><a href='https://arxiv.org/abs/2408.11363'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></div>"""
96
+
97
+ args = parse_args() # Parse arguments to get config and model info
98
+ chat = initialize_chat(args) # Initialize ProteinGPT model
99
+
100
+ with gr.Blocks() as demo:
101
+ gr.Markdown(title)
102
+ gr.Markdown(description)
103
+ gr.Markdown(article)
104
+
105
+ with gr.Row():
106
+ with gr.Column(scale=0.5):
107
+ structure = gr.File(type="filepath", label="Upload Protein Structure", show_label=True)
108
+ sequence = gr.File(type="filepath", label="Upload Protein Sequence", show_label=True)
109
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
110
+ clear = gr.Button("Restart")
111
+ num_beams = gr.Slider(minimum=1, maximum=5, value=1, step=1, interactive=True, label="Beam search numbers")
112
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, interactive=True, label="Temperature")
113
+
114
+ with gr.Column():
115
+ chat_state = gr.State()
116
+ img_list = gr.State()
117
+ chatbot = gr.Chatbot(label='ProteinGPT')
118
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
119
+
120
+ upload_button.click(upload_protein,
121
+ [structure, sequence, text_input, chat_state],
122
+ [structure, sequence, text_input, upload_button, chat_state, img_list])
123
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list])
124
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, structure, sequence, text_input, upload_button, chat_state, img_list], queue=False)
125
+
126
+ demo.launch(share=True)
assets/distribution.png ADDED
configs/evaluation.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+ model_type: pretrain_vicuna
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+ max_txt_len: 256
7
+ end_sym: "###"
8
+ low_resource: False
9
+ prompt_template: '###Human: {} ###Assistant: '
10
+ # ckpt: '/home/ubuntu/proteinchat/minigpt4/ft/Llama-2-7b-chat-hf/20240610191/checkpoint_5.pth'
11
+ ckpt: 'minigpt4/ft/Meta-Llama-3-8B-Instruct-hf/20240609203/checkpoint_5.pth'
12
+
13
+ datasets:
14
+ cc_sbu_align:
15
+ vis_processor:
16
+ train:
17
+ name: "blip2_image_eval"
18
+ image_size: 224
19
+ text_processor:
20
+ train:
21
+ name: "blip_caption"
22
+
23
+ run:
24
+ task: image_text_pretrain
25
+
configs/minigpt4.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ # llama_model: "/home/ubuntu/ckpt/hf/Meta-Llama-3-8B-Instruct-hf/"
17
+ llama_model: "meta-llama/Meta-Llama-3-8B-Instruct"
18
+
19
+
20
+ # generation configs
21
+ prompt: ""
22
+
23
+ preprocess:
24
+ vis_processor:
25
+ train:
26
+ name: "blip2_image_train"
27
+ image_size: 224
28
+ eval:
29
+ name: "blip2_image_eval"
30
+ image_size: 224
31
+ text_processor:
32
+ train:
33
+ name: "blip_caption"
34
+ eval:
35
+ name: "blip_caption"
configs/train_instruction_tuning.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+ model_type: pretrain_vicuna
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+ # low_resource: True
7
+ max_txt_len: 256
8
+ end_sym: "###"
9
+ prompt_template: '###Human: {} ###Assistant: '
10
+ # ckpt: '/home/ubuntu/proteinchat/minigpt4/output/Meta-Llama-3-8B-Instruct-hf/20240606190/checkpoint_2.pth'
11
+ ckpt: '/home/ubuntu/proteinchat/minigpt4/output/Llama-2-7b-chat-hf/20240606005/checkpoint_2.pth'
12
+
13
+
14
+ datasets:
15
+ cc_sbu_align:
16
+ vis_processor:
17
+ train:
18
+ name: "blip2_image_train"
19
+ image_size: 224
20
+ text_processor:
21
+ train:
22
+ name: "blip_caption"
23
+
24
+ run:
25
+ task: image_text_pretrain
26
+ # optimizer
27
+ lr_sched: "linear_warmup_cosine_lr"
28
+ init_lr: 1e-5
29
+ min_lr: 1e-6
30
+ warmup_lr: 1e-6
31
+
32
+ weight_decay: 0.05
33
+ max_epoch: 10
34
+ # iters_per_epoch: 762
35
+ batch_size_train: 1
36
+ batch_size_eval: 1
37
+ num_workers: 12
38
+ warmup_steps: 5000
39
+
40
+ seed: 42
41
+ # output_dir: "ft/Meta-Llama-3-8B-Instruct-hf/"
42
+ output_dir: "ft/Llama-2-7b-chat-hf/"
43
+
44
+ amp: True
45
+ resume_ckpt_path: null
46
+
47
+ evaluate: False
48
+ train_splits: ["train"]
49
+
50
+ device: "cuda"
51
+ world_size: 1
52
+ dist_url: "env://"
53
+ distributed: True
54
+ stage: 2
configs/train_modality_alignment.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+ model_type: pretrain_vicuna
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+ # low_resource: True
7
+ max_txt_len: 384
8
+
9
+
10
+ datasets:
11
+ laion:
12
+ vis_processor:
13
+ train:
14
+ name: "blip2_image_train"
15
+ image_size: 224
16
+ text_processor:
17
+ train:
18
+ name: "blip_caption"
19
+ sample_ratio: 115
20
+ cc_sbu:
21
+ vis_processor:
22
+ train:
23
+ name: "blip2_image_train"
24
+ image_size: 224
25
+ text_processor:
26
+ train:
27
+ name: "blip_caption"
28
+ sample_ratio: 14
29
+
30
+ run:
31
+ task: image_text_pretrain
32
+ # optimizer
33
+ lr_sched: "linear_warmup_cosine_lr"
34
+ init_lr: 1e-4
35
+ min_lr: 8e-5
36
+ warmup_lr: 1e-6
37
+
38
+ weight_decay: 0.05
39
+ max_epoch: 3
40
+ batch_size_train: 1
41
+ batch_size_eval: 1
42
+ num_workers: 12
43
+ warmup_steps: 5000
44
+
45
+ seed: 42
46
+ output_dir: "output/Meta-Llama-3-8B-Instruct-hf/"
47
+ # output_dir: "output/Llama-2-7b-chat-hf/"
48
+
49
+ amp: True
50
+ resume_ckpt_path: null
51
+
52
+ evaluate: False
53
+ train_splits: ["train"]
54
+
55
+ device: "cuda"
56
+ world_size: 1
57
+ dist_url: "env://"
58
+ distributed: True
59
+
60
+ stage: 1
dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
dataset/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Datast
2
+
3
+ ## Modality Alignment
4
+ - ProteinChat: PDB embedding and abstract information.
5
+
6
+ ## Instruction Tuning
7
+ - GPT-4o.
8
+
9
+
10
+ ## Citation
11
+ @article{guo2023proteinchat,
12
+ title={ProteinChat: Towards Enabling ChatGPT-Like Capabilities on Protein 3D Structures},
13
+ author={Guo, Han and Huo, Mingjia and Xie, Pengtao},
14
+ year={2023}
15
+ }
demo.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # PDB_ID=7rvu
2
+ # PDB_ID=5x1y
3
+ PDB_ID=6o7q
4
+
5
+ python demo_esm.py --cfg-path configs/evaluation.yaml --gpu-id 0 --pdb /home/ubuntu/pt/$PDB_ID.pt --seq /home/ubuntu/seq/$PDB_ID.pt
demo_esm.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.backends.cudnn as cudnn
8
+
9
+ from minigpt4.common.config import Config
10
+ from minigpt4.common.dist_utils import get_rank
11
+ from minigpt4.common.registry import registry
12
+ from minigpt4.conversation.conversation_esm import Chat, CONV_VISION
13
+
14
+ # imports modules for registration
15
+ from minigpt4.datasets.builders import *
16
+ from minigpt4.models import *
17
+ from minigpt4.processors import *
18
+ from minigpt4.runners import *
19
+ from minigpt4.tasks import *
20
+ import sys
21
+
22
+ import esm
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Demo")
27
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
28
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
29
+ parser.add_argument("--pdb", help="specifiy where the protein file is (.pt)")
30
+ parser.add_argument("--seq", help="specifiy where the sequence file is (.pt)")
31
+ parser.add_argument(
32
+ "--options",
33
+ nargs="+",
34
+ help="override some settings in the used config, the key-value pair "
35
+ "in xxx=yyy format will be merged into config file (deprecate), "
36
+ "change to --cfg-options instead.",
37
+ )
38
+ args = parser.parse_args()
39
+ return args
40
+
41
+
42
+ def setup_seeds(config):
43
+ seed = config.run_cfg.seed + get_rank()
44
+
45
+ random.seed(seed)
46
+ np.random.seed(seed)
47
+ torch.manual_seed(seed)
48
+
49
+ cudnn.benchmark = False
50
+ cudnn.deterministic = True
51
+
52
+
53
+ # ========================================
54
+ # Model Initialization
55
+ # ========================================
56
+
57
+ print('Initializing Chat')
58
+ args = parse_args()
59
+ cfg = Config(args)
60
+
61
+ model_config = cfg.model_cfg
62
+ model_config.device_8bit = args.gpu_id
63
+ model_cls = registry.get_model_class(model_config.arch)
64
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
65
+
66
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
67
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
68
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
69
+ print('Initialization Finished')
70
+
71
+ chat_state = CONV_VISION.copy()
72
+ img_list = []
73
+
74
+ pdb_path = args.pdb
75
+ seq_path = args.seq
76
+ if pdb_path[-3:] == ".pt":
77
+ pdb_embedding = torch.load(pdb_path, map_location=torch.device('cpu'))
78
+ sample_pdb = pdb_embedding.to('cuda:{}'.format(args.gpu_id))
79
+ if seq_path[-3:] == ".pt":
80
+ seq_embedding = torch.load(seq_path, map_location=torch.device('cpu'))
81
+ sample_seq = seq_embedding.to('cuda:{}'.format(args.gpu_id))
82
+
83
+ llm_message = chat.upload_protein(sample_pdb, sample_seq, chat_state, img_list)
84
+ print(llm_message)
85
+
86
+ img_list = [mat.half() for mat in img_list]
87
+ while True:
88
+ user_input = input(">")
89
+ if (len(user_input) == 0):
90
+ print("USER INPUT CANNOT BE EMPTY!")
91
+ continue
92
+ elif (user_input.lower() == "exit()"):
93
+ break
94
+ chat.ask(user_input, chat_state)
95
+ llm_message = chat.answer(conv=chat_state,
96
+ img_list=img_list,
97
+ num_beams=1,
98
+ temperature=0.7,
99
+ max_new_tokens=300,
100
+ max_length=2000)[0]
101
+ print("B: ", llm_message)
deprecate/inference.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import time
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.backends.cudnn as cudnn
9
+ import gradio as gr
10
+
11
+ import esm
12
+ from minigpt4.common.config import Config
13
+ from minigpt4.common.dist_utils import get_rank
14
+ from minigpt4.common.registry import registry
15
+ from minigpt4.conversation.conversation_esm import Chat, CONV_VISION
16
+
17
+ import json
18
+
19
+ # Imports PIL module
20
+ from PIL import Image
21
+
22
+ # imports modules for registration
23
+ from minigpt4.datasets.builders import *
24
+ from minigpt4.models import *
25
+ from minigpt4.processors import *
26
+ from minigpt4.runners import *
27
+ from minigpt4.tasks import *
28
+
29
+ import esm
30
+ import esm.inverse_folding
31
+
32
+ def parse_args():
33
+ parser = argparse.ArgumentParser(description="Demo")
34
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
35
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
36
+ # parser.add_argument("--json-path", default='/home/h5guo/shared/Mini-GPT4/coco_json/cocoval2014_img_prompt.json', help="path to the classification json file")
37
+ # parser.add_argument("--caption-save-path", default='/home/h5guo/shared/Mini-GPT4/coco_json_result/results.json', help="path to saved generated captions")
38
+ parser.add_argument(
39
+ "--options",
40
+ nargs="+",
41
+ help="override some settings in the used config, the key-value pair "
42
+ "in xxx=yyy format will be merged into config file (deprecate), "
43
+ "change to --cfg-options instead.",
44
+ )
45
+ args = parser.parse_args()
46
+ return args
47
+
48
+
49
+ def setup_seeds(config):
50
+ seed = config.run_cfg.seed + get_rank()
51
+
52
+ random.seed(seed)
53
+ np.random.seed(seed)
54
+ torch.manual_seed(seed)
55
+
56
+ cudnn.benchmark = False
57
+ cudnn.deterministic = True
58
+
59
+
60
+ # ========================================
61
+ # Model Initialization
62
+ # ========================================
63
+
64
+ print('Initializing Chat')
65
+ args = parse_args()
66
+ cfg = Config(args)
67
+
68
+ model_config = cfg.model_cfg
69
+ model_config.device_8bit = args.gpu_id
70
+ model_cls = registry.get_model_class(model_config.arch)
71
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
72
+
73
+ vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
74
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
75
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
76
+ print('Initialization Finished')
77
+
78
+ # ========================================
79
+ # Gradio Setting
80
+ # ========================================
81
+
82
+ def gradio_reset(chat_state, img_list):
83
+ if chat_state is not None:
84
+ chat_state.messages = []
85
+ if img_list is not None:
86
+ img_list = []
87
+ return chat_state, img_list
88
+
89
+ def upload_protein(gr_img):
90
+ chat_state = CONV_VISION.copy()
91
+ img_list = []
92
+ llm_message = chat.upload_protein(gr_img, chat_state, img_list)
93
+ return chat_state, img_list
94
+
95
+ def gradio_ask(user_message, chat_state):
96
+ chat.ask(user_message, chat_state)
97
+ return chat_state
98
+
99
+
100
+ def gradio_answer(chat_state, img_list, num_beams=1, temperature=1e-3):
101
+ llm_message = chat.answer(conv=chat_state,
102
+ img_list=img_list,
103
+ num_beams=num_beams,
104
+ temperature=temperature,
105
+ max_new_tokens=300,
106
+ max_length=2000)[0]
107
+ return llm_message, chat_state, img_list
108
+
109
+ if __name__ == "__main__":
110
+ start = time.time()
111
+ print("******************")
112
+ protein_embedding_path = "/home/h5guo/data/esm_subset/pt/2wge.pt"
113
+ protein_embedding = torch.load(protein_embedding_path, map_location=torch.device('cpu'))
114
+ sample_protein = protein_embedding.to('cuda:{}'.format(args.gpu_id))
115
+
116
+ user_message = "Describe this protein in a short paragraph."
117
+ chat_state, img_list = upload_protein(sample_protein)
118
+ chat_state = gradio_ask(user_message, chat_state)
119
+ llm_message, chat_state, img_list = gradio_answer(chat_state, img_list)
120
+
121
+ print(f"llm_message: {llm_message}")
122
+ end = time.time()
123
+ print(end - start)
124
+ # i += 1
125
+ print("******************")
126
+ # f.close()
127
+
128
+
129
+
environment.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: proteinchat
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ - anaconda
6
+ dependencies:
7
+ - python=3.9
8
+ - cudatoolkit
9
+ - pip
10
+ - pytorch=1.12.1
11
+ - pytorch-mutex=1.0=cuda
12
+ - torchaudio=0.12.1
13
+ - torchvision=0.13.1
14
+ - pip:
15
+ - accelerate==0.16.0
16
+ - aiohttp==3.8.4
17
+ - aiosignal==1.3.1
18
+ - async-timeout==4.0.2
19
+ - attrs==22.2.0
20
+ - bitsandbytes==0.37.0
21
+ - cchardet==2.1.7
22
+ - chardet==5.1.0
23
+ - contourpy==1.0.7
24
+ - cycler==0.11.0
25
+ - filelock==3.9.0
26
+ - fonttools==4.38.0
27
+ - frozenlist==1.3.3
28
+ - huggingface-hub==0.13.4
29
+ - importlib-resources==5.12.0
30
+ - kiwisolver==1.4.4
31
+ - matplotlib==3.7.0
32
+ - multidict==6.0.4
33
+ - openai==0.27.0
34
+ - packaging==23.0
35
+ - psutil==5.9.4
36
+ - pycocotools==2.0.6
37
+ - pyparsing==3.0.9
38
+ - python-dateutil==2.8.2
39
+ - pyyaml==6.0
40
+ - regex==2022.10.31
41
+ - tokenizers==0.13.2
42
+ - tqdm==4.64.1
43
+ - transformers==4.28.0
44
+ - timm==0.6.13
45
+ - spacy==3.5.1
46
+ - webdataset==0.2.48
47
+ - scikit-learn==1.2.2
48
+ - scipy==1.10.1
49
+ - yarl==1.8.2
50
+ - zipp==3.14.0
51
+ - omegaconf==2.3.0
52
+ - opencv-python==4.7.0.72
53
+ - iopath==0.1.10
54
+ - decord==0.6.0
55
+ - tenacity==8.2.2
56
+ - peft
57
+ - pycocoevalcap
58
+ - sentence-transformers
59
+ - umap-learn
60
+ - notebook
61
+ - gradio==3.24.1
62
+ - gradio-client==0.0.8
63
+ - wandb
esm/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .version import version as __version__ # noqa
7
+
8
+ from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa
9
+ from .model.esm1 import ProteinBertModel # noqa
10
+ from .model.esm2 import ESM2 # noqa
11
+ from .model.msa_transformer import MSATransformer #noqa
12
+ from . import pretrained # noqa
esm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (490 Bytes). View file
 
esm/__pycache__/axial_attention.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
esm/__pycache__/constants.cpython-310.pyc ADDED
Binary file (288 Bytes). View file
 
esm/__pycache__/data.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
esm/__pycache__/modules.cpython-310.pyc ADDED
Binary file (13 kB). View file
 
esm/__pycache__/multihead_attention.cpython-310.pyc ADDED
Binary file (12 kB). View file
 
esm/__pycache__/pretrained.cpython-310.pyc ADDED
Binary file (19.5 kB). View file
 
esm/__pycache__/rotary_embedding.cpython-310.pyc ADDED
Binary file (2.73 kB). View file
 
esm/__pycache__/version.cpython-310.pyc ADDED
Binary file (186 Bytes). View file
 
esm/axial_attention.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class RowSelfAttention(nn.Module):
12
+ """Compute self-attention over rows of a 2D input."""
13
+
14
+ def __init__(
15
+ self,
16
+ embed_dim,
17
+ num_heads,
18
+ dropout=0.0,
19
+ max_tokens_per_msa: int = 2 ** 16,
20
+ ):
21
+ super().__init__()
22
+ self.num_heads = num_heads
23
+ self.dropout = dropout
24
+ self.head_dim = embed_dim // num_heads
25
+ self.scaling = self.head_dim ** -0.5
26
+ self.max_tokens_per_msa = max_tokens_per_msa
27
+ self.attn_shape = "hnij"
28
+
29
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
30
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
31
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
32
+
33
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
34
+ self.dropout_module = nn.Dropout(dropout)
35
+
36
+ def align_scaling(self, q):
37
+ num_rows = q.size(0)
38
+ return self.scaling / math.sqrt(num_rows)
39
+
40
+ def _batched_forward(
41
+ self,
42
+ x,
43
+ self_attn_mask=None,
44
+ self_attn_padding_mask=None,
45
+ ):
46
+ num_rows, num_cols, batch_size, embed_dim = x.size()
47
+ max_rows = max(1, self.max_tokens_per_msa // num_cols)
48
+ attns = 0
49
+ scaling = self.align_scaling(x)
50
+ for start in range(0, num_rows, max_rows):
51
+ attn_weights = self.compute_attention_weights(
52
+ x[start : start + max_rows],
53
+ scaling,
54
+ self_attn_mask=self_attn_mask,
55
+ self_attn_padding_mask=self_attn_padding_mask[:, start : start + max_rows]
56
+ if self_attn_padding_mask is not None
57
+ else None,
58
+ )
59
+ attns += attn_weights
60
+ attn_probs = attns.softmax(-1)
61
+ attn_probs = self.dropout_module(attn_probs)
62
+
63
+ outputs = []
64
+ for start in range(0, num_rows, max_rows):
65
+ output = self.compute_attention_update(x[start : start + max_rows], attn_probs)
66
+ outputs.append(output)
67
+
68
+ output = torch.cat(outputs, 0)
69
+ return output, attn_probs
70
+
71
+ def compute_attention_weights(
72
+ self,
73
+ x,
74
+ scaling: float,
75
+ self_attn_mask=None,
76
+ self_attn_padding_mask=None,
77
+ ):
78
+ num_rows, num_cols, batch_size, embed_dim = x.size()
79
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
80
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
81
+ q *= scaling
82
+ if self_attn_padding_mask is not None:
83
+ # Zero out any padded aligned positions - this is important since
84
+ # we take a sum across the alignment axis.
85
+ q *= 1 - self_attn_padding_mask.permute(1, 2, 0).unsqueeze(3).unsqueeze(4).to(q)
86
+
87
+ attn_weights = torch.einsum(f"rinhd,rjnhd->{self.attn_shape}", q, k)
88
+
89
+ if self_attn_mask is not None:
90
+ raise NotImplementedError
91
+ # Mask Size: [B x R x C], Weights Size: [H x B x C x C]
92
+
93
+ if self_attn_padding_mask is not None:
94
+ attn_weights = attn_weights.masked_fill(
95
+ self_attn_padding_mask[:, 0].unsqueeze(0).unsqueeze(2),
96
+ -10000,
97
+ )
98
+
99
+ return attn_weights
100
+
101
+ def compute_attention_update(
102
+ self,
103
+ x,
104
+ attn_probs,
105
+ ):
106
+ num_rows, num_cols, batch_size, embed_dim = x.size()
107
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
108
+ context = torch.einsum(f"{self.attn_shape},rjnhd->rinhd", attn_probs, v)
109
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
110
+ output = self.out_proj(context)
111
+ return output
112
+
113
+ def forward(
114
+ self,
115
+ x,
116
+ self_attn_mask=None,
117
+ self_attn_padding_mask=None,
118
+ ):
119
+ num_rows, num_cols, batch_size, embed_dim = x.size()
120
+ if (num_rows * num_cols > self.max_tokens_per_msa) and not torch.is_grad_enabled():
121
+ return self._batched_forward(x, self_attn_mask, self_attn_padding_mask)
122
+ else:
123
+ scaling = self.align_scaling(x)
124
+ attn_weights = self.compute_attention_weights(
125
+ x, scaling, self_attn_mask, self_attn_padding_mask
126
+ )
127
+ attn_probs = attn_weights.softmax(-1)
128
+ attn_probs = self.dropout_module(attn_probs)
129
+ output = self.compute_attention_update(x, attn_probs)
130
+ return output, attn_probs
131
+
132
+
133
+ class ColumnSelfAttention(nn.Module):
134
+ """Compute self-attention over columns of a 2D input."""
135
+
136
+ def __init__(
137
+ self,
138
+ embed_dim,
139
+ num_heads,
140
+ dropout=0.0,
141
+ max_tokens_per_msa: int = 2 ** 16,
142
+ ):
143
+ super().__init__()
144
+
145
+ self.num_heads = num_heads
146
+ self.dropout = dropout
147
+ self.head_dim = embed_dim // num_heads
148
+ self.scaling = self.head_dim ** -0.5
149
+ self.max_tokens_per_msa = max_tokens_per_msa
150
+
151
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
152
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
153
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
154
+
155
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
156
+ self.dropout_module = nn.Dropout(dropout)
157
+
158
+ def _batched_forward(
159
+ self,
160
+ x,
161
+ self_attn_mask=None,
162
+ self_attn_padding_mask=None,
163
+ ):
164
+ num_rows, num_cols, batch_size, embed_dim = x.size()
165
+ max_cols = max(1, self.max_tokens_per_msa // num_rows)
166
+ outputs = []
167
+ attns = []
168
+ for start in range(0, num_cols, max_cols):
169
+ output, attn = self(
170
+ x[:, start : start + max_cols],
171
+ self_attn_mask=self_attn_mask,
172
+ self_attn_padding_mask=self_attn_padding_mask[:, :, start : start + max_cols]
173
+ if self_attn_padding_mask is not None
174
+ else None,
175
+ )
176
+ outputs.append(output)
177
+ attns.append(attn)
178
+ output = torch.cat(outputs, 1)
179
+ attns = torch.cat(attns, 1)
180
+ return output, attns
181
+
182
+ def compute_attention_update(
183
+ self,
184
+ x,
185
+ self_attn_mask=None,
186
+ self_attn_padding_mask=None,
187
+ ):
188
+ num_rows, num_cols, batch_size, embed_dim = x.size()
189
+ if num_rows == 1:
190
+ # if there is only 1 position, this is equivalent and doesn't break with padding
191
+ attn_probs = torch.ones(
192
+ self.num_heads,
193
+ num_cols,
194
+ batch_size,
195
+ num_rows,
196
+ num_rows,
197
+ device=x.device,
198
+ dtype=x.dtype,
199
+ )
200
+ output = self.out_proj(self.v_proj(x))
201
+ else:
202
+ q = self.q_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
203
+ k = self.k_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
204
+ v = self.v_proj(x).view(num_rows, num_cols, batch_size, self.num_heads, self.head_dim)
205
+ q *= self.scaling
206
+
207
+ attn_weights = torch.einsum("icnhd,jcnhd->hcnij", q, k)
208
+
209
+ if self_attn_mask is not None:
210
+ raise NotImplementedError
211
+ if self_attn_padding_mask is not None:
212
+ attn_weights = attn_weights.masked_fill(
213
+ self_attn_padding_mask.permute(2, 0, 1).unsqueeze(0).unsqueeze(3),
214
+ -10000,
215
+ )
216
+
217
+ attn_probs = attn_weights.softmax(-1)
218
+ attn_probs = self.dropout_module(attn_probs)
219
+ context = torch.einsum("hcnij,jcnhd->icnhd", attn_probs, v)
220
+ context = context.contiguous().view(num_rows, num_cols, batch_size, embed_dim)
221
+ output = self.out_proj(context)
222
+ return output, attn_probs
223
+
224
+ def forward(
225
+ self,
226
+ x,
227
+ self_attn_mask=None,
228
+ self_attn_padding_mask=None,
229
+ ):
230
+ num_rows, num_cols, batch_size, embed_dim = x.size()
231
+ # if False and num_rows * num_cols > 2 ** 14 and not torch.is_grad_enabled():
232
+ if (num_rows * num_cols) > self.max_tokens_per_msa and not torch.is_grad_enabled():
233
+ return self._batched_forward(
234
+ x,
235
+ self_attn_mask,
236
+ self_attn_padding_mask,
237
+ )
238
+ else:
239
+ return self.compute_attention_update(x, self_attn_mask, self_attn_padding_mask)
esm/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # fmt: off
7
+ proteinseq_toks = {
8
+ 'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-']
9
+ }
10
+ # fmt: on
esm/data.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import os
8
+ from typing import Sequence, Tuple, List, Union
9
+ import pickle
10
+ import re
11
+ import shutil
12
+ import torch
13
+ from pathlib import Path
14
+ from esm.constants import proteinseq_toks
15
+
16
+ RawMSA = Sequence[Tuple[str, str]]
17
+
18
+
19
+ class FastaBatchedDataset(object):
20
+ def __init__(self, sequence_labels, sequence_strs):
21
+ self.sequence_labels = list(sequence_labels)
22
+ self.sequence_strs = list(sequence_strs)
23
+
24
+ @classmethod
25
+ def from_file(cls, fasta_file):
26
+ sequence_labels, sequence_strs = [], []
27
+ cur_seq_label = None
28
+ buf = []
29
+
30
+ def _flush_current_seq():
31
+ nonlocal cur_seq_label, buf
32
+ if cur_seq_label is None:
33
+ return
34
+ sequence_labels.append(cur_seq_label)
35
+ sequence_strs.append("".join(buf))
36
+ cur_seq_label = None
37
+ buf = []
38
+
39
+ with open(fasta_file, "r") as infile:
40
+ for line_idx, line in enumerate(infile):
41
+ if line.startswith(">"): # label line
42
+ _flush_current_seq()
43
+ line = line[1:].strip()
44
+ if len(line) > 0:
45
+ cur_seq_label = line
46
+ else:
47
+ cur_seq_label = f"seqnum{line_idx:09d}"
48
+ else: # sequence line
49
+ buf.append(line.strip())
50
+
51
+ _flush_current_seq()
52
+
53
+ assert len(set(sequence_labels)) == len(
54
+ sequence_labels
55
+ ), "Found duplicate sequence labels"
56
+
57
+ return cls(sequence_labels, sequence_strs)
58
+
59
+ def __len__(self):
60
+ return len(self.sequence_labels)
61
+
62
+ def __getitem__(self, idx):
63
+ return self.sequence_labels[idx], self.sequence_strs[idx]
64
+
65
+ def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0):
66
+ sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)]
67
+ sizes.sort()
68
+ batches = []
69
+ buf = []
70
+ max_len = 0
71
+
72
+ def _flush_current_buf():
73
+ nonlocal max_len, buf
74
+ if len(buf) == 0:
75
+ return
76
+ batches.append(buf)
77
+ buf = []
78
+ max_len = 0
79
+
80
+ for sz, i in sizes:
81
+ sz += extra_toks_per_seq
82
+ if max(sz, max_len) * (len(buf) + 1) > toks_per_batch:
83
+ _flush_current_buf()
84
+ max_len = max(max_len, sz)
85
+ buf.append(i)
86
+
87
+ _flush_current_buf()
88
+ return batches
89
+
90
+
91
+ class Alphabet(object):
92
+ def __init__(
93
+ self,
94
+ standard_toks: Sequence[str],
95
+ prepend_toks: Sequence[str] = ("<null_0>", "<pad>", "<eos>", "<unk>"),
96
+ append_toks: Sequence[str] = ("<cls>", "<mask>", "<sep>"),
97
+ prepend_bos: bool = True,
98
+ append_eos: bool = False,
99
+ use_msa: bool = False,
100
+ ):
101
+ self.standard_toks = list(standard_toks)
102
+ self.prepend_toks = list(prepend_toks)
103
+ self.append_toks = list(append_toks)
104
+ self.prepend_bos = prepend_bos
105
+ self.append_eos = append_eos
106
+ self.use_msa = use_msa
107
+
108
+ self.all_toks = list(self.prepend_toks)
109
+ self.all_toks.extend(self.standard_toks)
110
+ for i in range((8 - (len(self.all_toks) % 8)) % 8):
111
+ self.all_toks.append(f"<null_{i + 1}>")
112
+ self.all_toks.extend(self.append_toks)
113
+
114
+ self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)}
115
+
116
+ self.unk_idx = self.tok_to_idx["<unk>"]
117
+ self.padding_idx = self.get_idx("<pad>")
118
+ self.cls_idx = self.get_idx("<cls>")
119
+ self.mask_idx = self.get_idx("<mask>")
120
+ self.eos_idx = self.get_idx("<eos>")
121
+ self.all_special_tokens = ['<eos>', '<unk>', '<pad>', '<cls>', '<mask>']
122
+ self.unique_no_split_tokens = self.all_toks
123
+
124
+ def __len__(self):
125
+ return len(self.all_toks)
126
+
127
+ def get_idx(self, tok):
128
+ return self.tok_to_idx.get(tok, self.unk_idx)
129
+
130
+ def get_tok(self, ind):
131
+ return self.all_toks[ind]
132
+
133
+ def to_dict(self):
134
+ return self.tok_to_idx.copy()
135
+
136
+ def get_batch_converter(self, truncation_seq_length: int = None):
137
+ if self.use_msa:
138
+ return MSABatchConverter(self, truncation_seq_length)
139
+ else:
140
+ return BatchConverter(self, truncation_seq_length)
141
+
142
+ @classmethod
143
+ def from_architecture(cls, name: str) -> "Alphabet":
144
+ if name in ("ESM-1", "protein_bert_base"):
145
+ standard_toks = proteinseq_toks["toks"]
146
+ prepend_toks: Tuple[str, ...] = ("<null_0>", "<pad>", "<eos>", "<unk>")
147
+ append_toks: Tuple[str, ...] = ("<cls>", "<mask>", "<sep>")
148
+ prepend_bos = True
149
+ append_eos = False
150
+ use_msa = False
151
+ elif name in ("ESM-1b", "roberta_large"):
152
+ standard_toks = proteinseq_toks["toks"]
153
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
154
+ append_toks = ("<mask>",)
155
+ prepend_bos = True
156
+ append_eos = True
157
+ use_msa = False
158
+ elif name in ("MSA Transformer", "msa_transformer"):
159
+ standard_toks = proteinseq_toks["toks"]
160
+ prepend_toks = ("<cls>", "<pad>", "<eos>", "<unk>")
161
+ append_toks = ("<mask>",)
162
+ prepend_bos = True
163
+ append_eos = False
164
+ use_msa = True
165
+ elif "invariant_gvp" in name.lower():
166
+ standard_toks = proteinseq_toks["toks"]
167
+ prepend_toks = ("<null_0>", "<pad>", "<eos>", "<unk>")
168
+ append_toks = ("<mask>", "<cath>", "<af2>")
169
+ prepend_bos = True
170
+ append_eos = False
171
+ use_msa = False
172
+ else:
173
+ raise ValueError("Unknown architecture selected")
174
+ return cls(standard_toks, prepend_toks, append_toks, prepend_bos, append_eos, use_msa)
175
+
176
+ def _tokenize(self, text) -> str:
177
+ return text.split()
178
+
179
+ def tokenize(self, text, **kwargs) -> List[str]:
180
+ """
181
+ Inspired by https://github.com/huggingface/transformers/blob/master/src/transformers/tokenization_utils.py
182
+ Converts a string in a sequence of tokens, using the tokenizer.
183
+
184
+ Args:
185
+ text (:obj:`str`):
186
+ The sequence to be encoded.
187
+
188
+ Returns:
189
+ :obj:`List[str]`: The list of tokens.
190
+ """
191
+
192
+ def split_on_token(tok, text):
193
+ result = []
194
+ split_text = text.split(tok)
195
+ for i, sub_text in enumerate(split_text):
196
+ # AddedToken can control whitespace stripping around them.
197
+ # We use them for GPT2 and Roberta to have different behavior depending on the special token
198
+ # Cf. https://github.com/huggingface/transformers/pull/2778
199
+ # and https://github.com/huggingface/transformers/issues/3788
200
+ # We strip left and right by default
201
+ if i < len(split_text) - 1:
202
+ sub_text = sub_text.rstrip()
203
+ if i > 0:
204
+ sub_text = sub_text.lstrip()
205
+
206
+ if i == 0 and not sub_text:
207
+ result.append(tok)
208
+ elif i == len(split_text) - 1:
209
+ if sub_text:
210
+ result.append(sub_text)
211
+ else:
212
+ pass
213
+ else:
214
+ if sub_text:
215
+ result.append(sub_text)
216
+ result.append(tok)
217
+ return result
218
+
219
+ def split_on_tokens(tok_list, text):
220
+ if not text.strip():
221
+ return []
222
+
223
+ tokenized_text = []
224
+ text_list = [text]
225
+ for tok in tok_list:
226
+ tokenized_text = []
227
+ for sub_text in text_list:
228
+ if sub_text not in self.unique_no_split_tokens:
229
+ tokenized_text.extend(split_on_token(tok, sub_text))
230
+ else:
231
+ tokenized_text.append(sub_text)
232
+ text_list = tokenized_text
233
+
234
+ return list(
235
+ itertools.chain.from_iterable(
236
+ (
237
+ self._tokenize(token)
238
+ if token not in self.unique_no_split_tokens
239
+ else [token]
240
+ for token in tokenized_text
241
+ )
242
+ )
243
+ )
244
+
245
+ no_split_token = self.unique_no_split_tokens
246
+ tokenized_text = split_on_tokens(no_split_token, text)
247
+ return tokenized_text
248
+
249
+ def encode(self, text):
250
+ return [self.tok_to_idx[tok] for tok in self.tokenize(text)]
251
+
252
+
253
+ class BatchConverter(object):
254
+ """Callable to convert an unprocessed (labels + strings) batch to a
255
+ processed (labels + tensor) batch.
256
+ """
257
+
258
+ def __init__(self, alphabet, truncation_seq_length: int = None):
259
+ self.alphabet = alphabet
260
+ self.truncation_seq_length = truncation_seq_length
261
+
262
+ def __call__(self, raw_batch: Sequence[Tuple[str, str]]):
263
+ # RoBERTa uses an eos token, while ESM-1 does not.
264
+ batch_size = len(raw_batch)
265
+ batch_labels, seq_str_list = zip(*raw_batch)
266
+ seq_encoded_list = [self.alphabet.encode(seq_str) for seq_str in seq_str_list]
267
+ if self.truncation_seq_length:
268
+ seq_encoded_list = [seq_str[:self.truncation_seq_length] for seq_str in seq_encoded_list]
269
+ max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
270
+ tokens = torch.empty(
271
+ (
272
+ batch_size,
273
+ max_len + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
274
+ ),
275
+ dtype=torch.int64,
276
+ )
277
+ tokens.fill_(self.alphabet.padding_idx)
278
+ labels = []
279
+ strs = []
280
+
281
+ for i, (label, seq_str, seq_encoded) in enumerate(
282
+ zip(batch_labels, seq_str_list, seq_encoded_list)
283
+ ):
284
+ labels.append(label)
285
+ strs.append(seq_str)
286
+ if self.alphabet.prepend_bos:
287
+ tokens[i, 0] = self.alphabet.cls_idx
288
+ seq = torch.tensor(seq_encoded, dtype=torch.int64)
289
+ tokens[
290
+ i,
291
+ int(self.alphabet.prepend_bos) : len(seq_encoded)
292
+ + int(self.alphabet.prepend_bos),
293
+ ] = seq
294
+ if self.alphabet.append_eos:
295
+ tokens[i, len(seq_encoded) + int(self.alphabet.prepend_bos)] = self.alphabet.eos_idx
296
+
297
+ return labels, strs, tokens
298
+
299
+
300
+ class MSABatchConverter(BatchConverter):
301
+ def __call__(self, inputs: Union[Sequence[RawMSA], RawMSA]):
302
+ if isinstance(inputs[0][0], str):
303
+ # Input is a single MSA
304
+ raw_batch: Sequence[RawMSA] = [inputs] # type: ignore
305
+ else:
306
+ raw_batch = inputs # type: ignore
307
+
308
+ batch_size = len(raw_batch)
309
+ max_alignments = max(len(msa) for msa in raw_batch)
310
+ max_seqlen = max(len(msa[0][1]) for msa in raw_batch)
311
+
312
+ tokens = torch.empty(
313
+ (
314
+ batch_size,
315
+ max_alignments,
316
+ max_seqlen + int(self.alphabet.prepend_bos) + int(self.alphabet.append_eos),
317
+ ),
318
+ dtype=torch.int64,
319
+ )
320
+ tokens.fill_(self.alphabet.padding_idx)
321
+ labels = []
322
+ strs = []
323
+
324
+ for i, msa in enumerate(raw_batch):
325
+ msa_seqlens = set(len(seq) for _, seq in msa)
326
+ if not len(msa_seqlens) == 1:
327
+ raise RuntimeError(
328
+ "Received unaligned sequences for input to MSA, all sequence "
329
+ "lengths must be equal."
330
+ )
331
+ msa_labels, msa_strs, msa_tokens = super().__call__(msa)
332
+ labels.append(msa_labels)
333
+ strs.append(msa_strs)
334
+ tokens[i, : msa_tokens.size(0), : msa_tokens.size(1)] = msa_tokens
335
+
336
+ return labels, strs, tokens
337
+
338
+
339
+ def read_fasta(
340
+ path,
341
+ keep_gaps=True,
342
+ keep_insertions=True,
343
+ to_upper=False,
344
+ ):
345
+ with open(path, "r") as f:
346
+ for result in read_alignment_lines(
347
+ f, keep_gaps=keep_gaps, keep_insertions=keep_insertions, to_upper=to_upper
348
+ ):
349
+ yield result
350
+
351
+
352
+ def read_alignment_lines(
353
+ lines,
354
+ keep_gaps=True,
355
+ keep_insertions=True,
356
+ to_upper=False,
357
+ ):
358
+ seq = desc = None
359
+
360
+ def parse(s):
361
+ if not keep_gaps:
362
+ s = re.sub("-", "", s)
363
+ if not keep_insertions:
364
+ s = re.sub("[a-z]", "", s)
365
+ return s.upper() if to_upper else s
366
+
367
+ for line in lines:
368
+ # Line may be empty if seq % file_line_width == 0
369
+ if len(line) > 0 and line[0] == ">":
370
+ if seq is not None:
371
+ yield desc, parse(seq)
372
+ desc = line.strip().lstrip(">")
373
+ seq = ""
374
+ else:
375
+ assert isinstance(seq, str)
376
+ seq += line.strip()
377
+ assert isinstance(seq, str) and isinstance(desc, str)
378
+ yield desc, parse(seq)
379
+
380
+
381
+ class ESMStructuralSplitDataset(torch.utils.data.Dataset):
382
+ """
383
+ Structural Split Dataset as described in section A.10 of the supplement of our paper.
384
+ https://doi.org/10.1101/622803
385
+
386
+ We use the full version of SCOPe 2.07, clustered at 90% sequence identity,
387
+ generated on January 23, 2020.
388
+
389
+ For each SCOPe domain:
390
+ - We extract the sequence from the corresponding PDB file
391
+ - We extract the 3D coordinates of the Carbon beta atoms, aligning them
392
+ to the sequence. We put NaN where Cb atoms are missing.
393
+ - From the 3D coordinates, we calculate a pairwise distance map, based
394
+ on L2 distance
395
+ - We use DSSP to generate secondary structure labels for the corresponding
396
+ PDB file. This is also aligned to the sequence. We put - where SSP
397
+ labels are missing.
398
+
399
+ For each SCOPe classification level of family/superfamily/fold (in order of difficulty),
400
+ we have split the data into 5 partitions for cross validation. These are provided
401
+ in a downloaded splits folder, in the format:
402
+ splits/{split_level}/{cv_partition}/{train|valid}.txt
403
+ where train is the partition and valid is the concatentation of the remaining 4.
404
+
405
+ For each SCOPe domain, we provide a pkl dump that contains:
406
+ - seq : The domain sequence, stored as an L-length string
407
+ - ssp : The secondary structure labels, stored as an L-length string
408
+ - dist : The distance map, stored as an LxL numpy array
409
+ - coords : The 3D coordinates, stored as an Lx3 numpy array
410
+
411
+ """
412
+
413
+ base_folder = "structural-data"
414
+ file_list = [
415
+ # url tar filename filename MD5 Hash
416
+ (
417
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/splits.tar.gz",
418
+ "splits.tar.gz",
419
+ "splits",
420
+ "456fe1c7f22c9d3d8dfe9735da52411d",
421
+ ),
422
+ (
423
+ "https://dl.fbaipublicfiles.com/fair-esm/structural-data/pkl.tar.gz",
424
+ "pkl.tar.gz",
425
+ "pkl",
426
+ "644ea91e56066c750cd50101d390f5db",
427
+ ),
428
+ ]
429
+
430
+ def __init__(
431
+ self,
432
+ split_level,
433
+ cv_partition,
434
+ split,
435
+ root_path=os.path.expanduser("~/.cache/torch/data/esm"),
436
+ download=False,
437
+ ):
438
+ super().__init__()
439
+ assert split in [
440
+ "train",
441
+ "valid",
442
+ ], "train_valid must be 'train' or 'valid'"
443
+ self.root_path = root_path
444
+ self.base_path = os.path.join(self.root_path, self.base_folder)
445
+
446
+ # check if root path has what you need or else download it
447
+ if download:
448
+ self.download()
449
+
450
+ self.split_file = os.path.join(
451
+ self.base_path, "splits", split_level, cv_partition, f"{split}.txt"
452
+ )
453
+ self.pkl_dir = os.path.join(self.base_path, "pkl")
454
+ self.names = []
455
+ with open(self.split_file) as f:
456
+ self.names = f.read().splitlines()
457
+
458
+ def __len__(self):
459
+ return len(self.names)
460
+
461
+ def _check_exists(self) -> bool:
462
+ for (_, _, filename, _) in self.file_list:
463
+ fpath = os.path.join(self.base_path, filename)
464
+ if not os.path.exists(fpath) or not os.path.isdir(fpath):
465
+ return False
466
+ return True
467
+
468
+ def download(self):
469
+
470
+ if self._check_exists():
471
+ print("Files already downloaded and verified")
472
+ return
473
+
474
+ from torchvision.datasets.utils import download_url
475
+
476
+ for url, tar_filename, filename, md5_hash in self.file_list:
477
+ download_path = os.path.join(self.base_path, tar_filename)
478
+ download_url(url=url, root=self.base_path, filename=tar_filename, md5=md5_hash)
479
+ shutil.unpack_archive(download_path, self.base_path)
480
+
481
+ def __getitem__(self, idx):
482
+ """
483
+ Returns a dict with the following entires
484
+ - seq : Str (domain sequence)
485
+ - ssp : Str (SSP labels)
486
+ - dist : np.array (distance map)
487
+ - coords : np.array (3D coordinates)
488
+ """
489
+ name = self.names[idx]
490
+ pkl_fname = os.path.join(self.pkl_dir, name[1:3], f"{name}.pkl")
491
+ with open(pkl_fname, "rb") as f:
492
+ obj = pickle.load(f)
493
+ return obj
esm/esmfold/v1/__init__.py ADDED
File without changes
esm/esmfold/v1/categorical_mixture.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+
7
+
8
+ class CategoricalMixture:
9
+ def __init__(self, param, bins=50, start=0, end=1):
10
+ # All tensors are of shape ..., bins.
11
+ self.logits = param
12
+ bins = torch.linspace(
13
+ start, end, bins + 1, device=self.logits.device, dtype=self.logits.dtype
14
+ )
15
+ self.v_bins = (bins[:-1] + bins[1:]) / 2
16
+
17
+ def log_prob(self, true):
18
+ # Shapes are:
19
+ # self.probs: ... x bins
20
+ # true : ...
21
+ true_index = (
22
+ (
23
+ true.unsqueeze(-1)
24
+ - self.v_bins[
25
+ [
26
+ None,
27
+ ]
28
+ * true.ndim
29
+ ]
30
+ )
31
+ .abs()
32
+ .argmin(-1)
33
+ )
34
+ nll = self.logits.log_softmax(-1)
35
+ return torch.take_along_dim(nll, true_index.unsqueeze(-1), dim=-1).squeeze(-1)
36
+
37
+ def mean(self):
38
+ return (self.logits.softmax(-1) @ self.v_bins.unsqueeze(1)).squeeze(-1)
39
+
40
+
41
+ def categorical_lddt(logits, bins=50):
42
+ # Logits are ..., 37, bins.
43
+ return CategoricalMixture(logits, bins=bins).mean()
esm/esmfold/v1/esmfold.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torch import nn
12
+ from torch.nn import LayerNorm
13
+
14
+ import esm
15
+ from esm import Alphabet
16
+ from esm.esmfold.v1.categorical_mixture import categorical_lddt
17
+ from esm.esmfold.v1.misc import (
18
+ batch_encode_sequences,
19
+ collate_dense_tensors,
20
+ output_to_pdb,
21
+ )
22
+ from esm.esmfold.v1.trunk import FoldingTrunk, FoldingTrunkConfig
23
+ from openfold.data.data_transforms import make_atom14_masks
24
+ from openfold.np import residue_constants
25
+ from openfold.utils.loss import compute_predicted_aligned_error, compute_tm
26
+
27
+
28
+ @dataclass
29
+ class ESMFoldConfig:
30
+ trunk: T.Any = FoldingTrunkConfig()
31
+ lddt_head_hid_dim: int = 128
32
+
33
+
34
+ load_fn = esm.pretrained.load_model_and_alphabet
35
+ esm_registry = {
36
+ "esm2_8M": partial(load_fn, "esm2_t6_8M_UR50D_500K"),
37
+ "esm2_8M_270K": esm.pretrained.esm2_t6_8M_UR50D,
38
+ "esm2_35M": partial(load_fn, "esm2_t12_35M_UR50D_500K"),
39
+ "esm2_35M_270K": esm.pretrained.esm2_t12_35M_UR50D,
40
+ "esm2_150M": partial(load_fn, "esm2_t30_150M_UR50D_500K"),
41
+ "esm2_150M_270K": partial(load_fn, "esm2_t30_150M_UR50D_270K"),
42
+ "esm2_650M": esm.pretrained.esm2_t33_650M_UR50D,
43
+ "esm2_650M_270K": partial(load_fn, "esm2_t33_650M_270K_UR50D"),
44
+ "esm2_3B": esm.pretrained.esm2_t36_3B_UR50D,
45
+ "esm2_3B_270K": partial(load_fn, "esm2_t36_3B_UR50D_500K"),
46
+ "esm2_15B": esm.pretrained.esm2_t48_15B_UR50D,
47
+ }
48
+
49
+
50
+ class ESMFold(nn.Module):
51
+ def __init__(self, esmfold_config=None, **kwargs):
52
+ super().__init__()
53
+
54
+ self.cfg = esmfold_config if esmfold_config else ESMFoldConfig(**kwargs)
55
+ cfg = self.cfg
56
+
57
+ self.distogram_bins = 64
58
+
59
+ self.esm, self.esm_dict = esm_registry.get(cfg.esm_type)()
60
+
61
+ self.esm.requires_grad_(False)
62
+ self.esm.half()
63
+
64
+ self.esm_feats = self.esm.embed_dim
65
+ self.esm_attns = self.esm.num_layers * self.esm.attention_heads
66
+ self.register_buffer("af2_to_esm", ESMFold._af2_to_esm(self.esm_dict))
67
+ self.esm_s_combine = nn.Parameter(torch.zeros(self.esm.num_layers + 1))
68
+
69
+ c_s = cfg.trunk.sequence_state_dim
70
+ c_z = cfg.trunk.pairwise_state_dim
71
+
72
+ self.esm_s_mlp = nn.Sequential(
73
+ LayerNorm(self.esm_feats),
74
+ nn.Linear(self.esm_feats, c_s),
75
+ nn.ReLU(),
76
+ nn.Linear(c_s, c_s),
77
+ )
78
+ if cfg.use_esm_attn_map:
79
+ self.esm_z_mlp = nn.Sequential(
80
+ LayerNorm(self.esm_attns),
81
+ nn.Linear(self.esm_attns, c_z),
82
+ nn.ReLU(),
83
+ nn.Linear(c_z, c_z),
84
+ )
85
+
86
+ # 0 is padding, N is unknown residues, N + 1 is mask.
87
+ self.n_tokens_embed = residue_constants.restype_num + 3
88
+ self.pad_idx = 0
89
+ self.unk_idx = self.n_tokens_embed - 2
90
+ self.mask_idx = self.n_tokens_embed - 1
91
+ self.embedding = nn.Embedding(self.n_tokens_embed, c_s, padding_idx=0)
92
+
93
+ self.trunk = FoldingTrunk(**cfg.trunk)
94
+
95
+ self.distogram_head = nn.Linear(c_z, self.distogram_bins)
96
+ self.ptm_head = nn.Linear(c_z, self.distogram_bins)
97
+ self.lm_head = nn.Linear(c_s, self.n_tokens_embed)
98
+ self.lddt_bins = 50
99
+ self.lddt_head = nn.Sequential(
100
+ nn.LayerNorm(cfg.trunk.structure_module.c_s),
101
+ nn.Linear(cfg.trunk.structure_module.c_s, cfg.lddt_head_hid_dim),
102
+ nn.Linear(cfg.lddt_head_hid_dim, cfg.lddt_head_hid_dim),
103
+ nn.Linear(cfg.lddt_head_hid_dim, 37 * self.lddt_bins),
104
+ )
105
+
106
+ @staticmethod
107
+ def _af2_to_esm(d: Alphabet):
108
+ # Remember that t is shifted from residue_constants by 1 (0 is padding).
109
+ esm_reorder = [d.padding_idx] + [
110
+ d.get_idx(v) for v in residue_constants.restypes_with_x
111
+ ]
112
+ return torch.tensor(esm_reorder)
113
+
114
+ def _af2_idx_to_esm_idx(self, aa, mask):
115
+ aa = (aa + 1).masked_fill(mask != 1, 0)
116
+ return self.af2_to_esm[aa]
117
+
118
+ def _compute_language_model_representations(
119
+ self, esmaa: torch.Tensor
120
+ ) -> torch.Tensor:
121
+ """Adds bos/eos tokens for the language model, since the structure module doesn't use these."""
122
+ batch_size = esmaa.size(0)
123
+
124
+ bosi, eosi = self.esm_dict.cls_idx, self.esm_dict.eos_idx
125
+ bos = esmaa.new_full((batch_size, 1), bosi)
126
+ eos = esmaa.new_full((batch_size, 1), self.esm_dict.padding_idx)
127
+ esmaa = torch.cat([bos, esmaa, eos], dim=1)
128
+ # Use the first padding index as eos during inference.
129
+ esmaa[range(batch_size), (esmaa != 1).sum(1)] = eosi
130
+
131
+ res = self.esm(
132
+ esmaa,
133
+ repr_layers=range(self.esm.num_layers + 1),
134
+ need_head_weights=self.cfg.use_esm_attn_map,
135
+ )
136
+ esm_s = torch.stack(
137
+ [v for _, v in sorted(res["representations"].items())], dim=2
138
+ )
139
+ esm_s = esm_s[:, 1:-1] # B, L, nLayers, C
140
+ esm_z = (
141
+ res["attentions"].permute(0, 4, 3, 1, 2).flatten(3, 4)[:, 1:-1, 1:-1, :]
142
+ if self.cfg.use_esm_attn_map
143
+ else None
144
+ )
145
+ return esm_s, esm_z
146
+
147
+ def _mask_inputs_to_esm(self, esmaa, pattern):
148
+ new_esmaa = esmaa.clone()
149
+ new_esmaa[pattern == 1] = self.esm_dict.mask_idx
150
+ return new_esmaa
151
+
152
+ def forward(
153
+ self,
154
+ aa: torch.Tensor,
155
+ mask: T.Optional[torch.Tensor] = None,
156
+ residx: T.Optional[torch.Tensor] = None,
157
+ masking_pattern: T.Optional[torch.Tensor] = None,
158
+ num_recycles: T.Optional[int] = None,
159
+ ):
160
+ """Runs a forward pass given input tokens. Use `model.infer` to
161
+ run inference from a sequence.
162
+
163
+ Args:
164
+ aa (torch.Tensor): Tensor containing indices corresponding to amino acids. Indices match
165
+ openfold.np.residue_constants.restype_order_with_x.
166
+ mask (torch.Tensor): Binary tensor with 1 meaning position is unmasked and 0 meaning position is masked.
167
+ residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
168
+ masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
169
+ as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
170
+ different masks are provided.
171
+ num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
172
+ recycles, which is 3.
173
+ """
174
+
175
+ if mask is None:
176
+ mask = torch.ones_like(aa)
177
+
178
+ B = aa.shape[0]
179
+ L = aa.shape[1]
180
+ device = aa.device
181
+
182
+ if residx is None:
183
+ residx = torch.arange(L, device=device).expand_as(aa)
184
+
185
+ # === ESM ===
186
+ esmaa = self._af2_idx_to_esm_idx(aa, mask)
187
+
188
+ if masking_pattern is not None:
189
+ esmaa = self._mask_inputs_to_esm(esmaa, masking_pattern)
190
+
191
+ esm_s, esm_z = self._compute_language_model_representations(esmaa)
192
+
193
+ # Convert esm_s to the precision used by the trunk and
194
+ # the structure module. These tensors may be a lower precision if, for example,
195
+ # we're running the language model in fp16 precision.
196
+ esm_s = esm_s.to(self.esm_s_combine.dtype)
197
+ esm_s = esm_s.detach()
198
+
199
+ # === preprocessing ===
200
+ esm_s = (self.esm_s_combine.softmax(0).unsqueeze(0) @ esm_s).squeeze(2)
201
+
202
+ s_s_0 = self.esm_s_mlp(esm_s)
203
+ if self.cfg.use_esm_attn_map:
204
+ esm_z = esm_z.to(self.esm_s_combine.dtype)
205
+ esm_z = esm_z.detach()
206
+ s_z_0 = self.esm_z_mlp(esm_z)
207
+ else:
208
+ s_z_0 = s_s_0.new_zeros(B, L, L, self.cfg.trunk.pairwise_state_dim)
209
+
210
+ s_s_0 += self.embedding(aa)
211
+
212
+ structure: dict = self.trunk(
213
+ s_s_0, s_z_0, aa, residx, mask, no_recycles=num_recycles
214
+ )
215
+ # Documenting what we expect:
216
+ structure = {
217
+ k: v
218
+ for k, v in structure.items()
219
+ if k
220
+ in [
221
+ "s_z",
222
+ "s_s",
223
+ "frames",
224
+ "sidechain_frames",
225
+ "unnormalized_angles",
226
+ "angles",
227
+ "positions",
228
+ "states",
229
+ ]
230
+ }
231
+
232
+ disto_logits = self.distogram_head(structure["s_z"])
233
+ disto_logits = (disto_logits + disto_logits.transpose(1, 2)) / 2
234
+ structure["distogram_logits"] = disto_logits
235
+
236
+ lm_logits = self.lm_head(structure["s_s"])
237
+ structure["lm_logits"] = lm_logits
238
+
239
+ structure["aatype"] = aa
240
+ make_atom14_masks(structure)
241
+
242
+ for k in [
243
+ "atom14_atom_exists",
244
+ "atom37_atom_exists",
245
+ ]:
246
+ structure[k] *= mask.unsqueeze(-1)
247
+ structure["residue_index"] = residx
248
+
249
+ lddt_head = self.lddt_head(structure["states"]).reshape(
250
+ structure["states"].shape[0], B, L, -1, self.lddt_bins
251
+ )
252
+ structure["lddt_head"] = lddt_head
253
+ plddt = categorical_lddt(lddt_head[-1], bins=self.lddt_bins)
254
+ structure["plddt"] = (
255
+ 100 * plddt
256
+ ) # we predict plDDT between 0 and 1, scale to be between 0 and 100.
257
+
258
+ ptm_logits = self.ptm_head(structure["s_z"])
259
+
260
+ seqlen = mask.type(torch.int64).sum(1)
261
+ structure["ptm_logits"] = ptm_logits
262
+ structure["ptm"] = torch.stack(
263
+ [
264
+ compute_tm(
265
+ batch_ptm_logits[None, :sl, :sl],
266
+ max_bins=31,
267
+ no_bins=self.distogram_bins,
268
+ )
269
+ for batch_ptm_logits, sl in zip(ptm_logits, seqlen)
270
+ ]
271
+ )
272
+ structure.update(
273
+ compute_predicted_aligned_error(
274
+ ptm_logits, max_bin=31, no_bins=self.distogram_bins
275
+ )
276
+ )
277
+
278
+ return structure
279
+
280
+ @torch.no_grad()
281
+ def infer(
282
+ self,
283
+ sequences: T.Union[str, T.List[str]],
284
+ residx=None,
285
+ masking_pattern: T.Optional[torch.Tensor] = None,
286
+ num_recycles: T.Optional[int] = None,
287
+ residue_index_offset: T.Optional[int] = 512,
288
+ chain_linker: T.Optional[str] = "G" * 25,
289
+ ):
290
+ """Runs a forward pass given input sequences.
291
+
292
+ Args:
293
+ sequences (Union[str, List[str]]): A list of sequences to make predictions for. Multimers can also be passed in,
294
+ each chain should be separated by a ':' token (e.g. "<chain1>:<chain2>:<chain3>").
295
+ residx (torch.Tensor): Residue indices of amino acids. Will assume contiguous if not provided.
296
+ masking_pattern (torch.Tensor): Optional masking to pass to the input. Binary tensor of the same size
297
+ as `aa`. Positions with 1 will be masked. ESMFold sometimes produces different samples when
298
+ different masks are provided.
299
+ num_recycles (int): How many recycle iterations to perform. If None, defaults to training max
300
+ recycles (cfg.trunk.max_recycles), which is 4.
301
+ residue_index_offset (int): Residue index separation between chains if predicting a multimer. Has no effect on
302
+ single chain predictions. Default: 512.
303
+ chain_linker (str): Linker to use between chains if predicting a multimer. Has no effect on single chain
304
+ predictions. Default: length-25 poly-G ("G" * 25).
305
+ """
306
+ if isinstance(sequences, str):
307
+ sequences = [sequences]
308
+
309
+ aatype, mask, _residx, linker_mask, chain_index = batch_encode_sequences(
310
+ sequences, residue_index_offset, chain_linker
311
+ )
312
+
313
+ if residx is None:
314
+ residx = _residx
315
+ elif not isinstance(residx, torch.Tensor):
316
+ residx = collate_dense_tensors(residx)
317
+
318
+ aatype, mask, residx, linker_mask = map(
319
+ lambda x: x.to(self.device), (aatype, mask, residx, linker_mask)
320
+ )
321
+
322
+ output = self.forward(
323
+ aatype,
324
+ mask=mask,
325
+ residx=residx,
326
+ masking_pattern=masking_pattern,
327
+ num_recycles=num_recycles,
328
+ )
329
+
330
+ output["atom37_atom_exists"] = output[
331
+ "atom37_atom_exists"
332
+ ] * linker_mask.unsqueeze(2)
333
+
334
+ output["mean_plddt"] = (output["plddt"] * output["atom37_atom_exists"]).sum(
335
+ dim=(1, 2)
336
+ ) / output["atom37_atom_exists"].sum(dim=(1, 2))
337
+ output["chain_index"] = chain_index
338
+
339
+ return output
340
+
341
+ def output_to_pdb(self, output: T.Dict) -> T.List[str]:
342
+ """Returns the pbd (file) string from the model given the model output."""
343
+ return output_to_pdb(output)
344
+
345
+ def infer_pdbs(self, seqs: T.List[str], *args, **kwargs) -> T.List[str]:
346
+ """Returns list of pdb (files) strings from the model given a list of input sequences."""
347
+ output = self.infer(seqs, *args, **kwargs)
348
+ return self.output_to_pdb(output)
349
+
350
+ def infer_pdb(self, sequence: str, *args, **kwargs) -> str:
351
+ """Returns the pdb (file) string from the model given an input sequence."""
352
+ return self.infer_pdbs([sequence], *args, **kwargs)[0]
353
+
354
+ def set_chunk_size(self, chunk_size: T.Optional[int]):
355
+ # This parameter means the axial attention will be computed
356
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
357
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
358
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
359
+ # Setting the value to None will return to default behavior, disable chunking.
360
+ self.trunk.set_chunk_size(chunk_size)
361
+
362
+ @property
363
+ def device(self):
364
+ return self.esm_s_combine.device
esm/esmfold/v1/misc.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+ from openfold.np import residue_constants
13
+ from openfold.np.protein import Protein as OFProtein
14
+ from openfold.np.protein import to_pdb
15
+ from openfold.utils.feats import atom14_to_atom37
16
+
17
+
18
+ def encode_sequence(
19
+ seq: str,
20
+ residue_index_offset: T.Optional[int] = 512,
21
+ chain_linker: T.Optional[str] = "G" * 25,
22
+ ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
23
+ if chain_linker is None:
24
+ chain_linker = ""
25
+ if residue_index_offset is None:
26
+ residue_index_offset = 0
27
+
28
+ chains = seq.split(":")
29
+ seq = chain_linker.join(chains)
30
+
31
+ unk_idx = residue_constants.restype_order_with_x["X"]
32
+ encoded = torch.tensor(
33
+ [residue_constants.restype_order_with_x.get(aa, unk_idx) for aa in seq]
34
+ )
35
+ residx = torch.arange(len(encoded))
36
+
37
+ if residue_index_offset > 0:
38
+ start = 0
39
+ for i, chain in enumerate(chains):
40
+ residx[start : start + len(chain) + len(chain_linker)] += (
41
+ i * residue_index_offset
42
+ )
43
+ start += len(chain) + len(chain_linker)
44
+
45
+ linker_mask = torch.ones_like(encoded, dtype=torch.float32)
46
+ chain_index = []
47
+ offset = 0
48
+ for i, chain in enumerate(chains):
49
+ if i > 0:
50
+ chain_index.extend([i - 1] * len(chain_linker))
51
+ chain_index.extend([i] * len(chain))
52
+ offset += len(chain)
53
+ linker_mask[offset : offset + len(chain_linker)] = 0
54
+ offset += len(chain_linker)
55
+
56
+ chain_index = torch.tensor(chain_index, dtype=torch.int64)
57
+
58
+ return encoded, residx, linker_mask, chain_index
59
+
60
+
61
+ def batch_encode_sequences(
62
+ sequences: T.Sequence[str],
63
+ residue_index_offset: T.Optional[int] = 512,
64
+ chain_linker: T.Optional[str] = "G" * 25,
65
+ ) -> T.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
66
+
67
+ aatype_list = []
68
+ residx_list = []
69
+ linker_mask_list = []
70
+ chain_index_list = []
71
+ for seq in sequences:
72
+ aatype_seq, residx_seq, linker_mask_seq, chain_index_seq = encode_sequence(
73
+ seq,
74
+ residue_index_offset=residue_index_offset,
75
+ chain_linker=chain_linker,
76
+ )
77
+ aatype_list.append(aatype_seq)
78
+ residx_list.append(residx_seq)
79
+ linker_mask_list.append(linker_mask_seq)
80
+ chain_index_list.append(chain_index_seq)
81
+
82
+ aatype = collate_dense_tensors(aatype_list)
83
+ mask = collate_dense_tensors(
84
+ [aatype.new_ones(len(aatype_seq)) for aatype_seq in aatype_list]
85
+ )
86
+ residx = collate_dense_tensors(residx_list)
87
+ linker_mask = collate_dense_tensors(linker_mask_list)
88
+ chain_index_list = collate_dense_tensors(chain_index_list, -1)
89
+
90
+ return aatype, mask, residx, linker_mask, chain_index_list
91
+
92
+
93
+ def output_to_pdb(output: T.Dict) -> T.List[str]:
94
+ """Returns the pbd (file) string from the model given the model output."""
95
+ # atom14_to_atom37 must be called first, as it fails on latest numpy if the
96
+ # input is a numpy array. It will work if the input is a torch tensor.
97
+ final_atom_positions = atom14_to_atom37(output["positions"][-1], output)
98
+ output = {k: v.to("cpu").numpy() for k, v in output.items()}
99
+ final_atom_positions = final_atom_positions.cpu().numpy()
100
+ final_atom_mask = output["atom37_atom_exists"]
101
+ pdbs = []
102
+ for i in range(output["aatype"].shape[0]):
103
+ aa = output["aatype"][i]
104
+ pred_pos = final_atom_positions[i]
105
+ mask = final_atom_mask[i]
106
+ resid = output["residue_index"][i] + 1
107
+ pred = OFProtein(
108
+ aatype=aa,
109
+ atom_positions=pred_pos,
110
+ atom_mask=mask,
111
+ residue_index=resid,
112
+ b_factors=output["plddt"][i],
113
+ chain_index=output["chain_index"][i] if "chain_index" in output else None,
114
+ )
115
+ pdbs.append(to_pdb(pred))
116
+ return pdbs
117
+
118
+
119
+ def collate_dense_tensors(
120
+ samples: T.List[torch.Tensor], pad_v: float = 0
121
+ ) -> torch.Tensor:
122
+ """
123
+ Takes a list of tensors with the following dimensions:
124
+ [(d_11, ..., d_1K),
125
+ (d_21, ..., d_2K),
126
+ ...,
127
+ (d_N1, ..., d_NK)]
128
+ and stack + pads them into a single tensor of:
129
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
130
+ """
131
+ if len(samples) == 0:
132
+ return torch.Tensor()
133
+ if len(set(x.dim() for x in samples)) != 1:
134
+ raise RuntimeError(
135
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
136
+ )
137
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
138
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
139
+ result = torch.empty(
140
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
141
+ )
142
+ result.fill_(pad_v)
143
+ for i in range(len(samples)):
144
+ result_i = result[i]
145
+ t = samples[i]
146
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
147
+ return result
148
+
149
+
150
+ class Attention(nn.Module):
151
+ def __init__(self, embed_dim, num_heads, head_width, gated=False):
152
+ super().__init__()
153
+ assert embed_dim == num_heads * head_width
154
+
155
+ self.embed_dim = embed_dim
156
+ self.num_heads = num_heads
157
+ self.head_width = head_width
158
+
159
+ self.proj = nn.Linear(embed_dim, embed_dim * 3, bias=False)
160
+ self.o_proj = nn.Linear(embed_dim, embed_dim, bias=True)
161
+ self.gated = gated
162
+ if gated:
163
+ self.g_proj = nn.Linear(embed_dim, embed_dim)
164
+ torch.nn.init.zeros_(self.g_proj.weight)
165
+ torch.nn.init.ones_(self.g_proj.bias)
166
+
167
+ self.rescale_factor = self.head_width**-0.5
168
+
169
+ torch.nn.init.zeros_(self.o_proj.bias)
170
+
171
+ def forward(self, x, mask=None, bias=None, indices=None):
172
+ """
173
+ Basic self attention with optional mask and external pairwise bias.
174
+ To handle sequences of different lengths, use mask.
175
+
176
+ Inputs:
177
+ x: batch of input sequneces (.. x L x C)
178
+ mask: batch of boolean masks where 1=valid, 0=padding position (.. x L_k). optional.
179
+ bias: batch of scalar pairwise attention biases (.. x Lq x Lk x num_heads). optional.
180
+
181
+ Outputs:
182
+ sequence projection (B x L x embed_dim), attention maps (B x L x L x num_heads)
183
+ """
184
+
185
+ t = rearrange(self.proj(x), "... l (h c) -> ... h l c", h=self.num_heads)
186
+ q, k, v = t.chunk(3, dim=-1)
187
+
188
+ q = self.rescale_factor * q
189
+ a = torch.einsum("...qc,...kc->...qk", q, k)
190
+
191
+ # Add external attention bias.
192
+ if bias is not None:
193
+ a = a + rearrange(bias, "... lq lk h -> ... h lq lk")
194
+
195
+ # Do not attend to padding tokens.
196
+ if mask is not None:
197
+ mask = repeat(
198
+ mask, "... lk -> ... h lq lk", h=self.num_heads, lq=q.shape[-2]
199
+ )
200
+ a = a.masked_fill(mask == False, -np.inf)
201
+
202
+ a = F.softmax(a, dim=-1)
203
+
204
+ y = torch.einsum("...hqk,...hkc->...qhc", a, v)
205
+ y = rearrange(y, "... h c -> ... (h c)", h=self.num_heads)
206
+
207
+ if self.gated:
208
+ y = self.g_proj(x).sigmoid() * y
209
+ y = self.o_proj(y)
210
+
211
+ return y, rearrange(a, "... lq lk h -> ... h lq lk")
212
+
213
+
214
+ class Dropout(nn.Module):
215
+ """
216
+ Implementation of dropout with the ability to share the dropout mask
217
+ along a particular dimension.
218
+ """
219
+
220
+ def __init__(self, r: float, batch_dim: T.Union[int, T.List[int]]):
221
+ super(Dropout, self).__init__()
222
+
223
+ self.r = r
224
+ if type(batch_dim) == int:
225
+ batch_dim = [batch_dim]
226
+ self.batch_dim = batch_dim
227
+ self.dropout = nn.Dropout(self.r)
228
+
229
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
230
+ shape = list(x.shape)
231
+ if self.batch_dim is not None:
232
+ for bd in self.batch_dim:
233
+ shape[bd] = 1
234
+ return x * self.dropout(x.new_ones(shape))
235
+
236
+
237
+ class SequenceToPair(nn.Module):
238
+ def __init__(self, sequence_state_dim, inner_dim, pairwise_state_dim):
239
+ super().__init__()
240
+
241
+ self.layernorm = nn.LayerNorm(sequence_state_dim)
242
+ self.proj = nn.Linear(sequence_state_dim, inner_dim * 2, bias=True)
243
+ self.o_proj = nn.Linear(2 * inner_dim, pairwise_state_dim, bias=True)
244
+
245
+ torch.nn.init.zeros_(self.proj.bias)
246
+ torch.nn.init.zeros_(self.o_proj.bias)
247
+
248
+ def forward(self, sequence_state):
249
+ """
250
+ Inputs:
251
+ sequence_state: B x L x sequence_state_dim
252
+
253
+ Output:
254
+ pairwise_state: B x L x L x pairwise_state_dim
255
+
256
+ Intermediate state:
257
+ B x L x L x 2*inner_dim
258
+ """
259
+
260
+ assert len(sequence_state.shape) == 3
261
+
262
+ s = self.layernorm(sequence_state)
263
+ s = self.proj(s)
264
+ q, k = s.chunk(2, dim=-1)
265
+
266
+ prod = q[:, None, :, :] * k[:, :, None, :]
267
+ diff = q[:, None, :, :] - k[:, :, None, :]
268
+
269
+ x = torch.cat([prod, diff], dim=-1)
270
+ x = self.o_proj(x)
271
+
272
+ return x
273
+
274
+
275
+ class PairToSequence(nn.Module):
276
+ def __init__(self, pairwise_state_dim, num_heads):
277
+ super().__init__()
278
+
279
+ self.layernorm = nn.LayerNorm(pairwise_state_dim)
280
+ self.linear = nn.Linear(pairwise_state_dim, num_heads, bias=False)
281
+
282
+ def forward(self, pairwise_state):
283
+ """
284
+ Inputs:
285
+ pairwise_state: B x L x L x pairwise_state_dim
286
+
287
+ Output:
288
+ pairwise_bias: B x L x L x num_heads
289
+ """
290
+ assert len(pairwise_state.shape) == 4
291
+ z = self.layernorm(pairwise_state)
292
+ pairwise_bias = self.linear(z)
293
+ return pairwise_bias
294
+
295
+
296
+ class ResidueMLP(nn.Module):
297
+ def __init__(self, embed_dim, inner_dim, norm=nn.LayerNorm, dropout=0):
298
+ super().__init__()
299
+
300
+ self.mlp = nn.Sequential(
301
+ norm(embed_dim),
302
+ nn.Linear(embed_dim, inner_dim),
303
+ nn.ReLU(),
304
+ nn.Linear(inner_dim, embed_dim),
305
+ nn.Dropout(dropout),
306
+ )
307
+
308
+ def forward(self, x):
309
+ return x + self.mlp(x)
esm/esmfold/v1/pretrained.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from pathlib import Path
7
+
8
+ import torch
9
+
10
+ from esm.esmfold.v1.esmfold import ESMFold
11
+
12
+
13
+ def _load_model(model_name):
14
+ if model_name.endswith(".pt"): # local, treat as filepath
15
+ model_path = Path(model_name)
16
+ model_data = torch.load(str(model_path), map_location="cpu")
17
+ else: # load from hub
18
+ url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt"
19
+ model_data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
20
+
21
+ cfg = model_data["cfg"]["model"]
22
+ model_state = model_data["model"]
23
+ model = ESMFold(esmfold_config=cfg)
24
+
25
+ expected_keys = set(model.state_dict().keys())
26
+ found_keys = set(model_state.keys())
27
+
28
+ missing_essential_keys = []
29
+ for missing_key in expected_keys - found_keys:
30
+ if not missing_key.startswith("esm."):
31
+ missing_essential_keys.append(missing_key)
32
+
33
+ if missing_essential_keys:
34
+ raise RuntimeError(f"Keys '{', '.join(missing_essential_keys)}' are missing.")
35
+
36
+ model.load_state_dict(model_state, strict=False)
37
+
38
+ return model
39
+
40
+
41
+ def esmfold_v0():
42
+ """
43
+ ESMFold v0 model with 3B ESM-2, 48 folding blocks.
44
+ This version was used for the paper (Lin et al, 2022). It was trained
45
+ on all PDB chains until 2020-05, to ensure temporal holdout with CASP14
46
+ and the CAMEO validation and test set reported there.
47
+ """
48
+ return _load_model("esmfold_3B_v0")
49
+
50
+
51
+ def esmfold_v1():
52
+ """
53
+ ESMFold v1 model using 3B ESM-2, 48 folding blocks.
54
+ ESMFold provides fast high accuracy atomic level structure prediction
55
+ directly from the individual sequence of a protein. ESMFold uses the ESM2
56
+ protein language model to extract meaningful representations from the
57
+ protein sequence.
58
+ """
59
+ return _load_model("esmfold_3B_v1")
60
+
61
+
62
+ def esmfold_structure_module_only_8M():
63
+ """
64
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
65
+ ESM-2 here is trained out to 500K updates.
66
+ This is a model designed to test the capabilities of the language model
67
+ when ablated for number of parameters in the language model.
68
+ See table S1 in (Lin et al, 2022).
69
+ """
70
+ return _load_model("esmfold_structure_module_only_8M")
71
+
72
+
73
+ def esmfold_structure_module_only_8M_270K():
74
+ """
75
+ ESMFold baseline model using 8M ESM-2, 0 folding blocks.
76
+ ESM-2 here is trained out to 270K updates.
77
+ This is a model designed to test the capabilities of the language model
78
+ when ablated for number of parameters in the language model.
79
+ See table S1 in (Lin et al, 2022).
80
+ """
81
+ return _load_model("esmfold_structure_module_only_8M_270K")
82
+
83
+
84
+ def esmfold_structure_module_only_35M():
85
+ """
86
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
87
+ ESM-2 here is trained out to 500K updates.
88
+ This is a model designed to test the capabilities of the language model
89
+ when ablated for number of parameters in the language model.
90
+ See table S1 in (Lin et al, 2022).
91
+ """
92
+ return _load_model("esmfold_structure_module_only_35M")
93
+
94
+
95
+ def esmfold_structure_module_only_35M_270K():
96
+ """
97
+ ESMFold baseline model using 35M ESM-2, 0 folding blocks.
98
+ ESM-2 here is trained out to 270K updates.
99
+ This is a model designed to test the capabilities of the language model
100
+ when ablated for number of parameters in the language model.
101
+ See table S1 in (Lin et al, 2022).
102
+ """
103
+ return _load_model("esmfold_structure_module_only_35M_270K")
104
+
105
+
106
+ def esmfold_structure_module_only_150M():
107
+ """
108
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
109
+ ESM-2 here is trained out to 500K updates.
110
+ This is a model designed to test the capabilities of the language model
111
+ when ablated for number of parameters in the language model.
112
+ See table S1 in (Lin et al, 2022).
113
+ """
114
+ return _load_model("esmfold_structure_module_only_150M")
115
+
116
+
117
+ def esmfold_structure_module_only_150M_270K():
118
+ """
119
+ ESMFold baseline model using 150M ESM-2, 0 folding blocks.
120
+ ESM-2 here is trained out to 270K updates.
121
+ This is a model designed to test the capabilities of the language model
122
+ when ablated for number of parameters in the language model.
123
+ See table S1 in (Lin et al, 2022).
124
+ """
125
+ return _load_model("esmfold_structure_module_only_150M_270K")
126
+
127
+
128
+ def esmfold_structure_module_only_650M():
129
+ """
130
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
131
+ ESM-2 here is trained out to 500K updates.
132
+ This is a model designed to test the capabilities of the language model
133
+ when ablated for number of parameters in the language model.
134
+ See table S1 in (Lin et al, 2022).
135
+ """
136
+ return _load_model("esmfold_structure_module_only_650M")
137
+
138
+
139
+ def esmfold_structure_module_only_650M_270K():
140
+ """
141
+ ESMFold baseline model using 650M ESM-2, 0 folding blocks.
142
+ ESM-2 here is trained out to 270K updates.
143
+ This is a model designed to test the capabilities of the language model
144
+ when ablated for number of parameters in the language model.
145
+ See table S1 in (Lin et al, 2022).
146
+ """
147
+ return _load_model("esmfold_structure_module_only_650M_270K")
148
+
149
+
150
+ def esmfold_structure_module_only_3B():
151
+ """
152
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
153
+ ESM-2 here is trained out to 500K updates.
154
+ This is a model designed to test the capabilities of the language model
155
+ when ablated for number of parameters in the language model.
156
+ See table S1 in (Lin et al, 2022).
157
+ """
158
+ return _load_model("esmfold_structure_module_only_3B")
159
+
160
+
161
+ def esmfold_structure_module_only_3B_270K():
162
+ """
163
+ ESMFold baseline model using 3B ESM-2, 0 folding blocks.
164
+ ESM-2 here is trained out to 270K updates.
165
+ This is a model designed to test the capabilities of the language model
166
+ when ablated for number of parameters in the language model.
167
+ See table S1 in (Lin et al, 2022).
168
+ """
169
+ return _load_model("esmfold_structure_module_only_3B_270K")
170
+
171
+
172
+ def esmfold_structure_module_only_15B():
173
+ """
174
+ ESMFold baseline model using 15B ESM-2, 0 folding blocks.
175
+ ESM-2 here is trained out to 270K updates.
176
+ The 15B parameter ESM-2 was not trained out to 500K updates
177
+ This is a model designed to test the capabilities of the language model
178
+ when ablated for number of parameters in the language model.
179
+ See table S1 in (Lin et al, 2022).
180
+ """
181
+ return _load_model("esmfold_structure_module_only_15B")
esm/esmfold/v1/tri_self_attn_block.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import torch
6
+ from openfold.model.triangular_attention import (
7
+ TriangleAttentionEndingNode,
8
+ TriangleAttentionStartingNode,
9
+ )
10
+ from openfold.model.triangular_multiplicative_update import (
11
+ TriangleMultiplicationIncoming,
12
+ TriangleMultiplicationOutgoing,
13
+ )
14
+ from torch import nn
15
+
16
+ from esm.esmfold.v1.misc import (
17
+ Attention,
18
+ Dropout,
19
+ PairToSequence,
20
+ ResidueMLP,
21
+ SequenceToPair,
22
+ )
23
+
24
+
25
+ class TriangularSelfAttentionBlock(nn.Module):
26
+ def __init__(
27
+ self,
28
+ sequence_state_dim,
29
+ pairwise_state_dim,
30
+ sequence_head_width,
31
+ pairwise_head_width,
32
+ dropout=0,
33
+ **__kwargs,
34
+ ):
35
+ super().__init__()
36
+
37
+ assert sequence_state_dim % sequence_head_width == 0
38
+ assert pairwise_state_dim % pairwise_head_width == 0
39
+ sequence_num_heads = sequence_state_dim // sequence_head_width
40
+ pairwise_num_heads = pairwise_state_dim // pairwise_head_width
41
+ assert sequence_state_dim == sequence_num_heads * sequence_head_width
42
+ assert pairwise_state_dim == pairwise_num_heads * pairwise_head_width
43
+ assert pairwise_state_dim % 2 == 0
44
+
45
+ self.sequence_state_dim = sequence_state_dim
46
+ self.pairwise_state_dim = pairwise_state_dim
47
+
48
+ self.layernorm_1 = nn.LayerNorm(sequence_state_dim)
49
+
50
+ self.sequence_to_pair = SequenceToPair(
51
+ sequence_state_dim, pairwise_state_dim // 2, pairwise_state_dim
52
+ )
53
+ self.pair_to_sequence = PairToSequence(pairwise_state_dim, sequence_num_heads)
54
+
55
+ self.seq_attention = Attention(
56
+ sequence_state_dim, sequence_num_heads, sequence_head_width, gated=True
57
+ )
58
+ self.tri_mul_out = TriangleMultiplicationOutgoing(
59
+ pairwise_state_dim,
60
+ pairwise_state_dim,
61
+ )
62
+ self.tri_mul_in = TriangleMultiplicationIncoming(
63
+ pairwise_state_dim,
64
+ pairwise_state_dim,
65
+ )
66
+ self.tri_att_start = TriangleAttentionStartingNode(
67
+ pairwise_state_dim,
68
+ pairwise_head_width,
69
+ pairwise_num_heads,
70
+ inf=1e9,
71
+ ) # type: ignore
72
+ self.tri_att_end = TriangleAttentionEndingNode(
73
+ pairwise_state_dim,
74
+ pairwise_head_width,
75
+ pairwise_num_heads,
76
+ inf=1e9,
77
+ ) # type: ignore
78
+
79
+ self.mlp_seq = ResidueMLP(sequence_state_dim, 4 * sequence_state_dim, dropout=dropout)
80
+ self.mlp_pair = ResidueMLP(pairwise_state_dim, 4 * pairwise_state_dim, dropout=dropout)
81
+
82
+ assert dropout < 0.4
83
+ self.drop = nn.Dropout(dropout)
84
+ self.row_drop = Dropout(dropout * 2, 2)
85
+ self.col_drop = Dropout(dropout * 2, 1)
86
+
87
+ torch.nn.init.zeros_(self.tri_mul_in.linear_z.weight)
88
+ torch.nn.init.zeros_(self.tri_mul_in.linear_z.bias)
89
+ torch.nn.init.zeros_(self.tri_mul_out.linear_z.weight)
90
+ torch.nn.init.zeros_(self.tri_mul_out.linear_z.bias)
91
+ torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.weight)
92
+ torch.nn.init.zeros_(self.tri_att_start.mha.linear_o.bias)
93
+ torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.weight)
94
+ torch.nn.init.zeros_(self.tri_att_end.mha.linear_o.bias)
95
+
96
+ torch.nn.init.zeros_(self.sequence_to_pair.o_proj.weight)
97
+ torch.nn.init.zeros_(self.sequence_to_pair.o_proj.bias)
98
+ torch.nn.init.zeros_(self.pair_to_sequence.linear.weight)
99
+ torch.nn.init.zeros_(self.seq_attention.o_proj.weight)
100
+ torch.nn.init.zeros_(self.seq_attention.o_proj.bias)
101
+ torch.nn.init.zeros_(self.mlp_seq.mlp[-2].weight)
102
+ torch.nn.init.zeros_(self.mlp_seq.mlp[-2].bias)
103
+ torch.nn.init.zeros_(self.mlp_pair.mlp[-2].weight)
104
+ torch.nn.init.zeros_(self.mlp_pair.mlp[-2].bias)
105
+
106
+ def forward(self, sequence_state, pairwise_state, mask=None, chunk_size=None, **__kwargs):
107
+ """
108
+ Inputs:
109
+ sequence_state: B x L x sequence_state_dim
110
+ pairwise_state: B x L x L x pairwise_state_dim
111
+ mask: B x L boolean tensor of valid positions
112
+
113
+ Output:
114
+ sequence_state: B x L x sequence_state_dim
115
+ pairwise_state: B x L x L x pairwise_state_dim
116
+ """
117
+ assert len(sequence_state.shape) == 3
118
+ assert len(pairwise_state.shape) == 4
119
+ if mask is not None:
120
+ assert len(mask.shape) == 2
121
+
122
+ batch_dim, seq_dim, sequence_state_dim = sequence_state.shape
123
+ pairwise_state_dim = pairwise_state.shape[3]
124
+ assert sequence_state_dim == self.sequence_state_dim
125
+ assert pairwise_state_dim == self.pairwise_state_dim
126
+ assert batch_dim == pairwise_state.shape[0]
127
+ assert seq_dim == pairwise_state.shape[1]
128
+ assert seq_dim == pairwise_state.shape[2]
129
+
130
+ # Update sequence state
131
+ bias = self.pair_to_sequence(pairwise_state)
132
+
133
+ # Self attention with bias + mlp.
134
+ y = self.layernorm_1(sequence_state)
135
+ y, _ = self.seq_attention(y, mask=mask, bias=bias)
136
+ sequence_state = sequence_state + self.drop(y)
137
+ sequence_state = self.mlp_seq(sequence_state)
138
+
139
+ # Update pairwise state
140
+ pairwise_state = pairwise_state + self.sequence_to_pair(sequence_state)
141
+
142
+ # Axial attention with triangular bias.
143
+ tri_mask = mask.unsqueeze(2) * mask.unsqueeze(1) if mask is not None else None
144
+ pairwise_state = pairwise_state + self.row_drop(
145
+ self.tri_mul_out(pairwise_state, mask=tri_mask)
146
+ )
147
+ pairwise_state = pairwise_state + self.col_drop(
148
+ self.tri_mul_in(pairwise_state, mask=tri_mask)
149
+ )
150
+ pairwise_state = pairwise_state + self.row_drop(
151
+ self.tri_att_start(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
152
+ )
153
+ pairwise_state = pairwise_state + self.col_drop(
154
+ self.tri_att_end(pairwise_state, mask=tri_mask, chunk_size=chunk_size)
155
+ )
156
+
157
+ # MLP over pairs.
158
+ pairwise_state = self.mlp_pair(pairwise_state)
159
+
160
+ return sequence_state, pairwise_state
esm/esmfold/v1/trunk.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import typing as T
6
+ from contextlib import ExitStack
7
+ from dataclasses import dataclass
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from openfold.model.structure_module import StructureModule
12
+
13
+ from esm.esmfold.v1.tri_self_attn_block import TriangularSelfAttentionBlock
14
+
15
+
16
+ @dataclass
17
+ class StructureModuleConfig:
18
+ c_s: int = 384
19
+ c_z: int = 128
20
+ c_ipa: int = 16
21
+ c_resnet: int = 128
22
+ no_heads_ipa: int = 12
23
+ no_qk_points: int = 4
24
+ no_v_points: int = 8
25
+ dropout_rate: float = 0.1
26
+ no_blocks: int = 8
27
+ no_transition_layers: int = 1
28
+ no_resnet_blocks: int = 2
29
+ no_angles: int = 7
30
+ trans_scale_factor: int = 10
31
+ epsilon: float = 1e-8
32
+ inf: float = 1e5
33
+
34
+
35
+ @dataclass
36
+ class FoldingTrunkConfig:
37
+ _name: str = "FoldingTrunkConfig"
38
+ num_blocks: int = 48
39
+ sequence_state_dim: int = 1024
40
+ pairwise_state_dim: int = 128
41
+ sequence_head_width: int = 32
42
+ pairwise_head_width: int = 32
43
+ position_bins: int = 32
44
+ dropout: float = 0
45
+ layer_drop: float = 0
46
+ cpu_grad_checkpoint: bool = False
47
+
48
+ max_recycles: int = 4
49
+ chunk_size: T.Optional[int] = None
50
+
51
+ structure_module: StructureModuleConfig = StructureModuleConfig()
52
+
53
+
54
+ def get_axial_mask(mask):
55
+ """
56
+ Helper to convert B x L mask of valid positions to axial mask used
57
+ in row column attentions.
58
+
59
+ Input:
60
+ mask: B x L tensor of booleans
61
+
62
+ Output:
63
+ mask: B x L x L tensor of booleans
64
+ """
65
+
66
+ if mask is None:
67
+ return None
68
+ assert len(mask.shape) == 2
69
+ batch_dim, seq_dim = mask.shape
70
+ m = mask.unsqueeze(1).expand(batch_dim, seq_dim, seq_dim)
71
+ m = m.reshape(batch_dim * seq_dim, seq_dim)
72
+ return m
73
+
74
+
75
+ class RelativePosition(nn.Module):
76
+ def __init__(self, bins, pairwise_state_dim):
77
+ super().__init__()
78
+ self.bins = bins
79
+
80
+ # Note an additional offset is used so that the 0th position
81
+ # is reserved for masked pairs.
82
+ self.embedding = torch.nn.Embedding(2 * bins + 2, pairwise_state_dim)
83
+
84
+ def forward(self, residue_index, mask=None):
85
+ """
86
+ Input:
87
+ residue_index: B x L tensor of indices (dytpe=torch.long)
88
+ mask: B x L tensor of booleans
89
+
90
+ Output:
91
+ pairwise_state: B x L x L x pairwise_state_dim tensor of embeddings
92
+ """
93
+
94
+ assert residue_index.dtype == torch.long
95
+ if mask is not None:
96
+ assert residue_index.shape == mask.shape
97
+
98
+ diff = residue_index[:, None, :] - residue_index[:, :, None]
99
+ diff = diff.clamp(-self.bins, self.bins)
100
+ diff = diff + self.bins + 1 # Add 1 to adjust for padding index.
101
+
102
+ if mask is not None:
103
+ mask = mask[:, None, :] * mask[:, :, None]
104
+ diff[mask == False] = 0
105
+
106
+ output = self.embedding(diff)
107
+ return output
108
+
109
+
110
+ class FoldingTrunk(nn.Module):
111
+ def __init__(self, **kwargs):
112
+ super().__init__()
113
+ self.cfg = FoldingTrunkConfig(**kwargs)
114
+ assert self.cfg.max_recycles > 0
115
+
116
+ c_s = self.cfg.sequence_state_dim
117
+ c_z = self.cfg.pairwise_state_dim
118
+
119
+ assert c_s % self.cfg.sequence_head_width == 0
120
+ assert c_z % self.cfg.pairwise_head_width == 0
121
+ block = TriangularSelfAttentionBlock
122
+
123
+ self.pairwise_positional_embedding = RelativePosition(self.cfg.position_bins, c_z)
124
+
125
+ self.blocks = nn.ModuleList(
126
+ [
127
+ block(
128
+ sequence_state_dim=c_s,
129
+ pairwise_state_dim=c_z,
130
+ sequence_head_width=self.cfg.sequence_head_width,
131
+ pairwise_head_width=self.cfg.pairwise_head_width,
132
+ dropout=self.cfg.dropout,
133
+ )
134
+ for i in range(self.cfg.num_blocks)
135
+ ]
136
+ )
137
+
138
+ self.recycle_bins = 15
139
+ self.recycle_s_norm = nn.LayerNorm(c_s)
140
+ self.recycle_z_norm = nn.LayerNorm(c_z)
141
+ self.recycle_disto = nn.Embedding(self.recycle_bins, c_z)
142
+ self.recycle_disto.weight[0].detach().zero_()
143
+
144
+ self.structure_module = StructureModule(**self.cfg.structure_module) # type: ignore
145
+ self.trunk2sm_s = nn.Linear(c_s, self.structure_module.c_s)
146
+ self.trunk2sm_z = nn.Linear(c_z, self.structure_module.c_z)
147
+
148
+ self.chunk_size = self.cfg.chunk_size
149
+
150
+ def set_chunk_size(self, chunk_size):
151
+ # This parameter means the axial attention will be computed
152
+ # in a chunked manner. This should make the memory used more or less O(L) instead of O(L^2).
153
+ # It's equivalent to running a for loop over chunks of the dimension we're iterative over,
154
+ # where the chunk_size is the size of the chunks, so 128 would mean to parse 128-lengthed chunks.
155
+ self.chunk_size = chunk_size
156
+
157
+ def forward(self, seq_feats, pair_feats, true_aa, residx, mask, no_recycles: T.Optional[int] = None):
158
+ """
159
+ Inputs:
160
+ seq_feats: B x L x C tensor of sequence features
161
+ pair_feats: B x L x L x C tensor of pair features
162
+ residx: B x L long tensor giving the position in the sequence
163
+ mask: B x L boolean tensor indicating valid residues
164
+
165
+ Output:
166
+ predicted_structure: B x L x (num_atoms_per_residue * 3) tensor wrapped in a Coordinates object
167
+ """
168
+
169
+ device = seq_feats.device
170
+ s_s_0 = seq_feats
171
+ s_z_0 = pair_feats
172
+
173
+ if no_recycles is None:
174
+ no_recycles = self.cfg.max_recycles
175
+ else:
176
+ assert no_recycles >= 0, "Number of recycles must not be negative."
177
+ no_recycles += 1 # First 'recycle' is just the standard forward pass through the model.
178
+
179
+ def trunk_iter(s, z, residx, mask):
180
+ z = z + self.pairwise_positional_embedding(residx, mask=mask)
181
+
182
+ for block in self.blocks:
183
+ s, z = block(s, z, mask=mask, residue_index=residx, chunk_size=self.chunk_size)
184
+ return s, z
185
+
186
+ s_s = s_s_0
187
+ s_z = s_z_0
188
+ recycle_s = torch.zeros_like(s_s)
189
+ recycle_z = torch.zeros_like(s_z)
190
+ recycle_bins = torch.zeros(*s_z.shape[:-1], device=device, dtype=torch.int64)
191
+
192
+ assert no_recycles > 0
193
+ for recycle_idx in range(no_recycles):
194
+ with ExitStack() if recycle_idx == no_recycles - 1 else torch.no_grad():
195
+ # === Recycling ===
196
+ recycle_s = self.recycle_s_norm(recycle_s.detach())
197
+ recycle_z = self.recycle_z_norm(recycle_z.detach())
198
+ recycle_z += self.recycle_disto(recycle_bins.detach())
199
+
200
+ s_s, s_z = trunk_iter(s_s_0 + recycle_s, s_z_0 + recycle_z, residx, mask)
201
+
202
+ # === Structure module ===
203
+ structure = self.structure_module(
204
+ {"single": self.trunk2sm_s(s_s), "pair": self.trunk2sm_z(s_z)},
205
+ true_aa,
206
+ mask.float(),
207
+ )
208
+
209
+ recycle_s = s_s
210
+ recycle_z = s_z
211
+ # Distogram needs the N, CA, C coordinates, and bin constants same as alphafold.
212
+ recycle_bins = FoldingTrunk.distogram(
213
+ structure["positions"][-1][:, :, :3],
214
+ 3.375,
215
+ 21.375,
216
+ self.recycle_bins,
217
+ )
218
+
219
+ assert isinstance(structure, dict) # type: ignore
220
+ structure["s_s"] = s_s
221
+ structure["s_z"] = s_z
222
+
223
+ return structure
224
+
225
+ @staticmethod
226
+ def distogram(coords, min_bin, max_bin, num_bins):
227
+ # Coords are [... L x 3 x 3], where it's [N, CA, C] x 3 coordinates.
228
+ boundaries = torch.linspace(
229
+ min_bin,
230
+ max_bin,
231
+ num_bins - 1,
232
+ device=coords.device,
233
+ )
234
+ boundaries = boundaries**2
235
+ N, CA, C = [x.squeeze(-2) for x in coords.chunk(3, dim=-2)]
236
+ # Infer CB coordinates.
237
+ b = CA - N
238
+ c = C - CA
239
+ a = b.cross(c, dim=-1)
240
+ CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
241
+ dists = (CB[..., None, :, :] - CB[..., :, None, :]).pow(2).sum(dim=-1, keepdims=True)
242
+ bins = torch.sum(dists > boundaries, dim=-1) # [..., L, L]
243
+ return bins
esm/inverse_folding/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ print("1")
6
+ from . import gvp_transformer
7
+ print("2")
8
+ from . import util
9
+ print("3")
10
+ from . import multichain_util
11
+ print("4")
esm/inverse_folding/features.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # Portions of this file were adapted from the open source code for the following
7
+ # two papers:
8
+ #
9
+ # Ingraham, J., Garg, V., Barzilay, R., & Jaakkola, T. (2019). Generative
10
+ # models for graph-based protein design. Advances in Neural Information
11
+ # Processing Systems, 32.
12
+ #
13
+ # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
14
+ # Learning from Protein Structure with Geometric Vector Perceptrons. In
15
+ # International Conference on Learning Representations.
16
+ #
17
+ # MIT License
18
+ #
19
+ # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
20
+ #
21
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
22
+ # of this software and associated documentation files (the "Software"), to deal
23
+ # in the Software without restriction, including without limitation the rights
24
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
25
+ # copies of the Software, and to permit persons to whom the Software is
26
+ # furnished to do so, subject to the following conditions:
27
+ #
28
+ # The above copyright notice and this permission notice shall be included in all
29
+ # copies or substantial portions of the Software.
30
+ #
31
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
32
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
33
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
34
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
35
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
36
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
37
+ # SOFTWARE.
38
+ #
39
+ # ================================================================
40
+ # The below license applies to the portions of the code (parts of
41
+ # src/datasets.py and src/models.py) adapted from Ingraham, et al.
42
+ # ================================================================
43
+ #
44
+ # MIT License
45
+ #
46
+ # Copyright (c) 2019 John Ingraham, Vikas Garg, Regina Barzilay, Tommi Jaakkola
47
+ #
48
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
49
+ # of this software and associated documentation files (the "Software"), to deal
50
+ # in the Software without restriction, including without limitation the rights
51
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
52
+ # copies of the Software, and to permit persons to whom the Software is
53
+ # furnished to do so, subject to the following conditions:
54
+ #
55
+ # The above copyright notice and this permission notice shall be included in all
56
+ # copies or substantial portions of the Software.
57
+ #
58
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
59
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
60
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
61
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
62
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
63
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
64
+ # SOFTWARE.
65
+
66
+ import math
67
+ import numpy as np
68
+ import torch
69
+ import torch.nn as nn
70
+ import torch.nn.functional as F
71
+
72
+ print("features1")
73
+ from .gvp_utils import flatten_graph
74
+ print("features2")
75
+ from .gvp_modules import GVP, LayerNorm
76
+ print("features3")
77
+ from .util import normalize, norm, nan_to_num, rbf
78
+ print("features4")
79
+
80
+
81
+ class GVPInputFeaturizer(nn.Module):
82
+
83
+ @staticmethod
84
+ def get_node_features(coords, coord_mask, with_coord_mask=True):
85
+ # scalar features
86
+ node_scalar_features = GVPInputFeaturizer._dihedrals(coords)
87
+ if with_coord_mask:
88
+ node_scalar_features = torch.cat([
89
+ node_scalar_features,
90
+ coord_mask.float().unsqueeze(-1)
91
+ ], dim=-1)
92
+ # vector features
93
+ X_ca = coords[:, :, 1]
94
+ orientations = GVPInputFeaturizer._orientations(X_ca)
95
+ sidechains = GVPInputFeaturizer._sidechains(coords)
96
+ node_vector_features = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2)
97
+ return node_scalar_features, node_vector_features
98
+
99
+ @staticmethod
100
+ def _orientations(X):
101
+ forward = normalize(X[:, 1:] - X[:, :-1])
102
+ backward = normalize(X[:, :-1] - X[:, 1:])
103
+ forward = F.pad(forward, [0, 0, 0, 1])
104
+ backward = F.pad(backward, [0, 0, 1, 0])
105
+ return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2)
106
+
107
+ @staticmethod
108
+ def _sidechains(X):
109
+ n, origin, c = X[:, :, 0], X[:, :, 1], X[:, :, 2]
110
+ c, n = normalize(c - origin), normalize(n - origin)
111
+ bisector = normalize(c + n)
112
+ perp = normalize(torch.cross(c, n, dim=-1))
113
+ vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3)
114
+ return vec
115
+
116
+ @staticmethod
117
+ def _dihedrals(X, eps=1e-7):
118
+ X = torch.flatten(X[:, :, :3], 1, 2)
119
+ bsz = X.shape[0]
120
+ dX = X[:, 1:] - X[:, :-1]
121
+ U = normalize(dX, dim=-1)
122
+ u_2 = U[:, :-2]
123
+ u_1 = U[:, 1:-1]
124
+ u_0 = U[:, 2:]
125
+
126
+ # Backbone normals
127
+ n_2 = normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
128
+ n_1 = normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
129
+
130
+ # Angle between normals
131
+ cosD = torch.sum(n_2 * n_1, -1)
132
+ cosD = torch.clamp(cosD, -1 + eps, 1 - eps)
133
+ D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD)
134
+
135
+ # This scheme will remove phi[0], psi[-1], omega[-1]
136
+ D = F.pad(D, [1, 2])
137
+ D = torch.reshape(D, [bsz, -1, 3])
138
+ # Lift angle representations to the circle
139
+ D_features = torch.cat([torch.cos(D), torch.sin(D)], -1)
140
+ return D_features
141
+
142
+ @staticmethod
143
+ def _positional_embeddings(edge_index,
144
+ num_embeddings=None,
145
+ num_positional_embeddings=16,
146
+ period_range=[2, 1000]):
147
+ # From https://github.com/jingraham/neurips19-graph-protein-design
148
+ num_embeddings = num_embeddings or num_positional_embeddings
149
+ d = edge_index[0] - edge_index[1]
150
+
151
+ frequency = torch.exp(
152
+ torch.arange(0, num_embeddings, 2, dtype=torch.float32,
153
+ device=edge_index.device)
154
+ * -(np.log(10000.0) / num_embeddings)
155
+ )
156
+ angles = d.unsqueeze(-1) * frequency
157
+ E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
158
+ return E
159
+
160
+ @staticmethod
161
+ def _dist(X, coord_mask, padding_mask, top_k_neighbors, eps=1e-8):
162
+ """ Pairwise euclidean distances """
163
+ bsz, maxlen = X.size(0), X.size(1)
164
+ coord_mask_2D = torch.unsqueeze(coord_mask,1) * torch.unsqueeze(coord_mask,2)
165
+ residue_mask = ~padding_mask
166
+ residue_mask_2D = torch.unsqueeze(residue_mask,1) * torch.unsqueeze(residue_mask,2)
167
+ dX = torch.unsqueeze(X,1) - torch.unsqueeze(X,2)
168
+ D = coord_mask_2D * norm(dX, dim=-1)
169
+
170
+ # sorting preference: first those with coords, then among the residues that
171
+ # exist but are masked use distance in sequence as tie breaker, and then the
172
+ # residues that came from padding are last
173
+ seqpos = torch.arange(maxlen, device=X.device)
174
+ Dseq = torch.abs(seqpos.unsqueeze(1) - seqpos.unsqueeze(0)).repeat(bsz, 1, 1)
175
+ D_adjust = nan_to_num(D) + (~coord_mask_2D) * (1e8 + Dseq*1e6) + (
176
+ ~residue_mask_2D) * (1e10)
177
+
178
+ if top_k_neighbors == -1:
179
+ D_neighbors = D_adjust
180
+ E_idx = seqpos.repeat(
181
+ *D_neighbors.shape[:-1], 1)
182
+ else:
183
+ # Identify k nearest neighbors (including self)
184
+ k = min(top_k_neighbors, X.size(1))
185
+ D_neighbors, E_idx = torch.topk(D_adjust, k, dim=-1, largest=False)
186
+
187
+ coord_mask_neighbors = (D_neighbors < 5e7)
188
+ residue_mask_neighbors = (D_neighbors < 5e9)
189
+ return D_neighbors, E_idx, coord_mask_neighbors, residue_mask_neighbors
190
+
191
+
192
+ class Normalize(nn.Module):
193
+ def __init__(self, features, epsilon=1e-6):
194
+ super(Normalize, self).__init__()
195
+ self.gain = nn.Parameter(torch.ones(features))
196
+ self.bias = nn.Parameter(torch.zeros(features))
197
+ self.epsilon = epsilon
198
+
199
+ def forward(self, x, dim=-1):
200
+ mu = x.mean(dim, keepdim=True)
201
+ sigma = torch.sqrt(x.var(dim, keepdim=True) + self.epsilon)
202
+ gain = self.gain
203
+ bias = self.bias
204
+ # Reshape
205
+ if dim != -1:
206
+ shape = [1] * len(mu.size())
207
+ shape[dim] = self.gain.size()[0]
208
+ gain = gain.view(shape)
209
+ bias = bias.view(shape)
210
+ return gain * (x - mu) / (sigma + self.epsilon) + bias
211
+
212
+
213
+ class DihedralFeatures(nn.Module):
214
+ def __init__(self, node_embed_dim):
215
+ """ Embed dihedral angle features. """
216
+ super(DihedralFeatures, self).__init__()
217
+ # 3 dihedral angles; sin and cos of each angle
218
+ node_in = 6
219
+ # Normalization and embedding
220
+ self.node_embedding = nn.Linear(node_in, node_embed_dim, bias=True)
221
+ self.norm_nodes = Normalize(node_embed_dim)
222
+
223
+ def forward(self, X):
224
+ """ Featurize coordinates as an attributed graph """
225
+ V = self._dihedrals(X)
226
+ V = self.node_embedding(V)
227
+ V = self.norm_nodes(V)
228
+ return V
229
+
230
+ @staticmethod
231
+ def _dihedrals(X, eps=1e-7, return_angles=False):
232
+ # First 3 coordinates are N, CA, C
233
+ X = X[:,:,:3,:].reshape(X.shape[0], 3*X.shape[1], 3)
234
+
235
+ # Shifted slices of unit vectors
236
+ dX = X[:,1:,:] - X[:,:-1,:]
237
+ U = F.normalize(dX, dim=-1)
238
+ u_2 = U[:,:-2,:]
239
+ u_1 = U[:,1:-1,:]
240
+ u_0 = U[:,2:,:]
241
+ # Backbone normals
242
+ n_2 = F.normalize(torch.cross(u_2, u_1, dim=-1), dim=-1)
243
+ n_1 = F.normalize(torch.cross(u_1, u_0, dim=-1), dim=-1)
244
+
245
+ # Angle between normals
246
+ cosD = (n_2 * n_1).sum(-1)
247
+ cosD = torch.clamp(cosD, -1+eps, 1-eps)
248
+ D = torch.sign((u_2 * n_1).sum(-1)) * torch.acos(cosD)
249
+
250
+ # This scheme will remove phi[0], psi[-1], omega[-1]
251
+ D = F.pad(D, (1,2), 'constant', 0)
252
+ D = D.view((D.size(0), int(D.size(1)/3), 3))
253
+ phi, psi, omega = torch.unbind(D,-1)
254
+
255
+ if return_angles:
256
+ return phi, psi, omega
257
+
258
+ # Lift angle representations to the circle
259
+ D_features = torch.cat((torch.cos(D), torch.sin(D)), 2)
260
+ return D_features
261
+
262
+
263
+ class GVPGraphEmbedding(GVPInputFeaturizer):
264
+
265
+ def __init__(self, args):
266
+ super().__init__()
267
+ self.top_k_neighbors = args.top_k_neighbors
268
+ self.num_positional_embeddings = 16
269
+ self.remove_edges_without_coords = True
270
+ node_input_dim = (7, 3)
271
+ edge_input_dim = (34, 1)
272
+ node_hidden_dim = (args.node_hidden_dim_scalar,
273
+ args.node_hidden_dim_vector)
274
+ edge_hidden_dim = (args.edge_hidden_dim_scalar,
275
+ args.edge_hidden_dim_vector)
276
+ self.embed_node = nn.Sequential(
277
+ GVP(node_input_dim, node_hidden_dim, activations=(None, None)),
278
+ LayerNorm(node_hidden_dim, eps=1e-4)
279
+ )
280
+ self.embed_edge = nn.Sequential(
281
+ GVP(edge_input_dim, edge_hidden_dim, activations=(None, None)),
282
+ LayerNorm(edge_hidden_dim, eps=1e-4)
283
+ )
284
+ self.embed_confidence = nn.Linear(16, args.node_hidden_dim_scalar)
285
+
286
+ def forward(self, coords, coord_mask, padding_mask, confidence):
287
+ with torch.no_grad():
288
+ node_features = self.get_node_features(coords, coord_mask)
289
+ edge_features, edge_index = self.get_edge_features(
290
+ coords, coord_mask, padding_mask)
291
+ node_embeddings_scalar, node_embeddings_vector = self.embed_node(node_features)
292
+ edge_embeddings = self.embed_edge(edge_features)
293
+
294
+ rbf_rep = rbf(confidence, 0., 1.)
295
+ node_embeddings = (
296
+ node_embeddings_scalar + self.embed_confidence(rbf_rep),
297
+ node_embeddings_vector
298
+ )
299
+
300
+ node_embeddings, edge_embeddings, edge_index = flatten_graph(
301
+ node_embeddings, edge_embeddings, edge_index)
302
+ return node_embeddings, edge_embeddings, edge_index
303
+
304
+ def get_edge_features(self, coords, coord_mask, padding_mask):
305
+ X_ca = coords[:, :, 1]
306
+ # Get distances to the top k neighbors
307
+ E_dist, E_idx, E_coord_mask, E_residue_mask = GVPInputFeaturizer._dist(
308
+ X_ca, coord_mask, padding_mask, self.top_k_neighbors)
309
+ # Flatten the graph to be batch size 1 for torch_geometric package
310
+ dest = E_idx
311
+ B, L, k = E_idx.shape[:3]
312
+ src = torch.arange(L, device=E_idx.device).view([1, L, 1]).expand(B, L, k)
313
+ # After flattening, [2, B, E]
314
+ edge_index = torch.stack([src, dest], dim=0).flatten(2, 3)
315
+ # After flattening, [B, E]
316
+ E_dist = E_dist.flatten(1, 2)
317
+ E_coord_mask = E_coord_mask.flatten(1, 2).unsqueeze(-1)
318
+ E_residue_mask = E_residue_mask.flatten(1, 2)
319
+ # Calculate relative positional embeddings and distance RBF
320
+ pos_embeddings = GVPInputFeaturizer._positional_embeddings(
321
+ edge_index,
322
+ num_positional_embeddings=self.num_positional_embeddings,
323
+ )
324
+ D_rbf = rbf(E_dist, 0., 20.)
325
+ # Calculate relative orientation
326
+ X_src = X_ca.unsqueeze(2).expand(-1, -1, k, -1).flatten(1, 2)
327
+ X_dest = torch.gather(
328
+ X_ca,
329
+ 1,
330
+ edge_index[1, :, :].unsqueeze(-1).expand([B, L*k, 3])
331
+ )
332
+ coord_mask_src = coord_mask.unsqueeze(2).expand(-1, -1, k).flatten(1, 2)
333
+ coord_mask_dest = torch.gather(
334
+ coord_mask,
335
+ 1,
336
+ edge_index[1, :, :].expand([B, L*k])
337
+ )
338
+ E_vectors = X_src - X_dest
339
+ # For the ones without coordinates, substitute in the average vector
340
+ E_vector_mean = torch.sum(E_vectors * E_coord_mask, dim=1,
341
+ keepdims=True) / torch.sum(E_coord_mask, dim=1, keepdims=True)
342
+ E_vectors = E_vectors * E_coord_mask + E_vector_mean * ~(E_coord_mask)
343
+ # Normalize and remove nans
344
+ edge_s = torch.cat([D_rbf, pos_embeddings], dim=-1)
345
+ edge_v = normalize(E_vectors).unsqueeze(-2)
346
+ edge_s, edge_v = map(nan_to_num, (edge_s, edge_v))
347
+ # Also add indications of whether the coordinates are present
348
+ edge_s = torch.cat([
349
+ edge_s,
350
+ (~coord_mask_src).float().unsqueeze(-1),
351
+ (~coord_mask_dest).float().unsqueeze(-1),
352
+ ], dim=-1)
353
+ edge_index[:, ~E_residue_mask] = -1
354
+ if self.remove_edges_without_coords:
355
+ edge_index[:, ~E_coord_mask.squeeze(-1)] = -1
356
+ return (edge_s, edge_v), edge_index.transpose(0, 1)
esm/inverse_folding/gvp_encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from argparse import Namespace
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ from .features import GVPGraphEmbedding
13
+ from .gvp_modules import GVPConvLayer, LayerNorm
14
+ from .gvp_utils import unflatten_graph
15
+
16
+
17
+
18
+ class GVPEncoder(nn.Module):
19
+
20
+ def __init__(self, args):
21
+ super().__init__()
22
+ self.args = args
23
+ self.embed_graph = GVPGraphEmbedding(args)
24
+
25
+ node_hidden_dim = (args.node_hidden_dim_scalar,
26
+ args.node_hidden_dim_vector)
27
+ edge_hidden_dim = (args.edge_hidden_dim_scalar,
28
+ args.edge_hidden_dim_vector)
29
+
30
+ conv_activations = (F.relu, torch.sigmoid)
31
+ self.encoder_layers = nn.ModuleList(
32
+ GVPConvLayer(
33
+ node_hidden_dim,
34
+ edge_hidden_dim,
35
+ drop_rate=args.dropout,
36
+ vector_gate=True,
37
+ attention_heads=0,
38
+ n_message=3,
39
+ conv_activations=conv_activations,
40
+ n_edge_gvps=0,
41
+ eps=1e-4,
42
+ layernorm=True,
43
+ )
44
+ for i in range(args.num_encoder_layers)
45
+ )
46
+
47
+ def forward(self, coords, coord_mask, padding_mask, confidence):
48
+ node_embeddings, edge_embeddings, edge_index = self.embed_graph(
49
+ coords, coord_mask, padding_mask, confidence)
50
+
51
+ for i, layer in enumerate(self.encoder_layers):
52
+ node_embeddings, edge_embeddings = layer(node_embeddings,
53
+ edge_index, edge_embeddings)
54
+
55
+ node_embeddings = unflatten_graph(node_embeddings, coords.shape[0])
56
+ return node_embeddings
esm/inverse_folding/gvp_modules.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contents of this file are from the open source code for
2
+ #
3
+ # Jing, B., Eismann, S., Suriana, P., Townshend, R. J. L., & Dror, R. (2020).
4
+ # Learning from Protein Structure with Geometric Vector Perceptrons. In
5
+ # International Conference on Learning Representations.
6
+ #
7
+ # MIT License
8
+ #
9
+ # Copyright (c) 2020 Bowen Jing, Stephan Eismann, Patricia Suriana, Raphael Townshend, Ron Dror
10
+ #
11
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ # of this software and associated documentation files (the "Software"), to deal
13
+ # in the Software without restriction, including without limitation the rights
14
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ # copies of the Software, and to permit persons to whom the Software is
16
+ # furnished to do so, subject to the following conditions:
17
+ #
18
+ # The above copyright notice and this permission notice shall be included in all
19
+ # copies or substantial portions of the Software.
20
+ #
21
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ # SOFTWARE.
28
+
29
+ import typing as T
30
+ import torch
31
+ from torch import nn
32
+ import torch.nn.functional as F
33
+ print("gvp_module1")
34
+ from torch_geometric.nn import MessagePassing
35
+ print("gvp_module2")
36
+ from torch_scatter import scatter_add, scatter
37
+
38
+ def tuple_size(tp):
39
+ return tuple([0 if a is None else a.size() for a in tp])
40
+
41
+ def tuple_sum(tp1, tp2):
42
+ s1, v1 = tp1
43
+ s2, v2 = tp2
44
+ if v2 is None and v2 is None:
45
+ return (s1 + s2, None)
46
+ return (s1 + s2, v1 + v2)
47
+
48
+ def tuple_cat(*args, dim=-1):
49
+ '''
50
+ Concatenates any number of tuples (s, V) elementwise.
51
+
52
+ :param dim: dimension along which to concatenate when viewed
53
+ as the `dim` index for the scalar-channel tensors.
54
+ This means that `dim=-1` will be applied as
55
+ `dim=-2` for the vector-channel tensors.
56
+ '''
57
+ dim %= len(args[0][0].shape)
58
+ s_args, v_args = list(zip(*args))
59
+ return torch.cat(s_args, dim=dim), torch.cat(v_args, dim=dim)
60
+
61
+ def tuple_index(x, idx):
62
+ '''
63
+ Indexes into a tuple (s, V) along the first dimension.
64
+
65
+ :param idx: any object which can be used to index into a `torch.Tensor`
66
+ '''
67
+ return x[0][idx], x[1][idx]
68
+
69
+ def randn(n, dims, device="cpu"):
70
+ '''
71
+ Returns random tuples (s, V) drawn elementwise from a normal distribution.
72
+
73
+ :param n: number of data points
74
+ :param dims: tuple of dimensions (n_scalar, n_vector)
75
+
76
+ :return: (s, V) with s.shape = (n, n_scalar) and
77
+ V.shape = (n, n_vector, 3)
78
+ '''
79
+ return torch.randn(n, dims[0], device=device), \
80
+ torch.randn(n, dims[1], 3, device=device)
81
+
82
+ def _norm_no_nan(x, axis=-1, keepdims=False, eps=1e-8, sqrt=True):
83
+ '''
84
+ L2 norm of tensor clamped above a minimum value `eps`.
85
+
86
+ :param sqrt: if `False`, returns the square of the L2 norm
87
+ '''
88
+ # clamp is slow
89
+ # out = torch.clamp(torch.sum(torch.square(x), axis, keepdims), min=eps)
90
+ out = torch.sum(torch.square(x), axis, keepdims) + eps
91
+ return torch.sqrt(out) if sqrt else out
92
+
93
+ def _split(x, nv):
94
+ '''
95
+ Splits a merged representation of (s, V) back into a tuple.
96
+ Should be used only with `_merge(s, V)` and only if the tuple
97
+ representation cannot be used.
98
+
99
+ :param x: the `torch.Tensor` returned from `_merge`
100
+ :param nv: the number of vector channels in the input to `_merge`
101
+ '''
102
+ v = torch.reshape(x[..., -3*nv:], x.shape[:-1] + (nv, 3))
103
+ s = x[..., :-3*nv]
104
+ return s, v
105
+
106
+ def _merge(s, v):
107
+ '''
108
+ Merges a tuple (s, V) into a single `torch.Tensor`, where the
109
+ vector channels are flattened and appended to the scalar channels.
110
+ Should be used only if the tuple representation cannot be used.
111
+ Use `_split(x, nv)` to reverse.
112
+ '''
113
+ v = torch.reshape(v, v.shape[:-2] + (3*v.shape[-2],))
114
+ return torch.cat([s, v], -1)
115
+
116
+ class GVP(nn.Module):
117
+ '''
118
+ Geometric Vector Perceptron. See manuscript and README.md
119
+ for more details.
120
+
121
+ :param in_dims: tuple (n_scalar, n_vector)
122
+ :param out_dims: tuple (n_scalar, n_vector)
123
+ :param h_dim: intermediate number of vector channels, optional
124
+ :param activations: tuple of functions (scalar_act, vector_act)
125
+ :param tuple_io: whether to keep accepting tuple inputs and outputs when vi
126
+ or vo = 0
127
+ '''
128
+ def __init__(self, in_dims, out_dims, h_dim=None, vector_gate=False,
129
+ activations=(F.relu, torch.sigmoid), tuple_io=True,
130
+ eps=1e-8):
131
+ super(GVP, self).__init__()
132
+ self.si, self.vi = in_dims
133
+ self.so, self.vo = out_dims
134
+ self.tuple_io = tuple_io
135
+ if self.vi:
136
+ self.h_dim = h_dim or max(self.vi, self.vo)
137
+ self.wh = nn.Linear(self.vi, self.h_dim, bias=False)
138
+ self.ws = nn.Linear(self.h_dim + self.si, self.so)
139
+ if self.vo:
140
+ self.wv = nn.Linear(self.h_dim, self.vo, bias=False)
141
+ if vector_gate:
142
+ self.wg = nn.Linear(self.so, self.vo)
143
+ else:
144
+ self.ws = nn.Linear(self.si, self.so)
145
+
146
+ self.vector_gate = vector_gate
147
+ self.scalar_act, self.vector_act = activations
148
+ self.eps = eps
149
+
150
+ def forward(self, x):
151
+ '''
152
+ :param x: tuple (s, V) of `torch.Tensor`,
153
+ or (if vectors_in is 0), a single `torch.Tensor`
154
+ :return: tuple (s, V) of `torch.Tensor`,
155
+ or (if vectors_out is 0), a single `torch.Tensor`
156
+ '''
157
+ if self.vi:
158
+ s, v = x
159
+ v = torch.transpose(v, -1, -2)
160
+ vh = self.wh(v)
161
+ vn = _norm_no_nan(vh, axis=-2, eps=self.eps)
162
+ s = self.ws(torch.cat([s, vn], -1))
163
+ if self.scalar_act:
164
+ s = self.scalar_act(s)
165
+ if self.vo:
166
+ v = self.wv(vh)
167
+ v = torch.transpose(v, -1, -2)
168
+ if self.vector_gate:
169
+ g = self.wg(s).unsqueeze(-1)
170
+ else:
171
+ g = _norm_no_nan(v, axis=-1, keepdims=True, eps=self.eps)
172
+ if self.vector_act:
173
+ g = self.vector_act(g)
174
+ v = v * g
175
+ else:
176
+ if self.tuple_io:
177
+ assert x[1] is None
178
+ x = x[0]
179
+ s = self.ws(x)
180
+ if self.scalar_act:
181
+ s = self.scalar_act(s)
182
+ if self.vo:
183
+ v = torch.zeros(list(s.shape)[:-1] + [self.vo, 3],
184
+ device=s.device)
185
+
186
+ if self.vo:
187
+ return (s, v)
188
+ elif self.tuple_io:
189
+ return (s, None)
190
+ else:
191
+ return s
192
+
193
+
194
+ class _VDropout(nn.Module):
195
+ '''
196
+ Vector channel dropout where the elements of each
197
+ vector channel are dropped together.
198
+ '''
199
+ def __init__(self, drop_rate):
200
+ super(_VDropout, self).__init__()
201
+ self.drop_rate = drop_rate
202
+
203
+ def forward(self, x):
204
+ '''
205
+ :param x: `torch.Tensor` corresponding to vector channels
206
+ '''
207
+ if x is None:
208
+ return None
209
+ device = x.device
210
+ if not self.training:
211
+ return x
212
+ mask = torch.bernoulli(
213
+ (1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
214
+ ).unsqueeze(-1)
215
+ x = mask * x / (1 - self.drop_rate)
216
+ return x
217
+
218
+ class Dropout(nn.Module):
219
+ '''
220
+ Combined dropout for tuples (s, V).
221
+ Takes tuples (s, V) as input and as output.
222
+ '''
223
+ def __init__(self, drop_rate):
224
+ super(Dropout, self).__init__()
225
+ self.sdropout = nn.Dropout(drop_rate)
226
+ self.vdropout = _VDropout(drop_rate)
227
+
228
+ def forward(self, x):
229
+ '''
230
+ :param x: tuple (s, V) of `torch.Tensor`,
231
+ or single `torch.Tensor`
232
+ (will be assumed to be scalar channels)
233
+ '''
234
+ if type(x) is torch.Tensor:
235
+ return self.sdropout(x)
236
+ s, v = x
237
+ return self.sdropout(s), self.vdropout(v)
238
+
239
+ class LayerNorm(nn.Module):
240
+ '''
241
+ Combined LayerNorm for tuples (s, V).
242
+ Takes tuples (s, V) as input and as output.
243
+ '''
244
+ def __init__(self, dims, tuple_io=True, eps=1e-8):
245
+ super(LayerNorm, self).__init__()
246
+ self.tuple_io = tuple_io
247
+ self.s, self.v = dims
248
+ self.scalar_norm = nn.LayerNorm(self.s)
249
+ self.eps = eps
250
+
251
+ def forward(self, x):
252
+ '''
253
+ :param x: tuple (s, V) of `torch.Tensor`,
254
+ or single `torch.Tensor`
255
+ (will be assumed to be scalar channels)
256
+ '''
257
+ if not self.v:
258
+ if self.tuple_io:
259
+ return self.scalar_norm(x[0]), None
260
+ return self.scalar_norm(x)
261
+ s, v = x
262
+ vn = _norm_no_nan(v, axis=-1, keepdims=True, sqrt=False, eps=self.eps)
263
+ nonzero_mask = (vn > 2 * self.eps)
264
+ vn = torch.sum(vn * nonzero_mask, dim=-2, keepdim=True
265
+ ) / (self.eps + torch.sum(nonzero_mask, dim=-2, keepdim=True))
266
+ vn = torch.sqrt(vn + self.eps)
267
+ v = nonzero_mask * (v / vn)
268
+ return self.scalar_norm(s), v
269
+
270
+ class GVPConv(MessagePassing):
271
+ '''
272
+ Graph convolution / message passing with Geometric Vector Perceptrons.
273
+ Takes in a graph with node and edge embeddings,
274
+ and returns new node embeddings.
275
+
276
+ This does NOT do residual updates and pointwise feedforward layers
277
+ ---see `GVPConvLayer`.
278
+
279
+ :param in_dims: input node embedding dimensions (n_scalar, n_vector)
280
+ :param out_dims: output node embedding dimensions (n_scalar, n_vector)
281
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
282
+ :param n_layers: number of GVPs in the message function
283
+ :param module_list: preconstructed message function, overrides n_layers
284
+ :param aggr: should be "add" if some incoming edges are masked, as in
285
+ a masked autoregressive decoder architecture
286
+ '''
287
+ def __init__(self, in_dims, out_dims, edge_dims, n_layers=3,
288
+ vector_gate=False, module_list=None, aggr="mean", eps=1e-8,
289
+ activations=(F.relu, torch.sigmoid)):
290
+ super(GVPConv, self).__init__(aggr=aggr)
291
+ self.eps = eps
292
+ self.si, self.vi = in_dims
293
+ self.so, self.vo = out_dims
294
+ self.se, self.ve = edge_dims
295
+
296
+ module_list = module_list or []
297
+ if not module_list:
298
+ if n_layers == 1:
299
+ module_list.append(
300
+ GVP((2*self.si + self.se, 2*self.vi + self.ve),
301
+ (self.so, self.vo), activations=(None, None)))
302
+ else:
303
+ module_list.append(
304
+ GVP((2*self.si + self.se, 2*self.vi + self.ve), out_dims,
305
+ vector_gate=vector_gate, activations=activations)
306
+ )
307
+ for i in range(n_layers - 2):
308
+ module_list.append(GVP(out_dims, out_dims,
309
+ vector_gate=vector_gate))
310
+ module_list.append(GVP(out_dims, out_dims,
311
+ activations=(None, None)))
312
+ self.message_func = nn.Sequential(*module_list)
313
+
314
+ def forward(self, x, edge_index, edge_attr):
315
+ '''
316
+ :param x: tuple (s, V) of `torch.Tensor`
317
+ :param edge_index: array of shape [2, n_edges]
318
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
319
+ '''
320
+ x_s, x_v = x
321
+ message = self.propagate(edge_index,
322
+ s=x_s, v=x_v.reshape(x_v.shape[0], 3*x_v.shape[1]),
323
+ edge_attr=edge_attr)
324
+ return _split(message, self.vo)
325
+
326
+ def message(self, s_i, v_i, s_j, v_j, edge_attr):
327
+ v_j = v_j.view(v_j.shape[0], v_j.shape[1]//3, 3)
328
+ v_i = v_i.view(v_i.shape[0], v_i.shape[1]//3, 3)
329
+ message = tuple_cat((s_j, v_j), edge_attr, (s_i, v_i))
330
+ message = self.message_func(message)
331
+ return _merge(*message)
332
+
333
+
334
+ class GVPConvLayer(nn.Module):
335
+ '''
336
+ Full graph convolution / message passing layer with
337
+ Geometric Vector Perceptrons. Residually updates node embeddings with
338
+ aggregated incoming messages, applies a pointwise feedforward
339
+ network to node embeddings, and returns updated node embeddings.
340
+
341
+ To only compute the aggregated messages, see `GVPConv`.
342
+
343
+ :param node_dims: node embedding dimensions (n_scalar, n_vector)
344
+ :param edge_dims: input edge embedding dimensions (n_scalar, n_vector)
345
+ :param n_message: number of GVPs to use in message function
346
+ :param n_feedforward: number of GVPs to use in feedforward function
347
+ :param drop_rate: drop probability in all dropout layers
348
+ :param autoregressive: if `True`, this `GVPConvLayer` will be used
349
+ with a different set of input node embeddings for messages
350
+ where src >= dst
351
+ '''
352
+ def __init__(self, node_dims, edge_dims, vector_gate=False,
353
+ n_message=3, n_feedforward=2, drop_rate=.1,
354
+ autoregressive=False, attention_heads=0,
355
+ conv_activations=(F.relu, torch.sigmoid),
356
+ n_edge_gvps=0, layernorm=True, eps=1e-8):
357
+
358
+ super(GVPConvLayer, self).__init__()
359
+ if attention_heads == 0:
360
+ self.conv = GVPConv(
361
+ node_dims, node_dims, edge_dims, n_layers=n_message,
362
+ vector_gate=vector_gate,
363
+ aggr="add" if autoregressive else "mean",
364
+ activations=conv_activations,
365
+ eps=eps,
366
+ )
367
+ else:
368
+ raise NotImplementedError
369
+ if layernorm:
370
+ self.norm = nn.ModuleList([LayerNorm(node_dims, eps=eps) for _ in range(2)])
371
+ else:
372
+ self.norm = nn.ModuleList([nn.Identity() for _ in range(2)])
373
+ self.dropout = nn.ModuleList([Dropout(drop_rate) for _ in range(2)])
374
+
375
+ ff_func = []
376
+ if n_feedforward == 1:
377
+ ff_func.append(GVP(node_dims, node_dims, activations=(None, None)))
378
+ else:
379
+ hid_dims = 4*node_dims[0], 2*node_dims[1]
380
+ ff_func.append(GVP(node_dims, hid_dims, vector_gate=vector_gate))
381
+ for i in range(n_feedforward-2):
382
+ ff_func.append(GVP(hid_dims, hid_dims, vector_gate=vector_gate))
383
+ ff_func.append(GVP(hid_dims, node_dims, activations=(None, None)))
384
+ self.ff_func = nn.Sequential(*ff_func)
385
+
386
+ self.edge_message_func = None
387
+ if n_edge_gvps > 0:
388
+ si, vi = node_dims
389
+ se, ve = edge_dims
390
+ module_list = [
391
+ GVP((2*si + se, 2*vi + ve), edge_dims, vector_gate=vector_gate)
392
+ ]
393
+ for i in range(n_edge_gvps - 2):
394
+ module_list.append(GVP(edge_dims, edge_dims,
395
+ vector_gate=vector_gate))
396
+ if n_edge_gvps > 1:
397
+ module_list.append(GVP(edge_dims, edge_dims,
398
+ activations=(None, None)))
399
+ self.edge_message_func = nn.Sequential(*module_list)
400
+ if layernorm:
401
+ self.edge_norm = LayerNorm(edge_dims, eps=eps)
402
+ else:
403
+ self.edge_norm = nn.Identity()
404
+ self.edge_dropout = Dropout(drop_rate)
405
+
406
+ def forward(self, x, edge_index, edge_attr,
407
+ autoregressive_x=None, node_mask=None):
408
+ '''
409
+ :param x: tuple (s, V) of `torch.Tensor`
410
+ :param edge_index: array of shape [2, n_edges]
411
+ :param edge_attr: tuple (s, V) of `torch.Tensor`
412
+ :param autoregressive_x: tuple (s, V) of `torch.Tensor`.
413
+ If not `None`, will be used as srcqq node embeddings
414
+ for forming messages where src >= dst. The corrent node
415
+ embeddings `x` will still be the base of the update and the
416
+ pointwise feedforward.
417
+ :param node_mask: array of type `bool` to index into the first
418
+ dim of node embeddings (s, V). If not `None`, only
419
+ these nodes will be updated.
420
+ '''
421
+ if self.edge_message_func:
422
+ src, dst = edge_index
423
+ if autoregressive_x is None:
424
+ x_src = x[0][src], x[1][src]
425
+ else:
426
+ mask = (src < dst).unsqueeze(-1)
427
+ x_src = (
428
+ torch.where(mask, x[0][src], autoregressive_x[0][src]),
429
+ torch.where(mask.unsqueeze(-1), x[1][src],
430
+ autoregressive_x[1][src])
431
+ )
432
+ x_dst = x[0][dst], x[1][dst]
433
+ x_edge = (
434
+ torch.cat([x_src[0], edge_attr[0], x_dst[0]], dim=-1),
435
+ torch.cat([x_src[1], edge_attr[1], x_dst[1]], dim=-2)
436
+ )
437
+ edge_attr_dh = self.edge_message_func(x_edge)
438
+ edge_attr = self.edge_norm(tuple_sum(edge_attr,
439
+ self.edge_dropout(edge_attr_dh)))
440
+
441
+ if autoregressive_x is not None:
442
+ src, dst = edge_index
443
+ mask = src < dst
444
+ edge_index_forward = edge_index[:, mask]
445
+ edge_index_backward = edge_index[:, ~mask]
446
+ edge_attr_forward = tuple_index(edge_attr, mask)
447
+ edge_attr_backward = tuple_index(edge_attr, ~mask)
448
+
449
+ dh = tuple_sum(
450
+ self.conv(x, edge_index_forward, edge_attr_forward),
451
+ self.conv(autoregressive_x, edge_index_backward, edge_attr_backward)
452
+ )
453
+
454
+ count = scatter_add(torch.ones_like(dst), dst,
455
+ dim_size=dh[0].size(0)).clamp(min=1).unsqueeze(-1)
456
+
457
+ dh = dh[0] / count, dh[1] / count.unsqueeze(-1)
458
+
459
+ else:
460
+ dh = self.conv(x, edge_index, edge_attr)
461
+
462
+ if node_mask is not None:
463
+ x_ = x
464
+ x, dh = tuple_index(x, node_mask), tuple_index(dh, node_mask)
465
+
466
+ x = self.norm[0](tuple_sum(x, self.dropout[0](dh)))
467
+
468
+ dh = self.ff_func(x)
469
+ x = self.norm[1](tuple_sum(x, self.dropout[1](dh)))
470
+
471
+ if node_mask is not None:
472
+ x_[0][node_mask], x_[1][node_mask] = x[0], x[1]
473
+ x = x_
474
+
475
+ return x, edge_attr
esm/inverse_folding/gvp_transformer.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # import argparse
7
+ # from typing import Any, Dict, List, Optional, Tuple, NamedTuple
8
+ import torch
9
+ from torch import nn
10
+ # from torch import Tensor
11
+ import torch.nn.functional as F
12
+ # from scipy.spatial import transform
13
+ #
14
+ # from esm.data import Alphabet
15
+
16
+ # from .features import DihedralFeatures
17
+ # from .gvp_encoder import GVPEncoder
18
+ # from .gvp_utils import unflatten_graph
19
+ print("gvp1_transformer")
20
+ from .gvp_transformer_encoder import GVPTransformerEncoder
21
+ print("gvp2_transformer")
22
+ from .transformer_decoder import TransformerDecoder
23
+ print("gvp3_transformer")
24
+ from .util import rotate, CoordBatchConverter
25
+ print("gvp4_transformer")
26
+
27
+
28
+ class GVPTransformerModel(nn.Module):
29
+ """
30
+ GVP-Transformer inverse folding model.
31
+
32
+ Architecture: Geometric GVP-GNN as initial layers, followed by
33
+ sequence-to-sequence Transformer encoder and decoder.
34
+ """
35
+
36
+ def __init__(self, args, alphabet):
37
+ super().__init__()
38
+ encoder_embed_tokens = self.build_embedding(
39
+ args, alphabet, args.encoder_embed_dim,
40
+ )
41
+ decoder_embed_tokens = self.build_embedding(
42
+ args, alphabet, args.decoder_embed_dim,
43
+ )
44
+ encoder = self.build_encoder(args, alphabet, encoder_embed_tokens)
45
+ decoder = self.build_decoder(args, alphabet, decoder_embed_tokens)
46
+ self.args = args
47
+ self.encoder = encoder
48
+ self.decoder = decoder
49
+
50
+ @classmethod
51
+ def build_encoder(cls, args, src_dict, embed_tokens):
52
+ encoder = GVPTransformerEncoder(args, src_dict, embed_tokens)
53
+ return encoder
54
+
55
+ @classmethod
56
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
57
+ decoder = TransformerDecoder(
58
+ args,
59
+ tgt_dict,
60
+ embed_tokens,
61
+ )
62
+ return decoder
63
+
64
+ @classmethod
65
+ def build_embedding(cls, args, dictionary, embed_dim):
66
+ num_embeddings = len(dictionary)
67
+ padding_idx = dictionary.padding_idx
68
+ emb = nn.Embedding(num_embeddings, embed_dim, padding_idx)
69
+ nn.init.normal_(emb.weight, mean=0, std=embed_dim ** -0.5)
70
+ nn.init.constant_(emb.weight[padding_idx], 0)
71
+ return emb
72
+
73
+ def forward(
74
+ self,
75
+ coords,
76
+ padding_mask,
77
+ confidence,
78
+ prev_output_tokens,
79
+ return_all_hiddens: bool = False,
80
+ features_only: bool = False,
81
+ ):
82
+ encoder_out = self.encoder(coords, padding_mask, confidence,
83
+ return_all_hiddens=return_all_hiddens)
84
+ logits, extra = self.decoder(
85
+ prev_output_tokens,
86
+ encoder_out=encoder_out,
87
+ features_only=features_only,
88
+ return_all_hiddens=return_all_hiddens,
89
+ )
90
+ return logits, extra
91
+
92
+ def sample(self, coords, partial_seq=None, temperature=1.0, confidence=None, device=None):
93
+ """
94
+ Samples sequences based on multinomial sampling (no beam search).
95
+
96
+ Args:
97
+ coords: L x 3 x 3 list representing one backbone
98
+ partial_seq: Optional, partial sequence with mask tokens if part of
99
+ the sequence is known
100
+ temperature: sampling temperature, use low temperature for higher
101
+ sequence recovery and high temperature for higher diversity
102
+ confidence: optional length L list of confidence scores for coordinates
103
+ """
104
+ L = len(coords)
105
+ # Convert to batch format
106
+ batch_converter = CoordBatchConverter(self.decoder.dictionary)
107
+ batch_coords, confidence, _, _, padding_mask = (
108
+ batch_converter([(coords, confidence, None)], device=device)
109
+ )
110
+
111
+ # Start with prepend token
112
+ mask_idx = self.decoder.dictionary.get_idx('<mask>')
113
+ sampled_tokens = torch.full((1, 1+L), mask_idx, dtype=int)
114
+ sampled_tokens[0, 0] = self.decoder.dictionary.get_idx('<cath>')
115
+ if partial_seq is not None:
116
+ for i, c in enumerate(partial_seq):
117
+ sampled_tokens[0, i+1] = self.decoder.dictionary.get_idx(c)
118
+
119
+ # Save incremental states for faster sampling
120
+ incremental_state = dict()
121
+
122
+ # Run encoder only once
123
+ encoder_out = self.encoder(batch_coords, padding_mask, confidence)
124
+
125
+ # Make sure all tensors are on the same device if a GPU is present
126
+ if device:
127
+ sampled_tokens = sampled_tokens.to(device)
128
+
129
+ # Decode one token at a time
130
+ for i in range(1, L+1):
131
+ logits, _ = self.decoder(
132
+ sampled_tokens[:, :i],
133
+ encoder_out,
134
+ incremental_state=incremental_state,
135
+ )
136
+ logits = logits[0].transpose(0, 1)
137
+ logits /= temperature
138
+ probs = F.softmax(logits, dim=-1)
139
+ if sampled_tokens[0, i] == mask_idx:
140
+ sampled_tokens[:, i] = torch.multinomial(probs, 1).squeeze(-1)
141
+ sampled_seq = sampled_tokens[0, 1:]
142
+
143
+ # Convert back to string via lookup
144
+ return ''.join([self.decoder.dictionary.get_tok(a) for a in sampled_seq]), encoder_out
esm/inverse_folding/gvp_transformer_encoder.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import argparse
9
+ import math
10
+ from typing import Dict, List, Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch import Tensor
15
+ print("gvp1_transformer_encoder")
16
+ from esm.modules import SinusoidalPositionalEmbedding
17
+ print("gvp2_transformer_encoder")
18
+ from .features import GVPInputFeaturizer, DihedralFeatures
19
+ print("gvp3_transformer_encoder")
20
+ from .gvp_encoder import GVPEncoder
21
+ print("gvp4_transformer_encoder")
22
+ from .transformer_layer import TransformerEncoderLayer
23
+ print("gvp5_transformer_encoder")
24
+ from .util import nan_to_num, get_rotation_frames, rotate, rbf
25
+ print("gvp6_transformer_encoder")
26
+
27
+
28
+ class GVPTransformerEncoder(nn.Module):
29
+ """
30
+ Transformer encoder consisting of *args.encoder.layers* layers. Each layer
31
+ is a :class:`TransformerEncoderLayer`.
32
+
33
+ Args:
34
+ args (argparse.Namespace): parsed command-line arguments
35
+ dictionary (~fairseq.data.Dictionary): encoding dictionary
36
+ embed_tokens (torch.nn.Embedding): input embedding
37
+ """
38
+
39
+ def __init__(self, args, dictionary, embed_tokens):
40
+ super().__init__()
41
+ self.args = args
42
+ self.dictionary = dictionary
43
+
44
+ self.dropout_module = nn.Dropout(args.dropout)
45
+
46
+ embed_dim = embed_tokens.embedding_dim
47
+ self.padding_idx = embed_tokens.padding_idx
48
+
49
+ self.embed_tokens = embed_tokens
50
+ self.embed_scale = math.sqrt(embed_dim)
51
+ self.embed_positions = SinusoidalPositionalEmbedding(
52
+ embed_dim,
53
+ self.padding_idx,
54
+ )
55
+ self.embed_gvp_input_features = nn.Linear(15, embed_dim)
56
+ self.embed_confidence = nn.Linear(16, embed_dim)
57
+ self.embed_dihedrals = DihedralFeatures(embed_dim)
58
+
59
+ gvp_args = argparse.Namespace()
60
+ for k, v in vars(args).items():
61
+ if k.startswith("gvp_"):
62
+ setattr(gvp_args, k[4:], v)
63
+ self.gvp_encoder = GVPEncoder(gvp_args)
64
+ gvp_out_dim = gvp_args.node_hidden_dim_scalar + (3 *
65
+ gvp_args.node_hidden_dim_vector)
66
+ self.embed_gvp_output = nn.Linear(gvp_out_dim, embed_dim)
67
+
68
+ self.layers = nn.ModuleList([])
69
+ self.layers.extend(
70
+ [self.build_encoder_layer(args) for i in range(args.encoder_layers)]
71
+ )
72
+ self.num_layers = len(self.layers)
73
+ self.layer_norm = nn.LayerNorm(embed_dim)
74
+
75
+ def build_encoder_layer(self, args):
76
+ return TransformerEncoderLayer(args)
77
+
78
+ def forward_embedding(self, coords, padding_mask, confidence):
79
+ """
80
+ Args:
81
+ coords: N, CA, C backbone coordinates in shape length x 3 (atoms) x 3
82
+ padding_mask: boolean Tensor (true for padding) of shape length
83
+ confidence: confidence scores between 0 and 1 of shape length
84
+ """
85
+ components = dict()
86
+ coord_mask = torch.all(torch.all(torch.isfinite(coords), dim=-1), dim=-1)
87
+ coords = nan_to_num(coords)
88
+ mask_tokens = (
89
+ padding_mask * self.dictionary.padding_idx +
90
+ ~padding_mask * self.dictionary.get_idx("<mask>")
91
+ )
92
+ components["tokens"] = self.embed_tokens(mask_tokens) * self.embed_scale
93
+ components["diherals"] = self.embed_dihedrals(coords)
94
+
95
+ # GVP encoder
96
+ gvp_out_scalars, gvp_out_vectors = self.gvp_encoder(coords,
97
+ coord_mask, padding_mask, confidence)
98
+ R = get_rotation_frames(coords)
99
+ # Rotate to local rotation frame for rotation-invariance
100
+ gvp_out_features = torch.cat([
101
+ gvp_out_scalars,
102
+ rotate(gvp_out_vectors, R.transpose(-2, -1)).flatten(-2, -1),
103
+ ], dim=-1)
104
+ components["gvp_out"] = self.embed_gvp_output(gvp_out_features)
105
+
106
+ components["confidence"] = self.embed_confidence(
107
+ rbf(confidence, 0., 1.))
108
+
109
+ # In addition to GVP encoder outputs, also directly embed GVP input node
110
+ # features to the Transformer
111
+ scalar_features, vector_features = GVPInputFeaturizer.get_node_features(
112
+ coords, coord_mask, with_coord_mask=False)
113
+ features = torch.cat([
114
+ scalar_features,
115
+ rotate(vector_features, R.transpose(-2, -1)).flatten(-2, -1),
116
+ ], dim=-1)
117
+ components["gvp_input_features"] = self.embed_gvp_input_features(features)
118
+
119
+ embed = sum(components.values())
120
+ # for k, v in components.items():
121
+ # print(k, torch.mean(v, dim=(0,1)), torch.std(v, dim=(0,1)))
122
+
123
+ x = embed
124
+ x = x + self.embed_positions(mask_tokens)
125
+ x = self.dropout_module(x)
126
+ return x, components
127
+
128
+ def forward(
129
+ self,
130
+ coords,
131
+ encoder_padding_mask,
132
+ confidence,
133
+ return_all_hiddens: bool = False,
134
+ ):
135
+ """
136
+ Args:
137
+ coords (Tensor): backbone coordinates
138
+ shape batch_size x num_residues x num_atoms (3 for N, CA, C) x 3
139
+ encoder_padding_mask (ByteTensor): the positions of
140
+ padding elements of shape `(batch_size x num_residues)`
141
+ confidence (Tensor): the confidence score of shape (batch_size x
142
+ num_residues). The value is between 0. and 1. for each residue
143
+ coordinate, or -1. if no coordinate is given
144
+ return_all_hiddens (bool, optional): also return all of the
145
+ intermediate hidden states (default: False).
146
+
147
+ Returns:
148
+ dict:
149
+ - **encoder_out** (Tensor): the last encoder layer's output of
150
+ shape `(num_residues, batch_size, embed_dim)`
151
+ - **encoder_padding_mask** (ByteTensor): the positions of
152
+ padding elements of shape `(batch_size, num_residues)`
153
+ - **encoder_embedding** (Tensor): the (scaled) embedding lookup
154
+ of shape `(batch_size, num_residues, embed_dim)`
155
+ - **encoder_states** (List[Tensor]): all intermediate
156
+ hidden states of shape `(num_residues, batch_size, embed_dim)`.
157
+ Only populated if *return_all_hiddens* is True.
158
+ """
159
+ x, encoder_embedding = self.forward_embedding(coords,
160
+ encoder_padding_mask, confidence)
161
+ # account for padding while computing the representation
162
+ x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
163
+
164
+ # B x T x C -> T x B x C
165
+ x = x.transpose(0, 1)
166
+
167
+ encoder_states = []
168
+
169
+ if return_all_hiddens:
170
+ encoder_states.append(x)
171
+
172
+ # encoder layers
173
+ for layer in self.layers:
174
+ x = layer(
175
+ x, encoder_padding_mask=encoder_padding_mask
176
+ )
177
+ if return_all_hiddens:
178
+ assert encoder_states is not None
179
+ encoder_states.append(x)
180
+
181
+ if self.layer_norm is not None:
182
+ x = self.layer_norm(x)
183
+
184
+ return {
185
+ "encoder_out": [x], # T x B x C
186
+ "encoder_padding_mask": [encoder_padding_mask], # B x T
187
+ "encoder_embedding": [encoder_embedding], # dictionary
188
+ "encoder_states": encoder_states, # List[T x B x C]
189
+ }
esm/inverse_folding/gvp_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def flatten_graph(node_embeddings, edge_embeddings, edge_index):
10
+ """
11
+ Flattens the graph into a batch size one (with disconnected subgraphs for
12
+ each example) to be compatible with pytorch-geometric package.
13
+ Args:
14
+ node_embeddings: node embeddings in tuple form (scalar, vector)
15
+ - scalar: shape batch size x nodes x node_embed_dim
16
+ - vector: shape batch size x nodes x node_embed_dim x 3
17
+ edge_embeddings: edge embeddings of in tuple form (scalar, vector)
18
+ - scalar: shape batch size x edges x edge_embed_dim
19
+ - vector: shape batch size x edges x edge_embed_dim x 3
20
+ edge_index: shape batch_size x 2 (source node and target node) x edges
21
+ Returns:
22
+ node_embeddings: node embeddings in tuple form (scalar, vector)
23
+ - scalar: shape batch total_nodes x node_embed_dim
24
+ - vector: shape batch total_nodes x node_embed_dim x 3
25
+ edge_embeddings: edge embeddings of in tuple form (scalar, vector)
26
+ - scalar: shape batch total_edges x edge_embed_dim
27
+ - vector: shape batch total_edges x edge_embed_dim x 3
28
+ edge_index: shape 2 x total_edges
29
+ """
30
+ x_s, x_v = node_embeddings
31
+ e_s, e_v = edge_embeddings
32
+ batch_size, N = x_s.shape[0], x_s.shape[1]
33
+ node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1))
34
+ edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1))
35
+
36
+ edge_mask = torch.any(edge_index != -1, dim=1)
37
+ # Re-number the nodes by adding batch_idx * N to each batch
38
+ edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) *
39
+ N).unsqueeze(-1).unsqueeze(-1)
40
+ edge_index = edge_index.permute(1, 0, 2).flatten(1, 2)
41
+ edge_mask = edge_mask.flatten()
42
+ edge_index = edge_index[:, edge_mask]
43
+ edge_embeddings = (
44
+ edge_embeddings[0][edge_mask, :],
45
+ edge_embeddings[1][edge_mask, :]
46
+ )
47
+ return node_embeddings, edge_embeddings, edge_index
48
+
49
+
50
+ def unflatten_graph(node_embeddings, batch_size):
51
+ """
52
+ Unflattens node embeddings.
53
+ Args:
54
+ node_embeddings: node embeddings in tuple form (scalar, vector)
55
+ - scalar: shape batch total_nodes x node_embed_dim
56
+ - vector: shape batch total_nodes x node_embed_dim x 3
57
+ batch_size: int
58
+ Returns:
59
+ node_embeddings: node embeddings in tuple form (scalar, vector)
60
+ - scalar: shape batch size x nodes x node_embed_dim
61
+ - vector: shape batch size x nodes x node_embed_dim x 3
62
+ """
63
+ x_s, x_v = node_embeddings
64
+ x_s = x_s.reshape(batch_size, -1, x_s.shape[1])
65
+ x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2])
66
+ return (x_s, x_v)
67
+
68
+
esm/inverse_folding/multichain_util.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # #
3
+ # # This source code is licensed under the MIT license found in the
4
+ # # LICENSE file in the root directory of this source tree.
5
+ #
6
+ # import biotite.structure
7
+ # import numpy as np
8
+ # import torch
9
+ # from typing import Sequence, Tuple, List
10
+ #
11
+ # from esm.inverse_folding.util import (
12
+ # load_structure,
13
+ # extract_coords_from_structure,
14
+ # load_coords,
15
+ # get_sequence_loss,
16
+ # get_encoder_output,
17
+ # )
18
+ #
19
+ #
20
+ # def extract_coords_from_complex(structure: biotite.structure.AtomArray):
21
+ # """
22
+ # Args:
23
+ # structure: biotite AtomArray
24
+ # Returns:
25
+ # Tuple (coords_list, seq_list)
26
+ # - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
27
+ # coordinates representing the backbone of each chain
28
+ # - seqs: Dictionary mapping chain ids to native sequences of each chain
29
+ # """
30
+ # coords = {}
31
+ # seqs = {}
32
+ # all_chains = biotite.structure.get_chains(structure)
33
+ # for chain_id in all_chains:
34
+ # chain = structure[structure.chain_id == chain_id]
35
+ # coords[chain_id], seqs[chain_id] = extract_coords_from_structure(chain)
36
+ # return coords, seqs
37
+ #
38
+ #
39
+ # def load_complex_coords(fpath, chains):
40
+ # """
41
+ # Args:
42
+ # fpath: filepath to either pdb or cif file
43
+ # chains: the chain ids (the order matters for autoregressive model)
44
+ # Returns:
45
+ # Tuple (coords_list, seq_list)
46
+ # - coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
47
+ # coordinates representing the backbone of each chain
48
+ # - seqs: Dictionary mapping chain ids to native sequences of each chain
49
+ # """
50
+ # structure = load_structure(fpath, chains)
51
+ # return extract_coords_from_complex(structure)
52
+ #
53
+ #
54
+ # def _concatenate_coords(coords, target_chain_id, padding_length=10):
55
+ # """
56
+ # Args:
57
+ # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
58
+ # coordinates representing the backbone of each chain
59
+ # target_chain_id: The chain id to sample sequences for
60
+ # padding_length: Length of padding between concatenated chains
61
+ # Returns:
62
+ # Tuple (coords, seq)
63
+ # - coords is an L x 3 x 3 array for N, CA, C coordinates, a
64
+ # concatenation of the chains with padding in between
65
+ # - seq is the extracted sequence, with padding tokens inserted
66
+ # between the concatenated chains
67
+ # """
68
+ # pad_coords = np.full((padding_length, 3, 3), np.nan, dtype=np.float32)
69
+ # # For best performance, put the target chain first in concatenation.
70
+ # coords_list = [coords[target_chain_id]]
71
+ # for chain_id in coords:
72
+ # if chain_id == target_chain_id:
73
+ # continue
74
+ # coords_list.append(pad_coords)
75
+ # coords_list.append(coords[chain_id])
76
+ # coords_concatenated = np.concatenate(coords_list, axis=0)
77
+ # return coords_concatenated
78
+ #
79
+ #
80
+ # def sample_sequence_in_complex(model, coords, target_chain_id, temperature=1.,
81
+ # padding_length=10):
82
+ # """
83
+ # Samples sequence for one chain in a complex.
84
+ # Args:
85
+ # model: An instance of the GVPTransformer model
86
+ # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
87
+ # coordinates representing the backbone of each chain
88
+ # target_chain_id: The chain id to sample sequences for
89
+ # padding_length: padding length in between chains
90
+ # Returns:
91
+ # Sampled sequence for the target chain
92
+ # """
93
+ # target_chain_len = coords[target_chain_id].shape[0]
94
+ # all_coords = _concatenate_coords(coords, target_chain_id)
95
+ # device = next(model.parameters()).device
96
+ #
97
+ # # Supply padding tokens for other chains to avoid unused sampling for speed
98
+ # padding_pattern = ['<pad>'] * all_coords.shape[0]
99
+ # for i in range(target_chain_len):
100
+ # padding_pattern[i] = '<mask>'
101
+ # sampled = model.sample(all_coords, partial_seq=padding_pattern,
102
+ # temperature=temperature, device=device)
103
+ # sampled = sampled[:target_chain_len]
104
+ # return sampled
105
+ #
106
+ #
107
+ # def score_sequence_in_complex(model, alphabet, coords, target_chain_id,
108
+ # target_seq, padding_length=10):
109
+ # """
110
+ # Scores sequence for one chain in a complex.
111
+ # Args:
112
+ # model: An instance of the GVPTransformer model
113
+ # alphabet: Alphabet for the model
114
+ # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
115
+ # coordinates representing the backbone of each chain
116
+ # target_chain_id: The chain id to sample sequences for
117
+ # target_seq: Target sequence for the target chain for scoring.
118
+ # padding_length: padding length in between chains
119
+ # Returns:
120
+ # Tuple (ll_fullseq, ll_withcoord)
121
+ # - ll_fullseq: Average log-likelihood over the full target chain
122
+ # - ll_withcoord: Average log-likelihood in target chain excluding those
123
+ # residues without coordinates
124
+ # """
125
+ # all_coords = _concatenate_coords(coords, target_chain_id)
126
+ #
127
+ # loss, target_padding_mask = get_sequence_loss(model, alphabet, all_coords,
128
+ # target_seq)
129
+ # ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(
130
+ # ~target_padding_mask)
131
+ #
132
+ # # Also calculate average when excluding masked portions
133
+ # coord_mask = np.all(np.isfinite(coords[target_chain_id]), axis=(-1, -2))
134
+ # ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
135
+ # return ll_fullseq, ll_withcoord
136
+ #
137
+ #
138
+ # def get_encoder_output_for_complex(model, alphabet, coords, target_chain_id):
139
+ # """
140
+ # Args:
141
+ # model: An instance of the GVPTransformer model
142
+ # alphabet: Alphabet for the model
143
+ # coords: Dictionary mapping chain ids to L x 3 x 3 array for N, CA, C
144
+ # coordinates representing the backbone of each chain
145
+ # target_chain_id: The chain id to sample sequences for
146
+ # Returns:
147
+ # Dictionary mapping chain id to encoder output for each chain
148
+ # """
149
+ # all_coords = _concatenate_coords(coords, target_chain_id)
150
+ # all_rep = get_encoder_output(model, alphabet, all_coords)
151
+ # target_chain_len = coords[target_chain_id].shape[0]
152
+ # return all_rep[:target_chain_len]
esm/inverse_folding/transformer_decoder.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import math
9
+ from typing import Any, Dict, List, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+ from esm.modules import SinusoidalPositionalEmbedding
16
+ from .transformer_layer import TransformerDecoderLayer
17
+
18
+
19
+ def fill_with_neg_inf(t):
20
+ """FP16-compatible function that fills a tensor with -inf."""
21
+ return t.float().fill_(float("-inf")).type_as(t)
22
+
23
+
24
+ class TransformerDecoder(nn.Module):
25
+ """
26
+ Transformer decoder consisting of *args.decoder.layers* layers. Each layer
27
+ is a :class:`TransformerDecoderLayer`.
28
+
29
+ Args:
30
+ args (argparse.Namespace): parsed command-line arguments
31
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
32
+ embed_tokens (torch.nn.Embedding): output embedding
33
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
34
+ (default: False).
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ args,
40
+ dictionary,
41
+ embed_tokens,
42
+ ):
43
+ super().__init__()
44
+ self.args = args
45
+ self.dictionary = dictionary
46
+ self._future_mask = torch.empty(0)
47
+
48
+ self.dropout_module = nn.Dropout(args.dropout)
49
+
50
+ input_embed_dim = embed_tokens.embedding_dim
51
+ embed_dim = args.decoder_embed_dim
52
+ self.embed_dim = embed_dim
53
+
54
+ self.padding_idx = embed_tokens.padding_idx
55
+
56
+ self.embed_tokens = embed_tokens
57
+ self.embed_scale = math.sqrt(embed_dim)
58
+
59
+ self.project_in_dim = (
60
+ nn.Linear(input_embed_dim, embed_dim, bias=False)
61
+ if embed_dim != input_embed_dim
62
+ else None
63
+ )
64
+ self.embed_positions = SinusoidalPositionalEmbedding(
65
+ embed_dim,
66
+ self.padding_idx,
67
+ )
68
+
69
+ self.layers = nn.ModuleList([])
70
+ self.layers.extend(
71
+ [
72
+ self.build_decoder_layer(args)
73
+ for _ in range(args.decoder_layers)
74
+ ]
75
+ )
76
+ self.num_layers = len(self.layers)
77
+ self.layer_norm = nn.LayerNorm(embed_dim)
78
+
79
+ self.build_output_projection(args, dictionary)
80
+
81
+ def build_output_projection(self, args, dictionary):
82
+ self.output_projection = nn.Linear(
83
+ args.decoder_embed_dim, len(dictionary), bias=False
84
+ )
85
+ nn.init.normal_(
86
+ self.output_projection.weight, mean=0, std=args.decoder_embed_dim ** -0.5
87
+ )
88
+
89
+ def build_decoder_layer(self, args):
90
+ return TransformerDecoderLayer(args)
91
+
92
+ def forward(
93
+ self,
94
+ prev_output_tokens,
95
+ encoder_out: Optional[Dict[str, List[Tensor]]] = None,
96
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
97
+ features_only: bool = False,
98
+ return_all_hiddens: bool = False,
99
+ ):
100
+ """
101
+ Args:
102
+ prev_output_tokens (LongTensor): previous decoder outputs of shape
103
+ `(batch, tgt_len)`, for teacher forcing
104
+ encoder_out (optional): output from the encoder, used for
105
+ encoder-side attention, should be of size T x B x C
106
+ incremental_state (dict): dictionary used for storing state during
107
+ :ref:`Incremental decoding`
108
+ features_only (bool, optional): only return features without
109
+ applying output layer (default: False).
110
+
111
+ Returns:
112
+ tuple:
113
+ - the decoder's output of shape `(batch, tgt_len, vocab)`
114
+ - a dictionary with any model-specific outputs
115
+ """
116
+
117
+ x, extra = self.extract_features(
118
+ prev_output_tokens,
119
+ encoder_out=encoder_out,
120
+ incremental_state=incremental_state,
121
+ )
122
+
123
+ if not features_only:
124
+ x = self.output_layer(x)
125
+ x = x.transpose(1, 2) # B x T x C -> B x C x T
126
+ return x, extra
127
+
128
+ def extract_features(
129
+ self,
130
+ prev_output_tokens,
131
+ encoder_out: Optional[Dict[str, List[Tensor]]],
132
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
133
+ ):
134
+ """
135
+ Similar to *forward* but only return features.
136
+
137
+ Includes several features from "Jointly Learning to Align and
138
+ Translate with Transformer Models" (Garg et al., EMNLP 2019).
139
+
140
+ Returns:
141
+ tuple:
142
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
143
+ - a dictionary with any model-specific outputs
144
+ """
145
+ bs, slen = prev_output_tokens.size()
146
+
147
+ enc: Optional[Tensor] = None
148
+ padding_mask: Optional[Tensor] = None
149
+ if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
150
+ enc = encoder_out["encoder_out"][0]
151
+ assert (
152
+ enc.size()[1] == bs
153
+ ), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
154
+ if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
155
+ padding_mask = encoder_out["encoder_padding_mask"][0]
156
+
157
+ # embed positions
158
+ positions = self.embed_positions(
159
+ prev_output_tokens
160
+ )
161
+
162
+ if incremental_state is not None:
163
+ prev_output_tokens = prev_output_tokens[:, -1:]
164
+ positions = positions[:, -1:]
165
+
166
+ # embed tokens and positions
167
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
168
+
169
+ if self.project_in_dim is not None:
170
+ x = self.project_in_dim(x)
171
+
172
+ x += positions
173
+
174
+ x = self.dropout_module(x)
175
+
176
+ # B x T x C -> T x B x C
177
+ x = x.transpose(0, 1)
178
+
179
+ self_attn_padding_mask: Optional[Tensor] = None
180
+ if prev_output_tokens.eq(self.padding_idx).any():
181
+ self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
182
+
183
+ # decoder layers
184
+ attn: Optional[Tensor] = None
185
+ inner_states: List[Optional[Tensor]] = [x]
186
+ for idx, layer in enumerate(self.layers):
187
+ if incremental_state is None:
188
+ self_attn_mask = self.buffered_future_mask(x)
189
+ else:
190
+ self_attn_mask = None
191
+
192
+ x, layer_attn, _ = layer(
193
+ x,
194
+ enc,
195
+ padding_mask,
196
+ incremental_state,
197
+ self_attn_mask=self_attn_mask,
198
+ self_attn_padding_mask=self_attn_padding_mask,
199
+ need_attn=False,
200
+ need_head_weights=False,
201
+ )
202
+ inner_states.append(x)
203
+
204
+ if self.layer_norm is not None:
205
+ x = self.layer_norm(x)
206
+
207
+ # T x B x C -> B x C x T
208
+ x = x.transpose(0, 1)
209
+
210
+ return x, {"inner_states": inner_states}
211
+
212
+ def output_layer(self, features):
213
+ """Project features to the vocabulary size."""
214
+ return self.output_projection(features)
215
+
216
+ def buffered_future_mask(self, tensor):
217
+ dim = tensor.size(0)
218
+ # self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
219
+ if (
220
+ self._future_mask.size(0) == 0
221
+ or (not self._future_mask.device == tensor.device)
222
+ or self._future_mask.size(0) < dim
223
+ ):
224
+ self._future_mask = torch.triu(
225
+ fill_with_neg_inf(torch.zeros([dim, dim])), 1
226
+ )
227
+ self._future_mask = self._future_mask.to(tensor)
228
+ return self._future_mask[:dim, :dim]
esm/inverse_folding/transformer_layer.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # Contents of this file were adapted from the open source fairseq repository.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from typing import Dict, List, Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from esm.multihead_attention import MultiheadAttention
14
+ from torch import Tensor
15
+
16
+
17
+ class TransformerEncoderLayer(nn.Module):
18
+ """Encoder layer block.
19
+ `layernorm -> dropout -> add residual`
20
+
21
+ Args:
22
+ args (argparse.Namespace): parsed command-line arguments
23
+ """
24
+
25
+ def __init__(self, args):
26
+ super().__init__()
27
+ self.args = args
28
+ self.embed_dim = args.encoder_embed_dim
29
+ self.self_attn = self.build_self_attention(self.embed_dim, args)
30
+ self.self_attn_layer_norm = torch.nn.LayerNorm(self.embed_dim)
31
+ self.dropout_module = nn.Dropout(args.dropout)
32
+ self.activation_fn = F.relu
33
+ self.fc1 = self.build_fc1(
34
+ self.embed_dim,
35
+ args.encoder_ffn_embed_dim,
36
+ )
37
+ self.fc2 = self.build_fc2(
38
+ args.encoder_ffn_embed_dim,
39
+ self.embed_dim,
40
+ )
41
+
42
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
43
+
44
+ def build_fc1(self, input_dim, output_dim):
45
+ return nn.Linear(input_dim, output_dim)
46
+
47
+ def build_fc2(self, input_dim, output_dim):
48
+ return nn.Linear(input_dim, output_dim)
49
+
50
+ def build_self_attention(self, embed_dim, args):
51
+ return MultiheadAttention(
52
+ embed_dim,
53
+ args.encoder_attention_heads,
54
+ dropout=args.attention_dropout,
55
+ self_attention=True,
56
+ )
57
+
58
+ def residual_connection(self, x, residual):
59
+ return residual + x
60
+
61
+ def forward(
62
+ self,
63
+ x,
64
+ encoder_padding_mask: Optional[Tensor],
65
+ attn_mask: Optional[Tensor] = None,
66
+ ):
67
+ """
68
+ Args:
69
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
70
+ encoder_padding_mask (ByteTensor): binary ByteTensor of shape
71
+ `(batch, seq_len)` where padding elements are indicated by ``1``.
72
+ attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
73
+ where `tgt_len` is the length of output and `src_len` is the
74
+ length of input, though here both are equal to `seq_len`.
75
+ `attn_mask[tgt_i, src_j] = 1` means that when calculating the
76
+ embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
77
+ useful for strided self-attention.
78
+
79
+ Returns:
80
+ encoded output of shape `(seq_len, batch, embed_dim)`
81
+ """
82
+ # anything in original attn_mask = 1, becomes -1e8
83
+ # anything in original attn_mask = 0, becomes 0
84
+ # Note that we cannot use -inf here, because at some edge cases,
85
+ # the attention weight (before softmax) for some padded element in query
86
+ # will become -inf, which results in NaN in model parameters
87
+ if attn_mask is not None:
88
+ attn_mask = attn_mask.masked_fill(
89
+ attn_mask.to(torch.bool), -1e8 if x.dtype == torch.float32 else -1e4
90
+ )
91
+
92
+ residual = x
93
+ x = self.self_attn_layer_norm(x)
94
+ x, _ = self.self_attn(
95
+ query=x,
96
+ key=x,
97
+ value=x,
98
+ key_padding_mask=encoder_padding_mask,
99
+ need_weights=False,
100
+ attn_mask=attn_mask,
101
+ )
102
+ x = self.dropout_module(x)
103
+ x = self.residual_connection(x, residual)
104
+
105
+ residual = x
106
+ x = self.final_layer_norm(x)
107
+ x = self.activation_fn(self.fc1(x))
108
+ x = self.fc2(x)
109
+ x = self.dropout_module(x)
110
+ x = self.residual_connection(x, residual)
111
+ return x
112
+
113
+
114
+ class TransformerDecoderLayer(nn.Module):
115
+ """Decoder layer block.
116
+ `layernorm -> dropout -> add residual`
117
+
118
+ Args:
119
+ args (argparse.Namespace): parsed command-line arguments
120
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
121
+ (default: False).
122
+ """
123
+
124
+ def __init__(
125
+ self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False
126
+ ):
127
+ super().__init__()
128
+ self.embed_dim = args.decoder_embed_dim
129
+ self.dropout_module = nn.Dropout(args.dropout)
130
+
131
+ self.self_attn = self.build_self_attention(
132
+ self.embed_dim,
133
+ args,
134
+ add_bias_kv=add_bias_kv,
135
+ add_zero_attn=add_zero_attn,
136
+ )
137
+ self.nh = self.self_attn.num_heads
138
+ self.head_dim = self.self_attn.head_dim
139
+
140
+ self.activation_fn = F.relu
141
+
142
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
143
+
144
+ if no_encoder_attn:
145
+ self.encoder_attn = None
146
+ self.encoder_attn_layer_norm = None
147
+ else:
148
+ self.encoder_attn = self.build_encoder_attention(self.embed_dim, args)
149
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
150
+
151
+ self.ffn_layernorm = (
152
+ LayerNorm(args.decoder_ffn_embed_dim)
153
+ if getattr(args, "scale_fc", False)
154
+ else None
155
+ )
156
+ self.w_resid = (
157
+ nn.Parameter(
158
+ torch.ones(
159
+ self.embed_dim,
160
+ ),
161
+ requires_grad=True,
162
+ )
163
+ if getattr(args, "scale_resids", False)
164
+ else None
165
+ )
166
+
167
+ self.fc1 = self.build_fc1(
168
+ self.embed_dim,
169
+ args.decoder_ffn_embed_dim,
170
+ )
171
+ self.fc2 = self.build_fc2(
172
+ args.decoder_ffn_embed_dim,
173
+ self.embed_dim,
174
+ )
175
+
176
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
177
+ self.need_attn = True
178
+
179
+ def build_fc1(self, input_dim, output_dim):
180
+ return nn.Linear(input_dim, output_dim)
181
+
182
+ def build_fc2(self, input_dim, output_dim):
183
+ return nn.Linear(input_dim, output_dim)
184
+
185
+ def build_self_attention(
186
+ self, embed_dim, args, add_bias_kv=False, add_zero_attn=False
187
+ ):
188
+ return MultiheadAttention(
189
+ embed_dim,
190
+ args.decoder_attention_heads,
191
+ dropout=args.attention_dropout,
192
+ add_bias_kv=add_bias_kv,
193
+ add_zero_attn=add_zero_attn,
194
+ self_attention=True,
195
+ )
196
+
197
+ def build_encoder_attention(self, embed_dim, args):
198
+ return MultiheadAttention(
199
+ embed_dim,
200
+ args.decoder_attention_heads,
201
+ kdim=args.encoder_embed_dim,
202
+ vdim=args.encoder_embed_dim,
203
+ dropout=args.attention_dropout,
204
+ encoder_decoder_attention=True,
205
+ )
206
+
207
+ def residual_connection(self, x, residual):
208
+ return residual + x
209
+
210
+ def forward(
211
+ self,
212
+ x,
213
+ encoder_out: Optional[torch.Tensor] = None,
214
+ encoder_padding_mask: Optional[torch.Tensor] = None,
215
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
216
+ prev_self_attn_state: Optional[List[torch.Tensor]] = None,
217
+ prev_attn_state: Optional[List[torch.Tensor]] = None,
218
+ self_attn_mask: Optional[torch.Tensor] = None,
219
+ self_attn_padding_mask: Optional[torch.Tensor] = None,
220
+ need_attn: bool = False,
221
+ need_head_weights: bool = False,
222
+ ):
223
+ """
224
+ Args:
225
+ x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
226
+ encoder_padding_mask (ByteTensor, optional): binary
227
+ ByteTensor of shape `(batch, src_len)` where padding
228
+ elements are indicated by ``1``.
229
+ need_attn (bool, optional): return attention weights
230
+ need_head_weights (bool, optional): return attention weights
231
+ for each head (default: return average over heads).
232
+
233
+ Returns:
234
+ encoded output of shape `(seq_len, batch, embed_dim)`
235
+ """
236
+ if need_head_weights:
237
+ need_attn = True
238
+
239
+ residual = x
240
+ x = self.self_attn_layer_norm(x)
241
+ if prev_self_attn_state is not None:
242
+ prev_key, prev_value = prev_self_attn_state[:2]
243
+ saved_state: Dict[str, Optional[Tensor]] = {
244
+ "prev_key": prev_key,
245
+ "prev_value": prev_value,
246
+ }
247
+ if len(prev_self_attn_state) >= 3:
248
+ saved_state["prev_key_padding_mask"] = prev_self_attn_state[2]
249
+ assert incremental_state is not None
250
+ self.self_attn._set_input_buffer(incremental_state, saved_state)
251
+ _self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state)
252
+ y = x
253
+
254
+ x, attn = self.self_attn(
255
+ query=x,
256
+ key=y,
257
+ value=y,
258
+ key_padding_mask=self_attn_padding_mask,
259
+ incremental_state=incremental_state,
260
+ need_weights=False,
261
+ attn_mask=self_attn_mask,
262
+ )
263
+ x = self.dropout_module(x)
264
+ x = self.residual_connection(x, residual)
265
+
266
+ if self.encoder_attn is not None and encoder_out is not None:
267
+ residual = x
268
+ x = self.encoder_attn_layer_norm(x)
269
+ if prev_attn_state is not None:
270
+ prev_key, prev_value = prev_attn_state[:2]
271
+ saved_state: Dict[str, Optional[Tensor]] = {
272
+ "prev_key": prev_key,
273
+ "prev_value": prev_value,
274
+ }
275
+ if len(prev_attn_state) >= 3:
276
+ saved_state["prev_key_padding_mask"] = prev_attn_state[2]
277
+ assert incremental_state is not None
278
+ self.encoder_attn._set_input_buffer(incremental_state, saved_state)
279
+
280
+ x, attn = self.encoder_attn(
281
+ query=x,
282
+ key=encoder_out,
283
+ value=encoder_out,
284
+ key_padding_mask=encoder_padding_mask,
285
+ incremental_state=incremental_state,
286
+ static_kv=True,
287
+ need_weights=need_attn or (not self.training and self.need_attn),
288
+ need_head_weights=need_head_weights,
289
+ )
290
+ x = self.dropout_module(x)
291
+ x = self.residual_connection(x, residual)
292
+
293
+ residual = x
294
+ x = self.final_layer_norm(x)
295
+
296
+ x = self.activation_fn(self.fc1(x))
297
+ if self.ffn_layernorm is not None:
298
+ x = self.ffn_layernorm(x)
299
+ x = self.fc2(x)
300
+ x = self.dropout_module(x)
301
+ if self.w_resid is not None:
302
+ residual = torch.mul(self.w_resid, residual)
303
+ x = self.residual_connection(x, residual)
304
+ return x, attn, None
esm/inverse_folding/util.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import math
8
+
9
+ import biotite.structure
10
+ from biotite.structure.io import pdbx, pdb
11
+ from biotite.structure.residues import get_residues
12
+ from biotite.structure import filter_backbone
13
+ from biotite.structure import get_chains
14
+ from biotite.sequence import ProteinSequence
15
+ import numpy as np
16
+ # from scipy.spatial import transform
17
+ # from scipy.stats import special_ortho_group
18
+ import torch
19
+ # import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ # import torch.utils.data as data
22
+ from typing import Sequence, Tuple, List
23
+
24
+ from esm.data import BatchConverter
25
+
26
+
27
+ def load_structure(fpath, chain=None):
28
+ """
29
+ Args:
30
+ fpath: filepath to either pdb or cif file
31
+ chain: the chain id or list of chain ids to load
32
+ Returns:
33
+ biotite.structure.AtomArray
34
+ """
35
+ if fpath.endswith('cif'):
36
+ with open(fpath) as fin:
37
+ pdbxf = pdbx.PDBxFile.read(fin)
38
+ structure = pdbx.get_structure(pdbxf, model=1)
39
+ elif fpath.endswith('pdb'):
40
+ with open(fpath) as fin:
41
+ pdbf = pdb.PDBFile.read(fin)
42
+ structure = pdb.get_structure(pdbf, model=1)
43
+ bbmask = filter_backbone(structure)
44
+ structure = structure[bbmask]
45
+ all_chains = get_chains(structure)
46
+ if len(all_chains) == 0:
47
+ raise ValueError('No chains found in the input file.')
48
+ if chain is None:
49
+ chain_ids = all_chains
50
+ elif isinstance(chain, list):
51
+ chain_ids = chain
52
+ else:
53
+ chain_ids = [chain]
54
+ for chain in chain_ids:
55
+ if chain not in all_chains:
56
+ raise ValueError(f'Chain {chain} not found in input file')
57
+ chain_filter = [a.chain_id in chain_ids for a in structure]
58
+ structure = structure[chain_filter]
59
+ return structure
60
+
61
+
62
+ def extract_coords_from_structure(structure: biotite.structure.AtomArray):
63
+ """
64
+ Args:
65
+ structure: An instance of biotite AtomArray
66
+ Returns:
67
+ Tuple (coords, seq)
68
+ - coords is an L x 3 x 3 array for N, CA, C coordinates
69
+ - seq is the extracted sequence
70
+ """
71
+ coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
72
+ residue_identities = get_residues(structure)[1]
73
+ seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
74
+ return coords, seq
75
+
76
+
77
+ def load_coords(fpath, chain):
78
+ """
79
+ Args:
80
+ fpath: filepath to either pdb or cif file
81
+ chain: the chain id
82
+ Returns:
83
+ Tuple (coords, seq)
84
+ - coords is an L x 3 x 3 array for N, CA, C coordinates
85
+ - seq is the extracted sequence
86
+ """
87
+ structure = load_structure(fpath, chain)
88
+ return extract_coords_from_structure(structure)
89
+
90
+
91
+ def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
92
+ """
93
+ Example for atoms argument: ["N", "CA", "C"]
94
+ """
95
+ def filterfn(s, axis=None):
96
+ filters = np.stack([s.atom_name == name for name in atoms], axis=1)
97
+ sum = filters.sum(0)
98
+ if not np.all(sum <= np.ones(filters.shape[1])):
99
+ raise RuntimeError("structure has multiple atoms with same name")
100
+ index = filters.argmax(0)
101
+ coords = s[index].coord
102
+ coords[sum == 0] = float("nan")
103
+ return coords
104
+
105
+ return biotite.structure.apply_residue_wise(struct, struct, filterfn)
106
+
107
+
108
+ def get_sequence_loss(model, alphabet, coords, seq):
109
+ device = next(model.parameters()).device
110
+ batch_converter = CoordBatchConverter(alphabet)
111
+ batch = [(coords, None, seq)]
112
+ coords, confidence, strs, tokens, padding_mask = batch_converter(
113
+ batch, device=device)
114
+
115
+ prev_output_tokens = tokens[:, :-1].to(device)
116
+ target = tokens[:, 1:]
117
+ target_padding_mask = (target == alphabet.padding_idx)
118
+ logits, _ = model.forward(coords, padding_mask, confidence, prev_output_tokens)
119
+ loss = F.cross_entropy(logits, target, reduction='none')
120
+ loss = loss[0].cpu().detach().numpy()
121
+ target_padding_mask = target_padding_mask[0].cpu().numpy()
122
+ return loss, target_padding_mask
123
+
124
+
125
+ def score_sequence(model, alphabet, coords, seq):
126
+ loss, target_padding_mask = get_sequence_loss(model, alphabet, coords, seq)
127
+ ll_fullseq = -np.sum(loss * ~target_padding_mask) / np.sum(~target_padding_mask)
128
+ # Also calculate average when excluding masked portions
129
+ coord_mask = np.all(np.isfinite(coords), axis=(-1, -2))
130
+ ll_withcoord = -np.sum(loss * coord_mask) / np.sum(coord_mask)
131
+ return ll_fullseq, ll_withcoord
132
+
133
+
134
+ def get_encoder_output(model, alphabet, coords):
135
+ device = next(model.parameters()).device
136
+ batch_converter = CoordBatchConverter(alphabet)
137
+ batch = [(coords, None, seq)]
138
+ coords, confidence, strs, tokens, padding_mask = batch_converter(
139
+ batch, device=device)
140
+ encoder_out = model.encoder.forward(coords, padding_mask, confidence,
141
+ return_all_hiddens=False)
142
+ # remove beginning and end (bos and eos tokens)
143
+ return encoder_out['encoder_out'][0][1:-1, 0]
144
+
145
+
146
+ def rotate(v, R):
147
+ """
148
+ Rotates a vector by a rotation matrix.
149
+
150
+ Args:
151
+ v: 3D vector, tensor of shape (length x batch_size x channels x 3)
152
+ R: rotation matrix, tensor of shape (length x batch_size x 3 x 3)
153
+
154
+ Returns:
155
+ Rotated version of v by rotation matrix R.
156
+ """
157
+ R = R.unsqueeze(-3)
158
+ v = v.unsqueeze(-1)
159
+ return torch.sum(v * R, dim=-2)
160
+
161
+
162
+ def get_rotation_frames(coords):
163
+ """
164
+ Returns a local rotation frame defined by N, CA, C positions.
165
+
166
+ Args:
167
+ coords: coordinates, tensor of shape (batch_size x length x 3 x 3)
168
+ where the third dimension is in order of N, CA, C
169
+
170
+ Returns:
171
+ Local relative rotation frames in shape (batch_size x length x 3 x 3)
172
+ """
173
+ v1 = coords[:, :, 2] - coords[:, :, 1]
174
+ v2 = coords[:, :, 0] - coords[:, :, 1]
175
+ e1 = normalize(v1, dim=-1)
176
+ u2 = v2 - e1 * torch.sum(e1 * v2, dim=-1, keepdim=True)
177
+ e2 = normalize(u2, dim=-1)
178
+ e3 = torch.cross(e1, e2, dim=-1)
179
+ R = torch.stack([e1, e2, e3], dim=-2)
180
+ return R
181
+
182
+
183
+ def nan_to_num(ts, val=0.0):
184
+ """
185
+ Replaces nans in tensor with a fixed value.
186
+ """
187
+ val = torch.tensor(val, dtype=ts.dtype, device=ts.device)
188
+ return torch.where(~torch.isfinite(ts), val, ts)
189
+
190
+
191
+ def rbf(values, v_min, v_max, n_bins=16):
192
+ """
193
+ Returns RBF encodings in a new dimension at the end.
194
+ """
195
+ rbf_centers = torch.linspace(v_min, v_max, n_bins, device=values.device)
196
+ rbf_centers = rbf_centers.view([1] * len(values.shape) + [-1])
197
+ rbf_std = (v_max - v_min) / n_bins
198
+ v_expand = torch.unsqueeze(values, -1)
199
+ z = (values.unsqueeze(-1) - rbf_centers) / rbf_std
200
+ return torch.exp(-z ** 2)
201
+
202
+
203
+ def norm(tensor, dim, eps=1e-8, keepdim=False):
204
+ """
205
+ Returns L2 norm along a dimension.
206
+ """
207
+ return torch.sqrt(
208
+ torch.sum(torch.square(tensor), dim=dim, keepdim=keepdim) + eps)
209
+
210
+
211
+ def normalize(tensor, dim=-1):
212
+ """
213
+ Normalizes a tensor along a dimension after removing nans.
214
+ """
215
+ return nan_to_num(
216
+ torch.div(tensor, norm(tensor, dim=dim, keepdim=True))
217
+ )
218
+
219
+
220
+ class CoordBatchConverter(BatchConverter):
221
+ def __call__(self, raw_batch: Sequence[Tuple[Sequence, str]], device=None):
222
+ """
223
+ Args:
224
+ raw_batch: List of tuples (coords, confidence, seq)
225
+ In each tuple,
226
+ coords: list of floats, shape L x 3 x 3
227
+ confidence: list of floats, shape L; or scalar float; or None
228
+ seq: string of length L
229
+ Returns:
230
+ coords: Tensor of shape batch_size x L x 3 x 3
231
+ confidence: Tensor of shape batch_size x L
232
+ strs: list of strings
233
+ tokens: LongTensor of shape batch_size x L
234
+ padding_mask: ByteTensor of shape batch_size x L
235
+ """
236
+ self.alphabet.cls_idx = self.alphabet.get_idx("<cath>")
237
+ batch = []
238
+ for coords, confidence, seq in raw_batch:
239
+ if confidence is None:
240
+ confidence = 1.
241
+ if isinstance(confidence, float) or isinstance(confidence, int):
242
+ confidence = [float(confidence)] * len(coords)
243
+ if seq is None:
244
+ seq = 'X' * len(coords)
245
+ batch.append(((coords, confidence), seq))
246
+
247
+ coords_and_confidence, strs, tokens = super().__call__(batch)
248
+
249
+ # pad beginning and end of each protein due to legacy reasons
250
+ coords = [
251
+ F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)
252
+ for cd, _ in coords_and_confidence
253
+ ]
254
+ confidence = [
255
+ F.pad(torch.tensor(cf), (1, 1), value=-1.)
256
+ for _, cf in coords_and_confidence
257
+ ]
258
+ coords = self.collate_dense_tensors(coords, pad_v=np.nan)
259
+ confidence = self.collate_dense_tensors(confidence, pad_v=-1.)
260
+ if device is not None:
261
+ coords = coords.to(device)
262
+ confidence = confidence.to(device)
263
+ tokens = tokens.to(device)
264
+ padding_mask = torch.isnan(coords[:,:,0,0])
265
+ coord_mask = torch.isfinite(coords.sum(-2).sum(-1))
266
+ confidence = confidence * coord_mask + (-1.) * padding_mask
267
+ return coords, confidence, strs, tokens, padding_mask
268
+
269
+ def from_lists(self, coords_list, confidence_list=None, seq_list=None, device=None):
270
+ """
271
+ Args:
272
+ coords_list: list of length batch_size, each item is a list of
273
+ floats in shape L x 3 x 3 to describe a backbone
274
+ confidence_list: one of
275
+ - None, default to highest confidence
276
+ - list of length batch_size, each item is a scalar
277
+ - list of length batch_size, each item is a list of floats of
278
+ length L to describe the confidence scores for the backbone
279
+ with values between 0. and 1.
280
+ seq_list: either None or a list of strings
281
+ Returns:
282
+ coords: Tensor of shape batch_size x L x 3 x 3
283
+ confidence: Tensor of shape batch_size x L
284
+ strs: list of strings
285
+ tokens: LongTensor of shape batch_size x L
286
+ padding_mask: ByteTensor of shape batch_size x L
287
+ """
288
+ batch_size = len(coords_list)
289
+ if confidence_list is None:
290
+ confidence_list = [None] * batch_size
291
+ if seq_list is None:
292
+ seq_list = [None] * batch_size
293
+ raw_batch = zip(coords_list, confidence_list, seq_list)
294
+ return self.__call__(raw_batch, device)
295
+
296
+ @staticmethod
297
+ def collate_dense_tensors(samples, pad_v):
298
+ """
299
+ Takes a list of tensors with the following dimensions:
300
+ [(d_11, ..., d_1K),
301
+ (d_21, ..., d_2K),
302
+ ...,
303
+ (d_N1, ..., d_NK)]
304
+ and stack + pads them into a single tensor of:
305
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
306
+ """
307
+ if len(samples) == 0:
308
+ return torch.Tensor()
309
+ if len(set(x.dim() for x in samples)) != 1:
310
+ raise RuntimeError(
311
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
312
+ )
313
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
314
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
315
+ result = torch.empty(
316
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
317
+ )
318
+ result.fill_(pad_v)
319
+ for i in range(len(samples)):
320
+ result_i = result[i]
321
+ t = samples[i]
322
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
323
+ return result
esm/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
esm/model/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
esm/model/__pycache__/esm1.cpython-310.pyc ADDED
Binary file (5.16 kB). View file
 
esm/model/__pycache__/esm2.cpython-310.pyc ADDED
Binary file (3.5 kB). View file