Skip to content

Commit

Permalink
Simpler tests that work with SparseArrays (#114)
Browse files Browse the repository at this point in the history
* Simpler tests that work with SparseArrays

* Add tests for sparse vector in addition to sparse matrix

* Document sparse array warning

* Skip sparse direct in 1.6
  • Loading branch information
gdalle authored Aug 10, 2023
1 parent e725c2f commit 6a317e0
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 64 deletions.
9 changes: 6 additions & 3 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ You can specify it with the `conditions_backend` keyword argument when construct

Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.

If the output is a small array (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.

### Scalars

Expand All @@ -34,7 +34,10 @@ Or better yet, wrap it in a static vector: `SVector(val)`.
### Sparse arrays

!!! danger "Danger"
Sparse arrays are not supported and might give incorrect values or `NaN`s!
Sparse arrays are not officially supported and might give incorrect values or `NaN`s!

With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658).
With Zygote.jl it appears to work, but this functionality is considered experimental and might evolve.

## Number of inputs and outputs

Expand All @@ -45,7 +48,7 @@ Well, it depends whether you want their derivatives or not.
| | Derivatives needed | Derivatives not needed |
| -------------------- | --------------------------------------- | --------------------------------------- |
| **Multiple inputs** | Make `x` a `ComponentVector` | Supply `args` and `kwargs` to `forward` |
| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct |
| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct |

We now detail each of these options.

Expand Down
6 changes: 3 additions & 3 deletions test/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ end
end

@testset "Weird ChainRulesTestUtils behavior" begin
x = rand(2, 3)
forward(x) = sqrt.(abs.(x)), 2
conditions(x, y, z) = abs.(y) .^ z .- abs.(x)
x = rand(3)
forward(x) = sqrt.(abs.(x)), 1
conditions(x, y, z) = abs.(y ./ z) .- abs.(x)
implicit = ImplicitFunction(forward, conditions)
y, z = implicit(x)
dy = similar(y)
Expand Down
124 changes: 66 additions & 58 deletions test/systematic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,41 +27,38 @@ Random.seed!(63);
## Utils

change_shape(x::AbstractArray{T,3}) where {T} = x[:, :, 1]
change_shape(x::AbstractSparseArray) = x

function mysqrt(x::AbstractArray)
return identity_break_autodiff(sqrt.(abs.(change_shape(x))))
end

function mypower(x::AbstractArray, p)
return identity_break_autodiff(abs.(change_shape(x)) .^ p)
return identity_break_autodiff(sqrt.(abs.(x)))
end

## Various signatures

function make_implicit_sqrt(; kwargs...)
forward(x) = mysqrt(x)
conditions(x, y) = y .^ 2 .- abs.(change_shape(x))
forward(x) = mysqrt(change_shape(x))
conditions(x, y) = abs2.(y) .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_sqrt_byproduct(; kwargs...)
forward(x) = mysqrt(x), 2
conditions(x, y, z::Integer) = y .^ z .- abs.(change_shape(x))
forward(x) = 1 * mysqrt(change_shape(x)), 1
conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_power_args(; kwargs...)
forward(x, p::Integer) = mypower(x, one(eltype(x)) / p)
conditions(x, y, p::Integer) = y .^ p .- abs.(change_shape(x))
function make_implicit_sqrt_args(; kwargs...)
forward(x, p::Integer) = p * mysqrt(change_shape(x))
conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end

function make_implicit_power_kwargs(; kwargs...)
forward(x; p::Integer) = mypower(x, one(eltype(x)) / p)
conditions(x, y; p::Integer) = y .^ p .- abs.(change_shape(x))
function make_implicit_sqrt_kwargs(; kwargs...)
forward(x; p::Integer) = p .* mysqrt(change_shape(x))
conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x))
implicit = ImplicitFunction(forward, conditions; kwargs...)
return implicit
end
Expand All @@ -85,21 +82,21 @@ end
function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T}
imf1 = make_implicit_sqrt(; kwargs...)
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
imf3 = make_implicit_power_args(; kwargs...)
imf4 = make_implicit_power_kwargs(; kwargs...)
imf3 = make_implicit_sqrt_args(; kwargs...)
imf4 = make_implicit_sqrt_kwargs(; kwargs...)

y_true = mysqrt(x)
y_true = mysqrt(change_shape(x))
y1 = @inferred imf1(x)
y2, z2 = @inferred imf2(x)
y3 = @inferred imf3(x, 2)
y4 = @inferred imf4(x; p=2)
y3 = @inferred imf3(x, 1)
y4 = @inferred imf4(x; p=1)

