Skip to content

Commit

Permalink
[GNNLux] add GMMConv, ResGatedGraphConv (#494)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Sep 23, 2024
1 parent 2313a96 commit d1831e7
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 13 deletions.
4 changes: 2 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ export AGNNConv,
GatedGraphConv,
GCNConv,
GINConv,
# GMMConv,
GMMConv,
GraphConv,
MEGNetConv,
NNConv,
# ResGatedGraphConv,
ResGatedGraphConv,
# SAGEConv,
SGConv
# TAGConv,
Expand Down
118 changes: 117 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,68 @@ function Base.show(io::IO, l::GINConv)
print(io, ")")
end

@concrete struct GMMConv <: GNNLayer
σ
ch::Pair{NTuple{2, Int}, Int}
K::Int
residual::Bool
init_weight
init_bias
use_bias::Bool
dense_x
end

function GMMConv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
K::Int = 1,
residual = false,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias = true)
dense_x = Dense(ch[1][1] => ch[2] * K, use_bias = false)
return GMMConv(σ, ch, K, residual, init_weight, init_bias, use_bias, dense_x)
end


function LuxCore.initialparameters(rng::AbstractRNG, l::GMMConv)
ein = l.ch[1][2]
mu = l.init_weight(rng, ein, l.K)
sigma_inv = l.init_weight(rng, ein, l.K)
ps = (; mu, sigma_inv, dense_x = LuxCore.initialparameters(rng, l.dense_x))
if l.use_bias
bias = l.init_bias(rng, l.ch[2])
ps = (; ps..., bias)
end
return ps
end

LuxCore.outputsize(l::GMMConv) = (l.ch[2],)

function LuxCore.parameterlength(l::GMMConv)
n = 2 * l.ch[1][2] * l.K
n += parameterlength(l.dense_x)
if l.use_bias
n += l.ch[2]
end
return n
end

function (l::GMMConv)(g::GNNGraph, x, e, ps, st)
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps))
return GNNlib.gmm_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GMMConv)
(nin, ein), out = l.ch
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
print(io, ", K=", l.K)
print(io, ", residual=", l.residual)
l.use_bias == true || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
Expand Down Expand Up @@ -712,6 +774,8 @@ function LuxCore.parameterlength(l::NNConv)
return n
end

LuxCore.outputsize(l::NNConv) = (l.out_dims,)

LuxCore.statelength(l::NNConv) = statelength(l.nn)

function (l::NNConv)(g, x, e, ps, st)
Expand All @@ -723,7 +787,59 @@ function (l::NNConv)(g, x, e, ps, st)
end

function Base.show(io::IO, l::NNConv)
print(io, "NNConv($(l.nn)")
print(io, "NNConv($(l.in_dims) => $(l.out_dims), $(l.nn)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct ResGatedGraphConv <: GNNLayer
in_dims::Int
out_dims::Int
σ
init_bias
init_weight
use_bias::Bool
end

function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true)
in_dims, out_dims = ch
return ResGatedGraphConv(in_dims, out_dims, σ, init_bias, init_weight, use_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::ResGatedGraphConv)
A = l.init_weight(rng, l.out_dims, l.in_dims)
B = l.init_weight(rng, l.out_dims, l.in_dims)
U = l.init_weight(rng, l.out_dims, l.in_dims)
V = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; A, B, U, V, bias)
else
return (; A, B, U, V)
end
end

function LuxCore.parameterlength(l::ResGatedGraphConv)
n = 4 * l.in_dims * l.out_dims
if l.use_bias
n += l.out_dims
end
return n
end

LuxCore.outputsize(l::ResGatedGraphConv) = (l.out_dims,)

function (l::ResGatedGraphConv)(g, x, ps, st)
m = (; ps.A, ps.B, ps.U, ps.V, bias = _getbias(ps), l.σ)
return GNNlib.res_gated_graph_conv(m, g, x), st
end

function Base.show(io::IO, l::ResGatedGraphConv)
print(io, "ResGatedGraphConv(", l.in_dims, " => ", l.out_dims)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
Expand Down
18 changes: 12 additions & 6 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
x = randn(Float32, n_in, g2.num_nodes)
e = randn(Float32, n_in_edge, g2.num_edges)
test_lux_layer(rng, l, g2, x; outputsize=(n_out,), e, container=true)
end

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@testset "GMMConv" begin
ein_dims = 4
e = randn(rng, Float32, ein_dims, g.num_edges)
l = GMMConv((in_dims, ein_dims) => out_dims, tanh; K = 2, residual = false)
test_lux_layer(rng, l, g, x; outputsize=(out_dims,), e)
end

y, st′ = l(g2, x, e, ps, st)

@test size(y) == (n_out, g2.num_nodes)
end
@testset "ResGatedGraphConv" begin
l = ResGatedGraphConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end
6 changes: 5 additions & 1 deletion GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test size(y) == (outputsize..., g.num_nodes)
end

loss = (x, ps) -> sum(first(l(g, x, ps, st)))
if e !== nothing
loss = (x, ps) -> sum(first(l(g, x, e, ps, st)))
else
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
end
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

Expand Down
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
m = propagate(e_mul_xj, g, mean, xj = xj, e = w)
m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes)

m = l.σ(m .+ l.bias)
m = l.σ.(m .+ l.bias)

if l.residual
if size(x, 1) == size(m, 1)
Expand Down
6 changes: 4 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,9 @@ end
function Base.show(io::IO, l::NNConv)
out, in = size(l.weight)
print(io, "NNConv($in => $out")
print(io, ", aggr=", l.aggr)
print(io, ", ", l.nn)
l.σ == identity || print(io, ", ", l.σ)
(l.aggr == +) || print(io, "; aggr=", l.aggr)
print(io, ")")
end

Expand Down Expand Up @@ -1136,7 +1138,7 @@ function Base.show(io::IO, l::GMMConv)
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
print(io, ", K=", l.K)
l.residual == true || print(io, ", residual=", l.residual)
print(io, ", residual=", l.residual)
print(io, ")")
end

Expand Down

0 comments on commit d1831e7

Please sign in to comment.