import torch import selfies as sf from transformers import T5EncoderModel from src.scripts.mytokenizers import Tokenizer from src.improved_diffusion import gaussian_diffusion as gd from src.improved_diffusion.respace import SpacedDiffusion from src.improved_diffusion.transformer_model import TransformerNetModel import streamlit as st import spaces import os @st.cache_resource def get_encoder(device): model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") model.to(device) model.eval() return model @st.cache_resource def get_tokenizer(): return Tokenizer() @st.cache_resource def get_model(device): model = TransformerNetModel( in_channels=32, model_channels=128, dropout=0.1, vocab_size=35073, hidden_size=1024, num_attention_heads=16, num_hidden_layers=12, ) model.load_state_dict( torch.load( os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"), map_location=torch.device(device), ) ) model.to(device) model.eval() return model @st.cache_resource def get_diffusion(): return SpacedDiffusion( use_timesteps=[i for i in range(0, 2000, 10)], betas=gd.get_named_beta_schedule("sqrt", 2000), model_mean_type=(gd.ModelMeanType.START_X), model_var_type=((gd.ModelVarType.FIXED_LARGE)), loss_type=gd.LossType.E2E_MSE, rescale_timesteps=True, model_arch="transformer", training_mode="e2e", ) @spaces.GPU def generate(text_input): with st.spinner("Please wait..."): output = tokenizer( text_input, max_length=256, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt", return_attention_mask=True, ) caption_state = encoder( input_ids=output["input_ids"].to(device), attention_mask=output["attention_mask"].to(device), ).last_hidden_state caption_mask = output["attention_mask"] outputs = diffusion.p_sample_loop( model, (1, 256, 32), clip_denoised=False, denoised_fn=None, model_kwargs={}, top_p=1.0, progress=True, caption=(caption_state.to(device), caption_mask.to(device)), ) logits = model.get_logits(torch.tensor(outputs)) cands = torch.topk(logits, k=1, dim=-1) outputs = cands.indices outputs = outputs.squeeze(-1) outputs = tokenizer.decode(outputs) result = sf.decoder( outputs[0].replace("", "").replace("", "").replace("\t", "") ).replace("\t", "") return result device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") tokenizer = get_tokenizer() encoder = get_encoder(device) model = get_model(device) diffusion = get_diffusion() st.title("Lang2mol-Diff") text_input = st.text_area("Enter molecule description") button = st.button("Submit") if button: result = generate(text_input) st.write(result)