Skip to content

Commit

Permalink
Add random_walk_pe (#273)
Browse files Browse the repository at this point in the history
* Add randomwalk positional encoding

* Add test randomwalkPE

* export randomwalkPE

* Fix degree

Co-authored-by: Carlo Lucibello <[email protected]>

* Fix initialization matrix

Co-authored-by: Carlo Lucibello <[email protected]>

* Change return

* Add clearer adjacency matrix parameters

Co-authored-by: Carlo Lucibello <[email protected]>

* Fix dense_zeros_like

* Rename function

* Export correct function

* Add new test compared with PyTorch

* Add docstring

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
aurorarossi and CarloLucibello authored Mar 25, 2023
1 parent 5b23e9c commit 05fca7c
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 11 deletions.
1 change: 1 addition & 0 deletions src/GNNGraphs/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export add_nodes,
set_edge_weight,
to_bidirected,
to_unidirected,
random_walk_pe,
# from Flux
batch,
unbatch,
Expand Down
27 changes: 27 additions & 0 deletions src/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,33 @@ function rand_edge_split(g::GNNGraph, frac; bidirected = is_bidirected(g))
return g1, g2
end

"""
random_walk_pe(g, walk_length)
Return the random walk positional encoding from the paper [Graph Neural Networks with Learnable Structural and Positional Representations](https://arxiv.org/abs/2110.07875) of the given graph `g` and the length of the walk `walk_length` as a matrix of size `(walk_length, g.num_nodes)`.
"""
function random_walk_pe(g::GNNGraph, walk_length::Int)
matrix = zeros(walk_length, g.num_nodes)
adj = adjacency_matrix(g, Float32; dir = :out)
matrix = dense_zeros_like(adj, Float32, (walk_length, g.num_nodes))
deg = sum(adj, dims = 2) |> vec
deg_inv = inv.(deg)
deg_inv[isinf.(deg_inv)] .= 0
RW = adj * Diagonal(deg_inv)
out = RW
matrix[1, :] .= diag(RW)
for i in 2:walk_length
out = out * RW
matrix[i, :] .= diag(out)
end
return matrix
end

dense_zeros_like(a::SparseMatrixCSC, T::Type, sz = size(a)) = zeros(T, sz)
dense_zeros_like(a::AbstractArray, T::Type, sz = size(a)) = fill!(similar(a, T, sz), 0)
dense_zeros_like(a::CUMAT_T, T::Type, sz = size(a)) = CUDA.zeros(T, sz)
dense_zeros_like(x, sz = size(x)) = dense_zeros_like(x, eltype(x), sz)

# """
# Transform vector of cartesian indexes into a tuple of vectors containing integers.
# """
Expand Down
33 changes: 22 additions & 11 deletions test/GNNGraphs/transform.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@testset "add self-loops" begin
A = [1 1 0 0
0 0 1 0
0 0 0 1
1 0 0 0]
0 0 1 0
0 0 0 1
1 0 0 0]
A2 = [2 1 0 0
0 1 1 0
0 0 1 1
1 0 0 1]
0 1 1 0
0 0 1 1
1 0 0 1]

g = GNNGraph(A; graph_type = GRAPH_T)
fg2 = add_self_loops(g)
Expand All @@ -18,7 +18,7 @@ end

@testset "batch" begin
g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
graph_type = GRAPH_T)
graph_type = GRAPH_T)
g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)

Expand All @@ -44,7 +44,7 @@ end
# Batch of batches
g123123 = Flux.batch([g123, g123])
@test g123123.graph_indicator ==
[fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
[fill(1, 10); fill(2, 4); fill(3, 7); fill(4, 10); fill(5, 4); fill(6, 7)]
@test g123123.num_graphs == 6
end

Expand All @@ -67,8 +67,8 @@ end
c = 3
ngraphs = 10
gs = [rand_graph(n, c * n, ndata = rand(2, n), edata = rand(3, c * n),
graph_type = GRAPH_T)
for _ in 1:ngraphs]
graph_type = GRAPH_T)
for _ in 1:ngraphs]
gall = Flux.batch(gs)
gs2 = Flux.unbatch(gall)
@test gs2[1] == gs[1]
Expand All @@ -77,7 +77,7 @@ end

@testset "getgraph" begin
g1 = GNNGraph(random_regular_graph(10, 2), ndata = rand(16, 10),
graph_type = GRAPH_T)
graph_type = GRAPH_T)
g2 = GNNGraph(random_regular_graph(4, 2), ndata = rand(16, 4), graph_type = GRAPH_T)
g3 = GNNGraph(random_regular_graph(7, 2), ndata = rand(16, 7), graph_type = GRAPH_T)
g = Flux.batch([g1, g2, g3])
Expand Down Expand Up @@ -268,3 +268,14 @@ end end
@test nv(DG) == g.num_nodes
@test ne(DG) == g.num_edges
end

@testset "random_walk_pe" begin
s = [1, 2, 2, 3]
t = [2, 1, 3, 2]
ndata = [-1, 0, 1]
g = GNNGraph(s, t, graph_type = GRAPH_T, ndata = ndata)
output = random_walk_pe(g, 3)
@test output == [0.0 0.0 0.0
0.5 1.0 0.5
0.0 0.0 0.0]
end

0 comments on commit 05fca7c

Please sign in to comment.