Skip to content

Commit

Permalink
[GNNLux] updates for Lux v1.0 (#490)
Browse files Browse the repository at this point in the history
* updates for Lux 1.0

* naming
  • Loading branch information
CarloLucibello authored Sep 14, 2024
1 parent bd5e2f2 commit 5715b26
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 38 deletions.
6 changes: 4 additions & 2 deletions GNNLux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ConcreteStructs = "0.2.3"
Lux = "0.5.61"
LuxCore = "0.1.20"
Lux = "1.0"
LuxCore = "1.0"
NNlib = "0.9.21"
Reexport = "1.2"
Static = "1.1"
julia = "1.10"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using Statistics: mean
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
using LuxCore: LuxCore, AbstractLuxLayer, AbstractLuxContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, GRUCell,
glorot_uniform, zeros32,
StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
using Static
@reexport using GNNGraphs

include("layers/basic.jl")
Expand Down
14 changes: 7 additions & 7 deletions GNNLux/src/layers/basic.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""
abstract type GNNLayer <: AbstractExplicitLayer end
abstract type GNNLayer <: AbstractLuxLayer end
An abstract type from which graph neural network layers are derived.
It is Derived from Lux's `AbstractExplicitLayer` type.
It is Derived from Lux's `AbstractLuxLayer` type.
See also [`GNNChain`](@ref GNNLux.GNNChain).
"""
abstract type GNNLayer <: AbstractExplicitLayer end
abstract type GNNLayer <: AbstractLuxLayer end

abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end
abstract type GNNContainerLayer{T} <: AbstractLuxContainerLayer{T} end

@concrete struct GNNChain <: GNNContainerLayer{(:layers,)}
layers <: NamedTuple
Expand All @@ -24,7 +24,7 @@ function GNNChain(; kw...)
return GNNChain(nt)
end

_wrapforchain(l::AbstractExplicitLayer) = l
_wrapforchain(l::AbstractLuxLayer) = l
_wrapforchain(l) = Lux.WrappedFunction(l)

Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers))
Expand All @@ -44,7 +44,7 @@ Base.firstindex(c::GNNChain) = firstindex(c.layers)

LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end])

(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st)
(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps.layers, st.layers)

function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times
newst = (;)
Expand All @@ -56,6 +56,6 @@ function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, help
end

_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;)
_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
_applylayer(l::AbstractLuxLayer, g::GNNGraph, x, ps, st) = l(x, ps, st)
_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st)
33 changes: 16 additions & 17 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
_getstate(s::StatefulLuxLayer{true}) = s.st
_getstate(s::StatefulLuxLayer{Static.True}) = s.st
_getstate(s::StatefulLuxLayer{false}) = s.st_any
_getstate(s::StatefulLuxLayer{Static.False}) = s.st_any


@concrete struct GCNConv <: GNNLayer
Expand All @@ -20,10 +22,9 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity;
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops::Bool = true,
use_edge_weight::Bool = false,
allow_fast_activation::Bool = true)
use_edge_weight::Bool = false)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
σ = NNlib.fast_act(σ)
return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ)
end

Expand Down Expand Up @@ -121,10 +122,9 @@ function GraphConv(ch::Pair{Int, Int}, σ = identity;
aggr = +,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
use_bias::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
σ = NNlib.fast_act(σ)
return GraphConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr)
end

Expand Down Expand Up @@ -212,11 +212,10 @@ end
CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)

function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32,
allow_fast_activation = true)
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32)
(nin, ein), out = ch
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation)
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation)
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias)
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias)
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
end

Expand All @@ -232,7 +231,7 @@ function (l::CGConv)(g, x, e, ps, st)
end

@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
nn <: AbstractLuxLayer
aggr
end

Expand All @@ -246,10 +245,10 @@ end


function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
m = (; nn, l.aggr)
y = GNNlib.edge_conv(m, g, x)
stnew = _getstate(nn)
stnew = (; nn = _getstate(nn)) # TODO: support also aggr state if present
return y, stnew
end

Expand Down Expand Up @@ -608,18 +607,18 @@ function Base.show(io::IO, l::GatedGraphConv)
end

@concrete struct GINConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
nn <: AbstractLuxLayer
ϵ <: Real
aggr
end

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)

function (l::GINConv)(g, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn)
m = (; nn, l.ϵ, l.aggr)
y = GNNlib.gin_conv(m, g, x)
stnew = _getstate(nn)
stnew = (; nn = _getstate(nn))
return y, stnew
end

Expand Down Expand Up @@ -669,4 +668,4 @@ function Base.show(io::IO, l::MEGNetConv)
nout = l.out_dims
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end
end
12 changes: 6 additions & 6 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@concrete struct StatefulRecurrentCell <: AbstractExplicitContainerLayer{(:cell,)}
@concrete struct StatefulRecurrentCell <: AbstractLuxContainerLayer{(:cell,)}
cell <: Union{<:Lux.AbstractRecurrentCell, <:GNNContainerLayer}
end

Expand All @@ -7,16 +7,16 @@ function LuxCore.initialstates(rng::AbstractRNG, r::GNNLux.StatefulRecurrentCell
end

function (r::StatefulRecurrentCell)(g, x::AbstractMatrix, ps, st::NamedTuple)
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps, st.cell, st.carry)
(out, carry), st = applyrecurrentcell(r.cell, g, x, ps.cell, st.cell, st.carry)
return out, (; cell=st, carry)
end

function (r::StatefulRecurrentCell)(g, x::AbstractVector, ps, st::NamedTuple)
st, carry = st.cell, st.carry
stcell, carry = st.cell, st.carry
for xᵢ in x
(out, carry), st = applyrecurrentcell(r.cell, g, xᵢ, ps, st, carry)
(out, carry), stcell = applyrecurrentcell(r.cell, g, xᵢ, ps.cell, stcell, carry)
end
return out, (; cell=st, carry)
return out, (; cell=stcell, carry)
end

function applyrecurrentcell(l, g, x, ps, st, carry)
Expand All @@ -35,7 +35,7 @@ end

function TGCNCell(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
in_dims, out_dims = ch
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight, allow_fast_activation= true)
conv = GCNConv(ch, sigmoid; init_weight, init_bias, use_bias, add_self_loops, use_edge_weight)
gru = Lux.GRUCell(out_dims => out_dims; use_bias, init_weight = (init_weight, init_weight, init_weight), init_bias = (init_bias, init_bias, init_bias), init_state = init_state)
return TGCNCell(in_dims, out_dims, conv, gru, init_state)
end
Expand Down
8 changes: 4 additions & 4 deletions GNNLux/test/layers/basic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
x = randn(rng, Float32, 3, 10)

@testset "GNNLayer" begin
@test GNNLayer <: LuxCore.AbstractExplicitLayer
@test GNNLayer <: LuxCore.AbstractLuxLayer
end

@testset "GNNContainerLayer" begin
@test GNNContainerLayer <: LuxCore.AbstractExplicitContainerLayer
@test GNNContainerLayer <: LuxCore.AbstractLuxContainerLayer
end

@testset "GNNChain" begin
@test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)}
c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3))
@test GNNChain <: LuxCore.AbstractLuxContainerLayer{(:layers,)}
c = GNNChain(GraphConv(3 => 5, tanh), GCNConv(5 => 3))
test_lux_layer(rng, c, g, x, outputsize=(3,), container=true)
end
end
2 changes: 1 addition & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
end

@testset "GINConv" begin
nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims))
nn = Chain(Dense(in_dims => out_dims, tanh), Dense(out_dims => out_dims))
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end
Expand Down

0 comments on commit 5715b26

Please sign in to comment.