FupBERT / smiles.py
c-dunlap's picture
Added model vocab
b4dd09c
"""
© Battelle Memorial Institute 2023
Made available under the GNU General Public License v 2.0
BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
REPAIR OR CORRECTION.
"""
import re
import torch
import pandas as pd
from rdkit import Chem
from rdkit.Chem.SaltRemover import SaltRemover
class InvalidSmile(Exception):
pass
def load_vocab(vocab_file_name):
"""
Load an existing vocabulary from a file. Assumes a single
token definition per line of the file.
Parameters
----------
vocab_file_name : str
The file name of the vocabulary to load.
Returns
-------
vocab_dict : dict
A dict of tokens as the keys and the corresponding
token index as the items.
"""
# Get vocabulary
vocab = pd.read_csv(vocab_file_name, header=None)[0].to_list()
vocab_dict = {v: ind for ind, v in enumerate(vocab)}
return vocab_dict
def smiles_tokenizer(smiles):
"""
Tokenize a SMILES string.
Parameters
----------
smiles : str
A SMILES string to turn into tokens.
Returns
-------
tokens : list
A list of tokens after tokenizing the input string.
"""
pattern = "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
regex = re.compile(pattern)
tokens = [token for token in regex.findall(smiles)]
# check if the smiles string had extra characters not recognized by regex
# solution based on https://stackoverflow.com/a/3879574
if len("".join(tokens)) < len(smiles):
raise Exception(
"Input smiles string contained invalid characters."
)
return tokens
def smiles_to_tensor(
smiles, vocab_dict, max_seq_len, desalt=True, canonical=True, isomeric=True
):
"""
Converts a SMILES string to a tensor using the provided vocabulary.
Parameters
----------
smiles : str
A SMILES string to convert to a tensor.
vocab_dict : dict
A dictionary of SMILES tokens and integer value as the dictionary key
and item, respectively.
max_seq_len : int
The maximum sequence length allowed for SMILES strings. Smaller
strings are padded to the maximum length using the [PAD] token
from the vocabulary provided.
desalt : bool, optional
Flag for removing salts and solvents from SMILES string, by default True.
canonical : bool, optional
Flag enabling the conversion of the SMILES to canonical form, by default True.
isomeric : bool, optional
Flag enabling the conversion of the SMILES to isomeric form, by default True.
Returns
-------
smiles_ten_long : tensor
A tensor representing the converted SMILES string based on the provided
vocabulary with shape (1, max_seq_len).
"""
# Initialize the salt/solvent remover
remover = SaltRemover()
# Convert the SMILES to molecule
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise InvalidSmile('Molecule could not be constructed from smile string')
# Remove the salts/solvents
if desalt:
mol = remover.StripMol(mol, dontRemoveEverything=True)
# Convert back to SMILES
smiles = Chem.MolToSmiles(mol, canonical=canonical, isomericSmiles=isomeric)
# Tokenize the SMILES
smiles_tok = smiles_tokenizer(smiles)
tok = [vocab_dict["[CLS]"], vocab_dict["[EDGE]"]]
tok += [vocab_dict[x] for x in smiles_tok]
tok += [vocab_dict["[EDGE]"]]
smiles_ten = torch.tensor(tok, dtype=torch.long)
smiles_ten_long = (
torch.ones((1, max_seq_len), dtype=torch.long) * vocab_dict["[PAD]"]
)
smiles_ten_long[0, : smiles_ten.shape[0]] = smiles_ten
return smiles_ten_long