Skip to content

Commit

Permalink
Add Lux tests (#372)
Browse files Browse the repository at this point in the history
* Lux tests

* Set up extension

* Test deps

* LuxTestUtils

* Add scenarios

* Fix x

* Lux tests working with ComponentArrays

* Fix Flux

* More tolerant lux tests

* Adapt to extensions

* Rng

* Chill rtol
  • Loading branch information
gdalle authored Jul 31, 2024
1 parent 2ee0b56 commit a65f228
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 4 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
- Down/Detector
- Down/DifferentiateWith
- Down/Flux
- Down/Lux
exclude:
# lts
- version: 'lts'
Expand All @@ -74,6 +75,8 @@ jobs:
group: Down/Detector
- version: 'lts'
group: Down/Flux
- version: 'lts'
group: Down/Lux
# pre-release
- version: 'pre'
group: Formalities
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Down/Flux/test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Pkg
Pkg.add(["Enzyme", "FiniteDifferences", "Flux", "Zygote"])
Pkg.add(["FiniteDifferences", "Enzyme", "Flux", "Zygote"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
Expand Down
22 changes: 22 additions & 0 deletions DifferentiationInterface/test/Down/Lux/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Pkg
Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils", "Zygote"])

using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using Lux: Lux
using LuxTestUtils: LuxTestUtils
using Random

Random.seed!(0)

test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
logging=LOGGING,
)
11 changes: 9 additions & 2 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"]
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDiff", "Lux", "LuxTestUtils"]
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"

[compat]
Expand All @@ -40,8 +45,8 @@ ComponentArrays = "0.15"
DataFrames = "1.6.1"
DifferentiationInterface = "0.5.6"
DocStringExtensions = "0.8,0.9"
Flux = "0.13,0.14"
FiniteDifferences = "0.12"
Flux = "0.13,0.14"
Functors = "0.4"
JET = "0.4 - 0.8, 0.9"
JLArrays = "0.1"
Expand All @@ -68,6 +73,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -78,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "LuxTestUtils", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
module DifferentiationInterfaceTestLuxExt

using Compat: @compat
using ComponentArrays: ComponentArray
import DifferentiationInterface as DI
using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using Lux
using LuxTestUtils
using LuxTestUtils: check_approx
using Random: AbstractRNG, default_rng

#=
Relevant discussions:
- https://github.com/LuxDL/Lux.jl/issues/769
=#

function DIT.lux_isequal(a, b)
return check_approx(a, b; atol=0, rtol=0)
end

function DIT.lux_isapprox(a, b; atol, rtol)
return check_approx(a, b; atol, rtol)
end

struct SquareLoss{M,X,S}
model::M
x::X
st::S
end

function (sql::SquareLoss)(ps)
@compat (; model, x, st) = sql
return sum(abs2, first(model(x, ps, st)))
end

function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
models_and_xs = [
(Dense(2, 4), randn(rng, Float32, 2, 3)),
(Dense(2, 4, gelu), randn(rng, Float32, 2, 3)),
(Dense(2, 4, gelu; use_bias=false), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4, relu), Dense(4, 3)), randn(rng, Float32, 2, 3)),
(Scale(2), randn(rng, Float32, 2, 3)),
(Conv((3, 3), 2 => 3), randn(rng, Float32, 3, 3, 2, 2)),
(Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(rng, Float32, 3, 3, 2, 2)),
(
Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()),
randn(rng, Float32, 3, 3, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)),
rand(rng, Float32, 5, 5, 2, 2),
),
(
Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())),
rand(rng, Float32, 5, 5, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))),
rand(rng, Float32, 5, 5, 2, 2),
),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(rng, Float32, 5, 2)),
(Bilinear((2, 2) => 3), randn(rng, Float32, 2, 3)),
(SkipConnection(Dense(2 => 2), vcat), randn(rng, Float32, 2, 3)),
(ConvTranspose((3, 3), 3 => 2; stride=2), rand(rng, Float32, 5, 5, 3, 1)),
(StatefulRecurrentCell(RNNCell(3 => 5)), rand(rng, Float32, 3, 2)),
(StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(rng, Float32, 3, 2)),
(
StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)),
rand(rng, Float32, 3, 2),
),
(
Chain(
StatefulRecurrentCell(RNNCell(3 => 5)),
StatefulRecurrentCell(RNNCell(5 => 3)),
),
rand(rng, Float32, 3, 2),
),
(StatefulRecurrentCell(LSTMCell(3 => 5)), rand(rng, Float32, 3, 2)),
(
Chain(
StatefulRecurrentCell(LSTMCell(3 => 5)),
StatefulRecurrentCell(LSTMCell(5 => 3)),
),
rand(rng, Float32, 3, 2),
),
(StatefulRecurrentCell(GRUCell(3 => 5)), rand(rng, Float32, 3, 10)),
(
Chain(
StatefulRecurrentCell(GRUCell(3 => 5)),
StatefulRecurrentCell(GRUCell(5 => 3)),
),
rand(rng, Float32, 3, 10),
),
(Chain(Dense(2, 4), BatchNorm(4)), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(rng, Float32, 2, 3)),
(
Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)),
randn(rng, Float32, 2, 3),
),
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(rng, Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(rng, Float32, 6, 6, 2, 2)),
(
Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)),
randn(rng, Float32, 6, 6, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))),
randn(rng, Float32, 4, 4, 2, 2),
),
(Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(
Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)),
randn(rng, Float32, 6, 6, 2, 2),
),
]

scens = Scenario[]

for (model, x) in models_and_xs
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
loss = SquareLoss(model, x, st)
l = loss(ps)
g = DI.gradient(loss, DI.AutoFiniteDiff(), ps)
scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace)
push!(scens, scen)
end

return scens
end

end
30 changes: 29 additions & 1 deletion DifferentiationInterfaceTest/src/scenarios/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function gpu_scenarios end
Create a vector of [`Scenario`](@ref)s with neural networks from [Flux.jl](https://github.com/FluxML/Flux.jl).
!!! warning
This function requires Flux.jl and FiniteDifferences.jl to be loaded (it is implemented in a package extension).
This function requires FiniteDifferences.jl and Flux.jl to be loaded (it is implemented in a package extension).
!!! danger
These scenarios are still experimental and not part of the public API.
Expand All @@ -55,3 +55,31 @@ function flux_isapprox end
Exact comparison function to use in correctness tests with gradients of Flux.jl networks.
"""
function flux_isequal end

"""
lux_scenarios(rng=Random.default_rng())
Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl).
!!! warning
This function requires ComponentArrays.jl, FiniteDiff.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension).
!!! danger
These scenarios are still experimental and not part of the public API.
Their ground truth values are computed with finite differences, and thus subject to imprecision.
"""
function lux_scenarios end

"""
lux_isapprox(x, y; atol, rtol)
Approximate comparison function to use in correctness tests with gradients of Lux.jl networks.
"""
function lux_isapprox end

"""
lux_isequal(x, y)
Exact comparison function to use in correctness tests with gradients of Lux.jl networks.
"""
function lux_isequal end
12 changes: 12 additions & 0 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using FiniteDifferences: FiniteDifferences
using Flux: Flux
using ForwardDiff: ForwardDiff
using JLArrays: JLArrays
using Lux: Lux
using LuxTestUtils: LuxTestUtils
using Random
using SparseConnectivityTracer
using SparseMatrixColorings
Expand Down Expand Up @@ -43,3 +45,13 @@ test_differentiation(
atol=1e-6,
logging=LOGGING,
)

test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
logging=LOGGING,
)

0 comments on commit a65f228

Please sign in to comment.