-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdataset.py
109 lines (82 loc) · 3.23 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from torch.utils.data import Dataset
from rdkit import Chem
import os, random, gc
import pickle
from hgraph.chemutils import get_leaves
from hgraph.mol_graph import MolGraph
class MoleculeDataset(Dataset):
def __init__(self, data, vocab, avocab, batch_size):
safe_data = []
for mol_s in data:
hmol = MolGraph(mol_s)
ok = True
for node,attr in hmol.mol_tree.nodes(data=True):
smiles = attr['smiles']
ok &= attr['label'] in vocab.vmap
for i,s in attr['inter_label']:
ok &= (smiles, s) in vocab.vmap
if ok:
safe_data.append(mol_s)
print(f'After pruning {len(data)} -> {len(safe_data)}')
self.batches = [safe_data[i : i + batch_size] for i in range(0, len(safe_data), batch_size)]
self.vocab = vocab
self.avocab = avocab
def __len__(self):
return len(self.batches)
def __getitem__(self, idx):
return MolGraph.tensorize(self.batches[idx], self.vocab, self.avocab)
class MolEnumRootDataset(Dataset):
def __init__(self, data, vocab, avocab):
self.batches = data
self.vocab = vocab
self.avocab = avocab
def __len__(self):
return len(self.batches)
def __getitem__(self, idx):
mol = Chem.MolFromSmiles(self.batches[idx])
leaves = get_leaves(mol)
smiles_list = set( [Chem.MolToSmiles(mol, rootedAtAtom=i, isomericSmiles=False) for i in leaves] )
smiles_list = sorted(list(smiles_list)) #To ensure reproducibility
safe_list = []
for s in smiles_list:
hmol = MolGraph(s)
ok = True
for node,attr in hmol.mol_tree.nodes(data=True):
if attr['label'] not in self.vocab.vmap:
ok = False
if ok: safe_list.append(s)
if len(safe_list) > 0:
return MolGraph.tensorize(safe_list, self.vocab, self.avocab)
else:
return None
class MolPairDataset(Dataset):
def __init__(self, data, vocab, avocab, batch_size):
self.batches = [data[i : i + batch_size] for i in range(0, len(data), batch_size)]
self.vocab = vocab
self.avocab = avocab
def __len__(self):
return len(self.batches)
def __getitem__(self, idx):
x, y = zip(*self.batches[idx])
x = MolGraph.tensorize(x, self.vocab, self.avocab)[:-1] #no need of order for x
y = MolGraph.tensorize(y, self.vocab, self.avocab)
return x + y
class DataFolder(object):
def __init__(self, data_folder, batch_size, shuffle=True):
self.data_folder = data_folder
self.data_files = [fn for fn in os.listdir(data_folder)]
self.batch_size = batch_size
self.shuffle = shuffle
def __len__(self):
return len(self.data_files) * 1000
def __iter__(self):
for fn in self.data_files:
fn = os.path.join(self.data_folder, fn)
with open(fn, 'rb') as f:
batches = pickle.load(f)
if self.shuffle: random.shuffle(batches) #shuffle data before batch
for batch in batches:
yield batch
del batches
gc.collect()