forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
144 lines (131 loc) · 5.89 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data.sampler import SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling
import argparse
class MLP(nn.Module):
"""Construct two-layer MLP-type aggreator for GIN model"""
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.linears = nn.ModuleList()
# two-layer MLP
self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
self.batch_norm = nn.BatchNorm1d((hidden_dim))
def forward(self, x):
h = x
h = F.relu(self.batch_norm(self.linears[0](h)))
return self.linears[1](h)
class GIN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.ginlayers = nn.ModuleList()
self.batch_norms = nn.ModuleList()
num_layers = 5
# five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
for layer in range(num_layers - 1): # excluding the input layer
if layer == 0:
mlp = MLP(input_dim, hidden_dim, hidden_dim)
else:
mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
self.ginlayers.append(GINConv(mlp, learn_eps=False)) # set to True if learning epsilon
self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
# linear functions for graph sum poolings of output of each layer
self.linear_prediction = nn.ModuleList()
for layer in range(num_layers):
if layer == 0:
self.linear_prediction.append(nn.Linear(input_dim, output_dim))
else:
self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
self.drop = nn.Dropout(0.5)
self.pool = SumPooling() # change to mean readout (AvgPooling) on social network datasets
def forward(self, g, h):
# list of hidden representation at each layer (including the input layer)
hidden_rep = [h]
for i, layer in enumerate(self.ginlayers):
h = layer(g, h)
h = self.batch_norms[i](h)
h = F.relu(h)
hidden_rep.append(h)
score_over_layer = 0
# perform graph sum pooling over all nodes in each layer
for i, h in enumerate(hidden_rep):
pooled_h = self.pool(g, h)
score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
return score_over_layer
def split_fold10(labels, fold_idx=0):
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=0)
idx_list = []
for idx in skf.split(np.zeros(len(labels)), labels):
idx_list.append(idx)
train_idx, valid_idx = idx_list[fold_idx]
return train_idx, valid_idx
def evaluate(dataloader, device, model):
model.eval()
total = 0
total_correct = 0
for batched_graph, labels in dataloader:
batched_graph = batched_graph.to(device)
labels = labels.to(device)
feat = batched_graph.ndata.pop('attr')
total += len(labels)
logits = model(batched_graph, feat)
_, predicted = torch.max(logits, 1)
total_correct += (predicted == labels).sum().item()
acc = 1.0 * total_correct / total
return acc
def train(train_loader, val_loader, device, model):
# loss function, optimizer and scheduler
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# training loop
for epoch in range(350):
model.train()
total_loss = 0
for batch, (batched_graph, labels) in enumerate(train_loader):
batched_graph = batched_graph.to(device)
labels = labels.to(device)
feat = batched_graph.ndata.pop('attr')
logits = model(batched_graph, feat)
loss = loss_fcn(logits, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
scheduler.step()
train_acc = evaluate(train_loader, device, model)
valid_acc = evaluate(val_loader, device, model)
print("Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} "
. format(epoch, total_loss / (batch + 1), train_acc, valid_acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="MUTAG",
choices=['MUTAG', 'PTC', 'NCI1', 'PROTEINS'],
help='name of dataset (default: MUTAG)')
args = parser.parse_args()
print(f'Training with DGL built-in GINConv module with a fixed epsilon = 0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# load and split dataset
dataset = GINDataset(args.dataset, self_loop=True, degree_as_nlabel=False) # add self_loop and disable one-hot encoding for input features
labels = [l for _, l in dataset]
train_idx, val_idx = split_fold10(labels)
# create dataloader
train_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(train_idx),
batch_size=128, pin_memory=torch.cuda.is_available())
val_loader = GraphDataLoader(dataset, sampler=SubsetRandomSampler(val_idx),
batch_size=128, pin_memory=torch.cuda.is_available())
# create GIN model
in_size = dataset.dim_nfeats
out_size = dataset.gclasses
model = GIN(in_size, 16, out_size).to(device)
# model training/validating
print('Training...')
train(train_loader, val_loader, device, model)