ESM-2 (esm2_t6_8M_UR50D) for Token Classification

This is a fine-tuned version of esm2_t6_8M_UR50D trained on the token classification task to classify amino acids in protein sequences into one of three categories 0: other, 1: alpha helix, 2: beta strand. It was trained with this notebook and achieves 78.13824286786025 % accuracy.

Using the Model

To use, try running:

from transformers import AutoTokenizer, AutoModelForTokenClassification
import numpy as np

# 1. Prepare the Model and Tokenizer
#  Replace with the path where your trained model is saved if you're training a new model
model_dir = "AmelieSchreiber/esm2_t6_8M_UR50D-finetuned-secondary-structure"

model = AutoModelForTokenClassification.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_dir)

# Define a mapping from label IDs to their string representations
label_map = {0: "Other", 1: "Helix", 2: "Strand"}

# 2. Tokenize the New Protein Sequence
new_protein_sequence = "MAVPETRPNHTIYINNLNEKIKKDELKKSLHAIFSRFGQILDILVSRSLKMRGQAFVIFKEVSSATNALRSMQGFPFYDKPMRIQYAKTDSDIIAKMKGT"  # Replace with your protein sequence
tokens = tokenizer.tokenize(new_protein_sequence)
inputs = tokenizer.encode(new_protein_sequence, return_tensors="pt")

# 3. Predict with the Model
with torch.no_grad():
    outputs = model(inputs).logits
    predictions = np.argmax(outputs[0].numpy(), axis=1)

# 4. Decode the Predictions
predicted_labels = [label_map[label_id] for label_id in predictions]

# Print the tokens along with their predicted labels
for token, label in zip(tokens, predicted_labels):
    print(f"{token}: {label}")
Downloads last month
11
Safetensors
Model size
7.74M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.