Spaces:
Runtime error
Runtime error
import torch | |
from torch_geometric.data import Dataset | |
import os | |
from torch_geometric.data import InMemoryDataset | |
from .data_utils import reformat_smiles | |
import random | |
import json | |
class PubChemDataset(InMemoryDataset): | |
def __init__(self, path): | |
super(PubChemDataset, self).__init__() | |
self.data, self.slices = torch.load(path) | |
def __getitem__(self, idx): | |
return self.get(idx) | |
class CaptionDataset(Dataset): | |
def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'): | |
super(CaptionDataset, self).__init__(root) | |
self.root = root | |
self.file_path = os.path.join(root, f'{mode}.pt') | |
self.smi_max_len = smi_max_len | |
self.tokenizer = None | |
self.use_graph = use_graph | |
self.smiles_type = smiles_type | |
self.data = PubChemDataset(self.file_path) | |
def get(self, index): | |
return self.__getitem__(index) | |
def len(self): | |
return len(self) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
data = self.data[index] | |
smiles = reformat_smiles(data.smiles, smiles_type=self.smiles_type) | |
smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. ' | |
text_list = [] | |
count = 0 | |
for line in data.text.split('\n'): | |
count += 1 | |
text_list.append(line.strip()) | |
if count > 100: | |
break | |
text = ' '.join(text_list) + '\n' | |
graph_list = [data] if self.use_graph else [] | |
return index, graph_list, text, smiles_prompt | |
class PretrainCaptionDataset(Dataset): | |
def __init__(self, root, smi_max_len=128, use_graph=True, disable_graph_cache=False): | |
super(PretrainCaptionDataset, self).__init__(root) | |
self.pre_train_data = CaptionDataset( | |
root, | |
'pretrain', | |
smi_max_len=smi_max_len, | |
use_graph=use_graph, | |
) | |
self.train_data = CaptionDataset( | |
root, | |
'train', | |
smi_max_len=smi_max_len, | |
use_graph=use_graph, | |
) | |
def get(self, index): | |
return self.__getitem__(index) | |
def len(self): | |
return len(self) | |
def __len__(self): | |
return len(self.pre_train_data) + len(self.train_data) | |
def __getitem__(self, index): | |
if index < len(self.pre_train_data): | |
index, graph_list, text, smiles_prompt = self.pre_train_data[index] | |
else: | |
index, graph_list, text, smiles_prompt = self.train_data[index - len(self.pre_train_data)] | |
graph_item = graph_list[0] | |
if hasattr(graph_item, 'iupac'): | |
del graph_item.iupac | |
if hasattr(graph_item, 'cid'): | |
del graph_item.cid | |
del graph_item.text | |
del graph_item.smiles | |
return graph_item, text, smiles_prompt |