TransHLA model
TransHLA
is a tool designed to discern whether a peptide will be recognized by HLA as an epitope.TransHLA
is the first tool capable of directly identifying peptides as epitopes without the need for inputting HLA alleles. Due the different length of epitopes, we trained two models. The first is TransHLA_I, which is used for the detection of the HLA-I epitope, the other is TransHLA_II, which is used for the detection of the HLA-II epitope.
Model description
TransHLA
is a hybrid transformer model that utilizes a transformer encoder module and a deep CNN module. It is trained using pretrained sequence embeddings from ESM2
and contact map structural features as inputs. It can serve as a preliminary screening for the currently popular tools that are specific for HLA-epitope binding affinity.
Intended uses
Due to variations in peptide lengths, our TransHLA is divided into TransHLA_I and TransHLA_II, which are used to separately identify epitopes presented by HLA class I and class II molecules, respectively. Specifically, TransHLA_I is designed for shorter peptides ranging from 8 to 14 amino acids in length, while TransHLA_II targets longer peptides with lengths of 13 to 21 amino acids. The output consists of two parts. The first output indicates whether the peptide is an epitope, presented in a two-column format where each row contains two numbers that sum to 1, representing probabilities. If the number in the second column is greater than or equal to 0.5, the peptide is classified as an epitope; otherwise, it is considered a normal peptide. The second output is the sequence embedding generated by the model. For both models, we have written separate tutorials in this file to facilitate ease of use.
How to use
First, users need to download the following packages: pytorch
, fair-esm
, and transformers
. Additionally, the CUDA version must be 11.8 or higher; otherwise, the model will need to be run on CPU.
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
pip install transformers
pip install fair-esm
Here is how to use TransHLA_I model to predict whether a peptide is an epitope:
from transformers import AutoTokenizer
from transformers import AutoModel
import torch
def pad_inner_lists_to_length(outer_list,target_length=16):
for inner_list in outer_list:
padding_length = target_length - len(inner_list)
if padding_length > 0:
inner_list.extend([1] * padding_length)
return outer_list
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("SkywalkerLu/TransHLA_I", trust_remote_code=True)
model.to(device)
peptide_examples = ['EDSAIVTPSR','SVWEPAKAKYVFR']
peptide_encoding = tokenizer(peptide_examples)['input_ids']
peptide_encoding = pad_inner_lists_to_length(peptide_encoding)
print(peptide_encoding)
peptide_encoding = torch.tensor(peptide_encoding)
outputs,representations = model(peptide_encoding.to(device))
print(outputs)
print(representations)
And here is how to use TransHLA_II model to predict the peptide whether epitope:
from transformers import AutoTokenizer
from transformers import AutoModel
import torch
def pad_inner_lists_to_length(outer_list,target_length=23):
for inner_list in outer_list:
padding_length = target_length - len(inner_list)
if padding_length > 0:
inner_list.extend([1] * padding_length)
return outer_list
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
model = AutoModel.from_pretrained("SkywalkerLu/TransHLA_II", trust_remote_code=True)
model.to(device)
model.eval()
peptide_examples = ['KMIYSYSSHAASSL','ARGDFFRATSRLTTDFG']
peptide_encoding = tokenizer(peptide_examples)['input_ids']
peptide_encoding = pad_inner_lists_to_length(peptide_encoding)
peptide_encoding = torch.tensor(peptide_encoding)
outputs,representations = model(peptide_encoding.to(device))
print(outputs)
print(representations)
- Downloads last month
- 13