forked from BUPT-GAMMA/GammaGL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeepwalk_trainer.py
125 lines (103 loc) · 4.98 KB
/
deepwalk_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# os.environ['TL_BACKEND'] = 'torch'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 0:Output all; 1:Filter out INFO; 2:Filter out INFO and WARNING; 3:Filter out INFO, WARNING, and ERROR
import argparse
import tensorlayerx as tlx
from gammagl.datasets import Planetoid
from gammagl.models import DeepWalkModel
from gammagl.utils import calc_gcn_norm, mask_to_index
from tensorlayerx.model import TrainOneStep, WithLoss
from sklearn.linear_model import LogisticRegression
class Unsupervised_Loss(WithLoss):
def __init__(self, net, loss_fn):
super(Unsupervised_Loss, self).__init__(backbone=net, loss_fn=loss_fn)
def forward(self, data, label):
logits = self.backbone_network(data["edge_index"])
loss = self._loss_fn(logits, label)
return loss
def calculate_acc(train_z, train_y, test_z, test_y, solver='lbfgs', multi_class='auto', max_iter=150):
train_z = tlx.convert_to_numpy(train_z)
train_y = tlx.convert_to_numpy(train_y)
test_z = tlx.convert_to_numpy(test_z)
test_y = tlx.convert_to_numpy(test_y)
clf = LogisticRegression(solver=solver, multi_class=multi_class, max_iter=max_iter).fit(train_z, train_y)
return clf.score(test_z, test_y)
def main(args):
# load datasets
if str.lower(args.dataset) not in ['cora', 'pubmed', 'citeseer']:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
dataset = Planetoid(args.dataset_path, args.dataset)
graph = dataset[0]
edge_index = graph.edge_index
edge_weight = tlx.convert_to_tensor(calc_gcn_norm(edge_index, graph.num_nodes, edge_weight=graph.edge_weight))
# for mindspore, it should be passed into node indices
train_idx = mask_to_index(graph.train_mask)
test_idx = mask_to_index(graph.test_mask)
val_idx = mask_to_index(graph.val_mask)
model = DeepWalkModel(edge_index=edge_index,
edge_weight=edge_weight,
embedding_dim=args.embedding_dim,
walk_length=args.walk_length,
num_walks=args.num_walks,
window_size=args.window_size,
name="DeepWalk")
optimizer = tlx.optimizers.Adam(lr=args.lr)
train_weights = model.trainable_weights
loss_func = Unsupervised_Loss(net=model, loss_fn=tlx.losses.absolute_difference_error)
train_one_step = TrainOneStep(loss_func, optimizer, train_weights)
data = {
"x": graph.x,
"y": graph.y,
"edge_index": graph.edge_index,
"train_idx": train_idx,
"test_idx": test_idx,
"val_idx": val_idx,
"num_nodes": graph.num_nodes,
}
best_val_acc = 0
for epoch in range(args.n_epoch):
model.set_train()
train_loss = train_one_step(data, tlx.convert_to_tensor(0, dtype=tlx.float32))
model.set_eval()
z = model.campute()
val_acc = calculate_acc(tlx.gather(z[0], data['train_idx']), tlx.gather(graph.y, data['train_idx']),
tlx.gather(z[0], data['val_idx']), tlx.gather(graph.y, data['val_idx']),
max_iter=150)
print("Epoch [{:0>3d}] ".format(epoch + 1) \
+ " train loss: {:.4f}".format(train_loss.item()) \
+ " val acc: {:.4f}".format(val_acc))
# save best model on evaluation set
if val_acc > best_val_acc:
best_val_acc = val_acc
model.save_weights(args.best_model_path + model.name + ".npz", format='npz_dict')
model.load_weights(args.best_model_path + model.name + ".npz", format='npz_dict')
if tlx.BACKEND == 'torch':
model.to(data['x'].device)
model.set_eval()
z = model.campute()
test_acc = calculate_acc(tlx.gather(z[0], data['train_idx']), tlx.gather(graph.y, data['train_idx']),
tlx.gather(z[0], data['test_idx']), tlx.gather(graph.y, data['test_idx']),
max_iter=150)
print("Test acc: {:.4f}".format(test_acc))
return test_acc
if __name__ == '__main__':
# parameters setting
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='cora', help='dataset')
parser.add_argument("--dataset_path", type=str, default=r'', help="path to save dataset")
parser.add_argument("--best_model_path", type=str, default=r'./', help="path to save best model")
parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
parser.add_argument("--n_epoch", type=int, default=100, help="number of epoch")
parser.add_argument("--embedding_dim", type=int, default=128)
parser.add_argument("--walk_length", type=int, default=10)
parser.add_argument("--num_walks", type=int, default=10)
parser.add_argument("--window_size", type=int, default=5)
parser.add_argument("--gpu", type=int, default=-1)
args = parser.parse_args()
if args.gpu >= 0:
tlx.set_device("GPU", args.gpu)
else:
tlx.set_device("CPU")
main(args)