""" © Battelle Memorial Institute 2023 Made available under the GNU General Public License v 2.0 BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. """ import re import torch import pandas as pd from rdkit import Chem from rdkit.Chem.SaltRemover import SaltRemover class InvalidSmile(Exception): pass def load_vocab(vocab_file_name): """ Load an existing vocabulary from a file. Assumes a single token definition per line of the file. Parameters ---------- vocab_file_name : str The file name of the vocabulary to load. Returns ------- vocab_dict : dict A dict of tokens as the keys and the corresponding token index as the items. """ # Get vocabulary vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list() vocab_dict = {v: ind for ind, v in enumerate(vocab)} return vocab_dict def smiles_tokenizer(smiles): """ Tokenize a SMILES string. Parameters ---------- smiles : str A SMILES string to turn into tokens. Returns ------- tokens : list A list of tokens after tokenizing the input string. """ pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" regex = re.compile(pattern) tokens = [token for token in regex.findall(smiles)] # check if the smiles string had extra characters not recognized by regex # solution based on https://stackoverflow.com/a/3879574 if len("".join(tokens)) < len(smiles): raise Exception( "Input smiles string contained invalid characters." ) return tokens def smiles_to_tensor( smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True ): """ Converts a SMILES string to a tensor using the provided vocabulary. Parameters ---------- smiles : str A SMILES string to convert to a tensor. vocab_dict : dict A dictionary of SMILES tokens and integer value as the dictionary key and item, respectively. max_seq_len : int The maximum sequence length allowed for SMILES strings. Smaller strings are padded to the maximum length using the [PAD] token from the vocabulary provided. desalt : bool, optional Flag for removing salts and solvents from SMILES string, by default True. canonical : bool, optional Flag enabling the conversion of the SMILES to canonical form, by default True. isomeric : bool, optional Flag enabling the conversion of the SMILES to isomeric form, by default True. Returns ------- smiles_ten_long : tensor A tensor representing the converted SMILES string based on the provided vocabulary with shape (1, max_seq_len). """ # Initialize the salt/solvent remover remover = SaltRemover() # Convert the SMILES to molecule mol = Chem.MolFromSmiles(smiles) if mol is None: raise InvalidSmile('Molecule could not be constructed from smile string') # Remove the salts/solvents if desalt: mol = remover.StripMol(mol, dontRemoveEverything=True) # Convert back to SMILES smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric) # Tokenize the SMILES smiles_tok = smiles_tokenizer(smiles) tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]] tok += [vocab_dict[x] for x in smiles_tok] tok += [vocab_dict["[EDGE]"]] smiles_ten = torch.tensor(tok, dtype=torch.long) smiles_ten_long = ( torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"] ) smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten return smiles_ten_long