import logging from typing import List import numpy as np import mols2grid import pandas as pd from rdkit import Chem logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) def draw_grid_predict( sequences: List[str], properties: np.array, property_names: List[str], domain: str ) -> str: """ Uses mols2grid to draw a HTML grid for the prediction Args: sequences: Sequences for which properties are predicted. properties: Predicted properties. Array of shape (n_samples, n_properties). names: List of property names domain: Domain of the prediction (molecules or proteins). Returns: HTML to display """ if domain not in ["Molecules", "Proteins"]: raise ValueError(f"Unsupported domain {domain}") if domain == "Proteins": converter = lambda x: Chem.MolToSmiles(Chem.MolFromFASTA(x)) else: converter = lambda x: x smiles = [] for sequence in sequences: try: seq = converter(sequence) smiles.append(seq) except Exception: logger.warning(f"Could not draw sequence {seq}") result = pd.DataFrame({"SMILES": smiles}) for i, name in enumerate(property_names): result[name] = properties[:, i] n_cols = min(3, len(result)) size = (140, 200) if len(result) > 3 else (600, 700) obj = mols2grid.display( result, tooltip=list(result.keys()), subset=["img"] + list(result.keys()), height=1100, n_cols=n_cols, name="Results", size=size, ) return obj.data