AmelieSchreiber commited on
Commit
1e414d1
1 Parent(s): 65808d9

Upload testing_and_inference.ipynb

Browse files
Files changed (1) hide show
  1. testing_and_inference.ipynb +324 -0
testing_and_inference.ipynb ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "04879e6b-3718-4d23-90fa-35e5c8956861",
6
+ "metadata": {},
7
+ "source": [
8
+ "# ESMB for Protein Binding Residue Prediction\n",
9
+ "\n",
10
+ "**ESMBind** (or ESMB) is for predicting residues in protein sequences that are binding sites or active sites. The `ESMB_35M` series of models are Low Rank Adaptation (LoRA) finetuned versions of the protein language model `esm2_t12_35M_UR50D`. The models were finetuned on ~549K protein sequences and appear to achieve competative performance compared to current SOTA geometric/structural models, surpassing many on certain classes and for certain metrics. These models have an especially high recall score, meaning they are highly likely to discover most binding residues. \n",
11
+ "\n",
12
+ "However, they have relatively low precision, meaning they may return false positives as well. Their MCC, AUC, and accuracy values are also quite high. We hope that scaling the model and dataset size in a 1-to-1 fashion will achieve SOTA performace, but the simplicity of the models and training procedure already make them an attractive set of models as the domain knowledge and data preparation required to use them are very modest and make the barrier to entry low for using these models in practive. \n",
13
+ "\n",
14
+ "These models also predict binding residues from sequence alone. Since most proteins have yet to have their 3D folds and backbone structures predicted, we hope this resource will be valuable to the community for this reason as well. Before we proceed, we need to run a few pip install statements. "
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "id": "b8df2453-f478-4ef5-b69f-633aff114438",
21
+ "metadata": {},
22
+ "outputs": [],
23
+ "source": [
24
+ "!pip install transformers -q \n",
25
+ "!pip install accelerate -q \n",
26
+ "!pip install peft -q \n",
27
+ "!pip install datasets -q "
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "bc67c0c9-6847-432d-89c0-2465981e7ebe",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Running Inference \n",
36
+ "\n",
37
+ "To run inference, simply replace the protein sequence below with your own. Afterwards, you'll be guided through how to check the train/test metrics on the datasets yourself. "
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "84d3cdeb-1a77-425c-8641-123107589868",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "from transformers import AutoModelForTokenClassification, AutoTokenizer\n",
48
+ "from peft import PeftModel\n",
49
+ "import torch\n",
50
+ "\n",
51
+ "# Path to the saved LoRA model\n",
52
+ "model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3\"\n",
53
+ "# ESM2 base model\n",
54
+ "base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n",
55
+ "\n",
56
+ "# Load the model\n",
57
+ "base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)\n",
58
+ "loaded_model = PeftModel.from_pretrained(base_model, model_path)\n",
59
+ "\n",
60
+ "# Ensure the model is in evaluation mode\n",
61
+ "loaded_model.eval()\n",
62
+ "\n",
63
+ "# Load the tokenizer\n",
64
+ "loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path)\n",
65
+ "\n",
66
+ "# Protein sequence for inference (replace with your own)\n",
67
+ "protein_sequence = \"MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT\" # @param {type:\"string\"}\n",
68
+ "\n",
69
+ "# Tokenize the sequence\n",
70
+ "inputs = loaded_tokenizer(protein_sequence, return_tensors=\"pt\", truncation=True, max_length=1024, padding='max_length')\n",
71
+ "\n",
72
+ "# Run the model\n",
73
+ "with torch.no_grad():\n",
74
+ " logits = loaded_model(**inputs).logits\n",
75
+ "\n",
76
+ "# Get predictions\n",
77
+ "tokens = loaded_tokenizer.convert_ids_to_tokens(inputs[\"input_ids\"][0]) # Convert input ids back to tokens\n",
78
+ "predictions = torch.argmax(logits, dim=2)\n",
79
+ "\n",
80
+ "# Define labels\n",
81
+ "id2label = {\n",
82
+ " 0: \"No binding site\",\n",
83
+ " 1: \"Binding site\"\n",
84
+ "}\n",
85
+ "\n",
86
+ "# Print the predicted labels for each token\n",
87
+ "for token, prediction in zip(tokens, predictions[0].numpy()):\n",
88
+ " if token not in ['<pad>', '<cls>', '<eos>']:\n",
89
+ " print((token, id2label[prediction]))"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "markdown",
94
+ "id": "38cebb36-6758-4e7d-b82d-e43cabfbe798",
95
+ "metadata": {},
96
+ "source": [
97
+ "## Train/Test Metrics\n",
98
+ "\n",
99
+ "### Loading and Tokenizing the Datasets\n",
100
+ "\n",
101
+ "To use this notebook to run the model on the train/test split and get the various metrics (accuracy, precision, recall, F1 score, AUC, and MCC) you will need to download the pickle files [found on Hugging Face here](https://huggingface.co/datasets/AmelieSchreiber/binding_sites_random_split_by_family_550K). Navigate to the \"Files and versions\" and download the four pickle files (you can ignore the TSV files unless you want to preprocess the data in a different way yourself). Once you have downloaded the pickle files, change the four file pickle paths in the cell below to match the local paths of the pickle files on your machine, then run the cell. "
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 1,
107
+ "id": "763eba61-fd1e-45d5-a427-0075e46c6293",
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "data": {
112
+ "text/plain": [
113
+ "(Dataset({\n",
114
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
115
+ " num_rows: 450330\n",
116
+ " }),\n",
117
+ " Dataset({\n",
118
+ " features: ['input_ids', 'attention_mask', 'labels'],\n",
119
+ " num_rows: 113475\n",
120
+ " }))"
121
+ ]
122
+ },
123
+ "execution_count": 1,
124
+ "metadata": {},
125
+ "output_type": "execute_result"
126
+ }
127
+ ],
128
+ "source": [
129
+ "from datasets import Dataset\n",
130
+ "from transformers import AutoTokenizer\n",
131
+ "import pickle\n",
132
+ "\n",
133
+ "# Load tokenizer\n",
134
+ "tokenizer = AutoTokenizer.from_pretrained(\"facebook/esm2_t12_35M_UR50D\")\n",
135
+ "\n",
136
+ "# Function to truncate labels\n",
137
+ "def truncate_labels(labels, max_length):\n",
138
+ " \"\"\"Truncate labels to the specified max_length.\"\"\"\n",
139
+ " return [label[:max_length] for label in labels]\n",
140
+ "\n",
141
+ "# Set the maximum sequence length\n",
142
+ "max_sequence_length = 1000\n",
143
+ "\n",
144
+ "# Load the data from pickle files (change to match your local paths)\n",
145
+ "with open(\"train_sequences_chunked_by_family.pkl\", \"rb\") as f:\n",
146
+ " train_sequences = pickle.load(f)\n",
147
+ "with open(\"test_sequences_chunked_by_family.pkl\", \"rb\") as f:\n",
148
+ " test_sequences = pickle.load(f)\n",
149
+ "with open(\"train_labels_chunked_by_family.pkl\", \"rb\") as f:\n",
150
+ " train_labels = pickle.load(f)\n",
151
+ "with open(\"test_labels_chunked_by_family.pkl\", \"rb\") as f:\n",
152
+ " test_labels = pickle.load(f)\n",
153
+ "\n",
154
+ "# Tokenize the sequences\n",
155
+ "train_tokenized = tokenizer(train_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
156
+ "test_tokenized = tokenizer(test_sequences, padding=True, truncation=True, max_length=max_sequence_length, return_tensors=\"pt\", is_split_into_words=False)\n",
157
+ "\n",
158
+ "# Truncate the labels to match the tokenized sequence lengths\n",
159
+ "train_labels = truncate_labels(train_labels, max_sequence_length)\n",
160
+ "test_labels = truncate_labels(test_labels, max_sequence_length)\n",
161
+ "\n",
162
+ "# Create train and test datasets\n",
163
+ "train_dataset = Dataset.from_dict({k: v for k, v in train_tokenized.items()}).add_column(\"labels\", train_labels)\n",
164
+ "test_dataset = Dataset.from_dict({k: v for k, v in test_tokenized.items()}).add_column(\"labels\", test_labels)\n",
165
+ "\n",
166
+ "train_dataset, test_dataset\n"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "markdown",
171
+ "id": "c56556a3-93a5-45c6-935d-dc959b18c608",
172
+ "metadata": {},
173
+ "source": [
174
+ "### Getting the Train/Test Metrics\n",
175
+ "\n",
176
+ "Next, run the following cell. Depending on your hardware, this may take a while. There are ~549K protein sequences to process in total. The train dataset will obviously take much longer than the test dataset. Be patient and let both of them complete to see both the train and test metrics."
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "id": "65dd11e8-f502-44cd-b439-a593bf4d5019",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "name": "stderr",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "Some weights of EsmForTokenClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
190
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
191
+ ]
192
+ },
193
+ {
194
+ "data": {
195
+ "application/vnd.jupyter.widget-view+json": {
196
+ "model_id": "f110a2bca7314f278e1b97a37f4ab033",
197
+ "version_major": 2,
198
+ "version_minor": 0
199
+ },
200
+ "text/plain": [
201
+ "Downloading (…)/adapter_config.json: 0%| | 0.00/457 [00:00<?, ?B/s]"
202
+ ]
203
+ },
204
+ "metadata": {},
205
+ "output_type": "display_data"
206
+ },
207
+ {
208
+ "data": {
209
+ "application/vnd.jupyter.widget-view+json": {
210
+ "model_id": "2bd08fb8fcb644d080746c42dc4d77d1",
211
+ "version_major": 2,
212
+ "version_minor": 0
213
+ },
214
+ "text/plain": [
215
+ "Downloading adapter_model.bin: 0%| | 0.00/307k [00:00<?, ?B/s]"
216
+ ]
217
+ },
218
+ "metadata": {},
219
+ "output_type": "display_data"
220
+ },
221
+ {
222
+ "data": {
223
+ "text/html": [
224
+ "\n",
225
+ " <div>\n",
226
+ " \n",
227
+ " <progress value='200' max='56292' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
228
+ " [ 200/56292 01:32 < 7:13:37, 2.16 it/s]\n",
229
+ " </div>\n",
230
+ " "
231
+ ],
232
+ "text/plain": [
233
+ "<IPython.core.display.HTML object>"
234
+ ]
235
+ },
236
+ "metadata": {},
237
+ "output_type": "display_data"
238
+ }
239
+ ],
240
+ "source": [
241
+ "from sklearn.metrics import(\n",
242
+ " matthews_corrcoef, \n",
243
+ " accuracy_score, \n",
244
+ " precision_recall_fscore_support, \n",
245
+ " roc_auc_score\n",
246
+ ")\n",
247
+ "from peft import PeftModel\n",
248
+ "from transformers import DataCollatorForTokenClassification, AutoModelForTokenClassification\n",
249
+ "from transformers import Trainer\n",
250
+ "from accelerate import Accelerator\n",
251
+ "\n",
252
+ "# Instantiate the accelerator\n",
253
+ "accelerator = Accelerator()\n",
254
+ "\n",
255
+ "# Define paths to the LoRA and base models\n",
256
+ "base_model_path = \"facebook/esm2_t12_35M_UR50D\"\n",
257
+ "lora_model_path = \"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp3\" # \"path/to/your/lora/model\" # Replace with the correct path to your LoRA model\n",
258
+ "\n",
259
+ "# Load the base model\n",
260
+ "base_model = AutoModelForTokenClassification.from_pretrained(base_model_path)\n",
261
+ "\n",
262
+ "# Load the LoRA model\n",
263
+ "model = PeftModel.from_pretrained(base_model, lora_model_path)\n",
264
+ "model = accelerator.prepare(model) # Prepare the model using the accelerator\n",
265
+ "\n",
266
+ "# Define label mappings\n",
267
+ "id2label = {0: \"No binding site\", 1: \"Binding site\"}\n",
268
+ "label2id = {v: k for k, v in id2label.items()}\n",
269
+ "\n",
270
+ "# Create a data collator\n",
271
+ "data_collator = DataCollatorForTokenClassification(tokenizer)\n",
272
+ "\n",
273
+ "# Define a function to compute the metrics\n",
274
+ "def compute_metrics(dataset):\n",
275
+ " # Get the predictions using the trained model\n",
276
+ " trainer = Trainer(model=model, data_collator=data_collator)\n",
277
+ " predictions, labels, _ = trainer.predict(test_dataset=dataset)\n",
278
+ " \n",
279
+ " # Remove padding and special tokens\n",
280
+ " mask = labels != -100\n",
281
+ " true_labels = labels[mask].flatten()\n",
282
+ " flat_predictions = np.argmax(predictions, axis=2)[mask].flatten().tolist()\n",
283
+ "\n",
284
+ " # Compute the metrics\n",
285
+ " accuracy = accuracy_score(true_labels, flat_predictions)\n",
286
+ " precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary')\n",
287
+ " auc = roc_auc_score(true_labels, flat_predictions)\n",
288
+ " mcc = matthews_corrcoef(true_labels, flat_predictions) # Compute the MCC\n",
289
+ " \n",
290
+ " return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1, \"auc\": auc, \"mcc\": mcc} # Include the MCC in the returned dictionary\n",
291
+ "\n",
292
+ "# Get the metrics for the training and test datasets\n",
293
+ "train_metrics = compute_metrics(train_dataset)\n",
294
+ "test_metrics = compute_metrics(test_dataset)\n",
295
+ "\n",
296
+ "train_metrics, test_metrics"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "execution_count": null,
302
+ "id": "d8cc0058-1f81-466d-9fed-4a7ef55ba11f",
303
+ "metadata": {},
304
+ "outputs": [],
305
+ "source": []
306
+ }
307
+ ],
308
+ "metadata": {
309
+ "language_info": {
310
+ "codemirror_mode": {
311
+ "name": "ipython",
312
+ "version": 3
313
+ },
314
+ "file_extension": ".py",
315
+ "mimetype": "text/x-python",
316
+ "name": "python",
317
+ "nbconvert_exporter": "python",
318
+ "pygments_lexer": "ipython3",
319
+ "version": "3.8.17"
320
+ }
321
+ },
322
+ "nbformat": 4,
323
+ "nbformat_minor": 5
324
+ }