import gradio as gr from transformers import AutoTokenizer, AutoModelForMaskedLM import torch from torch.distributions.categorical import Categorical # Load the model and tokenizer tokenizer = AutoTokenizer.from_pretrained("TianlaiChen/PepMLM-650M") model = AutoModelForMaskedLM.from_pretrained("TianlaiChen/PepMLM-650M") def compute_pseudo_perplexity(model, tokenizer, protein_seq, binder_seq): sequence = protein_seq + binder_seq tensor_input = tokenizer.encode(sequence, return_tensors='pt').to(model.device) # Create a mask for the binder sequence binder_mask = torch.zeros(tensor_input.shape).to(model.device) binder_mask[0, -len(binder_seq)-1:-1] = 1 # Mask the binder sequence in the input and create labels masked_input = tensor_input.clone().masked_fill_(binder_mask.bool(), tokenizer.mask_token_id) labels = tensor_input.clone().masked_fill_(~binder_mask.bool(), -100) with torch.no_grad(): loss = model(masked_input, labels=labels).loss return np.exp(loss.item()) def generate_peptide(protein_seq, peptide_length, top_k, num_binders): peptide_length = int(peptide_length) top_k = int(top_k) num_binders = int(num_binders) binders_with_ppl = [] for _ in range(num_binders): # Generate binder masked_peptide = '' * peptide_length input_sequence = protein_seq + masked_peptide inputs = tokenizer(input_sequence, return_tensors="pt").to(model.device) with torch.no_grad(): logits = model(**inputs).logits mask_token_indices = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] logits_at_masks = logits[0, mask_token_indices] # Apply top-k sampling top_k_logits, top_k_indices = logits_at_masks.topk(top_k, dim=-1) probabilities = torch.nn.functional.softmax(top_k_logits, dim=-1) predicted_indices = Categorical(probabilities).sample() predicted_token_ids = top_k_indices.gather(-1, predicted_indices.unsqueeze(-1)).squeeze(-1) generated_binder = tokenizer.decode(predicted_token_ids, skip_special_tokens=True).replace(' ', '') # Compute PPL for the generated binder ppl_value = compute_pseudo_perplexity(model, tokenizer, protein_seq, generated_binder) binders_with_ppl.append((generated_binder, ppl_value)) # Formatting the output output = "\n".join([f"Binder: {binder}, PPL: {ppl:.2f}" for binder, ppl in binders_with_ppl]) return output # Define the Gradio interface interface = gr.Interface( fn=generate_peptide, inputs=[ gr.Textbox(label="Protein Sequence", info="Enter protein sequence here", type="text"), gr.Slider(3, 50, value=15, label="Peptide Length", step=1, info='Default value is 15'), gr.Slider(1, 10, value=3, label="Top K Value", step=1, info='Default value is 3'), gr.Dropdown(choices=[1, 2, 4, 8, 16, 32], label="Number of Binders", value=4) ], outputs=gr.outputs.Textbox(label="Binders (with Perplexity)"), title="PepMLM: Target Sequence-Conditioned Generation of Peptide Binders via Masked Language Modeling" ) interface.launch()