-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_loader.py
61 lines (57 loc) · 2.77 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from torch.utils.data import Dataset, DataLoader
import dgl
from dgl.data import PPIDataset
import collections
#implement the collate_fn for dgl graph data class
PPIBatch = collections.namedtuple('PPIBatch', ['graph', 'label'])
def batcher(device):
def batcher_dev(batch):
batch_graphs = dgl.batch(batch)
return PPIBatch(graph=batch_graphs,
label=batch_graphs.ndata['label'].to(device))
return batcher_dev
#add a fresh "self-loop" edge type to the untyped PPI dataset and prepare train, val, test loaders
def load_PPI(batch_size=1, device='cpu'):
train_set = PPIDataset(mode='train')
valid_set = PPIDataset(mode='valid')
test_set = PPIDataset(mode='test')
#for each graph, add self-loops as a new relation type
#here we reconstruct the graph since the schema of a heterograph cannot be changed once constructed
for i in range(len(train_set)):
g = dgl.heterograph({
('_N','_E','_N'): train_set[i].edges(),
('_N', 'self', '_N'): (train_set[i].nodes(), train_set[i].nodes())
})
g.ndata['label'] = train_set[i].ndata['label']
g.ndata['feat'] = train_set[i].ndata['feat']
g.ndata['_ID'] = train_set[i].ndata['_ID']
g.edges['_E'].data['_ID'] = train_set[i].edata['_ID']
train_set.graphs[i] = g
for i in range(len(valid_set)):
g = dgl.heterograph({
('_N','_E','_N'): valid_set[i].edges(),
('_N', 'self', '_N'): (valid_set[i].nodes(), valid_set[i].nodes())
})
g.ndata['label'] = valid_set[i].ndata['label']
g.ndata['feat'] = valid_set[i].ndata['feat']
g.ndata['_ID'] = valid_set[i].ndata['_ID']
g.edges['_E'].data['_ID'] = valid_set[i].edata['_ID']
valid_set.graphs[i] = g
for i in range(len(test_set)):
g = dgl.heterograph({
('_N','_E','_N'): test_set[i].edges(),
('_N', 'self', '_N'): (test_set[i].nodes(), test_set[i].nodes())
})
g.ndata['label'] = test_set[i].ndata['label']
g.ndata['feat'] = test_set[i].ndata['feat']
g.ndata['_ID'] = test_set[i].ndata['_ID']
g.edges['_E'].data['_ID'] = test_set[i].edata['_ID']
test_set.graphs[i] = g
etypes = train_set[0].etypes
in_size = train_set[0].ndata['feat'].shape[1]
out_size = train_set[0].ndata['label'].shape[1]
#prepare train, valid, and test dataloaders
train_loader = DataLoader(train_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=batcher(device), shuffle=True)
return train_loader, valid_loader, test_loader, etypes, in_size, out_size