-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheda.py
154 lines (119 loc) · 4.98 KB
/
eda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
from wilds import get_dataset
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper
import torch
import numpy as np
from modules.gnn import GNN
from colorama import Fore
from utils import create_ff_network
from modules.fly_lsh import FlyLSH
# get the dataset
root_dir = "./data"
dataset = get_dataset("ogb-molpcba", download=True, root_dir=root_dir)
# get the train set and the training loader
train_data = dataset.get_subset("train", frac=1.0)
test_data = dataset.get_subset("test")
grouper = CombinatorialGrouper(dataset, ['scaffold'])
# prepare a group-based train loader
train_loader = get_train_loader(
"group", train_data, grouper=grouper, n_groups_per_batch=8, batch_size=256)
test_loader = get_eval_loader("standard", test_data, batch_size=5000)
# load an example batch
batch, _, batch_metadata = next(iter(train_loader))
# get relevant information about the data set
num_tasks = dataset.ogb_dataset.num_tasks
num_features = dataset.ogb_dataset.num_features
num_edge_features = dataset.ogb_dataset.num_edge_features
""" Build a GNN """
embedding_dim = 300
gnn = GNN(num_tasks, num_layer=5, emb_dim=embedding_dim,
gnn_type='gin', virtual_node=True, residual=False, drop_ratio=0.5, JK="sum", graph_pooling="mean")
""" Build a FlyLSH layer """
lsh_out_dim = 2000
tag_dim = 6
sr = tag_dim / embedding_dim
fly_lsh = FlyLSH(input_dim=embedding_dim, out_dim=lsh_out_dim, tag_dim=tag_dim, weight_sparsity=sr)
lsh_optimizer = torch.optim.Adam(fly_lsh.parameters(), lr=0.001)
""" Build a downstream classifier """
cls_loss_fn = torch.nn.BCEWithLogitsLoss()
classifier = create_ff_network([embedding_dim, num_tasks], h_activation='none', out_activation="none")
cls_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
""" Training """
# move model to gpu
device = "cuda"
gnn = gnn.to(device)
classifier = classifier.to(device)
fly_lsh = fly_lsh.to(device)
# create an optimizer
optimizer = torch.optim.Adam(gnn.parameters(), lr=0.003)
num_epochs = 50
num_batches = len(train_loader)
losses = []
cls_losses = []
#accs = []
for epoch in range(num_epochs):
epoch_losses = []
epoch_cls_losses = []
#epoch_accs = []
epoch_pred = []
epoch_true = []
epoch_metadata = []
batch_itr = 1
for batch, _, batch_metadata in train_loader:
# move to device
batch.to(device)
# compute embeddings
_, embedding = gnn(batch)
# compute lsh tag
tag = fly_lsh(embedding)
# pass to classifier
cls_logits = classifier(embedding)
cls_pred = (torch.sigmoid(cls_logits) > 0.5).float()
# compute similarity matching sim_loss
similarity = (batch_metadata[:, [0]] == batch_metadata[:, [0]].T).float().to(device)
similarity[similarity == 0] = -1
embed_sim = ((torch.tanh(embedding) @ torch.tanh(embedding).T) / embedding_dim)
sim_loss = torch.mean((similarity - embed_sim) ** 2) / 4
# compute classification sim_loss
is_labeled = batch.y == batch.y
cls_loss = cls_loss_fn(cls_logits.float()[is_labeled], batch.y.float()[is_labeled])
#accuracy = (cls_pred[is_labeled] == batch.y[is_labeled]).float().mean()
# optimize
optimizer.zero_grad()
lsh_optimizer.zero_grad()
cls_optimizer.zero_grad()
(cls_loss).backward()
optimizer.step()
lsh_optimizer.step()
cls_optimizer.step()
# append
epoch_losses.append(sim_loss.detach().item())
epoch_cls_losses.append(cls_loss.detach().item())
#epoch_accs.append(accuracy.detach().item())
epoch_pred.append(cls_pred)
epoch_true.append(batch.y)
epoch_metadata.append(batch_metadata)
# average precision
try:
ap = dataset.eval(cls_pred, batch.y, batch_metadata)[0]['ap']
except RuntimeError:
ap = torch.nan
# update
print(Fore.YELLOW + f'[train] Epoch {epoch + 1} ({batch_itr} / {num_batches}):\t '
f'\033[1mSM Loss\033[0m = {sim_loss.detach().item():0.3f}\t'
f'\033[1mBCE Loss\033[0m = {cls_loss.detach().item():0.3f}\t'
#f'\033[1mAccuracy\033[0m = {accuracy.detach().item():0.3f}\t'
f'\033[1mAP\033[0m = {ap:0.3f}',
end='\r')
batch_itr += 1
print('')
avg_loss = np.mean(epoch_losses)
avg_cls_loss = np.mean(epoch_cls_losses)
#avg_acc = np.mean(epoch_accs)
epoch_ap = dataset.eval(torch.cat(epoch_pred), torch.cat(epoch_true), torch.cat(epoch_metadata))[0]['ap']
losses.append(avg_loss)
cls_losses.append(avg_cls_loss)
#accs.append(avg_acc)
if epoch % 1 == 0:
print(f"Epoch {epoch} | SM Loss {avg_loss: 0.3f} | BCE Loss {avg_cls_loss: 0.3f} | AP {epoch_ap: 0.3f}")
print("Done")