From 6a317e0dfbf737e9705ebfe2f3518683029471be Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 10 Aug 2023 18:29:55 +0200 Subject: [PATCH] Simpler tests that work with SparseArrays (#114) * Simpler tests that work with SparseArrays * Add tests for sparse vector in addition to sparse matrix * Document sparse array warning * Skip sparse direct in 1.6 --- docs/src/faq.md | 9 ++-- test/errors.jl | 6 +-- test/systematic.jl | 124 ++++++++++++++++++++++++--------------------- 3 files changed, 75 insertions(+), 64 deletions(-) diff --git a/docs/src/faq.md b/docs/src/faq.md index 3b1328d..4c6b043 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -23,7 +23,7 @@ You can specify it with the `conditions_backend` keyword argument when construct Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size. -If the output is a small array (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance. +If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance. ### Scalars @@ -34,7 +34,10 @@ Or better yet, wrap it in a static vector: `SVector(val)`. ### Sparse arrays !!! danger "Danger" - Sparse arrays are not supported and might give incorrect values or `NaN`s! + Sparse arrays are not officially supported and might give incorrect values or `NaN`s! + +With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658). +With Zygote.jl it appears to work, but this functionality is considered experimental and might evolve. ## Number of inputs and outputs @@ -45,7 +48,7 @@ Well, it depends whether you want their derivatives or not. | | Derivatives needed | Derivatives not needed | | -------------------- | --------------------------------------- | --------------------------------------- | | **Multiple inputs** | Make `x` a `ComponentVector` | Supply `args` and `kwargs` to `forward` | -| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct | +| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct | We now detail each of these options. diff --git a/test/errors.jl b/test/errors.jl index 1ca201a..b0040bc 100644 --- a/test/errors.jl +++ b/test/errors.jl @@ -59,9 +59,9 @@ end end @testset "Weird ChainRulesTestUtils behavior" begin - x = rand(2, 3) - forward(x) = sqrt.(abs.(x)), 2 - conditions(x, y, z) = abs.(y) .^ z .- abs.(x) + x = rand(3) + forward(x) = sqrt.(abs.(x)), 1 + conditions(x, y, z) = abs.(y ./ z) .- abs.(x) implicit = ImplicitFunction(forward, conditions) y, z = implicit(x) dy = similar(y) diff --git a/test/systematic.jl b/test/systematic.jl index 1b20019..7dd3151 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -27,41 +27,38 @@ Random.seed!(63); ## Utils change_shape(x::AbstractArray{T,3}) where {T} = x[:, :, 1] +change_shape(x::AbstractSparseArray) = x function mysqrt(x::AbstractArray) - return identity_break_autodiff(sqrt.(abs.(change_shape(x)))) -end - -function mypower(x::AbstractArray, p) - return identity_break_autodiff(abs.(change_shape(x)) .^ p) + return identity_break_autodiff(sqrt.(abs.(x))) end ## Various signatures function make_implicit_sqrt(; kwargs...) - forward(x) = mysqrt(x) - conditions(x, y) = y .^ 2 .- abs.(change_shape(x)) + forward(x) = mysqrt(change_shape(x)) + conditions(x, y) = abs2.(y) .- abs.(change_shape(x)) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_byproduct(; kwargs...) - forward(x) = mysqrt(x), 2 - conditions(x, y, z::Integer) = y .^ z .- abs.(change_shape(x)) + forward(x) = 1 * mysqrt(change_shape(x)), 1 + conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(change_shape(x)) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end -function make_implicit_power_args(; kwargs...) - forward(x, p::Integer) = mypower(x, one(eltype(x)) / p) - conditions(x, y, p::Integer) = y .^ p .- abs.(change_shape(x)) +function make_implicit_sqrt_args(; kwargs...) + forward(x, p::Integer) = p * mysqrt(change_shape(x)) + conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end -function make_implicit_power_kwargs(; kwargs...) - forward(x; p::Integer) = mypower(x, one(eltype(x)) / p) - conditions(x, y; p::Integer) = y .^ p .- abs.(change_shape(x)) +function make_implicit_sqrt_kwargs(; kwargs...) + forward(x; p::Integer) = p .* mysqrt(change_shape(x)) + conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end @@ -85,21 +82,21 @@ end function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_power_args(; kwargs...) - imf4 = make_implicit_power_kwargs(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(x) + y_true = mysqrt(change_shape(x)) y1 = @inferred imf1(x) y2, z2 = @inferred imf2(x) - y3 = @inferred imf3(x, 2) - y4 = @inferred imf4(x; p=2) + y3 = @inferred imf3(x, 1) + y4 = @inferred imf4(x; p=1) @testset "Exact value" begin @test y1 ≈ y_true @test y2 ≈ y_true @test y3 ≈ y_true @test y4 ≈ y_true - @test z2 ≈ 2 + @test z2 ≈ 1 end @testset "Array type" begin @@ -112,38 +109,38 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} @testset "JET" begin @test_opt target_modules = (ID,) imf1(x) @test_opt target_modules = (ID,) imf2(x) - @test_opt target_modules = (ID,) imf3(x, 2) - @test_opt target_modules = (ID,) imf4(x; p=2) + @test_opt target_modules = (ID,) imf3(x, 1) + @test_opt target_modules = (ID,) imf4(x; p=1) @test_call target_modules = (ID,) imf1(x) @test_call target_modules = (ID,) imf2(x) - @test_call target_modules = (ID,) imf3(x, 2) - @test_call target_modules = (ID,) imf4(x; p=2) + @test_call target_modules = (ID,) imf3(x, 1) + @test_call target_modules = (ID,) imf4(x; p=1) end end function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_power_args(; kwargs...) - imf4 = make_implicit_power_kwargs(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(x) + y_true = mysqrt(change_shape(x)) dx = similar(x) dx .= one(T) x_and_dx = ForwardDiff.Dual.(x, dx) y_and_dy1 = @inferred imf1(x_and_dx) y_and_dy2, z2 = @inferred imf2(x_and_dx) - y_and_dy3 = @inferred imf3(x_and_dx, 2) - y_and_dy4 = @inferred imf4(x_and_dx; p=2) + y_and_dy3 = @inferred imf3(x_and_dx, 1) + y_and_dy4 = @inferred imf4(x_and_dx; p=1) @testset "Dual numbers" begin @test ForwardDiff.value.(y_and_dy1) ≈ y_true @test ForwardDiff.value.(y_and_dy2) ≈ y_true @test ForwardDiff.value.(y_and_dy3) ≈ y_true @test ForwardDiff.value.(y_and_dy4) ≈ y_true - @test z2 ≈ 2 + @test z2 ≈ 1 end @testset "Static arrays" begin @@ -156,31 +153,31 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} @testset "JET" begin @test_opt target_modules = (ID,) imf1(x_and_dx) @test_opt target_modules = (ID,) imf2(x_and_dx) - @test_opt target_modules = (ID,) imf3(x_and_dx, 2) - @test_opt target_modules = (ID,) imf4(x_and_dx; p=2) + @test_opt target_modules = (ID,) imf3(x_and_dx, 1) + @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) @test_call target_modules = (ID,) imf1(x_and_dx) @test_call target_modules = (ID,) imf2(x_and_dx) - @test_call target_modules = (ID,) imf3(x_and_dx, 2) - @test_call target_modules = (ID,) imf4(x_and_dx; p=2) + @test_call target_modules = (ID,) imf3(x_and_dx, 1) + @test_call target_modules = (ID,) imf4(x_and_dx; p=1) end end function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_power_args(; kwargs...) - imf4 = make_implicit_power_kwargs(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(x) + y_true = mysqrt(change_shape(x)) dy = similar(y_true) dy .= one(eltype(y_true)) dz = nothing y1, pb1 = @inferred rrule(rc, imf1, x) (y2, z2), pb2 = @inferred rrule(rc, imf2, x) - y3, pb3 = @inferred rrule(rc, imf3, x, 2) - y4, pb4 = @inferred rrule(rc, imf4, x; p=2) + y3, pb3 = @inferred rrule(rc, imf3, x, 1) + y4, pb4 = @inferred rrule(rc, imf4, x; p=1) dimf1, dx1 = @inferred pb1(dy) dimf2, dx2 = @inferred pb2((dy, dz)) @@ -192,7 +189,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @test y2 ≈ y_true @test y3 ≈ y_true @test y4 ≈ y_true - @test z2 ≈ 2 + @test z2 ≈ 1 @test dimf1 isa NoTangent @test dimf2 isa NoTangent @@ -222,8 +219,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @testset "JET" begin @test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x) @test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 2) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=2) + @test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) + @test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) @test_skip @test_opt target_modules = (ID,) pb1(dy) @test_skip @test_opt target_modules = (ID,) pb2((dy, dz)) @@ -232,8 +229,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @test_call target_modules = (ID,) rrule(rc, imf1, x) @test_call target_modules = (ID,) rrule(rc, imf2, x) - @test_call target_modules = (ID,) rrule(rc, imf3, x, 2) - @test_call target_modules = (ID,) rrule(rc, imf4, x; p=2) + @test_call target_modules = (ID,) rrule(rc, imf3, x, 1) + @test_call target_modules = (ID,) rrule(rc, imf4, x; p=1) @test_call target_modules = (ID,) pb1(dy) @test_call target_modules = (ID,) pb2((dy, dz)) @@ -244,8 +241,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @testset "ChainRulesTestUtils" begin test_rrule(rc, imf1, x; atol=1e-2) test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0)) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112 - test_rrule(rc, imf3, x, 2; atol=1e-2) - test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=2,)) + test_rrule(rc, imf3, x, 1; atol=1e-2) + test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,)) end end @@ -254,13 +251,13 @@ end function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_power_args(; kwargs...) - imf4 = make_implicit_power_kwargs(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) J1 = ForwardDiff.jacobian(imf1, x) J2 = ForwardDiff.jacobian(first ∘ imf2, x) - J3 = ForwardDiff.jacobian(_x -> imf3(_x, 2), x) - J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=2), x) + J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x) + J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x) J_true = ForwardDiff.jacobian(_x -> sqrt.(change_shape(_x)), x) @testset "Exact Jacobian" begin @@ -280,13 +277,13 @@ end function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_power_args(; kwargs...) - imf4 = make_implicit_power_kwargs(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) J1 = Zygote.jacobian(imf1, x)[1] J2 = Zygote.jacobian(first ∘ imf2, x)[1] - J3 = Zygote.jacobian(imf3, x, 2)[1] - J4 = Zygote.jacobian(_x -> imf4(_x; p=2), x)[1] + J3 = Zygote.jacobian(imf3, x, 1)[1] + J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1] J_true = Zygote.jacobian(_x -> sqrt.(change_shape(_x)), x)[1] @testset "Exact Jacobian" begin @@ -308,8 +305,10 @@ function test_implicit(x; kwargs...) test_implicit_call(x; kwargs...) end @testset verbose = true "ForwardDiff.jl" begin - test_implicit_forwarddiff(x; kwargs...) - test_implicit_duals(x; kwargs...) + if !(x isa AbstractSparseArray) + test_implicit_forwarddiff(x; kwargs...) + test_implicit_duals(x; kwargs...) + end end @testset verbose = true "Zygote.jl" begin rc = Zygote.ZygoteRuleConfig() @@ -337,6 +336,8 @@ conditions_backend_candidates = ( x_candidates = ( rand(Float32, 2, 3, 2), # SArray{Tuple{2,3,2}}(rand(Float32, 2, 3, 2)), # + sparse(rand(Float32, 2)), # + sparse(rand(Float32, 2, 3)), # ); params_candidates = [] @@ -366,8 +367,15 @@ end for (linear_solver, conditions_backend, x) in params_candidates testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))" + if ( + linear_solver isa DirectLinearSolver && + x isa AbstractSparseArray && + VERSION < v"1.9" + ) # missing linalg function for sparse arrays in 1.6 + continue + end @info "$testsetname" - @testset "$testsetname" begin + @testset verbose = true "$testsetname" begin test_implicit(x; linear_solver, conditions_backend) end end