ReactXT / data_provider /caption_dataset.py
SyrWin
init
95f97c5
raw
history blame
2.93 kB
import torch
from torch_geometric.data import Dataset
import os
from torch_geometric.data import InMemoryDataset
from .data_utils import reformat_smiles
import random
import json
class PubChemDataset(InMemoryDataset):
def __init__(self, path):
super(PubChemDataset, self).__init__()
self.data, self.slices = torch.load(path)
def __getitem__(self, idx):
return self.get(idx)
class CaptionDataset(Dataset):
def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'):
super(CaptionDataset, self).__init__(root)
self.root = root
self.file_path = os.path.join(root, f'{mode}.pt')
self.smi_max_len = smi_max_len
self.tokenizer = None
self.use_graph = use_graph
self.smiles_type = smiles_type
self.data = PubChemDataset(self.file_path)
def get(self, index):
return self.__getitem__(index)
def len(self):
return len(self)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
data = self.data[index]
smiles = reformat_smiles(data.smiles, smiles_type=self.smiles_type)
smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. '
text_list = []
count = 0
for line in data.text.split('\n'):
count += 1
text_list.append(line.strip())
if count > 100:
break
text = ' '.join(text_list) + '\n'
graph_list = [data] if self.use_graph else []
return index, graph_list, text, smiles_prompt
class PretrainCaptionDataset(Dataset):
def __init__(self, root, smi_max_len=128, use_graph=True, disable_graph_cache=False):
super(PretrainCaptionDataset, self).__init__(root)
self.pre_train_data = CaptionDataset(
root,
'pretrain',
smi_max_len=smi_max_len,
use_graph=use_graph,
)
self.train_data = CaptionDataset(
root,
'train',
smi_max_len=smi_max_len,
use_graph=use_graph,
)
def get(self, index):
return self.__getitem__(index)
def len(self):
return len(self)
def __len__(self):
return len(self.pre_train_data) + len(self.train_data)
def __getitem__(self, index):
if index < len(self.pre_train_data):
index, graph_list, text, smiles_prompt = self.pre_train_data[index]
else:
index, graph_list, text, smiles_prompt = self.train_data[index - len(self.pre_train_data)]
graph_item = graph_list[0]
if hasattr(graph_item, 'iupac'):
del graph_item.iupac
if hasattr(graph_item, 'cid'):
del graph_item.cid
del graph_item.text
del graph_item.smiles
return graph_item, text, smiles_prompt