Spaces:
Running
Running
import logging | |
import pathlib | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY | |
from utils import draw_grid_predict | |
logger = logging.getLogger(__name__) | |
logger.addHandler(logging.NullHandler()) | |
MOLFORMER_VERSIONS = { | |
"molformer_classification": ["bace", "bbbp", "hiv"], | |
"molformer_regression": [ | |
"alpha", | |
"cv", | |
"g298", | |
"gap", | |
"h298", | |
"homo", | |
"lipo", | |
"lumo", | |
"mu", | |
"r2", | |
"u0", | |
], | |
"molformer_multitask_classification": ["clintox", "sider", "tox21"], | |
} | |
REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"] | |
REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"]) | |
REMOVE.extend(MOLFORMER_VERSIONS.keys()) | |
MODEL_PROP_DESCRIPTION = { | |
"Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53", | |
"Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications", | |
"Clintox": "FDA approval, Clinical trial failure", | |
} | |
def main(property: str, smiles: str, smiles_file: str): | |
if "Molformer" in property: | |
version = property.split(" ")[-1].split("(")[-1].split(")")[0] | |
property = property.split(" ")[0] | |
algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()] | |
kwargs = ( | |
{"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {} | |
) | |
if property.lower() in MOLFORMER_VERSIONS.keys(): | |
kwargs["algorithm_version"] = version | |
model = algo(config(**kwargs)) | |
if smiles != "" and smiles_file is not None: | |
raise ValueError("Pass either smiles or smiles_file, not both.") | |
elif smiles != "": | |
smiles = [smiles] | |
elif smiles_file is not None: | |
smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist() | |
props = np.array(list(map(model, smiles))).round(2) | |
# Expand to 2D array if needed | |
if len(props.shape) == 1: | |
props = np.expand_dims(np.array(props), -1) | |
if property in MODEL_PROP_DESCRIPTION.keys(): | |
property_names = MODEL_PROP_DESCRIPTION[property].split(",") | |
else: | |
property_names = [property] | |
return draw_grid_predict( | |
smiles, props, property_names=property_names, domain="Molecules" | |
) | |
if __name__ == "__main__": | |
# Preparation (retrieve all available algorithms) | |
properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1] | |
for prop in REMOVE: | |
prop_to_idx = dict(zip(properties, range(len(properties)))) | |
properties.pop(prop_to_idx[prop]) | |
properties = list(map(lambda x: x.capitalize(), properties)) | |
# MolFormer options | |
for key in MOLFORMER_VERSIONS.keys(): | |
properties.extend( | |
[f"{key.capitalize()} ({version})" for version in MOLFORMER_VERSIONS[key]] | |
) | |
# Load metadata | |
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") | |
examples = [ | |
["Qed", "", str(metadata_root.joinpath("examples.smi"))], | |
[ | |
"Esol", | |
"CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1", | |
None, | |
], | |
] | |
with open(metadata_root.joinpath("article.md"), "r") as f: | |
article = f.read() | |
with open(metadata_root.joinpath("description.md"), "r") as f: | |
description = f.read() | |
demo = gr.Interface( | |
fn=main, | |
title="Molecular properties", | |
inputs=[ | |
gr.Dropdown(properties, label="Property", value="Scscore"), | |
gr.Textbox( | |
label="Single SMILES", | |
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", | |
lines=1, | |
), | |
gr.File( | |
file_types=[".smi"], | |
label="Multiple SMILES (tab-separated, `.smi` file)", | |
), | |
], | |
outputs=gr.HTML(label="Output"), | |
article=article, | |
description=description, | |
examples=examples, | |
) | |
demo.launch(debug=True, show_error=True, share=True) | |