Skip to content

Commit

Permalink
[GNNLux] Adding NNConv Layer (#491)
Browse files Browse the repository at this point in the history
* nnlux

* Update conv_tests.jl: test

* fix

* Update conv.jl: show

* Update shared_testsetup.jl: changed to e
  • Loading branch information
rbSparky authored Sep 15, 2024
1 parent 5715b26 commit c896eda
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 4 deletions.
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export AGNNConv,
# GMMConv,
GraphConv,
MEGNetConv,
# NNConv,
NNConv,
# ResGatedGraphConv,
# SAGEConv,
SGConv
Expand Down
59 changes: 59 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,62 @@ function Base.show(io::IO, l::MEGNetConv)
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end

@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractLuxLayer
aggr
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_bias
σ
end

function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
aggr = +,
init_bias = zeros32,
use_bias::Bool = true,
init_weight = glorot_uniform)
in_dims, out_dims = ch
σ = NNlib.fast_act(σ)
return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
ps = (; nn = LuxCore.initialparameters(rng, l.nn), weight)
if l.use_bias
ps = (; ps..., bias = l.init_bias(rng, l.out_dims))
end
return ps
end

function LuxCore.initialstates(rng::AbstractRNG, l::NNConv)
return (; nn = LuxCore.initialstates(rng, l.nn))
end

function LuxCore.parameterlength(l::NNConv)
n = parameterlength(l.nn) + l.in_dims * l.out_dims
if l.use_bias
n += l.out_dims
end
return n
end

LuxCore.statelength(l::NNConv) = statelength(l.nn)

function (l::NNConv)(g, x, e, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ)
y = GNNlib.nn_conv(m, g, x, e)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::NNConv)
print(io, "NNConv($(l.nn)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end
22 changes: 22 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,26 @@
@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end

@testset "NNConv" begin
n_in = 3
n_in_edge = 10
n_out = 5

s = [1,1,2,3]
t = [2,3,1,1]
g2 = GNNGraph(s, t)

nn = Dense(n_in_edge => n_out * n_in)
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
x = randn(Float32, n_in, g2.num_nodes)
e = randn(Float32, n_in_edge, g2.num_edges)

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

y, st′ = l(g2, x, e, ps, st)

@test size(y) == (n_out, g2.num_nodes)
end
end
10 changes: 7 additions & 3 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export test_lux_layer

function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)
atol=1.0f-2, rtol=1.0f-2, e=nothing)

if container
@test l isa GNNContainerLayer
Expand All @@ -27,7 +27,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
if e !== nothing
y, st′ = l(g, x, e, ps, st)
else
y, st′ = l(g, x, ps, st)
end
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
Expand All @@ -42,4 +46,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

end
end

0 comments on commit c896eda

Please sign in to comment.