Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
from torch.distributions.categorical import Categorical | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M") | |
model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M") | |
def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): | |
sequence = protein_seq + binder_seq | |
tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) | |
# Create a mask for the binder sequence | |
binder_mask = torch.zeros(tensor_input.shape).to(model.device) | |
binder_mask[0, -len(binder_seq)-1:-1] = 1 | |
# Mask the binder sequence in the input and create labels | |
masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id) | |
labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100) | |
with torch.no_grad(): | |
loss = model(masked_input, labels=labels).loss | |
return np.exp(loss.item()) | |
def generate_peptide(protein_seq, peptide_length, top_k, num_binders): | |
peptide_length = int(peptide_length) | |
top_k = int(top_k) | |
num_binders = int(num_binders) | |
binders_with_ppl = [] | |
for _ in range(num_binders): | |
# Generate binder | |
masked_peptide = '<mask>' * peptide_length | |
input_sequence = protein_seq + masked_peptide | |
inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] | |
logits_at_masks = logits[0, mask_token_indices] | |
# Apply top-k sampling | |
top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) | |
probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) | |
predicted_indices = Categorical(probabilities).sample() | |
predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) | |
generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') | |
# Compute PPL for the generated binder | |
ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) | |
binders_with_ppl.append((generated_binder, ppl_value)) | |
# Formatting the output | |
output = "\n".join([f"Binder: {binder}, PPL: {ppl:.2f}" for binder, ppl in binders_with_ppl]) | |
return output | |
# Define the Gradio interface | |
interface = gr.Interface( | |
fn=generate_peptide, | |
inputs=[ | |
gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"), | |
gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'), | |
gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'), | |
gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=4) | |
], | |
outputs=gr.outputs.Textbox(label="Binders (with Perplexity)"), | |
title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling" | |
) | |
interface.launch() |