-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathentity.py
92 lines (84 loc) · 3.72 KB
/
entity.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics.functional import accuracy
import dgl
from dgl.data.rdf import AIFBDataset, MUTAGDataset, BGSDataset, AMDataset
from dgl.nn.pytorch import RelGraphConv
import argparse
class RGCN(nn.Module):
def __init__(self, num_nodes, h_dim, out_dim, num_rels):
super().__init__()
self.emb = nn.Embedding(num_nodes, h_dim)
# two-layer RGCN
self.conv1 = RelGraphConv(h_dim, h_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
self.conv2 = RelGraphConv(h_dim, out_dim, num_rels, regularizer='basis',
num_bases=num_rels, self_loop=False)
def forward(self, g):
x = self.emb.weight
h = F.relu(self.conv1(g, x, g.edata[dgl.ETYPE], g.edata['norm']))
h = self.conv2(g, h, g.edata[dgl.ETYPE], g.edata['norm'])
return h
def evaluate(g, target_idx, labels, test_mask, model):
test_idx = torch.nonzero(test_mask, as_tuple=False).squeeze()
model.eval()
with torch.no_grad():
logits = model(g)
logits = logits[target_idx]
return accuracy(logits[test_idx].argmax(dim=1), labels[test_idx]).item()
def train(g, target_idx, labels, train_mask, model):
# define train idx, loss function and optimizer
train_idx = torch.nonzero(train_mask, as_tuple=False).squeeze()
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
model.train()
for epoch in range(50):
logits = model(g)
logits = logits[target_idx]
loss = loss_fcn(logits[train_idx], labels[train_idx])
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = accuracy(logits[train_idx].argmax(dim=1), labels[train_idx]).item()
print("Epoch {:05d} | Loss {:.4f} | Train Accuracy {:.4f} "
. format(epoch, loss.item(), acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='RGCN for entity classification')
parser.add_argument("--dataset", type=str, default="aifb",
help="Dataset name ('aifb', 'mutag', 'bgs', 'am').")
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training with DGL built-in RGCN module.')
# load and preprocess dataset
if args.dataset == 'aifb':
data = AIFBDataset()
elif args.dataset == 'mutag':
data = MUTAGDataset()
elif args.dataset == 'bgs':
data = BGSDataset()
elif args.dataset == 'am':
data = AMDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
g = data[0]
g = g.int().to(device)
num_rels = len(g.canonical_etypes)
category = data.predict_category
labels = g.nodes[category].data.pop('labels')
train_mask = g.nodes[category].data.pop('train_mask')
test_mask = g.nodes[category].data.pop('test_mask')
# calculate normalization weight for each edge, and find target category and node id
for cetype in g.canonical_etypes:
g.edges[cetype].data['norm'] = dgl.norm_by_dst(g, cetype).unsqueeze(1)
category_id = g.ntypes.index(category)
g = dgl.to_homogeneous(g, edata=['norm'])
node_ids = torch.arange(g.num_nodes()).to(device)
target_idx = node_ids[g.ndata[dgl.NTYPE] == category_id]
# create RGCN model
in_size = g.num_nodes() # featureless with one-hot encoding
out_size = data.num_classes
model = RGCN(in_size, 16, out_size, num_rels).to(device)
train(g, target_idx, labels, train_mask, model)
acc = evaluate(g, target_idx, labels, test_mask, model)
print("Test accuracy {:.4f}".format(acc))