-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathentity_sample.py
120 lines (112 loc) · 5.39 KB
/
entity_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.dataloading import MultiLayerNeighborSampler, DataLoader
from dgl.nn.pytorch import RelGraphConv
import argparse
class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb(g[0].srcdata[dgl.NID])
h = F.relu(self.conv1(g[0], x, g[0].edata[dgl.ETYPE], g[0].edata['norm']))
h = self.conv2(g[1], h, g[1].edata[dgl.ETYPE], g[1].edata['norm'])
return h
def evaluate(model, label, dataloader, inv_target):
model.eval()
eval_logits = []
eval_seeds = []
with torch.no_grad():
for input_nodes, output_nodes, blocks in dataloader:
output_nodes = inv_target[output_nodes]
for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
eval_logits.append(logits.cpu().detach())
eval_seeds.append(output_nodes.cpu().detach())
eval_logits = torch.cat(eval_logits)
eval_seeds = torch.cat(eval_seeds)
return accuracy(eval_logits.argmax(dim=1), labels[eval_seeds].cpu()).item()
def train(device, g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
# construct sampler and dataloader
sampler = MultiLayerNeighborSampler([4, 4])
train_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
batch_size=100, shuffle=True)
# no separate validation subset, use train index instead for validation
val_loader = DataLoader(g, target_idx[train_idx], sampler, device=device,
batch_size=100, shuffle=False)
for epoch in range(50):
model.train()
total_loss = 0
for it, (input_nodes, output_nodes, blocks) in enumerate(train_loader):
output_nodes = inv_target[output_nodes]
for block in blocks:
block.edata['norm'] = dgl.norm_by_dst(block).unsqueeze(1)
logits = model(blocks)
loss = loss_fcn(logits, labels[output_nodes])
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
acc = evaluate(model, labels, val_loader, inv_target)
print("Epoch {:05d} | Loss {:.4f} | Val. Accuracy {:.4f} "
. format(epoch, total_loss / (it+1), acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification with sampling')
parser.add_argument("--dataset", type=str, default="aifb",
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training with DGL built-in RGCN module with sampling.')
# load and preprocess dataset
if args.dataset == 'aifb':
data = AIFBDataset()
elif args.dataset == 'mutag':
data = MUTAGDataset()
elif args.dataset == 'bgs':
data = BGSDataset()
elif args.dataset == 'am':
data = AMDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
num_rels = len(g.canonical_etypes)
category = data.predict_category
labels = g.nodes[category].data.pop('labels').to(device)
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
# find target category and node id
category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g)
node_ids = torch.arange(g.num_nodes())
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# rename the fields as they can be changed by DataLoader
g.ndata['ntype'] = g.ndata.pop(dgl.NTYPE)
g.ndata['type_id'] = g.ndata.pop(dgl.NID)
# find the mapping (inv_target) from global node IDs to type-specific node IDs
inv_target = torch.empty((g.num_nodes(),), dtype=torch.int64).to(device)
inv_target[target_idx] = torch.arange(0, target_idx.shape[0], dtype=inv_target.dtype).to(device)
# create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(device, g, target_idx, labels, train_mask, model)
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
test_sampler = MultiLayerNeighborSampler([-1, -1]) # -1 for sampling all neighbors
test_loader = DataLoader(g, target_idx[test_idx], test_sampler, device=device,
batch_size=32, shuffle=False)
acc = evaluate(model, labels, test_loader, inv_target)
print("Test accuracy {:.4f}".format(acc))