Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Lux tests #372

Merged
merged 14 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
- Down/Detector
- Down/DifferentiateWith
- Down/Flux
- Down/Lux
exclude:
# lts
- version: 'lts'
Expand All @@ -74,6 +75,8 @@ jobs:
group: Down/Detector
- version: 'lts'
group: Down/Flux
- version: 'lts'
group: Down/Lux
# pre-release
- version: 'pre'
group: Formalities
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Down/Flux/test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Pkg
Pkg.add(["Enzyme", "FiniteDifferences", "Flux", "Zygote"])
Pkg.add(["FiniteDifferences", "Enzyme", "Flux", "Zygote"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
Expand Down
22 changes: 22 additions & 0 deletions DifferentiationInterface/test/Down/Lux/test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using Pkg
Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils", "Zygote"])

using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using Lux: Lux
using LuxTestUtils: LuxTestUtils
using Random

Random.seed!(0)

test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
logging=LOGGING,
)
11 changes: 9 additions & 2 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,20 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"]
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDiff", "Lux", "LuxTestUtils"]
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"

[compat]
Expand All @@ -40,8 +45,8 @@ ComponentArrays = "0.15"
DataFrames = "1.6.1"
DifferentiationInterface = "0.5.6"
DocStringExtensions = "0.8,0.9"
Flux = "0.13,0.14"
FiniteDifferences = "0.12"
Flux = "0.13,0.14"
Functors = "0.4"
JET = "0.4 - 0.8, 0.9"
JLArrays = "0.1"
Expand All @@ -68,6 +73,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -78,4 +85,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "LuxTestUtils", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
module DifferentiationInterfaceTestLuxExt

using Compat: @compat
using ComponentArrays: ComponentArray
import DifferentiationInterface as DI
using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using Lux
using LuxTestUtils
using LuxTestUtils: check_approx
using Random: AbstractRNG, default_rng

#=
Relevant discussions:

- https://github.com/LuxDL/Lux.jl/issues/769
=#

function DIT.lux_isequal(a, b)
return check_approx(a, b; atol=0, rtol=0)
end

function DIT.lux_isapprox(a, b; atol, rtol)
return check_approx(a, b; atol, rtol)
end

struct SquareLoss{M,X,S}
model::M
x::X
st::S
end

function (sql::SquareLoss)(ps)
@compat (; model, x, st) = sql
return sum(abs2, first(model(x, ps, st)))
end

function DIT.lux_scenarios(rng::AbstractRNG=default_rng())
models_and_xs = [
(Dense(2, 4), randn(rng, Float32, 2, 3)),
(Dense(2, 4, gelu), randn(rng, Float32, 2, 3)),
(Dense(2, 4, gelu; use_bias=false), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4, relu), Dense(4, 3)), randn(rng, Float32, 2, 3)),
(Scale(2), randn(rng, Float32, 2, 3)),
(Conv((3, 3), 2 => 3), randn(rng, Float32, 3, 3, 2, 2)),
(Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(rng, Float32, 3, 3, 2, 2)),
(
Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()),
randn(rng, Float32, 3, 3, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)),
rand(rng, Float32, 5, 5, 2, 2),
),
(
Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())),
rand(rng, Float32, 5, 5, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))),
rand(rng, Float32, 5, 5, 2, 2),
),
(Maxout(() -> Dense(5 => 4, tanh), 3), randn(rng, Float32, 5, 2)),
(Bilinear((2, 2) => 3), randn(rng, Float32, 2, 3)),
(SkipConnection(Dense(2 => 2), vcat), randn(rng, Float32, 2, 3)),
(ConvTranspose((3, 3), 3 => 2; stride=2), rand(rng, Float32, 5, 5, 3, 1)),
(StatefulRecurrentCell(RNNCell(3 => 5)), rand(rng, Float32, 3, 2)),
(StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(rng, Float32, 3, 2)),
(
StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)),
rand(rng, Float32, 3, 2),
),
(
Chain(
StatefulRecurrentCell(RNNCell(3 => 5)),
StatefulRecurrentCell(RNNCell(5 => 3)),
),
rand(rng, Float32, 3, 2),
),
(StatefulRecurrentCell(LSTMCell(3 => 5)), rand(rng, Float32, 3, 2)),
(
Chain(
StatefulRecurrentCell(LSTMCell(3 => 5)),
StatefulRecurrentCell(LSTMCell(5 => 3)),
),
rand(rng, Float32, 3, 2),
),
(StatefulRecurrentCell(GRUCell(3 => 5)), rand(rng, Float32, 3, 10)),
(
Chain(
StatefulRecurrentCell(GRUCell(3 => 5)),
StatefulRecurrentCell(GRUCell(5 => 3)),
),
rand(rng, Float32, 3, 10),
),
(Chain(Dense(2, 4), BatchNorm(4)), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(rng, Float32, 2, 3)),
(
Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)),
randn(rng, Float32, 2, 3),
),
(Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(rng, Float32, 2, 3)),
(Chain(Dense(2, 4), GroupNorm(4, 2)), randn(rng, Float32, 2, 3)),
(Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(rng, Float32, 6, 6, 2, 2)),
(
Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)),
randn(rng, Float32, 6, 6, 2, 2),
),
(
Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))),
randn(rng, Float32, 4, 4, 2, 2),
),
(Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(rng, Float32, 6, 6, 2, 2)),
(
Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)),
randn(rng, Float32, 6, 6, 2, 2),
),
]