@testset "Exact value" begin
@test y1 y_true
@test y2 y_true
@test y3 y_true
@test y4 y_true
@test z2 2
@test z2 1
end

@testset "Array type" begin
Expand All @@ -112,38 +109,38 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T}
@testset "JET" begin
@test_opt target_modules = (ID,) imf1(x)
@test_opt target_modules = (ID,) imf2(x)
@test_opt target_modules = (ID,) imf3(x, 2)
@test_opt target_modules = (ID,) imf4(x; p=2)
@test_opt target_modules = (ID,) imf3(x, 1)
@test_opt target_modules = (ID,) imf4(x; p=1)

@test_call target_modules = (ID,) imf1(x)
@test_call target_modules = (ID,) imf2(x)
@test_call target_modules = (ID,) imf3(x, 2)
@test_call target_modules = (ID,) imf4(x; p=2)
@test_call target_modules = (ID,) imf3(x, 1)
@test_call target_modules = (ID,) imf4(x; p=1)
end
end

function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T}
imf1 = make_implicit_sqrt(; kwargs...)
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
imf3 = make_implicit_power_args(; kwargs...)
imf4 = make_implicit_power_kwargs(; kwargs...)
imf3 = make_implicit_sqrt_args(; kwargs...)
imf4 = make_implicit_sqrt_kwargs(; kwargs...)

y_true = mysqrt(x)
y_true = mysqrt(change_shape(x))
dx = similar(x)
dx .= one(T)
x_and_dx = ForwardDiff.Dual.(x, dx)

y_and_dy1 = @inferred imf1(x_and_dx)
y_and_dy2, z2 = @inferred imf2(x_and_dx)
y_and_dy3 = @inferred imf3(x_and_dx, 2)
y_and_dy4 = @inferred imf4(x_and_dx; p=2)
y_and_dy3 = @inferred imf3(x_and_dx, 1)
y_and_dy4 = @inferred imf4(x_and_dx; p=1)

@testset "Dual numbers" begin
@test ForwardDiff.value.(y_and_dy1) y_true
@test ForwardDiff.value.(y_and_dy2) y_true
@test ForwardDiff.value.(y_and_dy3) y_true
@test ForwardDiff.value.(y_and_dy4) y_true
@test z2 2
@test z2 1
end

@testset "Static arrays" begin
Expand All @@ -156,31 +153,31 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T}
@testset "JET" begin
@test_opt target_modules = (ID,) imf1(x_and_dx)
@test_opt target_modules = (ID,) imf2(x_and_dx)
@test_opt target_modules = (ID,) imf3(x_and_dx, 2)
@test_opt target_modules = (ID,) imf4(x_and_dx; p=2)
@test_opt target_modules = (ID,) imf3(x_and_dx, 1)
@test_opt target_modules = (ID,) imf4(x_and_dx; p=1)

@test_call target_modules = (ID,) imf1(x_and_dx)
@test_call target_modules = (ID,) imf2(x_and_dx)
@test_call target_modules = (ID,) imf3(x_and_dx, 2)
@test_call target_modules = (ID,) imf4(x_and_dx; p=2)
@test_call target_modules = (ID,) imf3(x_and_dx, 1)
@test_call target_modules = (ID,) imf4(x_and_dx; p=1)
end
end

function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
imf1 = make_implicit_sqrt(; kwargs...)
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
imf3 = make_implicit_power_args(; kwargs...)
imf4 = make_implicit_power_kwargs(; kwargs...)
imf3 = make_implicit_sqrt_args(; kwargs...)
imf4 = make_implicit_sqrt_kwargs(; kwargs...)

y_true = mysqrt(x)
y_true = mysqrt(change_shape(x))
dy = similar(y_true)
dy .= one(eltype(y_true))
dz = nothing

y1, pb1 = @inferred rrule(rc, imf1, x)
(y2, z2), pb2 = @inferred rrule(rc, imf2, x)
y3, pb3 = @inferred rrule(rc, imf3, x, 2)
y4, pb4 = @inferred rrule(rc, imf4, x; p=2)
y3, pb3 = @inferred rrule(rc, imf3, x, 1)
y4, pb4 = @inferred rrule(rc, imf4, x; p=1)

dimf1, dx1 = @inferred pb1(dy)
dimf2, dx2 = @inferred pb2((dy, dz))
Expand All @@ -192,7 +189,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
@test y2 y_true
@test y3 y_true
@test y4 y_true
@test z2 2
@test z2 1

