forked from DuoLife-QNL/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Example] Hypergraph attention (dmlc#4941)
* hypergraph attention * address comments
- Loading branch information
Showing
1 changed file
with
116 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Hypergraph Convolution and Hypergraph Attention | ||
(https://arxiv.org/pdf/1901.08150.pdf). | ||
""" | ||
import dgl | ||
import dgl.mock_sparse as dglsp | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torchmetrics.functional import accuracy | ||
import tqdm | ||
|
||
def hypergraph_laplacian(H): | ||
########################################################### | ||
# (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API | ||
########################################################### | ||
d_V = H.sum(1) # node degree | ||
d_E = H.sum(0) # edge degree | ||
n_edges = d_E.shape[0] | ||
D_V_invsqrt = dglsp.diag(d_V ** -0.5) # D_V ** (-1/2) | ||
D_E_inv = dglsp.diag(d_E ** -1) # D_E ** (-1) | ||
W = dglsp.identity((n_edges, n_edges)) | ||
return D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt | ||
|
||
class HypergraphAttention(nn.Module): | ||
"""Hypergraph Attention module as in the paper | ||
`Hypergraph Convolution and Hypergraph Attention | ||
<https://arxiv.org/pdf/1901.08150.pdf>`_. | ||
""" | ||
def __init__(self, in_size, out_size): | ||
super().__init__() | ||
|
||
self.P = nn.Linear(in_size, out_size) | ||
self.a = nn.Linear(2 * out_size, 1) | ||
|
||
def forward(self, H, X, X_edges): | ||
Z = self.P(X) | ||
Z_edges = self.P(X_edges) | ||
sim = self.a(torch.cat([Z[H.row], Z_edges[H.col]], 1)) | ||
sim = F.leaky_relu(sim, 0.2).squeeze(1) | ||
# Reassign the hypergraph new weights. | ||
H_att = dglsp.create_from_coo(H.row, H.col, sim, shape=H.shape) | ||
H_att = H_att.softmax() | ||
return hypergraph_laplacian(H_att) @ Z | ||
|
||
class Net(nn.Module): | ||
def __init__(self, in_size, out_size, hidden_size=16): | ||
super().__init__() | ||
|
||
self.layer1 = HypergraphAttention(in_size, hidden_size) | ||
self.layer2 = HypergraphAttention(hidden_size, out_size) | ||
|
||
def forward(self, H, X): | ||
Z = self.layer1(H, X, X) | ||
Z = F.elu(Z) | ||
Z = self.layer2(H, Z, Z) | ||
return Z | ||
|
||
def train(model, optimizer, H, X, Y, train_mask): | ||
model.train() | ||
Y_hat = model(H, X) | ||
loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask]) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
return loss.item() | ||
|
||
def evaluate(model, H, X, Y, val_mask, test_mask): | ||
model.eval() | ||
Y_hat = model(H, X) | ||
val_acc = accuracy(Y_hat[val_mask], Y[val_mask]) | ||
test_acc = accuracy(Y_hat[test_mask], Y[test_mask]) | ||
return val_acc, test_acc | ||
|
||
def load_data(): | ||
dataset = dgl.data.CoraGraphDataset() | ||
|
||
graph = dataset[0] | ||
# The paper created a hypergraph from the original graph. For each node in | ||
# the original graph, a hyperedge in the hypergraph is created to connect | ||
# its neighbors and itself. In this case, the incidence matrix of the | ||
# hypergraph is the same as the adjacency matrix of the original graph (with | ||
# self-loops). | ||
# We follow the paper and assume that the rows of the incidence matrix | ||
# are for nodes and the columns are for edges. | ||
src, dst = graph.edges() | ||
H = dglsp.create_from_coo(dst, src) | ||
H = H + dglsp.identity(H.shape) | ||
|
||
X = graph.ndata["feat"] | ||
Y = graph.ndata["label"] | ||
train_mask = graph.ndata["train_mask"] | ||
val_mask = graph.ndata["val_mask"] | ||
test_mask = graph.ndata["test_mask"] | ||
return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask | ||
|
||
def main(): | ||
H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data() | ||
model = Net(X.shape[1], num_classes) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | ||
|
||
with tqdm.trange(500) as tq: | ||
for epoch in tq: | ||
loss = train(model, optimizer, H, X, Y, train_mask) | ||
val_acc, test_acc = evaluate(model, H, X, Y, val_mask, test_mask) | ||
tq.set_postfix( | ||
{ | ||
"Loss": f"{loss:.5f}", | ||
"Val acc": f"{val_acc:.5f}", | ||
"Test acc": f"{test_acc:.5f}", | ||
}, | ||
refresh=False, | ||
) | ||
|
||
if __name__ == '__main__': | ||
main() |