File size: 4,866 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
from torch_geometric.data import Data
from ogb.utils import smiles2graph
from rdkit import Chem
import random
import os
import json
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from .r_smiles import multi_process
import multiprocessing

def reformat_smiles(smiles, smiles_type='default'):
    if not smiles:
        return None
    if smiles_type == 'default':
        return smiles
    elif smiles_type=='canonical':
        mol = Chem.MolFromSmiles(smiles)
        return Chem.MolToSmiles(mol, canonical=True, isomericSmiles=False)
    elif smiles_type=='restricted':
        mol = Chem.MolFromSmiles(smiles)
        new_atom_order = list(range(mol.GetNumAtoms()))
        random.shuffle(new_atom_order)
        random_mol = Chem.RenumberAtoms(mol, newOrder=new_atom_order)
        return Chem.MolToSmiles(random_mol, canonical=False, isomericSmiles=False)
    elif smiles_type=='unrestricted':
        mol = Chem.MolFromSmiles(smiles)
        return Chem.MolToSmiles(mol, canonical=False, doRandom=True, isomericSmiles=False)
    elif smiles_type=='r_smiles':
        # the implementation of root-aligned smiles is in r_smiles.py
        return smiles
    else:
        raise NotImplementedError(f"smiles_type {smiles_type} not implemented")

def json_read(path):
    with open(path, 'r') as f:
        data = json.load(f)
    return data

def json_write(path, data):
    with open(path, 'w') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def format_float_from_string(s):
    try:
        float_value = float(s)
        return f'{float_value:.2f}'
    except ValueError:
        return s

def make_abstract(mol_dict, abstract_max_len=256, property_max_len=256):
    prompt = ''
    if 'abstract' in mol_dict:
        abstract_string = mol_dict['abstract'][:abstract_max_len]
        prompt += f'[Abstract] {abstract_string} '

    property_string = ''
    property_dict = mol_dict['property'] if 'property' in mol_dict else {}
    for property_key in ['Experimental Properties', 'Computed Properties']:
        if not property_key in property_dict:
            continue
        for key, value in property_dict[property_key].items():
            if isinstance(value, float):
                key_value_string = f'{key}: {value:.2f}; '
            elif isinstance(value, str):
                float_value = format_float_from_string(value)
                key_value_string = f'{key}: {float_value}; '
            else:
                key_value_string = f'{key}: {value}; '
            if len(property_string+key_value_string) > property_max_len:
                break
            property_string += key_value_string
    if property_string:
        property_string = property_string[:property_max_len]
        prompt += f'[Properties] {property_string}. '
    return prompt

def smiles2data(smiles):
    graph = smiles2graph(smiles)
    x = torch.from_numpy(graph['node_feat'])
    edge_index = torch.from_numpy(graph['edge_index'], )
    edge_attr = torch.from_numpy(graph['edge_feat'])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

import re
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"

CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")


def _insert_split_marker(m: re.Match):
    """
    Applies split marker based on a regex match of special tokens such as
    [START_DNA].

    Parameters
    ----------
    n : str
        Input text to split

    Returns
    ----------
    str - the text with the split token added
    """
    start_token, _, sequence, end_token = m.groups()
    sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
    return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"

def escape_custom_split_sequence(text):
    """
    Applies custom splitting to the text for GALILEO's tokenization

    Parameters
    ----------
    text : str
        Input text to split

    Returns
    ----------
    str - the text with the split token added
    """
    return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)

def generate_rsmiles(reactants, products, augmentation=20):
    """
        reactants: list of N, reactant smiles
        products: list of N, product smiles
        augmentation: int, number of augmentations
        
        return: list of N x augmentation
    """
    data = [{
        'reactant': r.strip().replace(' ', ''),
        'product': p.strip().replace(' ', ''),
        'augmentation': augmentation,
        'root_aligned': True,
    } for r, p in zip(reactants, products)]
    pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    results = pool.map(func=multi_process,iterable=data)
    product_smiles = [smi for r in results for smi in r['src_data']]
    reactant_smiles = [smi for r in results for smi in r['tgt_data']]
    return reactant_smiles, product_smiles