-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsampler.py
92 lines (80 loc) · 2.83 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import random
import dgl.function as fn
import torch
from partition_utils import *
class ClusterIter(object):
'''The partition sampler given a DGLGraph and partition number.
The metis is used as the graph partition backend.
'''
def __init__(self, dn, g, psize, batch_size, seed_nid, use_pp=True):
"""Initialize the sampler.
Paramters
---------
dn : str
The dataset name.
g : DGLGraph
The full graph of dataset
psize: int
The partition number
batch_size: int
The number of partitions in one batch
seed_nid: np.ndarray
The training nodes ids, used to extract the training graph
use_pp: bool
Whether to use precompute of AX
"""
self.use_pp = use_pp
self.g = g.subgraph(seed_nid)
# precalc the aggregated features from training graph only
if use_pp:
self.precalc(self.g)
print('precalculating')
self.psize = psize
self.batch_size = batch_size
# cache the partitions of known datasets&partition number
if dn:
fn = os.path.join('./datasets/', dn + '_{}.npy'.format(psize))
if os.path.exists(fn):
self.par_li = np.load(fn, allow_pickle=True)
else:
os.makedirs('./datasets/', exist_ok=True)
self.par_li = get_partition_list(self.g, psize)
np.save(fn, self.par_li)
else:
self.par_li = get_partition_list(self.g, psize)
self.max = int((psize) // batch_size)
random.shuffle(self.par_li)
self.get_fn = get_subgraph
def precalc(self, g):
norm = self.get_norm(g)
g.ndata['norm'] = norm
features = g.ndata['feat']
print("features shape, ", features.shape)
with torch.no_grad():
g.update_all(fn.copy_src(src='feat', out='m'),
fn.sum(msg='m', out='feat'),
None)
pre_feats = g.ndata['feat'] * norm
# use graphsage embedding aggregation style
g.ndata['feat'] = torch.cat([features, pre_feats], dim=1)
# use one side normalization
def get_norm(self, g):
norm = 1. / g.in_degrees().float().unsqueeze(1)
norm[torch.isinf(norm)] = 0
norm = norm.to(self.g.ndata['feat'].device)
return norm
def __len__(self):
return self.max
def __iter__(self):
self.n = 0
return self
def __next__(self):
if self.n < self.max:
result = self.get_fn(self.g, self.par_li, self.n,
self.psize, self.batch_size)
self.n += 1
return result
else:
random.shuffle(self.par_li)
raise StopIteration