@test dimf1 isa NoTangent
@test dimf2 isa NoTangent
Expand Down Expand Up @@ -222,8 +219,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
@testset "JET" begin
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 2)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=2)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1)
@test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1)

@test_skip @test_opt target_modules = (ID,) pb1(dy)
@test_skip @test_opt target_modules = (ID,) pb2((dy, dz))
Expand All @@ -232,8 +229,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}

@test_call target_modules = (ID,) rrule(rc, imf1, x)
@test_call target_modules = (ID,) rrule(rc, imf2, x)
@test_call target_modules = (ID,) rrule(rc, imf3, x, 2)
@test_call target_modules = (ID,) rrule(rc, imf4, x; p=2)
@test_call target_modules = (ID,) rrule(rc, imf3, x, 1)
@test_call target_modules = (ID,) rrule(rc, imf4, x; p=1)

@test_call target_modules = (ID,) pb1(dy)
@test_call target_modules = (ID,) pb2((dy, dz))
Expand All @@ -244,8 +241,8 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T}
@testset "ChainRulesTestUtils" begin
test_rrule(rc, imf1, x; atol=1e-2)
test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0)) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112
test_rrule(rc, imf3, x, 2; atol=1e-2)
test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=2,))
test_rrule(rc, imf3, x, 1; atol=1e-2)
test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,))
end
end

Expand All @@ -254,13 +251,13 @@ end
function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T}
imf1 = make_implicit_sqrt(; kwargs...)
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
imf3 = make_implicit_power_args(; kwargs...)
imf4 = make_implicit_power_kwargs(; kwargs...)
imf3 = make_implicit_sqrt_args(; kwargs...)
imf4 = make_implicit_sqrt_kwargs(; kwargs...)

J1 = ForwardDiff.jacobian(imf1, x)
J2 = ForwardDiff.jacobian(first imf2, x)
J3 = ForwardDiff.jacobian(_x -> imf3(_x, 2), x)
J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=2), x)
J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x)
J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x)
J_true = ForwardDiff.jacobian(_x -> sqrt.(change_shape(_x)), x)

@testset "Exact Jacobian" begin
Expand All @@ -280,13 +277,13 @@ end
function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T}
imf1 = make_implicit_sqrt(; kwargs...)
imf2 = make_implicit_sqrt_byproduct(; kwargs...)
imf3 = make_implicit_power_args(; kwargs...)
imf4 = make_implicit_power_kwargs(; kwargs...)
imf3 = make_implicit_sqrt_args(; kwargs...)
imf4 = make_implicit_sqrt_kwargs(; kwargs...)

J1 = Zygote.jacobian(imf1, x)[1]
J2 = Zygote.jacobian(first imf2, x)[1]
J3 = Zygote.jacobian(imf3, x, 2)[1]
J4 = Zygote.jacobian(_x -> imf4(_x; p=2), x)[1]
J3 = Zygote.jacobian(imf3, x, 1)[1]
J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1]
J_true = Zygote.jacobian(_x -> sqrt.(change_shape(_x)), x)[1]

@testset "Exact Jacobian" begin
Expand All @@ -308,8 +305,10 @@ function test_implicit(x; kwargs...)
test_implicit_call(x; kwargs...)
end
@testset verbose = true "ForwardDiff.jl" begin
test_implicit_forwarddiff(x; kwargs...)
test_implicit_duals(x; kwargs...)
if !(x isa AbstractSparseArray)
test_implicit_forwarddiff(x; kwargs...)
test_implicit_duals(x; kwargs...)
end
end
@testset verbose = true "Zygote.jl" begin
rc = Zygote.ZygoteRuleConfig()
Expand Down Expand Up @@ -337,6 +336,8 @@ conditions_backend_candidates = (
x_candidates = (
rand(Float32, 2, 3, 2), #
SArray{Tuple{2,3,2}}(rand(Float32, 2, 3, 2)), #
sparse(rand(Float32, 2)), #
sparse(rand(Float32, 2, 3)), #
);

params_candidates = []
Expand Down Expand Up @@ -366,8 +367,15 @@ end

for (linear_solver, conditions_backend, x) in params_candidates
testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))"
if (
linear_solver isa DirectLinearSolver &&
x isa AbstractSparseArray &&
VERSION < v"1.9"
) # missing linalg function for sparse arrays in 1.6
continue
end
@info "$testsetname"
@testset "$testsetname" begin
@testset verbose = true "$testsetname" begin
test_implicit(x; linear_solver, conditions_backend)
end
end

0 comments on commit 6a317e0

Please sign in to comment.