From 857f8207c587982d2c82180b9a5d1941d155b074 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 19 Jul 2024 09:59:22 +0200 Subject: [PATCH 01/12] Lux tests --- .github/workflows/Test.yml | 3 +++ DifferentiationInterface/test/Down/Lux/Project.toml | 8 ++++++++ DifferentiationInterface/test/Down/Lux/test.jl | 2 ++ DifferentiationInterfaceTest/Project.toml | 2 ++ .../DifferentiationInterfaceTestLuxExt.jl | 8 ++++++++ 5 files changed, 23 insertions(+) create mode 100644 DifferentiationInterface/test/Down/Lux/Project.toml create mode 100644 DifferentiationInterface/test/Down/Lux/test.jl create mode 100644 DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index aa92fe45e..e0890c84e 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -49,6 +49,7 @@ jobs: - Down/Detector - Down/DifferentiateWith - Down/Flux + - Down/Lux exclude: # lts - version: 'lts' @@ -73,6 +74,8 @@ jobs: group: Down/Detector - version: 'lts' group: Down/Flux + - version: 'lts' + group: Down/Lux # pre-release - version: 'pre' group: Formalities diff --git a/DifferentiationInterface/test/Down/Lux/Project.toml b/DifferentiationInterface/test/Down/Lux/Project.toml new file mode 100644 index 000000000..94c54310b --- /dev/null +++ b/DifferentiationInterface/test/Down/Lux/Project.toml @@ -0,0 +1,8 @@ +[deps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl new file mode 100644 index 000000000..45096f033 --- /dev/null +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -0,0 +1,2 @@ +using DifferentiationInterface, DifferentiationInterfaceTest +using Lux: Lux diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 8cb7764d2..cbaad856a 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -24,11 +24,13 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] +DifferentiationInterfaceTestLuxExt = ["FiniteDifferences", "Lux"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl new file mode 100644 index 000000000..f4e992fdc --- /dev/null +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -0,0 +1,8 @@ +module DifferentiationInterfaceTestLuxExt + +using DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT +using FiniteDifferences: FiniteDifferences +using Lux + +end From 846bcf7645dc0cc1543902bae39b42d63d7bdc4b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 19 Jul 2024 10:18:30 +0200 Subject: [PATCH 02/12] Set up extension --- DifferentiationInterfaceTest/Project.toml | 7 +++++-- .../DifferentiationInterfaceTestLuxExt.jl | 8 ++++++++ .../src/scenarios/extensions.jl | 14 ++++++++++++++ DifferentiationInterfaceTest/test/weird.jl | 4 ++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index cbaad856a..82df6f3bb 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -25,12 +25,13 @@ FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] -DifferentiationInterfaceTestLuxExt = ["FiniteDifferences", "Lux"] +DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDifferences", "Lux", "LuxTestUtils"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" @@ -69,6 +70,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" @@ -78,4 +81,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "LuxTestUtils", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index f4e992fdc..1c64ded62 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -4,5 +4,13 @@ using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using FiniteDifferences: FiniteDifferences using Lux +using LuxTestUtils +using LuxTestUtils: check_approx +using Random: AbstractRNG, default_rng + +function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) + scens = Scenario[] + return scens +end end diff --git a/DifferentiationInterfaceTest/src/scenarios/extensions.jl b/DifferentiationInterfaceTest/src/scenarios/extensions.jl index 9c16c857e..4f0e1c275 100644 --- a/DifferentiationInterfaceTest/src/scenarios/extensions.jl +++ b/DifferentiationInterfaceTest/src/scenarios/extensions.jl @@ -55,3 +55,17 @@ function flux_isapprox end Exact comparison function to use in correctness tests with gradients of Flux.jl networks. """ function flux_isequal end + +""" + lux_scenarios(rng=Random.default_rng()) + +Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl). + +!!! warning + This function requires Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). + +!!! danger + These scenarios are still experimental and not part of the public API. + Their ground truth values are computed with finite differences, and thus subject to imprecision. +""" +function lux_scenarios end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 0be6680ee..66b2af573 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -7,6 +7,8 @@ using FiniteDifferences: FiniteDifferences using Flux: Flux using ForwardDiff: ForwardDiff using JLArrays: JLArrays +using Lux: Lux +using LuxTestUtils: LuxTestUtils using SparseConnectivityTracer using SparseMatrixColorings using StaticArrays: StaticArrays @@ -32,3 +34,5 @@ test_differentiation( atol=1e-2, logging=LOGGING, ) + +test_differentiation(AutoZygote(), DIT.lux_scenarios(); logging=LOGGING) From b776839de4d1fcfc78bf430ef0cf1b0ca49e8f1c Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:21:04 +0200 Subject: [PATCH 03/12] Test deps --- DifferentiationInterface/test/Down/Lux/Project.toml | 8 -------- DifferentiationInterface/test/Down/Lux/test.jl | 3 +++ 2 files changed, 3 insertions(+), 8 deletions(-) delete mode 100644 DifferentiationInterface/test/Down/Lux/Project.toml diff --git a/DifferentiationInterface/test/Down/Lux/Project.toml b/DifferentiationInterface/test/Down/Lux/Project.toml deleted file mode 100644 index 94c54310b..000000000 --- a/DifferentiationInterface/test/Down/Lux/Project.toml +++ /dev/null @@ -1,8 +0,0 @@ -[deps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index 45096f033..02a14c4b3 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -1,2 +1,5 @@ +using Pkg +Pkg.add(["FiniteDifferences", "Lux", "LuxTestUtils", "Zygote"]) + using DifferentiationInterface, DifferentiationInterfaceTest using Lux: Lux From 9d6ca4ec47826d52aecb868edc8daf8e554ac7d6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 10:22:15 +0200 Subject: [PATCH 04/12] LuxTestUtils --- DifferentiationInterfaceTest/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 10b6911c6..562a82cf6 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -83,4 +83,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] +test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "DifferentiationInterface", "FiniteDiff", "FiniteDifferences", "Flux", "ForwardDiff", "JET", "JLArrays", "JuliaFormatter", "Lux", "LuxTestUtils", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StaticArrays", "Test", "Zygote"] From 73a1ec6c8805b9c9c60b44d75c866b9a7c155b74 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:16:19 +0200 Subject: [PATCH 05/12] Add scenarios --- DifferentiationInterfaceTest/Project.toml | 3 +- .../DifferentiationInterfaceTestLuxExt.jl | 124 +++++++++++++++++- .../src/scenarios/extensions.jl | 14 ++ DifferentiationInterfaceTest/test/weird.jl | 10 +- 4 files changed, 147 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 562a82cf6..ba13bf877 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -27,11 +27,12 @@ JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] -DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDifferences", "Lux", "LuxTestUtils"] +DifferentiationInterfaceTestLuxExt = ["FiniteDifferences", "Lux", "LuxTestUtils", "Zygote"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index 1c64ded62..c7c71664b 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -1,15 +1,135 @@ module DifferentiationInterfaceTestLuxExt +using Compat: @compat using DifferentiationInterfaceTest -import DifferentiationInterfaceTest as DIT using FiniteDifferences: FiniteDifferences using Lux using LuxTestUtils using LuxTestUtils: check_approx using Random: AbstractRNG, default_rng +using Zygote: Zygote + +#= +Relevant discussions: + +- https://github.com/LuxDL/Lux.jl/issues/769 +=# + +function DifferentiationInterfaceTest.lux_isequal(a, b) + return check_approx(a, b; atol=0, rtol=0) +end + +function DifferentiationInterfaceTest.lux_isapprox(a, b; atol, rtol) + return check_approx(a, b; atol, rtol) +end + +struct SquareLoss{M,X,S} + model::M + x::X + st::S +end + +function (sql::SquareLoss)(ps) + @compat (; model, x, st) = sql + # TODO: use deepcopy(st)? + return sum(abs2, first(model(x, ps, st))) +end + +function DifferentiationInterfaceTest.lux_scenarios(rng::AbstractRNG=default_rng()) + models_and_xs = [ + (Dense(2, 4), randn(rng, Float32, 2, 3)), + (Dense(2, 4, gelu), randn(rng, Float32, 2, 3)), + (Dense(2, 4, gelu; use_bias=false), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4, relu), Dense(4, 3)), randn(rng, Float32, 2, 3)), + (Scale(2), randn(rng, Float32, 2, 3)), + (Conv((3, 3), 2 => 3), randn(rng, Float32, 3, 3, 2, 2)), + (Conv((3, 3), 2 => 3, gelu; pad=SamePad()), randn(rng, Float32, 3, 3, 2, 2)), + ( + Conv((3, 3), 2 => 3, relu; use_bias=false, pad=SamePad()), + randn(rng, Float32, 3, 3, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, gelu), Conv((3, 3), 3 => 1, gelu)), + rand(rng, Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((4, 4), 2 => 2; pad=SamePad()), MeanPool((5, 5); pad=SamePad())), + rand(rng, Float32, 5, 5, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, relu; pad=SamePad()), MaxPool((2, 2))), + rand(rng, Float32, 5, 5, 2, 2), + ), + (Maxout(() -> Dense(5 => 4, tanh), 3), randn(rng, Float32, 5, 2)), + (Bilinear((2, 2) => 3), randn(rng, Float32, 2, 3)), + (SkipConnection(Dense(2 => 2), vcat), randn(rng, Float32, 2, 3)), + (ConvTranspose((3, 3), 3 => 2; stride=2), rand(rng, Float32, 5, 5, 3, 1)), + (StatefulRecurrentCell(RNNCell(3 => 5)), rand(rng, Float32, 3, 2)), + (StatefulRecurrentCell(RNNCell(3 => 5, gelu)), rand(rng, Float32, 3, 2)), + ( + StatefulRecurrentCell(RNNCell(3 => 5, gelu; use_bias=false)), + rand(rng, Float32, 3, 2), + ), + ( + Chain( + StatefulRecurrentCell(RNNCell(3 => 5)), + StatefulRecurrentCell(RNNCell(5 => 3)), + ), + rand(rng, Float32, 3, 2), + ), + (StatefulRecurrentCell(LSTMCell(3 => 5)), rand(rng, Float32, 3, 2)), + ( + Chain( + StatefulRecurrentCell(LSTMCell(3 => 5)), + StatefulRecurrentCell(LSTMCell(5 => 3)), + ), + rand(rng, Float32, 3, 2), + ), + (StatefulRecurrentCell(GRUCell(3 => 5)), rand(rng, Float32, 3, 10)), + ( + Chain( + StatefulRecurrentCell(GRUCell(3 => 5)), + StatefulRecurrentCell(GRUCell(5 => 3)), + ), + rand(rng, Float32, 3, 10), + ), + (Chain(Dense(2, 4), BatchNorm(4)), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4), BatchNorm(4, gelu)), randn(rng, Float32, 2, 3)), + ( + Chain(Dense(2, 4), BatchNorm(4, gelu; track_stats=false)), + randn(rng, Float32, 2, 3), + ), + (Chain(Conv((3, 3), 2 => 6), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + (Chain(Conv((3, 3), 2 => 6, tanh), BatchNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + (Chain(Dense(2, 4), GroupNorm(4, 2, gelu)), randn(rng, Float32, 2, 3)), + (Chain(Dense(2, 4), GroupNorm(4, 2)), randn(rng, Float32, 2, 3)), + (Chain(Conv((3, 3), 2 => 6), GroupNorm(6, 3)), randn(rng, Float32, 6, 6, 2, 2)), + ( + Chain(Conv((3, 3), 2 => 6, tanh), GroupNorm(6, 3)), + randn(rng, Float32, 6, 6, 2, 2), + ), + ( + Chain(Conv((3, 3), 2 => 3, gelu), LayerNorm((1, 1, 3))), + randn(rng, Float32, 4, 4, 2, 2), + ), + (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(rng, Float32, 6, 6, 2, 2)), + ( + Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), + randn(rng, Float32, 6, 6, 2, 2), + ), + ] -function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) scens = Scenario[] + + for (model, x) in models_and_xs + ps, st = Lux.setup(rng, model) + loss = SquareLoss(model, x, st) + l = loss(ps) + g = Zygote.gradient(loss, ps) # TODO: replace with FiniteDifferences + scen = GradientScenario(loss; x=model, y=l, grad=g, nb_args=1, place=:outofplace) + push!(scens, scen) + end + return scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/extensions.jl b/DifferentiationInterfaceTest/src/scenarios/extensions.jl index 4f0e1c275..0de4d9bd6 100644 --- a/DifferentiationInterfaceTest/src/scenarios/extensions.jl +++ b/DifferentiationInterfaceTest/src/scenarios/extensions.jl @@ -69,3 +69,17 @@ Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https: Their ground truth values are computed with finite differences, and thus subject to imprecision. """ function lux_scenarios end + +""" + lux_isapprox(x, y; atol, rtol) + +Approximate comparison function to use in correctness tests with gradients of Lux.jl networks. +""" +function lux_isapprox end + +""" + lux_isequal(x, y) + +Exact comparison function to use in correctness tests with gradients of Lux.jl networks. +""" +function lux_isequal end diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index e44405b5e..5036ecf8d 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -46,4 +46,12 @@ test_differentiation( logging=LOGGING, ) -test_differentiation(AutoZygote(), DIT.lux_scenarios(); logging=LOGGING) +test_differentiation( + AutoZygote(), + DIT.lux_scenarios(); + isequal=DIT.lux_isequal, + isapprox=DIT.lux_isapprox, + atol=1e-3, + rtol=1e-3, + logging=LOGGING, +) From 18750e1bfc4b8952aeabd63e2b2f46282d5d139b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:16:40 +0200 Subject: [PATCH 06/12] Fix x --- .../DifferentiationInterfaceTestLuxExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index c7c71664b..736114bd1 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -126,7 +126,7 @@ function DifferentiationInterfaceTest.lux_scenarios(rng::AbstractRNG=default_rng loss = SquareLoss(model, x, st) l = loss(ps) g = Zygote.gradient(loss, ps) # TODO: replace with FiniteDifferences - scen = GradientScenario(loss; x=model, y=l, grad=g, nb_args=1, place=:outofplace) + scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace) push!(scens, scen) end From 4b583a2a232f732e734d8b60b21b25f743572c0f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 15:50:19 +0200 Subject: [PATCH 07/12] Lux tests working with ComponentArrays --- .../test/Down/Flux/test.jl | 2 +- .../test/Down/Lux/test.jl | 15 ++++++++++++++- DifferentiationInterfaceTest/Project.toml | 8 ++++---- .../DifferentiationInterfaceTestLuxExt.jl | 19 ++++++++++++------- .../src/DifferentiationInterfaceTest.jl | 1 + .../src/scenarios/extensions.jl | 4 ++-- DifferentiationInterfaceTest/test/weird.jl | 4 ++-- 7 files changed, 36 insertions(+), 17 deletions(-) diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index 4c21cd9ec..c26044f3f 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(["Enzyme", "FiniteDifferences", "Flux", "Zygote"]) +Pkg.add(["Enzyme", "Flux", "Zygote"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index 02a14c4b3..d1c5be7c1 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -1,5 +1,18 @@ using Pkg -Pkg.add(["FiniteDifferences", "Lux", "LuxTestUtils", "Zygote"]) +Pkg.add(["Lux", "LuxTestUtils", "Zygote"]) +using ComponentArrays: ComponentArrays using DifferentiationInterface, DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT using Lux: Lux +using LuxTestUtils: LuxTestUtils + +test_differentiation( + AutoZygote(), + DIT.lux_scenarios(); + isequal=DIT.lux_isequal, + isapprox=DIT.lux_isapprox, + rtol=1.0f-2, + atol=1.0f-3, + logging=LOGGING, +) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index ba13bf877..375f86b2b 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -10,6 +10,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -21,7 +22,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -31,9 +31,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" -DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] -DifferentiationInterfaceTestLuxExt = ["FiniteDifferences", "Lux", "LuxTestUtils", "Zygote"] +DifferentiationInterfaceTestFluxExt = ["Flux"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" +DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "Lux", "LuxTestUtils"] DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" [compat] @@ -44,8 +44,8 @@ ComponentArrays = "0.15" DataFrames = "1.6.1" DifferentiationInterface = "0.5.6" DocStringExtensions = "0.8,0.9" -Flux = "0.13,0.14" FiniteDifferences = "0.12" +Flux = "0.13,0.14" Functors = "0.4" JET = "0.4 - 0.8, 0.9" JLArrays = "0.1" diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index 736114bd1..4e288df8d 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -1,13 +1,15 @@ module DifferentiationInterfaceTestLuxExt using Compat: @compat +using ComponentArrays: ComponentArray +import DifferentiationInterface as DI using DifferentiationInterfaceTest +import DifferentiationInterfaceTest as DIT using FiniteDifferences: FiniteDifferences using Lux using LuxTestUtils using LuxTestUtils: check_approx using Random: AbstractRNG, default_rng -using Zygote: Zygote #= Relevant discussions: @@ -15,11 +17,11 @@ Relevant discussions: - https://github.com/LuxDL/Lux.jl/issues/769 =# -function DifferentiationInterfaceTest.lux_isequal(a, b) +function DIT.lux_isequal(a, b) return check_approx(a, b; atol=0, rtol=0) end -function DifferentiationInterfaceTest.lux_isapprox(a, b; atol, rtol) +function DIT.lux_isapprox(a, b; atol, rtol) return check_approx(a, b; atol, rtol) end @@ -31,11 +33,11 @@ end function (sql::SquareLoss)(ps) @compat (; model, x, st) = sql - # TODO: use deepcopy(st)? - return sum(abs2, first(model(x, ps, st))) + # TODO: get rid of deepcopy(st)? + return sum(abs2, first(model(x, ps, deepcopy(st)))) end -function DifferentiationInterfaceTest.lux_scenarios(rng::AbstractRNG=default_rng()) +function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) models_and_xs = [ (Dense(2, 4), randn(rng, Float32, 2, 3)), (Dense(2, 4, gelu), randn(rng, Float32, 2, 3)), @@ -123,9 +125,12 @@ function DifferentiationInterfaceTest.lux_scenarios(rng::AbstractRNG=default_rng for (model, x) in models_and_xs ps, st = Lux.setup(rng, model) + ps = ComponentArray(ps) loss = SquareLoss(model, x, st) l = loss(ps) - g = Zygote.gradient(loss, ps) # TODO: replace with FiniteDifferences + g = DI.gradient( + loss, DI.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1)), ps + ) scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 4c7ef0fc2..29d1ed2cf 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -61,6 +61,7 @@ using DifferentiationInterface: SecondDerivativeExtras using DocStringExtensions import DifferentiationInterface as DI +using FiniteDifferences: FiniteDifferences using Functors: fmap using JET: JET using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent diff --git a/DifferentiationInterfaceTest/src/scenarios/extensions.jl b/DifferentiationInterfaceTest/src/scenarios/extensions.jl index 0de4d9bd6..47e1cf9c0 100644 --- a/DifferentiationInterfaceTest/src/scenarios/extensions.jl +++ b/DifferentiationInterfaceTest/src/scenarios/extensions.jl @@ -34,7 +34,7 @@ function gpu_scenarios end Create a vector of [`Scenario`](@ref)s with neural networks from [Flux.jl](https://github.com/FluxML/Flux.jl). !!! warning - This function requires Flux.jl and FiniteDifferences.jl to be loaded (it is implemented in a package extension). + This function requires Flux.jl to be loaded (it is implemented in a package extension). !!! danger These scenarios are still experimental and not part of the public API. @@ -62,7 +62,7 @@ function flux_isequal end Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl). !!! warning - This function requires Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). + This function requires ComponentArrays.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). !!! danger These scenarios are still experimental and not part of the public API. diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 5036ecf8d..aaae016cd 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -51,7 +51,7 @@ test_differentiation( DIT.lux_scenarios(); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, - atol=1e-3, - rtol=1e-3, + rtol=1.0f-2, + atol=1.0f-3, logging=LOGGING, ) From 79c50ce4d39858a753e7865535ee60195457a206 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:02:06 +0200 Subject: [PATCH 08/12] Fix Flux --- DifferentiationInterface/test/Down/Flux/test.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index c26044f3f..b99c3cdf2 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -4,7 +4,6 @@ Pkg.add(["Enzyme", "Flux", "Zygote"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using Enzyme: Enzyme -using FiniteDifferences: FiniteDifferences using Flux: Flux using Random using Zygote: Zygote From b469c494753ef9c91889abb7ba87709d1a24525b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 30 Jul 2024 16:22:06 +0200 Subject: [PATCH 09/12] More tolerant lux tests --- DifferentiationInterface/test/Down/Lux/test.jl | 5 ++++- DifferentiationInterfaceTest/test/weird.jl | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index d1c5be7c1..2bd6994ed 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -6,6 +6,9 @@ using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using Lux: Lux using LuxTestUtils: LuxTestUtils +using Random + +Random.seed!(0) test_differentiation( AutoZygote(), @@ -13,6 +16,6 @@ test_differentiation( isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, rtol=1.0f-2, - atol=1.0f-3, + atol=1.0f-2, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index aaae016cd..55bcd8fa9 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -52,6 +52,6 @@ test_differentiation( isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, rtol=1.0f-2, - atol=1.0f-3, + atol=1.0f-2, logging=LOGGING, ) From ce237e108bc8bf82dd46c969e754b79da81010d0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 31 Jul 2024 07:50:15 +0200 Subject: [PATCH 10/12] Adapt to extensions --- DifferentiationInterface/test/Down/Flux/test.jl | 3 ++- DifferentiationInterface/test/Down/Lux/test.jl | 7 ++++--- DifferentiationInterfaceTest/Project.toml | 7 ++++--- .../DifferentiationInterfaceTestLuxExt.jl | 9 +++------ .../src/DifferentiationInterfaceTest.jl | 1 - DifferentiationInterfaceTest/src/scenarios/extensions.jl | 4 ++-- DifferentiationInterfaceTest/test/weird.jl | 4 ++-- 7 files changed, 17 insertions(+), 18 deletions(-) diff --git a/DifferentiationInterface/test/Down/Flux/test.jl b/DifferentiationInterface/test/Down/Flux/test.jl index b99c3cdf2..c7ef89c1e 100644 --- a/DifferentiationInterface/test/Down/Flux/test.jl +++ b/DifferentiationInterface/test/Down/Flux/test.jl @@ -1,9 +1,10 @@ using Pkg -Pkg.add(["Enzyme", "Flux", "Zygote"]) +Pkg.add(["FiniteDifferences", "Enzyme", "Flux", "Zygote"]) using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using Enzyme: Enzyme +using FiniteDifferences: FiniteDifferences using Flux: Flux using Random using Zygote: Zygote diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index 2bd6994ed..7f8782350 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -1,9 +1,10 @@ using Pkg -Pkg.add(["Lux", "LuxTestUtils", "Zygote"]) +Pkg.add(["FiniteDiff", "Lux", "LuxTestUtils", "Zygote"]) using ComponentArrays: ComponentArrays using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT +using FiniteDiff: FiniteDiff using Lux: Lux using LuxTestUtils: LuxTestUtils using Random @@ -15,7 +16,7 @@ test_differentiation( DIT.lux_scenarios(); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, - rtol=1.0f-2, - atol=1.0f-2, + rtol=1.0f-3, + atol=1.0f-3, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/Project.toml b/DifferentiationInterfaceTest/Project.toml index 09d02cec2..35fea5311 100644 --- a/DifferentiationInterfaceTest/Project.toml +++ b/DifferentiationInterfaceTest/Project.toml @@ -10,7 +10,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -22,6 +21,8 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [weakdeps] ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -31,9 +32,9 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays" -DifferentiationInterfaceTestFluxExt = ["Flux"] +DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"] DifferentiationInterfaceTestJLArraysExt = "JLArrays" -DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "Lux", "LuxTestUtils"] +DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "FiniteDiff", "Lux", "LuxTestUtils"] DifferentiationInterfaceTestStaticArraysExt = "StaticArrays" [compat] diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl index 4e288df8d..cd0b7ae66 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl @@ -5,7 +5,7 @@ using ComponentArrays: ComponentArray import DifferentiationInterface as DI using DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT -using FiniteDifferences: FiniteDifferences +using FiniteDiff: FiniteDiff using Lux using LuxTestUtils using LuxTestUtils: check_approx @@ -33,8 +33,7 @@ end function (sql::SquareLoss)(ps) @compat (; model, x, st) = sql - # TODO: get rid of deepcopy(st)? - return sum(abs2, first(model(x, ps, deepcopy(st)))) + return sum(abs2, first(model(x, ps, st))) end function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) @@ -128,9 +127,7 @@ function DIT.lux_scenarios(rng::AbstractRNG=default_rng()) ps = ComponentArray(ps) loss = SquareLoss(model, x, st) l = loss(ps) - g = DI.gradient( - loss, DI.AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(5, 1)), ps - ) + g = DI.gradient(loss, DI.AutoFiniteDiff(), ps) scen = GradientScenario(loss; x=ps, y=l, grad=g, nb_args=1, place=:outofplace) push!(scens, scen) end diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index 29d1ed2cf..4c7ef0fc2 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -61,7 +61,6 @@ using DifferentiationInterface: SecondDerivativeExtras using DocStringExtensions import DifferentiationInterface as DI -using FiniteDifferences: FiniteDifferences using Functors: fmap using JET: JET using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent diff --git a/DifferentiationInterfaceTest/src/scenarios/extensions.jl b/DifferentiationInterfaceTest/src/scenarios/extensions.jl index 47e1cf9c0..83c1ce284 100644 --- a/DifferentiationInterfaceTest/src/scenarios/extensions.jl +++ b/DifferentiationInterfaceTest/src/scenarios/extensions.jl @@ -34,7 +34,7 @@ function gpu_scenarios end Create a vector of [`Scenario`](@ref)s with neural networks from [Flux.jl](https://github.com/FluxML/Flux.jl). !!! warning - This function requires Flux.jl to be loaded (it is implemented in a package extension). + This function requires FiniteDifferences.jl and Flux.jl to be loaded (it is implemented in a package extension). !!! danger These scenarios are still experimental and not part of the public API. @@ -62,7 +62,7 @@ function flux_isequal end Create a vector of [`Scenario`](@ref)s with neural networks from [Lux.jl](https://github.com/LuxDL/Lux.jl). !!! warning - This function requires ComponentArrays.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). + This function requires ComponentArrays.jl, FiniteDiff.jl, Lux.jl and LuxTestUtils.jl to be loaded (it is implemented in a package extension). !!! danger These scenarios are still experimental and not part of the public API. diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index 55bcd8fa9..a48a22f5b 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -51,7 +51,7 @@ test_differentiation( DIT.lux_scenarios(); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, - rtol=1.0f-2, - atol=1.0f-2, + rtol=1.0f-3, + atol=1.0f-3, logging=LOGGING, ) From d6060948d8a94e64378ede62cc7555f02f8575d1 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 31 Jul 2024 07:50:54 +0200 Subject: [PATCH 11/12] Rng --- DifferentiationInterface/test/Down/Lux/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index 7f8782350..aca510fba 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -13,7 +13,7 @@ Random.seed!(0) test_differentiation( AutoZygote(), - DIT.lux_scenarios(); + DIT.lux_scenarios(Random.Xoshiro(63)); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, rtol=1.0f-3, From a07ff9b69157e810a1788dc4f133d1222aaf9168 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 31 Jul 2024 08:03:20 +0200 Subject: [PATCH 12/12] Chill rtol --- DifferentiationInterface/test/Down/Lux/test.jl | 2 +- DifferentiationInterfaceTest/test/weird.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/DifferentiationInterface/test/Down/Lux/test.jl b/DifferentiationInterface/test/Down/Lux/test.jl index aca510fba..bbe294e34 100644 --- a/DifferentiationInterface/test/Down/Lux/test.jl +++ b/DifferentiationInterface/test/Down/Lux/test.jl @@ -16,7 +16,7 @@ test_differentiation( DIT.lux_scenarios(Random.Xoshiro(63)); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, - rtol=1.0f-3, + rtol=1.0f-2, atol=1.0f-3, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index a48a22f5b..b8475b528 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -48,10 +48,10 @@ test_differentiation( test_differentiation( AutoZygote(), - DIT.lux_scenarios(); + DIT.lux_scenarios(Random.Xoshiro(63)); isequal=DIT.lux_isequal, isapprox=DIT.lux_isapprox, - rtol=1.0f-3, + rtol=1.0f-2, atol=1.0f-3, logging=LOGGING, )