import random import os import numpy as np import argparse import json from collections import defaultdict from matplotlib import pyplot as plt from collections import Counter from .data_utils import json_read def set_random_seed(seed): random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) class Reaction_Cluster: def __init__(self, root, reaction_filename, reverse_ratio=0.5): self.root = root self.reaction_data = json_read(os.path.join(self.root, reaction_filename)) self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json')) self.mol_property_map = {d['canon_smiles']: d for d in self.property_data} self.reverse_ratio = reverse_ratio self.rxn_mols_attr = defaultdict(lambda:{ 'freq': 0, 'occurrence': 0, 'in_caption': False, }) self._read_reaction_mols() # add `valid_mols` in each rxn_dict self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols']) self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict def _read_reaction_mols(self): self.valid_rxn_indices = [] for rxn_id, rxn_dict in enumerate(self.reaction_data): mol_role_map = {} for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: for m in rxn_dict[key]: if m in mol_role_map: continue if m in self.mol_property_map: mol_role_map[m] = key valid_mols = [] for mol in mol_role_map: assert mol in self.mol_property_map # this is garanteed by the above if statement if 'abstract' not in self.mol_property_map[mol]: continue valid_mols.append(mol) # here the molecules should be in the R, C, S, P order. if len(valid_mols) > 0: self.valid_rxn_indices.append(rxn_id) rxn_dict['valid_mols'] = valid_mols rxn_dict['mol_role_map'] = mol_role_map def _calculate_Pr(self): total_weights = 0 for rxn_dict in self.reaction_data: rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']]) rxn_dict['weight'] = rxn_weight total_weights += rxn_weight for rxn_dict in self.reaction_data: rxn_dict['weight'] = rxn_dict['weight'] / total_weights def _calculate_Pir(self): for rxn_dict in self.reaction_data: mol_weight = {} for mol in rxn_dict['valid_mols']: mol_weight[mol] = 1/self.mol_counter[mol] total_weight = sum(mol_weight.values()) rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()} def choose_mol(self, valid_mols, k=4, weights=None): if k>=len(valid_mols): sampled_indices = list(range(len(valid_mols))) else: sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights) sampled_indices = list(sampled_indices) sampled_indices = sorted(sampled_indices) if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance. sampled_indices.reverse() sampled_mols = [valid_mols[i] for i in sampled_indices] return sampled_mols def sample_mol_batch(self, index=None, k=4): if index is None: index = self.sample_rxn_index(1)[0] assert index < len(self.reaction_data) rxn = self.reaction_data[index] valid_mols, weights = zip(*rxn['mol_weight'].items()) sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights) mol_property_batch = [] for mol in sampled_mols: mol_property = self.mol_property_map[mol] mol_role = rxn['mol_role_map'][mol] mol_property['role'] = mol_role mol_property_batch.append(mol_property) if 'rsmiles_map' in rxn: rsmiles_map = random.choice(rxn['rsmiles_map']) for mol_property in mol_property_batch: canon_smiles = mol_property['canon_smiles'] if canon_smiles in rsmiles_map: mol_property['r_smiles'] = rsmiles_map[canon_smiles] return mol_property_batch def sample_rxn_index(self, num_samples): indices = range(len(self.reaction_data)) weights = [d['weight'] for d in self.reaction_data] return np.random.choice(indices, num_samples, replace=False, p=weights) def __call__(self, rxn_num=1000, k=4): sampled_indices = self.sample_rxn_index(rxn_num) sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices] return sampled_batch def generate_batch_uniform_rxn(self, rxn_num=1000, k=4): assert rxn_num <= len(self.valid_rxn_indices) sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num) sampled_batch = [] for rxn_id in sampled_rxn_indices: rxn = self.reaction_data[rxn_id] sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None) mol_property_batch = [] for mol in sampled_mols: mol_property = self.mol_property_map[mol] mol_role = rxn['mol_role_map'][mol] mol_property['role'] = mol_role mol_property_batch.append(mol_property) sampled_batch.append(mol_property_batch) return sampled_batch def generate_batch_uniform_mol(self, rxn_num=1000, k=4): valid_mols = list(self.mol_counter.elements()) assert rxn_num*k <= len(valid_mols) sampled_batch = [] sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k) for i in range(rxn_num): sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]]) return sampled_batch def generate_batch_single(self, rxn_num=1000): valid_mols = list(self.mol_counter.elements()) sampled_mols = random.sample(valid_mols, rxn_num) total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols] return total_valid_mols # visaulize probability for molecules in caption dataset. def visualize_mol_distribution(self): prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()} N = len(prob_dict) M = len(self.reaction_data) assert N == len(self.mol_property_map) print(f'Number of molecules in Caption Dataset: {N}') print(f'Number of Reactions in Reaction Dataset: {M}') # prob distribution for molecules for rxn_dict in self.reaction_data: for mol, weight in rxn_dict['mol_weight'].items(): prob_dict[mol] += weight * rxn_dict['weight'] # sum of prob_dict.values() should already be 1. prob_values = np.array(list(prob_dict.values())) prob_values *= N # prob distribution for reactions rxn_weights = np.array([d['weight'] for d in self.reaction_data]) # sum of rxn_weights should already be 1. rxn_weights *= M return prob_values, rxn_weights # visaulize the frequency for molecules in caption dataset. def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100): sampled_mols_counter = Counter() sampled_rxns_counter = Counter() for _ in range(epochs): rxn_indices = self.sample_rxn_index(rxn_num) sampled_rxns_counter.update(rxn_indices) for index in rxn_indices: rxn = self.reaction_data[index] if len(rxn['valid_mols']) ==0: continue valid_mols, weights = zip(*rxn['mol_weight'].items()) mol_batch = self.choose_mol(valid_mols, k=k, weights=weights) sampled_mols_counter.update(mol_batch) sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())]) sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())]) return sampled_mols_count, sampled_rxns_count def _randomly(self, func, *args, **kwargs): # make fake weights and backup the weights for rxn_dict in self.reaction_data: rxn_dict['weight_bak'] = rxn_dict['weight'] rxn_dict['weight'] = 1/len(self.reaction_data) rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight'] rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']} # run the function result = func(*args, **kwargs) # weights recovery for rxn_dict in self.reaction_data: rxn_dict['weight'] = rxn_dict['weight_bak'] del rxn_dict['weight_bak'] rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak'] del rxn_dict['mol_weight_bak'] return result