|
--- |
|
library_name: peft |
|
license: mit |
|
language: |
|
- en |
|
tags: |
|
- transformers |
|
- biology |
|
- esm |
|
- esm2 |
|
- protein |
|
- protein language model |
|
--- |
|
# ESM-2 RNA Binding Site LoRA |
|
|
|
This is a Parameter Efficient Fine Tuning (PEFT) Low Rank Adaptation (LoRA) of |
|
the [esm2_t6_8M_UR50D](https://huggingface.co/facebook/esm2_t6_8M_UR50D) model for the (binary) token classification task of |
|
predicting RNA binding sites of proteins. You can also find a version of this model |
|
that was fine-tuned without LoRA [here](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_UR50D_rna_binding_site_predictor). |
|
|
|
## Training procedure |
|
|
|
This is a Low Rank Adaptation (LoRA) of `esm2_t6_8M_UR50D`, |
|
trained on `166` protein sequences in the [RNA binding sites dataset](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites) |
|
using a `80/20` train/test split. This model was trained with class weighting due to the imbalanced nature |
|
of the RNA binding site dataset (fewer binding sites than non-binding sites). You can train your own version |
|
using [this notebook](https://huggingface.co/AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding/blob/main/LoRA_binding_sites_no_sweeps_v2.ipynb)! |
|
You just need the RNA `binding_sites.xml` file [found here](https://huggingface.co/datasets/AmelieSchreiber/data_of_protein-rna_binding_sites). |
|
You may also need to run some `pip install` statements at the beginning of the script. If you are running in colab run: |
|
|
|
```python |
|
!pip install transformers[torch] datasets peft -q |
|
``` |
|
```python |
|
!pip install accelerate -U -q |
|
``` |
|
Try to improve upon these metrics by adjusting the hyperparameters: |
|
``` |
|
{'eval_loss': 0.49476009607315063, |
|
'eval_precision': 0.14372964169381108, |
|
'eval_recall': 0.7526652452025586, |
|
'eval_f1': 0.24136752136752138, |
|
'eval_auc': 0.7710141129858947, |
|
'epoch': 15.0} |
|
``` |
|
|
|
A similar model can also be trained using the Github with a training script and conda env YAML, which can be |
|
[found here](https://github.com/Amelie-Schreiber/esm2_LoRA_binding_sites/tree/main). This version uses wandb sweeps for hyperparameter search. |
|
However, it does not use class weighting. |
|
|
|
|
|
### Framework versions |
|
|
|
- PEFT 0.4.0 |
|
|
|
## Using the Model |
|
|
|
To use the model, try running the following pip install statements: |
|
```python |
|
!pip install transformers peft -q |
|
``` |
|
then try tunning: |
|
```python |
|
from transformers import AutoModelForTokenClassification, AutoTokenizer |
|
from peft import PeftModel |
|
import torch |
|
|
|
# Path to the saved LoRA model |
|
model_path = "AmelieSchreiber/esm2_t6_8M_weighted_lora_rna_binding" |
|
# ESM2 base model |
|
base_model_path = "facebook/esm2_t6_8M_UR50D" |
|
|
|
# Load the model |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
loaded_model = PeftModel.from_pretrained(base_model, model_path) |
|
|
|
# Ensure the model is in evaluation mode |
|
loaded_model.eval() |
|
|
|
# Load the tokenizer |
|
loaded_tokenizer = AutoTokenizer.from_pretrained(base_model_path) |
|
|
|
# Protein sequence for inference |
|
protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT" # Replace with your actual sequence |
|
|
|
# Tokenize the sequence |
|
inputs = loaded_tokenizer(protein_sequence, return_tensors="pt", truncation=True, max_length=1024, padding='max_length') |
|
|
|
# Run the model |
|
with torch.no_grad(): |
|
logits = loaded_model(**inputs).logits |
|
|
|
# Get predictions |
|
tokens = loaded_tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) # Convert input ids back to tokens |
|
predictions = torch.argmax(logits, dim=2) |
|
|
|
# Define labels |
|
id2label = { |
|
0: "No binding site", |
|
1: "Binding site" |
|
} |
|
|
|
# Print the predicted labels for each token |
|
for token, prediction in zip(tokens, predictions[0].numpy()): |
|
if token not in ['<pad>', '<cls>', '<eos>']: |
|
print((token, id2label[prediction])) |
|
|
|
``` |
|
|