Spaces:
Runtime error
Runtime error
import torch | |
from torch_geometric.data import Data | |
from ogb.utils import smiles2graph | |
from rdkit import Chem | |
import random | |
import os | |
import json | |
from rdkit import RDLogger | |
RDLogger.DisableLog('rdApp.*') | |
from .r_smiles import multi_process | |
import multiprocessing | |
def reformat_smiles(smiles, smiles_type='default'): | |
if not smiles: | |
return None | |
if smiles_type == 'default': | |
return smiles | |
elif smiles_type=='canonical': | |
mol = Chem.MolFromSmiles(smiles) | |
return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False) | |
elif smiles_type=='restricted': | |
mol = Chem.MolFromSmiles(smiles) | |
new_atom_order = list(range(mol.GetNumAtoms())) | |
random.shuffle(new_atom_order) | |
random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order) | |
return Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False) | |
elif smiles_type=='unrestricted': | |
mol = Chem.MolFromSmiles(smiles) | |
return Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False) | |
elif smiles_type=='r_smiles': | |
# the implementation of root-aligned smiles is in r_smiles.py | |
return smiles | |
else: | |
raise NotImplementedError(f"smiles_type {smiles_type} not implemented") | |
def json_read(path): | |
with open(path, 'r') as f: | |
data = json.load(f) | |
return data | |
def json_write(path, data): | |
with open(path, 'w') as f: | |
json.dump(data, f, indent=4, ensure_ascii=False) | |
def format_float_from_string(s): | |
try: | |
float_value = float(s) | |
return f'{float_value:.2f}' | |
except ValueError: | |
return s | |
def make_abstract(mol_dict, abstract_max_len=256, property_max_len=256): | |
prompt = '' | |
if 'abstract' in mol_dict: | |
abstract_string = mol_dict['abstract'][:abstract_max_len] | |
prompt += f'[Abstract] {abstract_string} ' | |
property_string = '' | |
property_dict = mol_dict['property'] if 'property' in mol_dict else {} | |
for property_key in ['Experimental Properties', 'Computed Properties']: | |
if not property_key in property_dict: | |
continue | |
for key, value in property_dict[property_key].items(): | |
if isinstance(value, float): | |
key_value_string = f'{key}: {value:.2f}; ' | |
elif isinstance(value, str): | |
float_value = format_float_from_string(value) | |
key_value_string = f'{key}: {float_value}; ' | |
else: | |
key_value_string = f'{key}: {value}; ' | |
if len(property_string+key_value_string) > property_max_len: | |
break | |
property_string += key_value_string | |
if property_string: | |
property_string = property_string[:property_max_len] | |
prompt += f'[Properties] {property_string}. ' | |
return prompt | |
def smiles2data(smiles): | |
graph = smiles2graph(smiles) | |
x = torch.from_numpy(graph['node_feat']) | |
edge_index = torch.from_numpy(graph['edge_index'], ) | |
edge_attr = torch.from_numpy(graph['edge_feat']) | |
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
return data | |
import re | |
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" | |
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") | |
def _insert_split_marker(m: re.Match): | |
""" | |
Applies split marker based on a regex match of special tokens such as | |
[START_DNA]. | |
Parameters | |
---------- | |
n : str | |
Input text to split | |
Returns | |
---------- | |
str - the text with the split token added | |
""" | |
start_token, _, sequence, end_token = m.groups() | |
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) | |
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" | |
def escape_custom_split_sequence(text): | |
""" | |
Applies custom splitting to the text for GALILEO's tokenization | |
Parameters | |
---------- | |
text : str | |
Input text to split | |
Returns | |
---------- | |
str - the text with the split token added | |
""" | |
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) | |
def generate_rsmiles(reactants, products, augmentation=20): | |
""" | |
reactants: list of N, reactant smiles | |
products: list of N, product smiles | |
augmentation: int, number of augmentations | |
return: list of N x augmentation | |
""" | |
data = [{ | |
'reactant': r.strip().replace(' ', ''), | |
'product': p.strip().replace(' ', ''), | |
'augmentation': augmentation, | |
'root_aligned': True, | |
} for r, p in zip(reactants, products)] | |
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) | |
results = pool.map(func=multi_process,iterable=data) | |
product_smiles = [smi for r in results for smi in r['src_data']] | |
reactant_smiles = [smi for r in results for smi in r['tgt_data']] | |
return reactant_smiles, product_smiles |