-
Notifications
You must be signed in to change notification settings - Fork 31
/
train.py
55 lines (45 loc) · 1.77 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tqdm
from utils import *
from torch.optim import Adam
def train(model, X, y, A, A_norm, Ad):
"""
train our model
Args:
model: Dual Correlation Reduction Network
X: input feature matrix
y: input label
A: input origin adj
A_norm: normalized adj
Ad: graph diffusion
Returns: acc, nmi, ari, f1
"""
print("Training…")
# calculate embedding similarity and cluster centers
sim, centers = model_init(model, X, y, A_norm)
# initialize cluster centers
model.cluster_centers.data = torch.tensor(centers).to(opt.args.device)
# edge-masked adjacency matrix (Am): remove edges based on feature-similarity
Am = remove_edge(A, sim, remove_rate=0.1)
optimizer = Adam(model.parameters(), lr=opt.args.lr)
for epoch in tqdm.tqdm(range(opt.args.epoch)):
# add gaussian noise to X
X_tilde1, X_tilde2 = gaussian_noised_feature(X)
# input & output
X_hat, Z_hat, A_hat, _, Z_ae_all, Z_gae_all, Q, Z, AZ_all, Z_all = model(X_tilde1, Ad, X_tilde2, Am)
# calculate loss: L_{DICR}, L_{REC} and L_{KL}
L_DICR = dicr_loss(Z_ae_all, Z_gae_all, AZ_all, Z_all)
L_REC = reconstruction_loss(X, A_norm, X_hat, Z_hat, A_hat)
L_KL = distribution_loss(Q, target_distribution(Q[0].data))
loss = L_DICR + L_REC + opt.args.lambda_value * L_KL
# optimization
optimizer.zero_grad()
loss.backward(retain_graph=True)
optimizer.step()
# clustering & evaluation
acc, nmi, ari, f1, _ = clustering(Z, y)
if acc > opt.args.acc:
opt.args.acc = acc
opt.args.nmi = nmi
opt.args.ari = ari
opt.args.f1 = f1
return opt.args.acc, opt.args.nmi, opt.args.ari, opt.args.f1