Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] Adding NNConv Layer #491

Merged
merged 5 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ export AGNNConv,
# GMMConv,
GraphConv,
MEGNetConv,
# NNConv,
NNConv,
# ResGatedGraphConv,
# SAGEConv,
SGConv
Expand Down
59 changes: 59 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -669,3 +669,62 @@ function Base.show(io::IO, l::MEGNetConv)
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end

@concrete struct NNConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractLuxLayer
aggr
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_bias
σ
end

function NNConv(ch::Pair{Int, Int}, nn, σ = identity;
aggr = +,
init_bias = zeros32,
use_bias::Bool = true,
init_weight = glorot_uniform)
in_dims, out_dims = ch
σ = NNlib.fast_act(σ)
return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
ps = (; nn = LuxCore.initialparameters(rng, l.nn), weight)
if l.use_bias
ps = (; ps..., bias = l.init_bias(rng, l.out_dims))
end
return ps
end

function LuxCore.initialstates(rng::AbstractRNG, l::NNConv)
return (; nn = LuxCore.initialstates(rng, l.nn))
end

function LuxCore.parameterlength(l::NNConv)
n = parameterlength(l.nn) + l.in_dims * l.out_dims
if l.use_bias
n += l.out_dims
end
return n
end

LuxCore.statelength(l::NNConv) = statelength(l.nn)

function (l::NNConv)(g, x, e, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ)
y = GNNlib.nn_conv(m, g, x, e)
stnew = _getstate(nn)
return y, stnew
end

function Base.show(io::IO, l::NNConv)
print(io, "NNConv($(l.nn)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end
22 changes: 22 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,26 @@
@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end

@testset "NNConv" begin
n_in = 3
n_in_edge = 10
n_out = 5

s = [1,1,2,3]
t = [2,3,1,1]
g2 = GNNGraph(s, t)

nn = Dense(n_in_edge => n_out * n_in)
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
x = randn(Float32, n_in, g2.num_nodes)
e = randn(Float32, n_in_edge, g2.num_edges)

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

y, st′ = l(g2, x, e, ps, st)

@test size(y) == (n_out, g2.num_nodes)
end
end
10 changes: 7 additions & 3 deletions GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export test_lux_layer

function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
outputsize=nothing, sizey=nothing, container=false,
atol=1.0f-2, rtol=1.0f-2)
atol=1.0f-2, rtol=1.0f-2, e=nothing)

if container
@test l isa GNNContainerLayer
Expand All @@ -27,7 +27,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps)
@test LuxCore.statelength(l) == LuxCore.statelength(st)

y, st′ = l(g, x, ps, st)
if e !== nothing
y, st′ = l(g, x, e, ps, st)
else
y, st′ = l(g, x, ps, st)
end
@test eltype(y) == eltype(x)
if outputsize !== nothing
@test LuxCore.outputsize(l) == outputsize
Expand All @@ -42,4 +46,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

end
end