import logging import pathlib import gradio as gr import numpy as np import pandas as pd from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY from utils import draw_grid_predict logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) MOLFORMER_VERSIONS = { "molformer_classification": ["bace", "bbbp", "hiv"], "molformer_regression": [ "alpha", "cv", "g298", "gap", "h298", "homo", "lipo", "lumo", "mu", "r2", "u0", ], "molformer_multitask_classification": ["clintox", "sider", "tox21"], } REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"] REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"]) REMOVE.extend(MOLFORMER_VERSIONS.keys()) MODEL_PROP_DESCRIPTION = { "Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53", "Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications", "Clintox": "FDA approval, Clinical trial failure", } def main(property: str, smiles: str, smiles_file: str): if "Molformer" in property: version = property.split(" ")[-1].split("(")[-1].split(")")[0] property = property.split(" ")[0] algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()] kwargs = ( {"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {} ) if property.lower() in MOLFORMER_VERSIONS.keys(): kwargs["algorithm_version"] = version model = algo(config(**kwargs)) if smiles != "" and smiles_file is not None: raise ValueError("Pass either smiles or smiles_file, not both.") elif smiles != "": smiles = [smiles] elif smiles_file is not None: smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist() props = np.array(list(map(model, smiles))).round(2) # Expand to 2D array if needed if len(props.shape) == 1: props = np.expand_dims(np.array(props), -1) if property in MODEL_PROP_DESCRIPTION.keys(): property_names = MODEL_PROP_DESCRIPTION[property].split(",") else: property_names = [property] return draw_grid_predict( smiles, props, property_names=property_names, domain="Molecules" ) if __name__ == "__main__": # Preparation (retrieve all available algorithms) properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1] for prop in REMOVE: prop_to_idx = dict(zip(properties, range(len(properties)))) properties.pop(prop_to_idx[prop]) properties = list(map(lambda x: x.capitalize(), properties)) # MolFormer options for key in MOLFORMER_VERSIONS.keys(): properties.extend( [f"{key.capitalize()} ({version})" for version in MOLFORMER_VERSIONS[key]] ) # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = [ ["Qed", "", str(metadata_root.joinpath("examples.smi"))], [ "Esol", "CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1", None, ], ] with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() demo = gr.Interface( fn=main, title="Molecular properties", inputs=[ gr.Dropdown(properties, label="Property", value="Scscore"), gr.Textbox( label="Single SMILES", placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1", lines=1, ), gr.File( file_types=[".smi"], label="Multiple SMILES (tab-separated, `.smi` file)", ), ], outputs=gr.HTML(label="Output"), article=article, description=description, examples=examples, ) demo.launch(debug=True, show_error=True, share=True)