# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import torch from pytorch_lightning import LightningDataModule import torch_geometric # from torch_geometric.loader import DataLoader from torch.utils.data import DataLoader from torch_geometric.loader.dataloader import Collater from data_provider.reaction_action_dataset import ActionDataset from data_provider.synthesis_dataset import SynthesisDataset from data_provider.caption_dataset import CaptionDataset from data_provider.chebi_dataset import ChEBI_dataset import re # we split individual characters inside special tokens like [START_DNA] CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") # token added to implement a custom sequence tokenization. This token is added at # corpus cleaning step and removed in pretokenization. The digits are added to increase the chance # that they do not occur in the corpus. The digits are escaped so that the token does not appear # literally in the source code in case we ever include it in the training data. SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" def _insert_split_marker(m: re.Match): """ Applies split marker based on a regex match of special tokens such as [START_DNA]. Parameters ---------- n : str Input text to split Returns ---------- str - the text with the split token added """ start_token, _, sequence, end_token = m.groups() sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" def smiles_handler(text, mol_ph, is_gal=True): smiles_list = [] for match in CUSTOM_SEQ_RE.finditer(text): smiles = match.group(3) smiles_list.append(smiles) if is_gal: text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text) text = escape_custom_split_sequence(text) return text, smiles_list else: text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text) return text, smiles_list def escape_custom_split_sequence(text): """ Applies custom splitting to the text for GALILEO's tokenization Parameters ---------- text : str Input text to split Returns ---------- str - the text with the split token added """ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) class TrainCollater: def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True, use_graph=True, use_qa_pair=True): self.rxn_max_len = rxn_max_len self.text_max_len = text_max_len self.tokenizer = tokenizer self.collater = Collater([], []) self.mol_ph = mol_ph self.mol_token_id = mol_token_id self.is_gal = is_gal self.use_graph = use_graph self.use_qa_pair = use_qa_pair def __call__(self, batch): return self.collate_qa(batch) if self.use_qa_pair else self.collate(batch) def collate(self, batch): rxn_ids, graphs, texts, smiles_prompt = zip(*batch) if graphs: graphs = self.collater(graphs) ## deal with prompt if self.use_graph: smiles_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in smiles_prompt] else: smiles_prompt = [escape_custom_split_sequence(p) for p in smiles_prompt] self.tokenizer.padding_side = 'left' smiles_prompt_tokens = self.tokenizer(text=smiles_prompt, truncation=False, padding='longest', add_special_tokens=True, return_tensors='pt', return_attention_mask=True) is_mol_token = smiles_prompt_tokens.input_ids == self.mol_token_id smiles_prompt_tokens['is_mol_token'] = is_mol_token self.tokenizer.padding_side = 'right' text_tokens = self.tokenizer(text=texts, truncation=True, padding='longest', add_special_tokens=True, max_length=self.text_max_len, return_tensors='pt', return_attention_mask=True) return rxn_ids, graphs, smiles_prompt_tokens, text_tokens def collate_qa(self, batch): rxn_ids, graphs, texts, input_prompt = zip(*batch) graphs = [graph for graph_batch in graphs for graph in graph_batch] if graphs: graphs = self.collater(graphs) ## deal with prompt if self.use_graph: input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt] else: input_prompt = [escape_custom_split_sequence(p) for p in input_prompt] self.tokenizer.padding_side = 'right' qa_pair = [[q, a] for q, a in zip(input_prompt, texts)] qa_batch = self.tokenizer(qa_pair, truncation=True, padding='longest', add_special_tokens=True, max_length=self.rxn_max_len + self.text_max_len, return_tensors='pt', return_attention_mask=True, return_token_type_ids=True) is_mol_token = qa_batch.input_ids == self.mol_token_id qa_batch['is_mol_token'] = is_mol_token return rxn_ids, graphs, qa_batch class InferenceCollater: def __init__(self, tokenizer, text_max_len, rxn_max_len, mol_ph, mol_token_id, is_gal=True): self.text_max_len = text_max_len self.rxn_max_len = rxn_max_len self.tokenizer = tokenizer self.collater = Collater([], []) self.mol_ph = mol_ph self.mol_token_id = mol_token_id self.is_gal = is_gal def __call__(self, batch): rxn_ids, graphs, texts, input_prompt = zip(*batch) inputs = input_prompt graphs = [graph for graph_batch in graphs for graph in graph_batch] if graphs: graphs = self.collater(graphs) input_prompt = [smiles_handler(p, self.mol_ph, self.is_gal)[0] for p in input_prompt] ## deal with prompt self.tokenizer.padding_side = 'left' input_prompt_tokens = self.tokenizer(input_prompt, truncation=True, padding='longest', add_special_tokens=True, max_length=self.rxn_max_len, return_tensors='pt', return_attention_mask=True) is_mol_token = input_prompt_tokens.input_ids == self.mol_token_id input_prompt_tokens['is_mol_token'] = is_mol_token return rxn_ids, graphs, input_prompt_tokens, texts, inputs class TuneDM(LightningDataModule): def __init__( self, num_workers: int = 0, batch_size: int = 256, root: str = 'data/', text_max_len: int = 128, smi_max_len: int = 128, rxn_max_len: int = 128, tokenizer=None, downstream_task='action', args=None, ): super().__init__() self.args = args self.batch_size = batch_size self.inference_batch_size = args.inference_batch_size self.num_workers = num_workers self.rxn_max_len = rxn_max_len self.text_max_len = text_max_len self.prompt = args.prompt DownstreamDataset = { 'action': ActionDataset, 'synthesis': SynthesisDataset, 'caption': CaptionDataset, 'chebi': ChEBI_dataset, }[downstream_task] ds_args = { 'use_graph': not args.disable_graphs, 'disable_graph_cache': args.disable_graph_cache, 'smiles_type': args.smiles_type, } if downstream_task == 'action': ds_args['predict_rxn_condition'] = args.predict_rxn_condition if downstream_task == 'synthesis': ds_args['roundrobin_train'] = args.roundrobin_train ds_args['test_subset'] = args.test_subset self.train_dataset = DownstreamDataset(root, 'train', smi_max_len, **ds_args) self.val_dataset = DownstreamDataset(root, 'valid', smi_max_len, **ds_args) self.test_dataset = DownstreamDataset(root, 'test', smi_max_len, **ds_args) self.init_tokenizer(tokenizer) self.mol_ph_token = '' * self.args.num_query_token self.is_gal = args.opt_model.find('galactica') >= 0 self.use_graph = not args.disable_graphs self.is_t5 = args.opt_model.find('t5') >= 0 def init_tokenizer(self, tokenizer): self.tokenizer = tokenizer self.train_dataset.tokenizer = tokenizer self.val_dataset.tokenizer = tokenizer self.test_dataset.tokenizer = tokenizer self.mol_token_id = self.tokenizer.mol_token_id # self.tokenizer.mol_token_id = tokenizer("", add_special_tokens=False).input_ids[0] def train_dataloader(self): if self.args.roundrobin_train: self.train_dataset.reload_data() if hasattr(self.train_dataset, 'renew_r_smiles'): self.train_dataset.renew_r_smiles() loader = DataLoader( self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=False, drop_last=True, persistent_workers=True, collate_fn=TrainCollater( tokenizer=self.tokenizer, text_max_len=self.text_max_len, rxn_max_len=self.rxn_max_len, mol_ph=self.mol_ph_token, mol_token_id=self.mol_token_id, is_gal=self.is_gal, use_graph=self.use_graph, use_qa_pair=not self.is_t5, ), ) return loader def val_dataloader(self): test_loader = DataLoader( self.test_dataset, batch_size=self.inference_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=False, drop_last=False, persistent_workers=True, collate_fn=InferenceCollater( tokenizer=self.tokenizer, text_max_len=self.text_max_len, rxn_max_len=self.rxn_max_len, mol_ph=self.mol_ph_token, mol_token_id=self.mol_token_id, is_gal=self.is_gal ), ) return [test_loader] val_loader = DataLoader( self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=False, drop_last=False, persistent_workers=True, collate_fn=InferenceCollater( tokenizer=self.tokenizer, text_max_len=self.text_max_len, rxn_max_len=self.rxn_max_len, mol_ph=self.mol_ph_token, mol_token_id=self.mol_token_id, is_gal=self.is_gal ), ) return [val_loader, test_loader] def test_dataloader(self): loader = DataLoader( self.test_dataset, batch_size=self.inference_batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=False, drop_last=False, persistent_workers=True, collate_fn=InferenceCollater( tokenizer=self.tokenizer, text_max_len=self.text_max_len, rxn_max_len=self.rxn_max_len, mol_ph=self.mol_ph_token, mol_token_id=self.mol_token_id, is_gal=self.is_gal ), ) return loader