scens = Scenario[]

for (model, x) in models_and_xs
ps, st = Lux.setup(rng, model)
ps = ComponentArray(ps)
loss = SquareLoss(model, x, st)
l = loss(ps)
g = DI.gradient(loss, DI.AutoFiniteDiff(), ps)
scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace)
push!(scens, scen)
end

return scens
end

end
30 changes: 29 additions & 1 deletion DifferentiationInterfaceTest/src/scenarios/extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function gpu_scenarios end
Create a vector of [`Scenario`](@ref)s with neural networks from [Flux.jl](https://github.com/FluxML/Flux.jl).

!!! warning
This function requires Flux.jl and FiniteDifferences.jl to be loaded (it is implemented in a package extension).
This function requires FiniteDifferences.jl and Flux.jl to be loaded (it is implemented in a package extension).

!!! danger
These scenarios are still experimental and not part of the public API.
Expand All @@ -55,3 +55,31 @@ function flux_isapprox end
Exact comparison function to use in correctness tests with gradients of Flux.jl networks.
"""
function flux_isequal end

"""
lux_scenarios(rng=Random.default_rng())

Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl).

!!! warning
This function requires ComponentArrays.jl, FiniteDiff.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension).

!!! danger
These scenarios are still experimental and not part of the public API.
Their ground truth values are computed with finite differences, and thus subject to imprecision.
"""
function lux_scenarios end

"""
lux_isapprox(x, y; atol, rtol)

Approximate comparison function to use in correctness tests with gradients of Lux.jl networks.
"""
function lux_isapprox end

"""
lux_isequal(x, y)

Exact comparison function to use in correctness tests with gradients of Lux.jl networks.
"""
function lux_isequal end
12 changes: 12 additions & 0 deletions DifferentiationInterfaceTest/test/weird.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using FiniteDifferences: FiniteDifferences
using Flux: Flux
using ForwardDiff: ForwardDiff
using JLArrays: JLArrays
using Lux: Lux
using LuxTestUtils: LuxTestUtils
using Random
using SparseConnectivityTracer
using SparseMatrixColorings
Expand Down Expand Up @@ -43,3 +45,13 @@ test_differentiation(
atol=1e-6,
logging=LOGGING,
)

test_differentiation(
AutoZygote(),
DIT.lux_scenarios(Random.Xoshiro(63));
isequal=DIT.lux_isequal,
isapprox=DIT.lux_isapprox,
rtol=1.0f-2,
atol=1.0f-3,
logging=LOGGING,
)
Loading