|
""" |
|
© Battelle Memorial Institute 2023 |
|
Made available under the GNU General Public License v 2.0 |
|
|
|
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY |
|
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN |
|
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES |
|
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED |
|
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF |
|
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS |
|
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE |
|
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, |
|
REPAIR OR CORRECTION. |
|
""" |
|
|
|
import re |
|
import torch |
|
import pandas as pd |
|
from rdkit import Chem |
|
from rdkit.Chem.SaltRemover import SaltRemover |
|
|
|
|
|
class InvalidSmile(Exception): |
|
pass |
|
|
|
|
|
def load_vocab(vocab_file_name): |
|
""" |
|
Load an existing vocabulary from a file. Assumes a single |
|
token definition per line of the file. |
|
|
|
Parameters |
|
---------- |
|
vocab_file_name : str |
|
The file name of the vocabulary to load. |
|
|
|
Returns |
|
------- |
|
vocab_dict : dict |
|
A dict of tokens as the keys and the corresponding |
|
token index as the items. |
|
|
|
""" |
|
|
|
vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list() |
|
vocab_dict = {v: ind for ind, v in enumerate(vocab)} |
|
|
|
return vocab_dict |
|
|
|
|
|
def smiles_tokenizer(smiles): |
|
""" |
|
Tokenize a SMILES string. |
|
|
|
Parameters |
|
---------- |
|
smiles : str |
|
A SMILES string to turn into tokens. |
|
|
|
Returns |
|
------- |
|
tokens : list |
|
A list of tokens after tokenizing the input string. |
|
|
|
""" |
|
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])" |
|
regex = re.compile(pattern) |
|
tokens = [token for token in regex.findall(smiles)] |
|
|
|
|
|
if len("".join(tokens)) < len(smiles): |
|
raise Exception( |
|
"Input smiles string contained invalid characters." |
|
) |
|
|
|
return tokens |
|
|
|
|
|
def smiles_to_tensor( |
|
smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True |
|
): |
|
""" |
|
Converts a SMILES string to a tensor using the provided vocabulary. |
|
|
|
Parameters |
|
---------- |
|
smiles : str |
|
A SMILES string to convert to a tensor. |
|
vocab_dict : dict |
|
A dictionary of SMILES tokens and integer value as the dictionary key |
|
and item, respectively. |
|
max_seq_len : int |
|
The maximum sequence length allowed for SMILES strings. Smaller |
|
strings are padded to the maximum length using the [PAD] token |
|
from the vocabulary provided. |
|
desalt : bool, optional |
|
Flag for removing salts and solvents from SMILES string, by default True. |
|
canonical : bool, optional |
|
Flag enabling the conversion of the SMILES to canonical form, by default True. |
|
isomeric : bool, optional |
|
Flag enabling the conversion of the SMILES to isomeric form, by default True. |
|
|
|
Returns |
|
------- |
|
smiles_ten_long : tensor |
|
A tensor representing the converted SMILES string based on the provided |
|
vocabulary with shape (1, max_seq_len). |
|
|
|
""" |
|
|
|
remover = SaltRemover() |
|
|
|
mol = Chem.MolFromSmiles(smiles) |
|
if mol is None: |
|
raise InvalidSmile('Molecule could not be constructed from smile string') |
|
|
|
if desalt: |
|
mol = remover.StripMol(mol, dontRemoveEverything=True) |
|
|
|
smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric) |
|
|
|
smiles_tok = smiles_tokenizer(smiles) |
|
tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]] |
|
tok += [vocab_dict[x] for x in smiles_tok] |
|
tok += [vocab_dict["[EDGE]"]] |
|
smiles_ten = torch.tensor(tok, dtype=torch.long) |
|
smiles_ten_long = ( |
|
torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"] |
|
) |
|
smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten |
|
|
|
return smiles_ten_long |
|
|