sagawa's picture
Update README.md
ff61674 verified
metadata
license: mit
tags:
  - biology
  - protein

PLTNUM-SaProt-NIH3T3

PLTNUM is a protein language model trained to predict protein half-lives based on their sequences.
This model was created based on westlake-repl/SaProt_650M_AF2 and trained on protein half-life dataset of NIH3T3 mouse embryo fibroblast cell line (paper link).

Model Sources

Uses

How to Get Started with the Model

Use the code below to get started with the model.

from torch import sigmoid
import torch.nn as nn
from transformers import AutoModel, AutoConfig, PreTrainedModel, AutoTokenizer


class PLTNUM_PreTrainedModel(PreTrainedModel):
    config_class = AutoConfig

    def __init__(self, config):
        super(PLTNUM_PreTrainedModel, self).__init__(config)
        self.model = AutoModel.from_pretrained(self.config._name_or_path)

        self.fc_dropout1 = nn.Dropout(0.8)
        self.fc_dropout2 = nn.Dropout(0.4)
        self.fc = nn.Linear(self.config.hidden_size, 1)
        self._init_weights(self.fc)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                nn.init.constant_(module.weight[module.padding_idx], 0.0)
        elif isinstance(module, nn.LayerNorm):
            nn.init.constant_(module.bias, 0)
            nn.init.constant_(module.weight, 1.0)

    def forward(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_state = outputs.last_hidden_state[:, 0]
        output = (
            self.fc(self.fc_dropout1(last_hidden_state))
            + self.fc(self.fc_dropout2(last_hidden_state))
        ) / 2
        return output

    def create_embedding(self, inputs):
        outputs = self.model(**inputs)
        last_hidden_state = outputs.last_hidden_state[:, 0]
        return last_hidden_state


model = PLTNUM_PreTrainedModel.from_pretrained("sagawa/PLTNUM-SaProt-NIH3T3")
tokenizer = AutoTokenizer.from_pretrained("sagawa/PLTNUM-SaProt-NIH3T3")
seq = "MdSdGdRdGdKpQpGpGpKdApRpApKpAdKdTaRpScSvRvAlGvLaQpFfPrVlGvRvVqHvRvLvLvRvKvGvNpYpSdEpRdVdGdAsGcAnPsVsYvLvArAvVvLvErYvLvTvAvEqIlLcEvLqAlGcNvAqAcRvDvNvKvKhTrRdIrIdPlRlHsLsQqLvAsIqRcNvDdEpEvLsNcKvLvLcGvRpVpTdIrApQpGnGdVhLdPdNdIdQdApVvLpLdPdKdKdTdEpSpHpHpKpPpKpGdKd"
input = tokenizer(
    [seq],
    add_special_tokens=True,
    max_length=512,
    padding="max_length",
    truncation=True,
    return_offsets_mapping=False,
    return_attention_mask=True,
    return_tensors="pt",
)
print(sigmoid(model(input)))  # tensor([[0.9798]], grad_fn=<SigmoidBackward0>)

Citation

Prediction of Protein Half-lives from Amino Acid Sequences by Protein Language Models
Tatsuya Sagawa, Eisuke Kanao, Kosuke Ogata, Koshi Imami, Yasushi Ishihama
bioRxiv 2024.09.10.612367; doi: https://doi.org/10.1101/2024.09.10.612367