ReactXT / data_provider /pretrain_dm.py
SyrWin
init
95f97c5
raw
history blame
13 kB
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
from pytorch_lightning import LightningDataModule
import torch_geometric
# from torch_geometric.loader import DataLoader
from torch.utils.data import DataLoader
from torch_geometric.loader.dataloader import Collater
from data_provider.molecule_abstract_dataset import MoleculeAbstract
import re
from transformers import BatchEncoding
# we split individual characters inside special tokens like [START_DNA]
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])")
# token added to implement a custom sequence tokenization. This token is added at
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance
# that they do not occur in the corpus. The digits are escaped so that the token does not appear
# literally in the source code in case we ever include it in the training data.
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E"
def _insert_split_marker(m: re.Match):
"""
Applies split marker based on a regex match of special tokens such as
[START_DNA].
Parameters
----------
n : str
Input text to split
Returns
----------
str - the text with the split token added
"""
start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
def smiles_handler(text, mol_ph, is_gal=True):
smiles_list = []
for match in CUSTOM_SEQ_RE.finditer(text):
smiles = match.group(3)
smiles_list.append(smiles)
if is_gal:
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text)
text = escape_custom_split_sequence(text)
return text, smiles_list
else:
text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text)
return text, smiles_list
def escape_custom_split_sequence(text):
"""
Applies custom splitting to the text for GALILEO's tokenization
Parameters
----------
text : str
Input text to split
Returns
----------
str - the text with the split token added
"""
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
def tokenize_and_merge_batched_qa_pairs(tokenizer, qa_pairs_list, max_length):
tokenized_batches = {
'input_ids': [],
'attention_mask': []
}
for qa_pairs in qa_pairs_list:
max_length_per_qa = max_length // len(qa_pairs)
batch_input_ids = []
batch_attention_mask = []
for qa in qa_pairs:
# here qa should be string
tokens = tokenizer(qa,
truncation=True,
padding=False,
add_special_tokens=False,
max_length=max_length_per_qa,
return_tensors='pt',
return_attention_mask=True)
batch_input_ids.extend(tokens['input_ids'].squeeze().tolist())
batch_attention_mask.extend(tokens['attention_mask'].squeeze().tolist())
# Pad the batch to max_length
padding_length = max_length - len(batch_input_ids)
batch_input_ids.extend([tokenizer.pad_token_id] * padding_length)
batch_attention_mask.extend([0] * padding_length)
tokenized_batches['input_ids'].append(torch.tensor(batch_input_ids).unsqueeze(0))
tokenized_batches['attention_mask'].append(torch.tensor(batch_attention_mask).unsqueeze(0))
tokenized_batches['input_ids'] = torch.cat(tokenized_batches['input_ids'], dim=0)
tokenized_batches['attention_mask'] = torch.cat(tokenized_batches['attention_mask'], dim=0)
tokenized_batch = BatchEncoding(data=tokenized_batches, tensor_type='pt')
return tokenized_batch
class TrainCollater:
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False):
self.text_max_len = text_max_len
self.tokenizer = tokenizer
self.collater = Collater([], [])
self.mol_ph = mol_ph
self.mol_token_id = mol_token_id
self.is_gal = is_gal
self.disable_graphs = disable_graphs
def __call__(self, batch):
graphs, mol_prompt, text_prompt = zip(*batch)
if not self.disable_graphs:
graphs = [graph for graph_batch in graphs for graph in graph_batch]
graphs = self.collater(graphs)
qa_pairs = []
for mol_batch, text_batch in zip(mol_prompt, text_prompt):
qa_list = []
for mol_prompt, text_prompt in zip(mol_batch, text_batch):
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
qa_list.append(f'{smiles_prompt} {text_prompt}')
qa_pairs.append(qa_list)
self.tokenizer.padding_side = 'right'
qa_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, qa_pairs, self.text_max_len)
is_mol_token = qa_batch.input_ids == self.mol_token_id
qa_batch['is_mol_token'] = is_mol_token
return graphs, qa_batch
class InferenceCollater:
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False, last_only=False):
self.text_max_len = text_max_len
self.tokenizer = tokenizer
self.collater = Collater([], [])
self.mol_ph = mol_ph
self.mol_token_id = mol_token_id
self.is_gal = is_gal
self.disable_graphs = disable_graphs
self.last_only = last_only
def __call__(self, batch):
graphs, mol_prompt, text_prompt = zip(*batch)
rxn_ids = [0 for i in range(len(mol_prompt))]
if self.last_only:
mol_prompt = [[mol_batch[-1]] for mol_batch in mol_prompt]
text_prompt = [[text_batch[-1]] for text_batch in text_prompt]
graphs = [[graph_batch[-1]] for graph_batch in graphs]
if not self.disable_graphs:
graphs = [graph for graph_batch in graphs for graph in graph_batch]
graphs = self.collater(graphs)
input_text, output_text = [], []
for mol_batch, text_batch in zip(mol_prompt, text_prompt):
qa_list = []
for mol_prompt, text_prompt in list(zip(mol_batch, text_batch))[:-1]:
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0]
qa_list.append(f'{smiles_prompt} {text_prompt}')
qa_list.append(f'{smiles_handler(mol_batch[-1], self.mol_ph, self.is_gal)[0]} ')
output_text.append(text_batch[-1])
input_text.append(qa_list)
self.tokenizer.padding_side = 'right'
input_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, input_text, self.text_max_len)
is_mol_token = input_batch.input_ids == self.mol_token_id
input_batch['is_mol_token'] = is_mol_token
return rxn_ids, graphs, input_batch, output_text, input_text
class PretrainDM(LightningDataModule):
def __init__(
self,
num_workers: int = 0,
batch_size: int = 256,
root: str = 'data/',
text_max_len: int = 128,
rxn_max_len: int = 128,
smi_max_len: int = 128,
tokenizer=None,
args=None,
):
super().__init__()
self.args = args
self.batch_size = batch_size
self.inference_batch_size = args.inference_batch_size
self.num_workers = num_workers
self.text_max_len = text_max_len
self.rxn_max_len = rxn_max_len
self.pretrain_dataset = MoleculeAbstract(
root,
rxn_num=args.pretrain_rxn_num,
rxn_batch_size=args.rxn_batch_size,
smi_max_len=smi_max_len,
disable_graph_cache=args.disable_graph_cache,
context_style=args.context_style,
disable_graphs=args.disable_graphs,
use_caption_dataset=args.pretrain_use_caption,
caption_batch_num=args.caption_batch_num,
synthesis_datasetpath=args.pretrain_synthesis_path,
synthesis_batch_num=args.synthesis_batch_num,
reverse_ratio=args.reverse_ratio,
enable_abstract=not args.disable_abstract,
enable_property=not args.disable_property,
smiles_type=args.smiles_type,
)
self.test_dataset = MoleculeAbstract(
root,
rxn_num=args.pretrain_rxn_num,
rxn_batch_size=args.rxn_batch_size,
smi_max_len=smi_max_len,
disable_graph_cache=args.disable_graph_cache,
context_style=args.context_style,
disable_graphs=args.disable_graphs,
use_caption_dataset=args.pretrain_use_caption,
caption_batch_num=args.caption_batch_num,
reverse_ratio=args.reverse_ratio,
enable_abstract=not args.disable_abstract,
enable_property=not args.disable_property,
smiles_type=args.smiles_type,
mode='test',
)
self.init_tokenizer(tokenizer)
self.mol_ph_token = '<mol>' * self.args.num_query_token
self.is_gal = args.opt_model.find('galactica') >= 0
self.disable_graphs = args.disable_graphs
self.last_only = args.pretrain_eval_last_only
def init_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
self.pretrain_dataset.tokenizer = tokenizer
self.test_dataset.tokenizer = tokenizer
self.mol_token_id = self.tokenizer.mol_token_id
# self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0]
def train_dataloader(self):
self.pretrain_dataset.reload_data_list()
loader = DataLoader(
self.pretrain_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=False,
drop_last=True,
persistent_workers=True,
collate_fn=TrainCollater(
tokenizer=self.tokenizer,
text_max_len=self.text_max_len,
mol_ph=self.mol_ph_token,
mol_token_id=self.mol_token_id,
is_gal=self.is_gal,
disable_graphs=self.disable_graphs,
),
)
return loader
def val_dataloader(self):
test_loader = DataLoader(
self.test_dataset,
batch_size=self.inference_batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=False,
drop_last=False,
persistent_workers=True,
collate_fn=InferenceCollater(
tokenizer=self.tokenizer,
text_max_len=self.text_max_len,
mol_ph=self.mol_ph_token,
mol_token_id=self.mol_token_id,
is_gal=self.is_gal,
disable_graphs=self.disable_graphs,
last_only=self.last_only,
),
)
return [test_loader]
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("Data module")
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--inference_batch_size', type=int, default=4)
parser.add_argument('--use_smiles', action='store_true', default=False)
parser.add_argument('--root', type=str, default='data/action_data')
parser.add_argument('--context_style', type=str, default='weighted_rxn', choices=['weighted_rxn', 'uniform_rxn', 'uniform_mol', 'single_mol', 'hybrid'])
parser.add_argument('--rxn_max_len', type=int, default=512)
parser.add_argument('--text_max_len', type=int, default=512)
parser.add_argument('--smi_max_len', type=int, default=128)
parser.add_argument('--pretrain_rxn_num', type=int, default=50000)
parser.add_argument('--reverse_ratio', type=float, default=0.5, help='ratio of reversed reactions (retro reactions)')
parser.add_argument('--disable_abstract', action='store_true', default=False)
parser.add_argument('--disable_property', action='store_true', default=False)
parser.add_argument('--pretrain_use_caption', action='store_true', default=False)
parser.add_argument('--caption_batch_num', type=int, default=5000)
parser.add_argument('--pretrain_synthesis_path', type=str, default=None)
parser.add_argument('--synthesis_batch_num', type=int, default=5000)
parser.add_argument('--rxn_batch_size', type=int, default=4)
parser.add_argument('--roundrobin_train', action='store_true', default=False)
parser.add_argument('--test_subset', type=int, default=-1)
parser.add_argument('--pretrain_eval_last_only', default=False, action='store_true')
parser.add_argument('--prompt', type=str, default=None)
return parent_parser