jannisborn's picture
update
14da265 unverified
raw
history blame
3.87 kB
import logging
import pathlib
import gradio as gr
import numpy as np
import pandas as pd
from gt4sd.properties.molecules import MOLECULE_PROPERTY_PREDICTOR_FACTORY
from utils import draw_grid_predict
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
REMOVE = ["docking", "docking_tdc", "molecule_one", "askcos", "plogp"]
REMOVE.extend(["similarity_seed", "activity_against_target", "organtox"])
MODEL_PROP_DESCRIPTION = {
"Tox21": "NR-AR, NR-AR-LBD, NR-AhR, NR-Aromatase, NR-ER, NR-ER-LBD, NR-PPAR-gamma, SR-ARE, SR-ATAD5, SR-HSE, SR-MMP, SR-p53",
"Sider": "Hepatobiliary disorders,Metabolism and nutrition disorders,Product issues,Eye disorders,Investigations,Musculoskeletal disorders,Gastrointestinal disorders,Social circumstances,Immune system disorders,Reproductive system and breast disorders,Bening & malignant,General disorders,Endocrine disorders,Surgical & medical procedures,Vascular disorders,Blood & lymphatic disorders,Skin & subcutaneous disorders,Congenital & genetic disorders,Infections,Respiratory & thoracic disorders,Psychiatric disorders,Renal & urinary disorders,Pregnancy conditions,Ear disorders,Cardiac disorders,Nervous system disorders,Injury & procedural complications",
"Clintox": "FDA approval, Clinical trial failure",
}
def main(property: str, smiles: str, smiles_file: str):
algo, config = MOLECULE_PROPERTY_PREDICTOR_FACTORY[property.lower()]
kwargs = (
{"algorithm_version": "v0"} if property in MODEL_PROP_DESCRIPTION.keys() else {}
)
model = algo(config(**kwargs))
if smiles is not None and smiles_file is not None:
raise ValueError("Pass either smiles or smiles_file, not both.")
elif smiles is not None:
smiles = [smiles]
elif smiles_file is not None:
smiles = pd.read_csv(smiles_file.name, header=None, sep="\t")[0].tolist()
props = np.array(list(map(model, smiles))).round(2)
# Expand to 2D array if needed
if len(props.shape) == 1:
props = np.expand_dims(np.array(props), -1)
if property in MODEL_PROP_DESCRIPTION.keys():
property_names = MODEL_PROP_DESCRIPTION[property].split(",")
else:
property_names = [property]
return draw_grid_predict(
smiles, props, property_names=property_names, domain="Molecules"
)
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
properties = list(MOLECULE_PROPERTY_PREDICTOR_FACTORY.keys())[::-1]
for prop in REMOVE:
prop_to_idx = dict(zip(properties, range(len(properties))))
properties.pop(prop_to_idx[prop])
properties = list(map(lambda x: x.capitalize(), properties))
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = [
["Qed", None, metadata_root.joinpath("examples.smi")],
[
"Esol",
"CN1CCN(CCCOc2ccc(N3C(=O)C(=Cc4ccc(Oc5ccc([N+](=O)[O-])cc5)cc4)SC3=S)cc2)CC1",
None,
],
]
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
demo = gr.Interface(
fn=main,
title="Molecular properties",
inputs=[
gr.Dropdown(properties, label="Property", value="qed"),
gr.Textbox(
label="Single SMILES",
placeholder="CC(C#C)N(C)C(=O)NC1=CC=C(Cl)C=C1",
lines=1,
),
gr.File(
file_types=[".smi"],
label="Multiple SMILES (tab-separated, `.smi` file)",
),
],
outputs=gr.HTML(label="Output"),
article=article,
description=description,
examples=examples,
)
demo.launch(debug=True, show_error=True)