From 4cef2888ef8debd044a149b167d6f9e4f71cba91 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Nov 2024 12:03:13 -0500 Subject: [PATCH 1/5] test: use TestExtras in Lux testing --- test/Project.toml | 2 ++ test/helpers/compact_tests.jl | 6 ++-- test/helpers/loss_tests.jl | 50 ++++++++++++++++++++------------- test/helpers/training_tests.jl | 18 ++++++------ test/layers/containers_tests.jl | 8 ++++-- test/shared_testsetup.jl | 2 +- test/zygote_type_stability.jl | 10 +++---- 7 files changed, 56 insertions(+), 40 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index aca27bdbf..c9e33f67d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,6 +37,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -79,5 +80,6 @@ Static = "1" StaticArrays = "1.9" Statistics = "1.11.1" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/test/helpers/compact_tests.jl b/test/helpers/compact_tests.jl index 31b8fd52b..aa13daa9c 100644 --- a/test/helpers/compact_tests.jl +++ b/test/helpers/compact_tests.jl @@ -329,7 +329,7 @@ @test st_new.incr == 10 _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - @test @inferred(model(x, ps, st)) isa Any + @constinferred model(x, ps, st) function ScaledDense2(; d_in=5, d_out=7, act=relu) @compact(W=randn(d_out, d_in), b=zeros(d_out), incr=1) do x @@ -349,10 +349,10 @@ _, st_new = model(x, ps, st_new) @test st_new.incr == 100 - @test @inferred(model(x, ps, st)) isa Any + @constinferred model(x, ps, st) __f = (m, x, ps, st) -> sum(abs2, first(m(x, ps, st))) - @test @inferred(Zygote.gradient(__f, model, x, ps, st)) isa Any + @constinferred Zygote.gradient(__f, model, x, ps, st) end @testset "Multiple @return" begin diff --git a/test/helpers/loss_tests.jl b/test/helpers/loss_tests.jl index ba7934314..7598620fc 100644 --- a/test/helpers/loss_tests.jl +++ b/test/helpers/loss_tests.jl @@ -9,8 +9,8 @@ ∂x2 = Zygote.gradient(LuxOps.xlogx, 2.0)[1] @test ∂x1 ≈ ∂x2 - @test @inferred(LuxOps.xlogx(2)) isa Number - @test @inferred(LuxOps.xlogx(0)) isa Number + @constinferred LuxOps.xlogx(2) + @constinferred LuxOps.xlogx(0) @jet LuxOps.xlogx(2) @test iszero(LuxOps.xlogy(0, 1)) @@ -33,13 +33,13 @@ @test_broken false end - @test @inferred(LuxOps.xlogy(2, 3)) isa Number - @test @inferred(LuxOps.xlogy(0, 1)) isa Number + @constinferred LuxOps.xlogy(2, 3) + @constinferred LuxOps.xlogy(0, 1) @jet LuxOps.xlogy(2, 3) if LuxTestUtils.ENZYME_TESTING_ENABLED - @test @inferred(Enzyme.autodiff( - Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0))) isa Any + @constinferred Enzyme.autodiff( + Enzyme.Reverse, LuxOps.xlogy, Active, Active(2.0), Active(3.0)) else @test_broken false end @@ -74,7 +74,7 @@ end @test loss_sum(ŷ, y) ≈ loss_res * 4 @test loss_sum2(ŷ, y) ≈ loss_res * 4 - @test @inferred(Zygote.gradient(loss_mean, ŷ, y)) isa Any + @constinferred Zygote.gradient(loss_mean, ŷ, y) @jet loss_mean(ŷ, y) @jet loss_sum(ŷ, y) @@ -91,7 +91,11 @@ end @jet MSLELoss()(ŷ, y) - @test @inferred(Zygote.gradient(MSLELoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(MSLELoss(), ŷ, y) + else + @constinferred Zygote.gradient(MSLELoss(), ŷ, y) + end __f = Base.Fix2(MSLELoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -150,7 +154,7 @@ end @jet celoss(ŷ, y) @jet celoss_smooth(ŷ, y) - @test @inferred(Zygote.gradient(celoss, ŷ, y)) isa Any + @constinferred Zygote.gradient(celoss, ŷ, y) @test_gradients(Base.Fix2(celoss, y), ŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -173,7 +177,7 @@ end @jet logitceloss(logŷ, y) @jet logitceloss_smooth(logŷ, y) - @test @inferred(Zygote.gradient(logitceloss, logŷ, y)) isa Any + @constinferred Zygote.gradient(logitceloss, logŷ, y) @test_gradients(Base.Fix2(logitceloss, y), logŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -201,7 +205,7 @@ end @jet bceloss(σ.(logŷ), y) @jet bceloss_smooth(σ.(logŷ), y) - @test @inferred(Zygote.gradient(bceloss, σ.(logŷ), y)) isa Any + @constinferred Zygote.gradient(bceloss, σ.(logŷ), y) __f = Base.Fix2(bceloss, y) σlogŷ = σ.(logŷ) @@ -223,7 +227,7 @@ end @jet logitbceloss(logŷ, y) @jet logitbceloss_smooth(logŷ, y) - @test @inferred(Zygote.gradient(logitbceloss, logŷ, y)) isa Any + @constinferred Zygote.gradient(logitbceloss, logŷ, y) __f = Base.Fix2(logitbceloss, y) @test_gradients(__f, logŷ; atol=1.0f-3, rtol=1.0f-3) @@ -246,7 +250,11 @@ end @jet BinaryFocalLoss()(ŷ, y) - @test @inferred(Zygote.gradient(BinaryFocalLoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(BinaryFocalLoss(), ŷ, y) + else + @constinferred Zygote.gradient(BinaryFocalLoss(), ŷ, y) + end __f = Base.Fix2(BinaryFocalLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -270,7 +278,11 @@ end @jet FocalLoss()(ŷ, y) - @test @inferred(Zygote.gradient(FocalLoss(), ŷ, y)) isa Any broken=ongpu + if ongpu + @constinferred_broken Zygote.gradient(FocalLoss(), ŷ, y) + else + @constinferred Zygote.gradient(FocalLoss(), ŷ, y) + end __f = Base.Fix2(FocalLoss(), y) # FD will lead to out of domain errors @@ -301,7 +313,7 @@ end @test KLDivergenceLoss()(y, y) ≈ 0 @jet KLDivergenceLoss()(ŷ, y) - @test @inferred(Zygote.gradient(KLDivergenceLoss(), ŷ, y)) isa Any + @constinferred Zygote.gradient(KLDivergenceLoss(), ŷ, y) @test_gradients(Base.Fix2(KLDivergenceLoss(), y), ŷ; atol=1.0f-3, rtol=1.0f-3, skip_backends=VERSION ≥ v"1.11-" ? [AutoEnzyme()] : []) @@ -315,7 +327,7 @@ end @test Lux.HingeLoss()(y, 0.5 .* y) ≈ 0.125 @jet Lux.HingeLoss()(ŷ, y) - @test @inferred(Zygote.gradient(Lux.HingeLoss(), ŷ, y)) isa Any + @constinferred Zygote.gradient(Lux.HingeLoss(), ŷ, y) __f = Base.Fix2(Lux.HingeLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -329,7 +341,7 @@ end @test SquaredHingeLoss()(y, 0.5 .* y) ≈ 0.0625 @jet SquaredHingeLoss()(ŷ, y) - @inferred Zygote.gradient(SquaredHingeLoss(), ŷ, y) + @constinferred Zygote.gradient(SquaredHingeLoss(), ŷ, y) __f = Base.Fix2(SquaredHingeLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -343,7 +355,7 @@ end @test Lux.PoissonLoss()(y, y) ≈ 0.5044459776946685 @jet Lux.PoissonLoss()(ŷ, y) - @test @inferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) isa Any + @constinferred Zygote.gradient(Lux.PoissonLoss(), ŷ, y) __f = Base.Fix2(Lux.PoissonLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3) @@ -357,7 +369,7 @@ end @test DiceCoeffLoss()(y, y) ≈ 0.0 @jet DiceCoeffLoss()(ŷ, y) - @test @inferred(Zygote.gradient(DiceCoeffLoss(), ŷ, y)) isa Any broken=true + @constinferred_broken Zygote.gradient(DiceCoeffLoss(), ŷ, y) __f = Base.Fix2(DiceCoeffLoss(), y) @test_gradients(__f, ŷ; atol=1.0f-3, rtol=1.0f-3, diff --git a/test/helpers/training_tests.jl b/test/helpers/training_tests.jl index 222bb7eb2..66716c3b0 100644 --- a/test/helpers/training_tests.jl +++ b/test/helpers/training_tests.jl @@ -150,7 +150,7 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -160,13 +160,12 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa - Any + @constinferred Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) - _, _, _, tstate_new2 = @inferred Training.compute_gradients( + _, _, _, tstate_new2 = @constinferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) @@ -180,7 +179,7 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) @test tstate_new.states !== tstate.states @@ -190,13 +189,12 @@ end tstate = Training.TrainState(model, ps, st, opt) - _, _, _, tstate_new = @inferred Training.compute_gradients( + _, _, _, tstate_new = @constinferred Training.compute_gradients( AutoEnzyme(), mse, (x, x), tstate) - @test @inferred(Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new)) isa - Any + @constinferred Training.compute_gradients(AutoEnzyme(), mse, (x, x), tstate_new) - _, _, _, tstate_new2 = @inferred Training.compute_gradients( + _, _, _, tstate_new2 = @constinferred Training.compute_gradients( AutoEnzyme(), mse2, (x, x), tstate_new) @test hasfield(typeof(tstate_new2.cache.extras), :forward) @test hasfield(typeof(tstate_new2.cache.extras), :reverse) diff --git a/test/layers/containers_tests.jl b/test/layers/containers_tests.jl index cfa7e77d9..887f10c74 100644 --- a/test/layers/containers_tests.jl +++ b/test/layers/containers_tests.jl @@ -430,9 +430,13 @@ end st = st |> dev ps_nt = ps |> dev - @test @inferred(froggie(x, ps_nt, st)) isa Any + @constinferred froggie(x, ps_nt, st) ps_ca = ps |> ComponentArray |> dev - @test @inferred(froggie(x, ps_ca, st)) isa Any broken=ongpu + if ongpu + @constinferred_broken froggie(x, ps_ca, st) + else + @constinferred froggie(x, ps_ca, st) + end end end diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index e5d853744..db3b0125d 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -8,7 +8,7 @@ using Lux, Functors using Setfield: @set using DispatchDoctor: allow_unstable @reexport using ComponentArrays, LuxCore, LuxLib, LuxTestUtils, Random, StableRNGs, Test, - Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff + Zygote, Statistics, Enzyme, LinearAlgebra, ForwardDiff, TestExtras using MLDataDevices: default_device_rng, CPUDevice, CUDADevice, AMDGPUDevice using LuxTestUtils: check_approx using Static: True diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 1338ca229..8a37d52c0 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -75,13 +75,13 @@ include("setup_modes.jl") ps, st = Lux.setup(rng, model) |> dev x = input |> dev - @test @inferred(model(x, ps, Lux.testmode(st))) isa Any - @test @inferred(loss_function(model, x, ps, Lux.testmode(st))) isa Number + @constinferred model(x, ps, Lux.testmode(st)) + @constinferred loss_function(model, x, ps, Lux.testmode(st)) + if mode == "amdgpu" && model isa Conv - @test_broken @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa - Any + @constinferred_broken Zygote.gradient(loss_function, model, x, ps, st) else - @test @inferred(Zygote.gradient(loss_function, model, x, ps, st)) isa Any + @constinferred Zygote.gradient(loss_function, model, x, ps, st) end end end From 56a48e2d64e72822f670d2ffb95cacc389581fb5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Nov 2024 12:10:40 -0500 Subject: [PATCH 2/5] test: use TestExtras in MLDataDevices testing --- lib/MLDataDevices/test/Project.toml | 2 ++ lib/MLDataDevices/test/amdgpu_tests.jl | 4 ++-- lib/MLDataDevices/test/cuda_tests.jl | 4 ++-- lib/MLDataDevices/test/metal_tests.jl | 6 +++--- lib/MLDataDevices/test/misc_tests.jl | 6 +++--- lib/MLDataDevices/test/oneapi_tests.jl | 6 +++--- lib/MLDataDevices/test/xla_tests.jl | 4 ++-- 7 files changed, 17 insertions(+), 15 deletions(-) diff --git a/lib/MLDataDevices/test/Project.toml b/lib/MLDataDevices/test/Project.toml index e1e1f1e10..d14f76a58 100644 --- a/lib/MLDataDevices/test/Project.toml +++ b/lib/MLDataDevices/test/Project.toml @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -39,5 +40,6 @@ ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.69" diff --git a/lib/MLDataDevices/test/amdgpu_tests.jl b/lib/MLDataDevices/test/amdgpu_tests.jl index a771ada6e..9099ceb08 100644 --- a/lib/MLDataDevices/test/amdgpu_tests.jl +++ b/lib/MLDataDevices/test/amdgpu_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -122,7 +122,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) end end diff --git a/lib/MLDataDevices/test/cuda_tests.jl b/lib/MLDataDevices/test/cuda_tests.jl index 2fce4806a..cc0b7ff23 100644 --- a/lib/MLDataDevices/test/cuda_tests.jl +++ b/lib/MLDataDevices/test/cuda_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Functors, Test +using MLDataDevices, Random, Functors, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -144,7 +144,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) @test_throws ErrorException @inferred(return_val2(ps)) diff --git a/lib/MLDataDevices/test/metal_tests.jl b/lib/MLDataDevices/test/metal_tests.jl index 2bc884553..25411ebab 100644 --- a/lib/MLDataDevices/test/metal_tests.jl +++ b/lib/MLDataDevices/test/metal_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{get_device(x)} + @constinferred Val{get_device(x)} return_val2(ps) end end diff --git a/lib/MLDataDevices/test/misc_tests.jl b/lib/MLDataDevices/test/misc_tests.jl index 65f63c9a9..4573dbe0b 100644 --- a/lib/MLDataDevices/test/misc_tests.jl +++ b/lib/MLDataDevices/test/misc_tests.jl @@ -1,4 +1,4 @@ -using Adapt, MLDataDevices, ComponentArrays, Random +using Adapt, MLDataDevices, ComponentArrays, Random, TestExtras using ArrayInterface: parameterless_type using ChainRulesTestUtils: test_rrule using ReverseDiff, Tracker, ForwardDiff @@ -148,10 +148,10 @@ end ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{typeof(cpu_device())} + @constinferred Val{typeof(cpu_device())} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{cpu_device()} + @constinferred Val{cpu_device()} return_val2(ps) end @testset "undefined references array" begin diff --git a/lib/MLDataDevices/test/oneapi_tests.jl b/lib/MLDataDevices/test/oneapi_tests.jl index 2169869d3..355241e54 100644 --- a/lib/MLDataDevices/test/oneapi_tests.jl +++ b/lib/MLDataDevices/test/oneapi_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,10 +108,10 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) - @test @inferred(return_val2(ps)) isa Val{get_device(x)} + @constinferred Val{get_device(x)} return_val2(ps) end end diff --git a/lib/MLDataDevices/test/xla_tests.jl b/lib/MLDataDevices/test/xla_tests.jl index dd59af96e..2853a06cd 100644 --- a/lib/MLDataDevices/test/xla_tests.jl +++ b/lib/MLDataDevices/test/xla_tests.jl @@ -1,4 +1,4 @@ -using MLDataDevices, Random, Test +using MLDataDevices, Random, Test, TestExtras using ArrayInterface: parameterless_type @testset "CPU Fallback" begin @@ -108,7 +108,7 @@ using FillArrays, Zygote # Extensions ps = (; weight=x, bias=x, d=(x, x)) return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work - @test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))} + @constinferred Val{parameterless_type(typeof(device))} return_val(ps) return_val2(x) = Val(get_device(x)) @test_throws TypeError @inferred(return_val2(ps)) From 74e32588870602b7caf79937ff1b62081326a8ba Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Nov 2024 12:20:39 -0500 Subject: [PATCH 3/5] test: use TestExtras in LuxLib testing --- lib/LuxLib/test/Project.toml | 2 + .../test/common_ops/activation_tests.jl | 12 +++--- lib/LuxLib/test/common_ops/bias_act_tests.jl | 38 +++++++------------ lib/LuxLib/test/common_ops/conv_tests.jl | 15 ++++---- lib/LuxLib/test/common_ops/dense_tests.jl | 14 +++---- lib/LuxLib/test/common_ops/dropout_tests.jl | 25 ++++++------ .../test/normalization/batchnorm_tests.jl | 8 ++-- .../test/normalization/groupnorm_tests.jl | 6 +-- .../test/normalization/instancenorm_tests.jl | 12 +++--- .../test/normalization/layernorm_tests.jl | 6 +-- lib/LuxLib/test/shared_testsetup.jl | 2 +- 11 files changed, 62 insertions(+), 78 deletions(-) diff --git a/lib/LuxLib/test/Project.toml b/lib/LuxLib/test/Project.toml index 403bc57fb..1e1b5c58b 100644 --- a/lib/LuxLib/test/Project.toml +++ b/lib/LuxLib/test/Project.toml @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -61,5 +62,6 @@ Static = "0.8.4, 1" StaticArrays = "1.9.7" Statistics = "1.10" Test = "1.10" +TestExtras = "0.3.1" Tracker = "0.2.36" Zygote = "0.6.70" diff --git a/lib/LuxLib/test/common_ops/activation_tests.jl b/lib/LuxLib/test/common_ops/activation_tests.jl index 2789e7d4c..8a2a56def 100644 --- a/lib/LuxLib/test/common_ops/activation_tests.jl +++ b/lib/LuxLib/test/common_ops/activation_tests.jl @@ -30,18 +30,18 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(apply_act(f, x)) isa Any - @test @inferred(apply_act_fast(f, x)) isa Any - @test @inferred(apply_act_fast2(f, x)) isa Any + @constinferred apply_act(f, x) + @constinferred apply_act_fast(f, x) + @constinferred apply_act_fast2(f, x) @jet apply_act_fast(f, x) @jet apply_act_fast2(f, x) - @test @inferred(Zygote.gradient(apply_act, f, x)) isa Any + @constinferred Zygote.gradient(apply_act, f, x) if f !== lisht - @test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast, f, x) end - @test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any + @constinferred Zygote.gradient(apply_act_fast2, f, x) @test_gradients(apply_act, f, x; atol, rtol) @test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()]) diff --git a/lib/LuxLib/test/common_ops/bias_act_tests.jl b/lib/LuxLib/test/common_ops/bias_act_tests.jl index 4e0e51ced..1e932f3d9 100644 --- a/lib/LuxLib/test/common_ops/bias_act_tests.jl +++ b/lib/LuxLib/test/common_ops/bias_act_tests.jl @@ -5,12 +5,6 @@ bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b)) bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b)) - struct __Fix1{F, A} - f::F - act::A - end - (f::__Fix1)(x, b) = f.f(f.act, x, b) - @testset "$mode" for (mode, aType, ongpu, fp64) in MODES @testset "$act, $T, $sz" for act in [ identity, relu, sigmoid, sigmoid_fast, softplus, @@ -27,9 +21,8 @@ y2 = bias_act_loss2(act, x, b) y3 = bias_act_loss3(act, x, b) - fp16 = T == Float16 - atol = fp16 ? 1.0f-2 : 1.0f-3 - rtol = fp16 ? 1.0f-2 : 1.0f-3 + atol = 1.0f-3 + rtol = 1.0f-3 @test y1≈y2 atol=atol rtol=rtol @test y1≈y3 atol=atol rtol=rtol @@ -37,28 +30,25 @@ @test eltype(y2) == T @test eltype(y3) == T - @test @inferred(bias_act_loss1(act, x, b)) isa Any - @test @inferred(bias_act_loss2(act, x, b)) isa Any - @test @inferred(bias_act_loss3(act, x, b)) isa Any + @constinferred bias_act_loss1(act, x, b) + @constinferred bias_act_loss2(act, x, b) + @constinferred bias_act_loss3(act, x, b) @jet bias_act_loss2(act, x, b) @jet bias_act_loss3(act, x, b) - if act !== lisht && T != Float16 - @test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any - @test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any + if act !== lisht + @constinferred Zygote.gradient(bias_act_loss2, act, x, b) + @constinferred Zygote.gradient(bias_act_loss3, act, x, b) end - @test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) - @test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol, - soft_fail=fp16 ? [AutoFiniteDiff()] : []) + @test_gradients(bias_act_loss1, act, x, b; atol, rtol) + @test_gradients(bias_act_loss2, act, x, b; atol, rtol) + @test_gradients(bias_act_loss3, act, x, b; atol, rtol) - ∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b) - ∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b) - ∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b) + _, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b) + _, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b) + _, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b) @test ∂x1≈∂x2 atol=atol rtol=rtol @test ∂x1≈∂x3 atol=atol rtol=rtol diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index b58aafcd3..9fff43364 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,5 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras expand(_, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -43,20 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, @test eltype(y) == promote_type(Tw, Tx) - @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any + @constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims) @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) if mode != "amdgpu" && activation !== anonact - @test @inferred(Zygote.gradient( - sumabs2conv, activation, weight, x, bias, cdims - )) isa Any + @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) else try - @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) - @test true + @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) catch e e isa ErrorException || rethrow() - @test_broken false + @constinferred_broken Zygote.gradient( + sumabs2conv, activation, weight, x, bias, cdims + ) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index bc4d40e55..6e65b4654 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,5 +1,5 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras anonact = x -> x^3 @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any + @constinferred fused_dense_bias_activation(activation, w, x, bias) @jet fused_dense_bias_activation(activation, w, x, bias) atol = 1.0f-3 rtol = 1.0f-3 if activation !== anonact - @test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any + @constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias) end skip_backends = [] @@ -117,23 +117,23 @@ end end @testitem "Fused Dense: StaticArrays" tags=[:dense] begin - using StaticArrays, NNlib + using StaticArrays, NNlib, TestExtras x = @SArray rand(2, 4) weight = @SArray rand(3, 2) bias = @SArray rand(3) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) end @testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin - using JLArrays, NNlib + using JLArrays, NNlib, TestExtras x = JLArray(rand(Float32, 2, 4)) weight = JLArray(rand(Float32, 3, 2)) bias = JLArray(rand(Float32, 3)) - @test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray + @constinferred fused_dense_bias_activation(relu, weight, x, bias) @test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp end diff --git a/lib/LuxLib/test/common_ops/dropout_tests.jl b/lib/LuxLib/test/common_ops/dropout_tests.jl index 1ec9b4618..e1de98c7e 100644 --- a/lib/LuxLib/test/common_ops/dropout_tests.jl +++ b/lib/LuxLib/test/common_ops/dropout_tests.jl @@ -10,7 +10,7 @@ x = randn(rng, T, x_shape) |> aType - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims) @@ -21,10 +21,10 @@ @test rng != rng_ @jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims))) - @test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any + @constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims) __f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3) @@ -54,8 +54,7 @@ end mask = rand(T, x_shape) |> aType # Update mask - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(true), T(2), :) @@ -69,7 +68,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true), @@ -79,8 +78,7 @@ end rng, x, mask, T(0.5), Val(true), Val(true), T(2), :))) # Try using mask if possible (possible!!) - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(true), Val(false), T(2), :) @@ -94,7 +92,7 @@ end __f = (x, mask) -> sum(first(dropout( StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :))) - @test @inferred(Zygote.gradient(__f, x, mask)) isa Any + @constinferred Zygote.gradient(__f, x, mask) @test_gradients(sumabs2first, dropout, rng, x, LuxTestUtils.Constant(mask), @@ -107,8 +105,7 @@ end mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType # Testing Mode - @test @inferred(dropout( - rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any + @constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) y, mask_, rng_ = dropout( rng, x, mask, T(0.5), Val(false), Val(false), T(2), :) @@ -135,7 +132,7 @@ end x = randn(rng, T, x_shape) |> aType - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(true)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true)) @@ -146,13 +143,13 @@ end @test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2 __f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true)))) - @test @inferred(Zygote.gradient(__f, x)) isa Any + @constinferred Zygote.gradient(__f, x) @test_gradients(sumabs2first, alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3) @jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true)))) - @test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any + @constinferred alpha_dropout(rng, x, T(0.5), Val(false)) y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false)) diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index 58b6196c1..d47c542d6 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module BatchNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool) x = gen_f(T, sz) |> aType @@ -69,8 +69,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end end - @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa - Any + @constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @@ -91,8 +90,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, if anonact !== act lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( x, sc, b, rm, rv, tr, act, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon) end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index c103595f9..891c68715 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras using LuxTestUtils: check_approx function setup_groupnorm(rng, aType, T, sz, affine) @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) @test ∂bias≈∂bias_simple atol=atol rtol=rtol end - @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any + @constinferred groupnorm(x, scale, bias, groups, act, epsilon) @jet groupnorm(x, scale, bias, groups, act, epsilon) if anonact !== act lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon) end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index dd999ff09..cc8b1e81b 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras is_training(::Val{training}) where {training} = training @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) atol = 1.0f-2 rtol = 1.0f-2 - @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any + @constinferred instancenorm(x, scale, bias, training, act, epsilon) @jet instancenorm(x, scale, bias, training, act, epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon) end @test y isa aType{T, length(sz)} @@ -46,15 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) - @test @inferred(instancenorm( - x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any + @constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( x, sc, b, rm, rv, Val(true), act, m, ϵ))) - @test @inferred(Zygote.gradient( - lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon) end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 6b82390a4..940e95c06 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module LayerNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Statistics, TestExtras using LuxTestUtils: check_approx function setup_layernorm(gen_f, aType, T, x_size, affine_shape, expand_dims::Bool=true) @@ -40,7 +40,7 @@ function run_layernorm_testing_core( epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any + @constinferred layernorm(x, scale, bias, act, dims, epsilon) @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) @@ -61,7 +61,7 @@ function run_layernorm_testing_core( if anonact !== act lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon)) isa Any + @constinferred Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon) end end diff --git a/lib/LuxLib/test/shared_testsetup.jl b/lib/LuxLib/test/shared_testsetup.jl index 77cdab470..c2072420f 100644 --- a/lib/LuxLib/test/shared_testsetup.jl +++ b/lib/LuxLib/test/shared_testsetup.jl @@ -2,7 +2,7 @@ import Reexport: @reexport using LuxLib, MLDataDevices -@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib +@reexport using LuxTestUtils, StableRNGs, Test, Enzyme, Zygote, NNlib, TestExtras LuxTestUtils.jet_target_modules!(["LuxLib"]) From a5ce2019cb5c066220fa3fa274488288fe2831b3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 22 Nov 2024 13:11:21 -0500 Subject: [PATCH 4/5] fix: missing import --- test/zygote_type_stability.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/zygote_type_stability.jl b/test/zygote_type_stability.jl index 8a37d52c0..cd7af49d1 100644 --- a/test/zygote_type_stability.jl +++ b/test/zygote_type_stability.jl @@ -1,4 +1,4 @@ -using Lux, Random, Zygote, StableRNGs, Test +using Lux, Random, Zygote, StableRNGs, Test, TestExtras include("setup_modes.jl") From 21dde5449a5b0a53be4e586bfd7d3bb51aa6bee5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 27 Nov 2024 01:21:16 -0500 Subject: [PATCH 5/5] test: try bypassing the world age issues --- lib/LuxLib/test/common_ops/conv_tests.jl | 12 ++++++------ lib/LuxLib/test/common_ops/dense_tests.jl | 6 +++--- lib/LuxLib/test/normalization/batchnorm_tests.jl | 8 ++++---- lib/LuxLib/test/normalization/groupnorm_tests.jl | 8 ++++---- .../test/normalization/instancenorm_tests.jl | 15 ++++++++------- lib/LuxLib/test/normalization/layernorm_tests.jl | 6 +++--- test/Project.toml | 2 +- 7 files changed, 29 insertions(+), 28 deletions(-) diff --git a/lib/LuxLib/test/common_ops/conv_tests.jl b/lib/LuxLib/test/common_ops/conv_tests.jl index 9fff43364..ee223c09b 100644 --- a/lib/LuxLib/test/common_ops/conv_tests.jl +++ b/lib/LuxLib/test/common_ops/conv_tests.jl @@ -1,5 +1,5 @@ @testsetup module ConvSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib expand(_, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) @@ -43,19 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding, @test eltype(y) == promote_type(Tw, Tx) - @constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims) + @test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any @jet fused_conv_bias_activation(activation, weight, x, bias, cdims) if mode != "amdgpu" && activation !== anonact - @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any else try - @constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims) + @test @inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)) isa Any catch e e isa ErrorException || rethrow() - @constinferred_broken Zygote.gradient( + @test_broken @inferred(Zygote.gradient( sumabs2conv, activation, weight, x, bias, cdims - ) + )) end end diff --git a/lib/LuxLib/test/common_ops/dense_tests.jl b/lib/LuxLib/test/common_ops/dense_tests.jl index 6e65b4654..9689a5ca8 100644 --- a/lib/LuxLib/test/common_ops/dense_tests.jl +++ b/lib/LuxLib/test/common_ops/dense_tests.jl @@ -1,5 +1,5 @@ @testsetup module DenseSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs anonact = x -> x^3 @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu @test y ≈ y_generic @test eltype(y) == promote_type(Tw, Tx) - @constinferred fused_dense_bias_activation(activation, w, x, bias) + @test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any @jet fused_dense_bias_activation(activation, w, x, bias) atol = 1.0f-3 rtol = 1.0f-3 if activation !== anonact - @constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias) + @test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any end skip_backends = [] diff --git a/lib/LuxLib/test/normalization/batchnorm_tests.jl b/lib/LuxLib/test/normalization/batchnorm_tests.jl index d47c542d6..ea8cb02e2 100644 --- a/lib/LuxLib/test/normalization/batchnorm_tests.jl +++ b/lib/LuxLib/test/normalization/batchnorm_tests.jl @@ -69,7 +69,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end end - @constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) + @test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa + Any @jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon) @test y isa aType{T, length(sz)} @@ -88,9 +89,8 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act, end if anonact !== act - lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm( - x, sc, b, rm, rv, tr, act, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon) + @test @inferred(Zygote.gradient( + sumabs2first, x, scale, bias, rm, rv, training, act, epsilon)) isa Any end end diff --git a/lib/LuxLib/test/normalization/groupnorm_tests.jl b/lib/LuxLib/test/normalization/groupnorm_tests.jl index 891c68715..6302bc6dd 100644 --- a/lib/LuxLib/test/normalization/groupnorm_tests.jl +++ b/lib/LuxLib/test/normalization/groupnorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module GroupNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs using LuxTestUtils: check_approx function setup_groupnorm(rng, aType, T, sz, affine) @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu) @test ∂bias≈∂bias_simple atol=atol rtol=rtol end - @constinferred groupnorm(x, scale, bias, groups, act, epsilon) + @test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any @jet groupnorm(x, scale, bias, groups, act, epsilon) if anonact !== act - lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ)) - @constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon) + @test @inferred(Zygote.gradient( + sumabs2groupnorm, x, scale, bias, groups, act, epsilon)) isa Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/instancenorm_tests.jl b/lib/LuxLib/test/normalization/instancenorm_tests.jl index cc8b1e81b..a0e9e2130 100644 --- a/lib/LuxLib/test/normalization/instancenorm_tests.jl +++ b/lib/LuxLib/test/normalization/instancenorm_tests.jl @@ -1,5 +1,5 @@ @testsetup module InstanceNormSetup -using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras +using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib is_training(::Val{training}) where {training} = training @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) atol = 1.0f-2 rtol = 1.0f-2 - @constinferred instancenorm(x, scale, bias, training, act, epsilon) + @test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any @jet instancenorm(x, scale, bias, training, act, epsilon) if anonact !== act && is_training(training) lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon) + @test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any end @test y isa aType{T, length(sz)} @@ -46,13 +46,14 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType) y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) - @constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) + @test @inferred(instancenorm( + x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any @jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon) if anonact !== act && is_training(training) - lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm( - x, sc, b, rm, rv, Val(true), act, m, ϵ))) - @constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon) + @test @inferred(Zygote.gradient( + sumabs2instancenorm, x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa + Any end @test y isa aType{T, length(sz)} diff --git a/lib/LuxLib/test/normalization/layernorm_tests.jl b/lib/LuxLib/test/normalization/layernorm_tests.jl index 940e95c06..43b989615 100644 --- a/lib/LuxLib/test/normalization/layernorm_tests.jl +++ b/lib/LuxLib/test/normalization/layernorm_tests.jl @@ -40,7 +40,7 @@ function run_layernorm_testing_core( epsilon = LuxLib.Utils.default_epsilon(T) _f = (args...) -> layernorm(args..., act, dims, epsilon) - @constinferred layernorm(x, scale, bias, act, dims, epsilon) + @test @inferred(layernorm(x, scale, bias, act, dims, epsilon)) isa Any @jet layernorm(x, scale, bias, act, dims, epsilon) y = _f(x, scale, bias) @@ -60,8 +60,8 @@ function run_layernorm_testing_core( soft_fail=[AutoFiniteDiff()]) if anonact !== act - lfn = (x, sc, b, act, dim, ϵ) -> sum(layernorm(x, sc, b, act, dim, ϵ)) - @constinferred Zygote.gradient(lfn, x, scale, bias, act, dims, epsilon) + @test @inferred(Zygote.gradient( + sumabs2layernorm, x, scale, bias, act, dims, epsilon)) isa Any end end diff --git a/test/Project.toml b/test/Project.toml index c9e33f67d..9466d196e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -78,7 +78,7 @@ SimpleChains = "0.4.7" StableRNGs = "1.0.2" Static = "1" StaticArrays = "1.9" -Statistics = "1.11.1" +Statistics = "1.10" Test = "1.10" TestExtras = "0.3.1" Tracker = "0.2.36"