From f72cacbe30b93eab63e9d5fd4ac0f4d9a96c9ff1 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 13:10:53 +0100 Subject: [PATCH 01/17] Add functions --- src/utils.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 1ff5fe74f..8c7bb98fe 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -99,3 +99,49 @@ function broadcast_edges(g::GNNGraph, x) gi = graph_indicator(g, edges = true) return gather(x, gi) end + +function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) + index = sortperm(view(matrix, :, sortby); rev) + return matrix[index, :] +end + +function _sort_row2(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) #scalarindexing + sorted_matrix=sort(collect(eachrow(matrix)),by= x->x[end]) + reduce(hcat,sorted_matrix)' +end + +function _topk(feat::DataStore, k::Int; rev::Bool = true, sortby = nothing) + matrices = values(feat) + if sortby === nothing + return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices) + else + return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices) + end +end + +function _topk2(matrices, k::Int; rev::Bool = true, sortby = nothing) + if sortby === nothing + return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices) + else + return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices) + end +end + +function _topk_tensor(feat::DataStore,numgra, k::Int; rev::Bool = true, sortby = nothing) + matrices = values(feat) + p=map(matrix -> reshape(matrix,size(matrix,1),size(matrix,2)÷numgra,numgra),matrices) + v=map(x -> _topk2(collect(eachslice(x,dims=3)), k; rev,sortby), p) + p=map(matrix -> reduce(hcat,matrix),v) +end + + + + +function topk_nodes(g::GNNGraph, k::Int; rev = true, sortby = nothing) + return _topk(g.ndata, k; rev, sortby) +end + +function topk_edges(g::GNNGraph, k::Int; rev = true, sortby = nothing) + return _topk(g.edata, k; rev, sortby) +end + From 92e731472f88bdfe972baac361e07449d20015ef Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 14:45:07 +0100 Subject: [PATCH 02/17] Add test --- test/utils.jl | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/utils.jl b/test/utils.jl index 1b36ea59f..6070c6229 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -59,4 +59,43 @@ @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) end + + @testset "topk_nodes" begin + A = [0.0297 0.8307 0.9140 0.6702 0.3346; + 0.5901 0.3030 0.9280 0.6893 0.7997; + 0.0880 0.6515 0.4451 0.7507 0.5297; + 0.5171 0.6379 0.2695 0.8954 0.5197] + B = [0.3168 0.3174 0.5303 0.0804 0.3808; + 0.1752 0.9105 0.5692 0.8489 0.0539; + 0.1931 0.4954 0.3455 0.3934 0.0857; + 0.5065 0.5182 0.5418 0.1520 0.3872] + C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346; + 0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; + 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; + 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] + g1 = rand_graph(5, 6, ndata = (w = A,)) + g2 = rand_graph(5, 6, ndata = (w = B,)) + g3 = rand_graph(5, 6, edata = (e = C,)) + g = Flux.batch([g1, g2]) + output1 = topk_nodes(g1, :w, 3) + output2 = topk_nodes(g1, :w, 3; sortby = 5) + output3 = topk_edges(g3, :e, 3; sortby = 6) + output_batch = topk_nodes(g, :w, 3; sortby = 5) + correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997; + 0.5171 0.6515 0.9140 0.7507 0.5297; + 0.0880 0.6379 0.4451 0.6893 0.5197] + correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997; + 0.0880 0.6515 0.4451 0.7507 0.5297; + 0.5171 0.6379 0.2695 0.8954 0.5197] + correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; + 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; + 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] + correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872; + 0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808; + 0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857] + @test output1 == correctout1 + @test output2 == correctout2 + @test output3 == correctout3 + @test output_batch == correctout_batch + end end From cc0a0155caa17e2caacaabc241071d0edb1dea1e Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 14:45:19 +0100 Subject: [PATCH 03/17] Fix functions --- src/utils.jl | 52 +++++++++++++++++++++++++--------------------------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8c7bb98fe..8255704d2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -105,43 +105,41 @@ function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) return matrix[index, :] end -function _sort_row2(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) #scalarindexing - sorted_matrix=sort(collect(eachrow(matrix)),by= x->x[end]) - reduce(hcat,sorted_matrix)' -end - -function _topk(feat::DataStore, k::Int; rev::Bool = true, sortby = nothing) - matrices = values(feat) +function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing - return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices) + return sort(matrix, dims = 1; rev)[1:k, :] else - return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices) + return _sort_row(matrix; rev, sortby)[1:k, :] end end -function _topk2(matrices, k::Int; rev::Bool = true, sortby = nothing) - if sortby === nothing - return map(matrix -> sort(matrix, dims = 1; rev)[1:k, :], matrices) - else - return map(matrix -> _sort_row(matrix; rev, sortby)[1:k, :], matrices) - end +function _sort_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) + return map(x -> _sort_matrix(x, k; rev, sortby), matrices) end -function _topk_tensor(feat::DataStore,numgra, k::Int; rev::Bool = true, sortby = nothing) - matrices = values(feat) - p=map(matrix -> reshape(matrix,size(matrix,1),size(matrix,2)÷numgra,numgra),matrices) - v=map(x -> _topk2(collect(eachslice(x,dims=3)), k; rev,sortby), p) - p=map(matrix -> reduce(hcat,matrix),v) +function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, + sortby = nothing) + tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs, + number_graphs) + sorted_matrix = _sort_batch(collect(eachslice(tensor_matrix, dims = 3)), k; rev, sortby) + return reduce(hcat, sorted_matrix) end - - - -function topk_nodes(g::GNNGraph, k::Int; rev = true, sortby = nothing) - return _topk(g.ndata, k; rev, sortby) +function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, + sortby = nothing) + if number_graphs==1 + return _sort_matrix(matrix, k; rev, sortby) + else + return _topk_batch(matrix, number_graphs, k; rev, sortby) + end end -function topk_edges(g::GNNGraph, k::Int; rev = true, sortby = nothing) - return _topk(g.edata, k; rev, sortby) +function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) + matrix = getproperty(g.ndata, feat) + return _topk(matrix, g.num_graphs, k; rev, sortby) end +function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) + matrix = getproperty(g.edata, feat) + return _topk(matrix, g.num_graphs, k; rev, sortby) +end From 6c1abb02cc305626ca0481c3520864dbf0172452 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 14:45:29 +0100 Subject: [PATCH 04/17] Export functions --- src/GraphNeuralNetworks.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 06dfa178a..b9d9f8dcb 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -27,6 +27,8 @@ export broadcast_nodes, broadcast_edges, softmax_edge_neighbors, + topk_nodes, + topk_edges, # msgpass apply_edges, From 4d788f20112d2fd808992c1f6ea482cd5252c529 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 19:31:09 +0100 Subject: [PATCH 05/17] Fix --- src/utils.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8255704d2..10c43ad32 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -100,16 +100,16 @@ function broadcast_edges(g::GNNGraph, x) return gather(x, gi) end -function _sort_row(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) - index = sortperm(view(matrix, :, sortby); rev) - return matrix[index, :] +function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) + index = sortperm(view(matrix, sortby, : ); rev) + return matrix[ :, index] end function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing - return sort(matrix, dims = 1; rev)[1:k, :] + return sort(matrix, dims = 2; rev)[:, 1:k] else - return _sort_row(matrix; rev, sortby)[1:k, :] + return _sort_col(matrix; rev, sortby)[:, 1:k] end end From 9de994f67906fbc171a1a863838d9861e060b2b2 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 19:53:02 +0100 Subject: [PATCH 06/17] Simplify test --- test/utils.jl | 62 +++++++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 36 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 6070c6229..973de68b1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -61,41 +61,31 @@ end @testset "topk_nodes" begin - A = [0.0297 0.8307 0.9140 0.6702 0.3346; - 0.5901 0.3030 0.9280 0.6893 0.7997; - 0.0880 0.6515 0.4451 0.7507 0.5297; - 0.5171 0.6379 0.2695 0.8954 0.5197] - B = [0.3168 0.3174 0.5303 0.0804 0.3808; - 0.1752 0.9105 0.5692 0.8489 0.0539; - 0.1931 0.4954 0.3455 0.3934 0.0857; - 0.5065 0.5182 0.5418 0.1520 0.3872] - C = [0.0297 0.0297 0.8307 0.9140 0.6702 0.3346; - 0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; - 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; - 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] - g1 = rand_graph(5, 6, ndata = (w = A,)) - g2 = rand_graph(5, 6, ndata = (w = B,)) - g3 = rand_graph(5, 6, edata = (e = C,)) + A = [1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0] + B = [0.318907 0.189981 0.991791; + 0.547022 0.977349 0.680538; + 0.921823 0.35132 0.494715; + 0.451793 0.00704976 0.0189275] + g1 = rand_graph(3, 6, ndata = (x = A,)) + g2 = rand_graph(3, 6, ndata = B) + + # + output1 = topk_nodes(g1, :x, 2) + output2 = topk_nodes(g2, :x, 1, sortby = 2) + + @test output1 == [9.0 5.0; + 10.0 6.0; + 11.0 7.0; + 12.0 8.0] + @test output2 == [0.189981; + 0.977349; + 0.35132; + 0.00704976;;] g = Flux.batch([g1, g2]) - output1 = topk_nodes(g1, :w, 3) - output2 = topk_nodes(g1, :w, 3; sortby = 5) - output3 = topk_edges(g3, :e, 3; sortby = 6) - output_batch = topk_nodes(g, :w, 3; sortby = 5) - correctout1 = [0.5901 0.8307 0.9280 0.8954 0.7997; - 0.5171 0.6515 0.9140 0.7507 0.5297; - 0.0880 0.6379 0.4451 0.6893 0.5197] - correctout2 = [0.5901 0.3030 0.9280 0.6893 0.7997; - 0.0880 0.6515 0.4451 0.7507 0.5297; - 0.5171 0.6379 0.2695 0.8954 0.5197] - correctout3 = [0.5901 0.5901 0.3030 0.9280 0.6893 0.7997; - 0.0880 0.0880 0.6515 0.4451 0.7507 0.5297; - 0.5171 0.5171 0.6379 0.2695 0.8954 0.5197] - correctout_batch = [0.5901 0.3030 0.9280 0.6893 0.7997 0.5065 0.5182 0.5418 0.1520 0.3872; - 0.0880 0.6515 0.4451 0.7507 0.5297 0.3168 0.3174 0.5303 0.0804 0.3808; - 0.5171 0.6379 0.2695 0.8954 0.5197 0.1931 0.4954 0.3455 0.3934 0.0857] - @test output1 == correctout1 - @test output2 == correctout2 - @test output3 == correctout3 - @test output_batch == correctout_batch + output3 = topk_nodes(g, :x, 2; sortby = 4) + @test output3 == [9.0 5.0 0.318907 0.991791; + 10.0 6.0 0.547022 0.680538; + 11.0 7.0 0.921823 0.494715; + 12.0 8.0 0.451793 0.0189275] end -end +end; From e69402ea6eeb7080cdb1a67e69bbe2a0b083dca7 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 22:35:25 +0100 Subject: [PATCH 07/17] Add docstrings --- src/utils.jl | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 10c43ad32..ee3e85d6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -100,11 +100,13 @@ function broadcast_edges(g::GNNGraph, x) return gather(x, gi) end +# return a permuted matrix according to the sorting of the sortby column function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) - index = sortperm(view(matrix, sortby, : ); rev) - return matrix[ :, index] + index = sortperm(view(matrix, sortby, :); rev) + return matrix[:, index] end +# sort and reshape matrix function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing return sort(matrix, dims = 2; rev)[:, 1:k] @@ -113,32 +115,45 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = end end -function _sort_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) +# sort the iterator of batch matrices +function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing) return map(x -> _sort_matrix(x, k; rev, sortby), matrices) end +# sort and reshape batch matrix function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, sortby = nothing) tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs, number_graphs) - sorted_matrix = _sort_batch(collect(eachslice(tensor_matrix, dims = 3)), k; rev, sortby) + sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby) return reduce(hcat, sorted_matrix) end +# topk for a feature matrix function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, sortby = nothing) - if number_graphs==1 + if number_graphs == 1 return _sort_matrix(matrix, k; rev, sortby) else return _topk_batch(matrix, number_graphs, k; rev, sortby) end end +""" + topk_nodes(g, feat, k; rev = true, sortby = nothing) + +Graph-wise top-k on node features `feat` according to the `sortby` feature index. +""" function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) matrix = getproperty(g.ndata, feat) return _topk(matrix, g.num_graphs, k; rev, sortby) end +""" + topk_edges(g, feat, k; rev = true, sortby = nothing) + +Graph-wise top-k on edge features `feat` according to the `sortby` feature index. +""" function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) matrix = getproperty(g.edata, feat) return _topk(matrix, g.num_graphs, k; rev, sortby) From 6d2579f4d8a369f99bb39fd9eed639a8a72aed52 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Thu, 9 Mar 2023 22:37:08 +0100 Subject: [PATCH 08/17] Remove comments --- src/utils.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ee3e85d6d..42a7f6b05 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -100,13 +100,11 @@ function broadcast_edges(g::GNNGraph, x) return gather(x, gi) end -# return a permuted matrix according to the sorting of the sortby column function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) index = sortperm(view(matrix, sortby, :); rev) return matrix[:, index] end -# sort and reshape matrix function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing return sort(matrix, dims = 2; rev)[:, 1:k] @@ -115,12 +113,10 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = end end -# sort the iterator of batch matrices function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing) return map(x -> _sort_matrix(x, k; rev, sortby), matrices) end -# sort and reshape batch matrix function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, sortby = nothing) tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs, @@ -129,7 +125,6 @@ function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Boo return reduce(hcat, sorted_matrix) end -# topk for a feature matrix function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, sortby = nothing) if number_graphs == 1 From e10c4a92c32779e486988d10a19232f74faa0bf7 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Fri, 10 Mar 2023 10:09:40 +0100 Subject: [PATCH 09/17] Add topk_edges tests --- test/utils.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index 973de68b1..bf34a2ca8 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -68,11 +68,8 @@ 0.451793 0.00704976 0.0189275] g1 = rand_graph(3, 6, ndata = (x = A,)) g2 = rand_graph(3, 6, ndata = B) - - # output1 = topk_nodes(g1, :x, 2) output2 = topk_nodes(g2, :x, 1, sortby = 2) - @test output1 == [9.0 5.0; 10.0 6.0; 11.0 7.0; @@ -88,4 +85,13 @@ 11.0 7.0 0.921823 0.494715; 12.0 8.0 0.451793 0.0189275] end -end; + + @testset "topk_edges" begin + A = [0.157163 0.561874 0.886584 0.0475203 0.72576 0.815986; + 0.852048 0.974619 0.0345627 0.874303 0.614322 0.113491] + g1 = rand_graph(5, 6, edata = (x = A,)) + output1 = topk_edges(g1, :x, 2) + @test output1 == [0.886584 0.815986; + 0.974619 0.874303] + end +end From eec3a46014385ade56f389d5ef0406f86984b41d Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Fri, 10 Mar 2023 13:08:05 +0100 Subject: [PATCH 10/17] Fix batch case and reorder --- src/utils.jl | 41 +++++++++++++++++++---------------------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 42a7f6b05..e3578517b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -105,7 +105,7 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) return matrix[:, index] end -function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) +function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing return sort(matrix, dims = 2; rev)[:, 1:k] else @@ -113,35 +113,26 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = end end -function _sort_batch(matrices, k::Int; rev::Bool = true, sortby = nothing) - return map(x -> _sort_matrix(x, k; rev, sortby), matrices) -end - -function _topk_batch(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, +function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) - tensor_matrix = reshape(matrix, size(matrix, 1), size(matrix, 2) ÷ number_graphs, - number_graphs) - sorted_matrix = _sort_batch(eachslice(tensor_matrix, dims = 3), k; rev, sortby) + sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby), matrices) return reduce(hcat, sorted_matrix) end -function _topk(matrix::AbstractArray, number_graphs::Int, k::Int; rev::Bool = true, - sortby = nothing) - if number_graphs == 1 - return _sort_matrix(matrix, k; rev, sortby) - else - return _topk_batch(matrix, number_graphs, k; rev, sortby) - end -end - """ topk_nodes(g, feat, k; rev = true, sortby = nothing) Graph-wise top-k on node features `feat` according to the `sortby` feature index. """ function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) - matrix = getproperty(g.ndata, feat) - return _topk(matrix, g.num_graphs, k; rev, sortby) + if g.num_graphs == 1 + matrix = getproperty(g.ndata, feat) + return _topk_matrix(matrix, k; rev, sortby) + else + graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] + matrices = map(graph -> getproperty(graph.ndata, feat), graphs) + return _topk_batch(matrices, k; rev, sortby) + end end """ @@ -150,6 +141,12 @@ end Graph-wise top-k on edge features `feat` according to the `sortby` feature index. """ function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) - matrix = getproperty(g.edata, feat) - return _topk(matrix, g.num_graphs, k; rev, sortby) + if g.num_graphs == 1 + matrix = getproperty(g.edata, feat) + return _topk_matrix(matrix, k; rev, sortby) + else + graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] + matrices = map(graph -> getproperty(graph.edata, feat), graphs) + return _topk_batch(matrices, k; rev, sortby) + end end From b02dcaa4fe61dd37f3d814963783b1cef12e036e Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Fri, 10 Mar 2023 13:10:02 +0100 Subject: [PATCH 11/17] Modify test arbitrary node number case --- test/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index bf34a2ca8..10541dce3 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -61,12 +61,12 @@ end @testset "topk_nodes" begin - A = [1.0 5.0 9.0; 2.0 6.0 10.0; 3.0 7.0 11.0; 4.0 8.0 12.0] + A = [1.0 5.0 9.0 2.0; 2.0 6.0 10.0 1.0; 3.0 7.0 11.0 2.0; 4.0 8.0 12.0 1.0] B = [0.318907 0.189981 0.991791; 0.547022 0.977349 0.680538; 0.921823 0.35132 0.494715; 0.451793 0.00704976 0.0189275] - g1 = rand_graph(3, 6, ndata = (x = A,)) + g1 = rand_graph(4, 6, ndata = (x = A,)) g2 = rand_graph(3, 6, ndata = B) output1 = topk_nodes(g1, :x, 2) output2 = topk_nodes(g2, :x, 1, sortby = 2) From 75b0b8fedfd0fb514f10f85de43328978b51fbf7 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 17:00:32 +0200 Subject: [PATCH 12/17] Add tests like to DGL --- test/utils.jl | 140 ++++++++++++++++++-------------------------------- 1 file changed, 50 insertions(+), 90 deletions(-) diff --git a/test/utils.jl b/test/utils.jl index daa304f7f..73621c820 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,98 +1,15 @@ De, Dx = 3, 2 g = Flux.batch([GNNGraph(erdos_renyi(10, 30), - ndata = rand(Dx, 10), - edata = rand(De, 30), - graph_type = GRAPH_T) for i in 1:5]) + ndata = rand(Dx, 10), + edata = rand(De, 30), + graph_type = GRAPH_T) for i in 1:5]) x = g.ndata.x e = g.edata.e - @testset "reduce_nodes" begin - r = reduce_nodes(mean, g, x) - @test size(r) == (Dx, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) - end - - @testset "reduce_edges" begin - r = reduce_edges(mean, g, e) - @test size(r) == (De, g.num_graphs) - @test r[:, 2] ≈ mean(getgraph(g, 2).edata.e, dims = 2) - end - - @testset "softmax_nodes" begin - r = softmax_nodes(g, x) - @test size(r) == size(x) - @test r[:, 1:10] ≈ softmax(getgraph(g, 1).ndata.x, dims = 2) - end - - @testset "softmax_edges" begin - r = softmax_edges(g, e) - @test size(r) == size(e) - @test r[:, 1:60] ≈ softmax(getgraph(g, 1).edata.e, dims = 2) - end - - @testset "broadcast_nodes" begin - z = rand(4, g.num_graphs) - r = broadcast_nodes(g, z) - @test size(r) == (4, g.num_nodes) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 10] ≈ z[:, 1] - @test r[:, 11] ≈ z[:, 2] - end - - @testset "broadcast_edges" begin - z = rand(4, g.num_graphs) - r = broadcast_edges(g, z) - @test size(r) == (4, g.num_edges) - @test r[:, 1] ≈ z[:, 1] - @test r[:, 60] ≈ z[:, 1] - @test r[:, 61] ≈ z[:, 2] - end - - @testset "softmax_edge_neighbors" begin - s = [1, 2, 3, 4] - t = [5, 5, 6, 6] - g2 = GNNGraph(s, t) - e2 = randn(Float32, 3, g2.num_edges) - z = softmax_edge_neighbors(g2, e2) - @test size(z) == size(e2) - @test z[:, 1:2] ≈ NNlib.softmax(e2[:, 1:2], dims = 2) - @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) - end - - @testset "topk_nodes" begin - A = [1.0 5.0 9.0 2.0; 2.0 6.0 10.0 1.0; 3.0 7.0 11.0 2.0; 4.0 8.0 12.0 1.0] - B = [0.318907 0.189981 0.991791; - 0.547022 0.977349 0.680538; - 0.921823 0.35132 0.494715; - 0.451793 0.00704976 0.0189275] - g1 = rand_graph(4, 6, ndata = (x = A,)) - g2 = rand_graph(3, 6, ndata = B) - output1 = topk_nodes(g1, :x, 2) - output2 = topk_nodes(g2, :x, 1, sortby = 2) - @test output1 == [9.0 5.0; - 10.0 6.0; - 11.0 7.0; - 12.0 8.0] - @test output2 == [0.189981; - 0.977349; - 0.35132; - 0.00704976;;] - g = Flux.batch([g1, g2]) - output3 = topk_nodes(g, :x, 2; sortby = 4) - @test output3 == [9.0 5.0 0.318907 0.991791; - 10.0 6.0 0.547022 0.680538; - 11.0 7.0 0.921823 0.494715; - 12.0 8.0 0.451793 0.0189275] - end - - @testset "topk_edges" begin - A = [0.157163 0.561874 0.886584 0.0475203 0.72576 0.815986; - 0.852048 0.974619 0.0345627 0.874303 0.614322 0.113491] - g1 = rand_graph(5, 6, edata = (x = A,)) - output1 = topk_edges(g1, :x, 2) - @test output1 == [0.886584 0.815986; - 0.974619 0.874303] - end +@testset "reduce_nodes" begin + r = reduce_nodes(mean, g, x) + @test size(r) == (Dx, g.num_graphs) + @test r[:, 2] ≈ mean(getgraph(g, 2).ndata.x, dims = 2) end @testset "reduce_edges" begin @@ -142,3 +59,46 @@ end @test z[:, 3:4] ≈ NNlib.softmax(e2[:, 3:4], dims = 2) end +@testset "topk_feature" begin + A = [0.0297 0.5901 0.088 0.5171; + 0.8307 0.303 0.6515 0.6379; + 0.914 0.928 0.4451 0.2695; + 0.6702 0.6893 0.7507 0.8954; + 0.3346 0.7997 0.5297 0.5197] + B = [0.3168 0.1323 0.1752 0.1931 0.5065; + 0.3174 0.2766 0.9105 0.4954 0.5182; + 0.5303 0.4318 0.5692 0.3455 0.5418; + 0.0804 0.6114 0.8489 0.3934 0.152; + 0.3808 0.1458 0.0539 0.0857 0.3872] + g1 = rand_graph(4, 2, ndata = (x = A,)) + g2 = rand_graph(5, 4, ndata = B) + g = Flux.batch([g1, g2]) + output1 = topk_feature(g, g.ndata.x, 3) + @test output1[1][:, :, 1] == [0.5901 0.5171 0.088; + 0.8307 0.6515 0.6379; + 0.928 0.914 0.4451; + 0.8954 0.7507 0.6893; + 0.7997 0.5297 0.5197] + @test output1[1][:, :, 2] == [0.5065 0.3168 0.1931; + 0.9105 0.5182 0.4954; + 0.5692 0.5418 0.5303; + 0.8489 0.6114 0.3934; + 0.3872 0.3808 0.1458] + @test output1[2][:, :, 1] == [2 4 3; + 1 3 4; + 2 1 3; + 4 3 2; + 2 3 4] + @test output1[2][:, :, 2] == [5 1 4; + 3 5 4; + 3 5 1; + 3 2 4; + 5 1 2] + output2 = topk_feature(g, g.ndata.x, 2; sortby = 5) + @test output2[1][:, :, 1] == [0.5901 0.088 + 0.303 0.6515; + 0.928 0.4451; + 0.6893 0.7507; + 0.7997 0.5297] + @test output2[2][:, :, 1] == [2; 3;;] +end From 226b07d643537c995355c3f092b70cd379018e44 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 17:01:12 +0200 Subject: [PATCH 13/17] Fix to return permutations --- src/utils.jl | 48 ++++++++++++++++++++++-------------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e3578517b..76cec6262 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -102,51 +102,47 @@ end function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) index = sortperm(view(matrix, sortby, :); rev) - return matrix[:, index] + return matrix[:, index], index end function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) if sortby === nothing - return sort(matrix, dims = 2; rev)[:, 1:k] + sorted_matrix = sort(matrix, dims = 2; rev)[:, 1:k] + vector_indices = map(x -> sortperm(x; rev), eachrow(matrix)) + indices = reduce(vcat, vector_indices')[:, 1:k] + return sorted_matrix, indices else - return _sort_col(matrix; rev, sortby)[:, 1:k] + sorted_matrix, indices = _sort_col(matrix; rev, sortby) + return sorted_matrix[:, 1:k], indices[1:k] end end function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) - sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby), matrices) - return reduce(hcat, sorted_matrix) -end - -""" - topk_nodes(g, feat, k; rev = true, sortby = nothing) - -Graph-wise top-k on node features `feat` according to the `sortby` feature index. -""" -function topk_nodes(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) - if g.num_graphs == 1 - matrix = getproperty(g.ndata, feat) - return _topk_matrix(matrix, k; rev, sortby) + num_graphs = length(matrices) + num_feat = size(matrices[1], 1) + sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby)[1], matrices) + output_matrix = reshape(reduce(hcat, sorted_matrix), num_feat, k, num_graphs) + indices = map(x -> _topk_matrix(x, k; rev, sortby)[2], matrices) + if sortby === nothing + output_indices = reshape(reduce(hcat, indices), num_feat, k, num_graphs) else - graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] - matrices = map(graph -> getproperty(graph.ndata, feat), graphs) - return _topk_batch(matrices, k; rev, sortby) + output_indices = reshape(reduce(hcat, indices), k, 1, num_graphs) end + return output_matrix, output_indices end """ - topk_edges(g, feat, k; rev = true, sortby = nothing) + topk_feature(g, feat, k; rev = true, sortby = nothing) -Graph-wise top-k on edge features `feat` according to the `sortby` feature index. +Graph-wise top-k on feature array `x` according to the `sortby` index. """ -function topk_edges(g::GNNGraph, feat::Symbol, k::Int; rev = true, sortby = nothing) +function topk_feature(g::GNNGraph, x::AbstractArray, k::Int; rev::Bool = true, + sortby::Union{Nothing, Int} = nothing) if g.num_graphs == 1 - matrix = getproperty(g.edata, feat) - return _topk_matrix(matrix, k; rev, sortby) + return _topk_matrix(x, k; rev, sortby) else - graphs = [getgraph(g, i) for i in 1:(g.num_graphs)] - matrices = map(graph -> getproperty(graph.edata, feat), graphs) + matrices = [x[:, g.graph_indicator .== i] for i in 1:(g.num_graphs)] return _topk_batch(matrices, k; rev, sortby) end end From 24faa9ad0b0af01315548e7401a179970649d287 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 17:01:23 +0200 Subject: [PATCH 14/17] Change name --- src/GraphNeuralNetworks.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index b9d9f8dcb..7845e7280 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -27,8 +27,7 @@ export broadcast_nodes, broadcast_edges, softmax_edge_neighbors, - topk_nodes, - topk_edges, + topk_feature, # msgpass apply_edges, From 87f043055ef276ecc88ebf6ace118eb3995a0fda Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 17:34:37 +0200 Subject: [PATCH 15/17] Improve docs --- src/utils.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 76cec6262..ab407fb06 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -105,7 +105,7 @@ function _sort_col(matrix::AbstractArray; rev::Bool = true, sortby::Int = 1) return matrix[:, index], index end -function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = nothing) +function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby::Union{Nothing, Int} = nothing) if sortby === nothing sorted_matrix = sort(matrix, dims = 2; rev)[:, 1:k] vector_indices = map(x -> sortperm(x; rev), eachrow(matrix)) @@ -118,7 +118,7 @@ function _topk_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby = end function _topk_batch(matrices::AbstractArray, k::Int; rev::Bool = true, - sortby = nothing) + sortby::Union{Nothing, Int} = nothing) num_graphs = length(matrices) num_feat = size(matrices[1], 1) sorted_matrix = map(x -> _topk_matrix(x, k; rev, sortby)[1], matrices) @@ -135,7 +135,15 @@ end """ topk_feature(g, feat, k; rev = true, sortby = nothing) -Graph-wise top-k on feature array `x` according to the `sortby` index. +Graph-wise top-`k` on feature array `x` according to the `sortby` index. + +# Arguments + +- `g`: a `GNNGraph``. +- `x`: a feature array of size `(number_features, g.num_nodes)` or `(number_features, g.num_edges)` of the graph `g`. +- `k`: the number of top features to return. +- `rev`: if `true`, sort in descending order otherwise returns the `k` smallest elements. +- `sortby`: the index of the feature to sort by. If `nothing`, every row independently. """ function topk_feature(g::GNNGraph, x::AbstractArray, k::Int; rev::Bool = true, sortby::Union{Nothing, Int} = nothing) From da04648afbc8010ede4ce8cfbad2ee9b00592f36 Mon Sep 17 00:00:00 2001 From: aurorarossi Date: Tue, 28 Mar 2023 20:50:16 +0200 Subject: [PATCH 16/17] Add example --- src/utils.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index ab407fb06..eb28a26e1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -136,6 +136,7 @@ end topk_feature(g, feat, k; rev = true, sortby = nothing) Graph-wise top-`k` on feature array `x` according to the `sortby` index. +Returns a tuple of the top-`k` features and their indices. # Arguments @@ -144,6 +145,26 @@ Graph-wise top-`k` on feature array `x` according to the `sortby` index. - `k`: the number of top features to return. - `rev`: if `true`, sort in descending order otherwise returns the `k` smallest elements. - `sortby`: the index of the feature to sort by. If `nothing`, every row independently. + +# Examples + +```julia +julia> g = rand_graph(5, 4, ndata = rand(3,5)); + +julia> g.ndata.x +3×5 Matrix{Float64}: + 0.333661 0.683551 0.315145 0.794089 0.840085 + 0.263023 0.726028 0.626617 0.412247 0.0914052 + 0.296433 0.186584 0.960758 0.0999844 0.813808 + +julia> topk_feature(g, g.ndata.x, 2) +([0.8400845757074524 0.7940891040468462; 0.7260276789396128 0.6266174187625888; 0.9607582005024967 0.8138081223752274], [5 4; 2 3; 3 5]) + +julia> topk_feature(g, g.ndata.x, 2; sortby=3) +([0.3151452763177829 0.8400845757074524; 0.6266174187625888 0.09140519108918477; 0.9607582005024967 0.8138081223752274], [3, 5]) + +``` + """ function topk_feature(g::GNNGraph, x::AbstractArray, k::Int; rev::Bool = true, sortby::Union{Nothing, Int} = nothing) From caa3b6ee228f1dc9e8a25dfe826895f2ef028777 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 10 Mar 2024 11:34:23 +0100 Subject: [PATCH 17/17] Fix function signature --- src/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ed36416db..dd2a4a4b7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -161,7 +161,7 @@ Returns a tuple of the top-`k` features and their indices. # Arguments - `g`: a `GNNGraph``. -- `x`: a feature array of size `(number_features, g.num_nodes)` or `(number_features, g.num_edges)` of the graph `g`. +- `feat`: a feature array of size `(number_features, g.num_nodes)` or `(number_features, g.num_edges)` of the graph `g`. - `k`: the number of top features to return. - `rev`: if `true`, sort in descending order otherwise returns the `k` smallest elements. - `sortby`: the index of the feature to sort by. If `nothing`, every row independently. @@ -186,12 +186,12 @@ julia> topk_feature(g, g.ndata.x, 2; sortby=3) ``` """ -function topk_feature(g::GNNGraph, x::AbstractArray, k::Int; rev::Bool = true, +function topk_feature(g::GNNGraph, feat::AbstractArray, k::Int; rev::Bool = true, sortby::Union{Nothing, Int} = nothing) if g.num_graphs == 1 - return _topk_matrix(x, k; rev, sortby) + return _topk_matrix(feat, k; rev, sortby) else - matrices = [x[:, g.graph_indicator .== i] for i in 1:(g.num_graphs)] + matrices = [feat[:, g.graph_indicator .== i] for i in 1:(g.num_graphs)] return _topk_batch(matrices, k; rev, sortby) end end