From 9b992dcfdcb1b83f3da234f12e1280350235d180 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 9 Apr 2024 11:59:26 +0200 Subject: [PATCH 1/8] Switch to DifferentiationInterface --- .github/workflows/CI.yml | 6 +- CITATION.bib | 6 +- Project.toml | 52 ++-- benchmark/Project.toml | 13 - benchmark/analysis.jl | 168 ------------ benchmark/benchmarks.jl | 143 ---------- benchmark/judge.jl | 13 - docs/Project.toml | 4 +- examples/0_intro.jl | 21 -- examples/1_basic.jl | 6 +- examples/3_tricks.jl | 2 +- ...mplicitDifferentiationChainRulesCoreExt.jl | 92 ++----- ext/ImplicitDifferentiationEnzymeCoreExt.jl | 5 + ext/ImplicitDifferentiationForwardDiffExt.jl | 68 +---- ext/ImplicitDifferentiationStaticArraysExt.jl | 25 -- ext/ImplicitDifferentiationZygoteExt.jl | 24 -- src/ImplicitDifferentiation.jl | 36 +-- src/implicit_function.jl | 65 ++--- src/linear_solver.jl | 78 ------ src/operators.jl | 251 ++++++++++++------ src/utils.jl | 5 - test/systematic.jl | 9 - 22 files changed, 275 insertions(+), 817 deletions(-) delete mode 100644 benchmark/Project.toml delete mode 100644 benchmark/analysis.jl delete mode 100644 benchmark/benchmarks.jl delete mode 100644 benchmark/judge.jl create mode 100644 ext/ImplicitDifferentiationEnzymeCoreExt.jl delete mode 100644 ext/ImplicitDifferentiationStaticArraysExt.jl delete mode 100644 ext/ImplicitDifferentiationZygoteExt.jl delete mode 100644 src/linear_solver.jl delete mode 100644 src/utils.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f4e70da..113cff1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,13 +18,13 @@ jobs: strategy: fail-fast: false matrix: - version: ['1.6', '1'] + version: ['1.10', '1'] os: [ubuntu-latest] arch: [x64] allow_failure: [false] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} @@ -45,7 +45,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: '1' - uses: julia-actions/cache@v1 diff --git a/CITATION.bib b/CITATION.bib index a9a43d6..981d199 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl author = {Guillaume Dalle, Mohamed Tarek and contributors}, title = {ImplicitDifferentiation.jl}, url = {https://github.com/gdalle/ImplicitDifferentiation.jl}, - version = {v0.5.0}, - year = {2023}, - month = {8} + version = {v0.6.0}, + year = {2024}, + month = {4} } diff --git a/Project.toml b/Project.toml index b7be278..f00f7ae 100644 --- a/Project.toml +++ b/Project.toml @@ -4,40 +4,31 @@ authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] version = "0.5.2" [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" -ImplicitDifferentiationStaticArraysExt = "StaticArrays" -ImplicitDifferentiationZygoteExt = "Zygote" +ImplicitDifferentiationEnzymeCoreExt = "EnzymeCore" [compat] -AbstractDifferentiation = "0.5, 0.6" -ChainRulesCore = "1.14" -ForwardDiff = "0.10" -Krylov = "0.8, 0.9" -LinearAlgebra = "1.6" -LinearOperators = "2.2" -PrecompileTools = "1.1" -Requires = "1.3" -SimpleUnPack = "1.1" -StaticArrays = "1.6" -Zygote = "0.6" -julia = "1.6" +ChainRulesCore = "1.23.0" +EnzymeCore = "0.6.5" +ForwardDiff = "0.10.36" +Krylov = "0.9.5" +LinearAlgebra = "1.10" +LinearOperators = "2.7.0" +julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" @@ -54,11 +45,28 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "ReverseDiff", "SparseArrays", "StaticArrays", "Test", "Zygote"] +test = [ + "Aqua", + "ChainRulesCore", + "ChainRulesTestUtils", + "ComponentArrays", + "Documenter", + "FiniteDifferences", + "ForwardDiff", + "JET", + "JuliaFormatter", + "NLsolve", + "Optim", + "Pkg", + "Random", + "SparseArrays", + "StaticArrays", + "Test", + "Zygote", +] diff --git a/benchmark/Project.toml b/benchmark/Project.toml deleted file mode 100644 index 9c1fae3..0000000 --- a/benchmark/Project.toml +++ /dev/null @@ -1,13 +0,0 @@ -[deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmark/analysis.jl b/benchmark/analysis.jl deleted file mode 100644 index be425e1..0000000 --- a/benchmark/analysis.jl +++ /dev/null @@ -1,168 +0,0 @@ - -## Benchmark analysis - -using CSV -using DataFrames -using Plots - -function export_results( - results; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, - path=joinpath(@__DIR__, "benchmark_results.csv"), -) - min_results = minimum(results) - - data = DataFrame() - - for sc in scenario_symbols, - ls in linear_solver_symbols, - ba in backend_symbols, - cb in conditions_backend_symbols, - is in input_sizes, - os in output_sizes - - try - perf = min_results[sc][ls][ba][cb][is][os] - @unpack time, gctime, memory, allocs = perf - row = (; - scenario=sc, - linear_solver=ls, - backend=ba, - conditions_backend=cb, - input_size=is, - output_size=os, - time, - gctime, - memory, - allocs, - ) - push!(data, row) - catch KeyError - nothing - end - end - - if !isnothing(path) - open(path, "w") do file - CSV.write(file, data) - end - end - return data -end - -function plot_results( - data; - scenario::Symbol, - linear_solver_symbols=unique(data[!, :linear_solver]), - backend_symbols=unique(data[!, :backend]), - conditions_backend_symbols=unique(data[!, :conditions_backend]), - input_size=nothing, - output_size=nothing, - path=joinpath( - @__DIR__, - "benchmark_plot_$(scenario)_$(linear_solver_symbols)_$(backend_symbols)_$(conditions_backend_symbols)_$(input_size)_$(output_size).png", - ), -) - pl = plot(; - size=(800, 400), - ylabel="Time [s] (log)", - legendtitle="lin. solver / AD / cond. AD", - legend=:outerright, - xaxis=:log10, - yaxis=:log10, - margin=5Plots.mm, - legendtitlefontsize=7, - legendfontsize=6, - ) - - data = subset(data, :scenario => _col -> _col .== scenario) - - if isnothing(input_size) && isnothing(output_size) - error("Cannot plot if neither input nor output size is fixed") - elseif !isnothing(input_size) && !isnothing(output_size) - error("Cannot plot if both input and output size are fixed") - elseif !isnothing(input_size) - plot!( - pl; - xlabel="Output dimension (log)", - title="Implicit diff. - $scenario - input size $input_size", - ) - data = subset(data, :input_size => _col -> _col .== Ref(input_size)) - else - plot!( - pl; - xlabel="Input dimension (log)", - title="Implicit diff. - $scenario - output size $output_size", - ) - data = subset(data, :output_size => _col -> _col .== Ref(output_size)) - end - - for ls in linear_solver_symbols, ba in backend_symbols, cb in conditions_backend_symbols - filtered_data = subset( - data, - :linear_solver => _col -> _col .== ls, - :backend => _col -> _col .== ba, - :conditions_backend => _col -> _col .== cb, - ) - - if !isempty(filtered_data) - x = nothing - if !isnothing(output_size) - x = map(prod, filtered_data[!, :input_size]) - elseif !isnothing(output_size) - x = map(prod, filtered_data[!, :output_size]) - end - y = filtered_data[!, :time] ./ 1e9 - plot!( - pl, - x, - y; - linestyle=:auto, - markershape=:auto, - label="$ls / $ba / $(cb == :nothing ? ba : cb)", - ) - end - end - - if !isnothing(path) - savefig(pl, path) - end - return pl -end - -# results = BenchmarkTools.run(SUITE; verbose=true, evals=1, seconds=1) - -# data = export_results( -# results; -# scenario_symbols, -# linear_solver_symbols, -# backend_symbols, -# conditions_backend_symbols, -# input_sizes, -# output_sizes, -# ) - -# plot_results(data; scenario=:pullback, input_size=(1,)) -# plot_results(data; scenario=:pullback, input_size=(10,)) -# plot_results(data; scenario=:pullback, input_size=(100,)) -# plot_results(data; scenario=:pullback, input_size=(1000,)) - -# plot_results(data; scenario=:pushforward, output_size=(1,)) -# plot_results(data; scenario=:pushforward, output_size=(10,)) -# plot_results(data; scenario=:pushforward, output_size=(100,)) -# plot_results(data; scenario=:pushforward, output_size=(1000,)) - -# plot_results(data; scenario=:rrule, input_size=(1,)) -# plot_results(data; scenario=:rrule, input_size=(10,)) -# plot_results(data; scenario=:rrule, input_size=(100,)) -# plot_results(data; scenario=:rrule, input_size=(1000,)) - -# plot_results(data; scenario=:jacobian, input_size=(1,)) -# plot_results(data; scenario=:jacobian, input_size=(10,)) -# plot_results(data; scenario=:jacobian, input_size=(100,)) -# plot_results(data; scenario=:jacobian, input_size=(1000,)) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl deleted file mode 100644 index eefeb23..0000000 --- a/benchmark/benchmarks.jl +++ /dev/null @@ -1,143 +0,0 @@ -using Pkg -Pkg.activate(@__DIR__) - -using AbstractDifferentiation: ForwardDiffBackend -using BenchmarkTools -using ForwardDiff: ForwardDiff, Dual -using ProgressMeter -using Random -using SimpleUnPack -using Zygote: Zygote - -using ImplicitDifferentiation - -## Benchmark definition - -forward(x; output_size) = fill(sqrt(sum(x)), output_size...) -conditions(x, y; output_size) = abs2.(y) .- sum(x) - -function get_linear_solver(linear_solver_symbol::Symbol) - if linear_solver_symbol == :direct - return DirectLinearSolver() - elseif linear_solver_symbol == :iterative - return IterativeLinearSolver() - end -end - -function get_conditions_backend(conditions_backend_symbol::Symbol) - if conditions_backend_symbol == :nothing - return nothing - elseif conditions_backend_symbol == :ForwardDiff - return ForwardDiffBackend() - end -end - -function create_benchmarkable(; - scenario_symbol, - linear_solver_symbol, - backend_symbol, - conditions_backend_symbol, - input_size, - output_size, -) - linear_solver = get_linear_solver(linear_solver_symbol) - conditions_backend = get_conditions_backend(conditions_backend_symbol) - - if scenario_symbol == :jacobian && prod(input_size) * prod(output_size) >= 10^5 - return nothing - end - - x = rand(input_size...) - implicit = ImplicitFunction( - x -> forward(x; output_size), - (x, y) -> conditions(x, y; output_size); - linear_solver, - conditions_backend, - ) - - dx = similar(x) - dx .= one(eltype(x)) - x_and_dx = Dual.(x, dx) - y = implicit(x) - dy = similar(y) - dy .= one(eltype(y)) - - if scenario_symbol == :jacobian && backend_symbol == :ForwardDiff - return @benchmarkable ForwardDiff.jacobian($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :jacobian && backend_symbol == :Zygote - return @benchmarkable Zygote.jacobian($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :rrule && backend_symbol == :Zygote - return @benchmarkable Zygote.pullback($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :pullback && backend_symbol == :Zygote - _, back = Zygote.pullback(implicit, x) - return @benchmarkable ($back)($dy) seconds = 1 samples = 100 - elseif scenario_symbol == :pushforward && backend_symbol == :ForwardDiff - return @benchmarkable $implicit($x_and_dx) seconds = 1 samples = 100 - else - return nothing - end -end - -function make_suite(; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, -) - SUITE = BenchmarkGroup() - - for sc in scenario_symbols, - ls in linear_solver_symbols, - ba in backend_symbols, - cb in conditions_backend_symbols, - is in input_sizes, - os in output_sizes - - bench = create_benchmarkable(; - scenario_symbol=sc, - linear_solver_symbol=ls, - backend_symbol=ba, - conditions_backend_symbol=cb, - input_size=is, - output_size=os, - ) - - isnothing(bench) && continue - - if !haskey(SUITE, sc) - SUITE[sc] = BenchmarkGroup() - end - if !haskey(SUITE[sc], ls) - SUITE[sc][ls] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls], ba) - SUITE[sc][ls][ba] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls][ba], cb) - SUITE[sc][ls][ba][cb] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls][ba][cb], is) - SUITE[sc][ls][ba][cb][is] = BenchmarkGroup() - end - SUITE[sc][ls][ba][cb][is][os] = bench - end - return SUITE -end - -scenario_symbols = (:jacobian, :rrule, :pullback, :pushforward) -linear_solver_symbols = (:direct, :iterative) -backend_symbols = (:Zygote, :ForwardDiff) -conditions_backend_symbols = (:nothing, :ForwardDiff) -input_sizes = [(n,) for n in floor.(Int, 10 .^ (0:1:3))]; -output_sizes = [(n,) for n in floor.(Int, 10 .^ (0:1:3))]; - -SUITE = make_suite(; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, -) diff --git a/benchmark/judge.jl b/benchmark/judge.jl deleted file mode 100644 index c361ba6..0000000 --- a/benchmark/judge.jl +++ /dev/null @@ -1,13 +0,0 @@ -using BenchmarkTools -using PkgBenchmark - -pkg = dirname(@__DIR__) # this git repo -baseline = "112d549" # commit id -target = "82242b9" # commit id - -results_baseline = benchmarkpkg(pkg, baseline; verbose=true, retune=false) -results_target = benchmarkpkg(pkg, target; verbose=true, retune=false) - -judgement = judge(results_target, results_baseline, minimum) - -export_markdown(joinpath(@__DIR__, "benchmark_judgement.md"), judgement) diff --git a/docs/Project.toml b/docs/Project.toml index b070058..dbd70ee 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,4 @@ [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -15,4 +13,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Documenter = "1.3" \ No newline at end of file +Documenter = "1.3" diff --git a/examples/0_intro.jl b/examples/0_intro.jl index ba5d3f3..262738d 100644 --- a/examples/0_intro.jl +++ b/examples/0_intro.jl @@ -4,7 +4,6 @@ We explain the basics of our package on a simple function that is not amenable to naive automatic differentiation. =# -using ChainRulesCore #src using ForwardDiff using ImplicitDifferentiation using JET #src @@ -139,23 +138,3 @@ And so does Zygote.jl. Hurray! Zygote.jacobian(implicit, x)[1] ≈ J @test Zygote.jacobian(implicit, x)[1] ≈ J #src - -# ## Second derivative - -#= -We can even go higher-order by mixing the two packages (forward-over-reverse mode). -The only technical requirement is to switch the linear solver to something that can handle dual numbers: -=# - -implicit_higher_order = ImplicitFunction( - forward, conditions; linear_solver=DirectLinearSolver() -) - -#= -Then the Jacobian itself is differentiable. -=# - -h = rand(2) -J_Z(t) = Zygote.jacobian(implicit_higher_order, x .+ t .* h)[1] -ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) -@test ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src diff --git a/examples/1_basic.jl b/examples/1_basic.jl index 942e68f..85bfb66 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -74,7 +74,7 @@ end We now have all the ingredients to construct our implicit function. =# -implicit_optim = ImplicitFunction(forward_optim, conditions_optim) +implicit_optim = ImplicitFunction(forward=forward_optim, conditions=conditions_optim) # And indeed, it behaves as it should when we call it: @@ -90,7 +90,7 @@ ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) In this instance, we could use ForwardDiff.jl directly on the solver, but it returns the wrong result (not sure why). =# -ForwardDiff.jacobian(_x -> forward_optim(x; method=LBFGS()), x) +ForwardDiff.jacobian(_x -> forward_optim(_x; method=LBFGS()), x) # Reverse mode autodiff @@ -102,7 +102,7 @@ In this instance, we cannot use Zygote.jl directly on the solver (due to unsuppo =# try - Zygote.jacobian(_x -> forward_optim(x; method=LBFGS()), x)[1] + Zygote.jacobian(_x -> forward_optim(_x; method=LBFGS()), x)[1] catch e e end diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index f4532f6..13c1026 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -55,7 +55,7 @@ Krylov.ktypeof(::ComponentVector{T,V}) where {T,V} = V # Now we're good to go. -a, b, m = rand(2), rand(3), 7 +a, b, m = rand(2), rand(3), 7.0 x = ComponentVector(; a=a, b=b, m=m) implicit_components(x) diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index cd24173..26e2d66 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -1,95 +1,35 @@ module ImplicitDifferentiationChainRulesCoreExt -using AbstractDifferentiation: AbstractBackend, ReverseRuleConfigBackend, ruleconfig -using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig +using ADTypes: AbstractADType, AutoChainRules +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig using ChainRulesCore: rrule, rrule_via_ad, unthunk, @not_implemented -using ImplicitDifferentiation: ImplicitDifferentiation -using ImplicitDifferentiation: ImplicitFunction -using ImplicitDifferentiation: conditions_reverse_operators -using ImplicitDifferentiation: get_output, presolve, solve +using ImplicitDifferentiation: ImplicitFunction, build_Aᵀ, build_Bᵀ, get_output using LinearAlgebra: mul! -using SimpleUnPack: @unpack -""" - rrule(rc, implicit, x, args...; kwargs...) - -Custom reverse rule for an [`ImplicitFunction`](@ref), to ensure compatibility with reverse mode autodiff. - -This is only available if ChainRulesCore.jl is loaded (extension), except on Julia < 1.9 where it is always available. - -We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`. -Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`. -""" function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::X, args...; kwargs... -) where {R,X<:AbstractArray{R}} - linear_solver = implicit.linear_solver + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args...; kwargs... +) y_or_yz = implicit(x, args...; kwargs...) - backend = reverse_conditions_backend(rc, implicit) - Aᵀ_vec, Bᵀ_vec = conditions_reverse_operators( - backend, implicit, x, y_or_yz, args; kwargs - ) - Aᵀ_vec_presolved = presolve(linear_solver, Aᵀ_vec, get_output(y_or_yz)) - - byproduct = y_or_yz isa Tuple - nbargs = length(args) - implicit_pullback = ImplicitPullback{byproduct,nbargs}( - Aᵀ_vec_presolved, Bᵀ_vec, linear_solver, vec(x), size(x) - ) - return y_or_yz, implicit_pullback -end - -function reverse_conditions_backend( - rc::RuleConfig, ::ImplicitFunction{F,C,L,Nothing} -) where {F,C,L} - return ReverseRuleConfigBackend(rc) -end - -function reverse_conditions_backend( - ::RuleConfig, implicit::ImplicitFunction{F,C,L,<:AbstractBackend} -) where {F,C,L} - return implicit.conditions_backend -end -struct ImplicitPullback{byproduct,nbargs,A,B,L,X,N} - Aᵀ_vec::A - Bᵀ_vec::B - linear_solver::L - x_vec::X - x_size::NTuple{N,Int} + suggested_backend = AutoChainRules(rc) + Aᵀ = build_Aᵀ(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + Bᵀ = build_Bᵀ(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + project_x = ProjectTo(x) - function ImplicitPullback{byproduct,nbargs}( - Aᵀ_vec::A, Bᵀ_vec::B, linear_solver::L, x_vec::X, x_size::NTuple{N,Int} - ) where {byproduct,nbargs,A,B,L,X,N} - return new{byproduct,nbargs,A,B,L,X,N}(Aᵀ_vec, Bᵀ_vec, linear_solver, x_vec, x_size) + function implicit_pullback(dy_or_dydz) + dy = get_output(unthunk(dy_or_dydz)) + dc = implicit.linear_solver(Aᵀ, -dy) + dx = Bᵀ * dc + return (NoTangent(), project_x(dx), ntuple(unimplemented_tangent, length(args))...) end -end -function (implicit_pullback::ImplicitPullback{true})((dy, dz)) - return apply_implicit_pullback(implicit_pullback, dy) -end - -function (implicit_pullback::ImplicitPullback{false})(dy) - return apply_implicit_pullback(implicit_pullback, dy) + return y_or_yz, implicit_pullback end function unimplemented_tangent(_) return @not_implemented( - "Tangents for positional arguments of an ImplicitFunction beyond x (the first one) are not implemented" + "Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented" ) end -function apply_implicit_pullback( - implicit_pullback::ImplicitPullback{byproduct,nbargs}, dy_thunk -) where {byproduct,nbargs} - @unpack Aᵀ_vec, Bᵀ_vec, linear_solver, x_vec, x_size = implicit_pullback - dy = unthunk(dy_thunk) - dy_vec = vec(dy) - dc_vec = solve(linear_solver, Aᵀ_vec, -dy_vec) - dx_vec = similar(x_vec) - mul!(dx_vec, Bᵀ_vec, dc_vec) - dx = reshape(dx_vec, x_size) - return (NoTangent(), dx, ntuple(unimplemented_tangent, nbargs)...) -end - end diff --git a/ext/ImplicitDifferentiationEnzymeCoreExt.jl b/ext/ImplicitDifferentiationEnzymeCoreExt.jl new file mode 100644 index 0000000..96580e5 --- /dev/null +++ b/ext/ImplicitDifferentiationEnzymeCoreExt.jl @@ -0,0 +1,5 @@ +module ImplicitDifferentiationEnzymeCoreExt + +using EnzymeCore + +end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index cb83390..3169638 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -1,53 +1,30 @@ module ImplicitDifferentiationForwardDiffExt -@static if isdefined(Base, :get_extension) - using ForwardDiff: Dual, Partials, jacobian, partials, value -else - using ..ForwardDiff: Dual, Partials, jacobian, partials, value -end - -using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend -using ImplicitDifferentiation: ImplicitFunction, DirectLinearSolver, IterativeLinearSolver -using ImplicitDifferentiation: conditions_forward_operators -using ImplicitDifferentiation: get_output, get_byproduct, presolve, solve -using ImplicitDifferentiation: identity_break_autodiff -using LinearAlgebra: mul! -using PrecompileTools: @compile_workload - -""" - implicit(x_and_dx::AbstractArray{<:Dual}, args...; kwargs...) +using ADTypes: AutoForwardDiff +using ForwardDiff: Chunk, Dual, Partials, jacobian, partials, value +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, get_byproduct, get_output -Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff. +chunksize(::Chunk{N}) where {N} = N -This is only available if ForwardDiff.jl is loaded (extension). - -We compute the Jacobian-vector product `Jv` by solving `Au = -Bv` and setting `Jv = u`. -Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`. -""" function (implicit::ImplicitFunction)( - x_and_dx::AbstractArray{Dual{T,R,N}}, args...; kwargs... + x_and_dx::AbstractVector{Dual{T,R,N}}, args...; kwargs... ) where {T,R,N} - linear_solver = implicit.linear_solver - x = value.(x_and_dx) y_or_yz = implicit(x, args...; kwargs...) y = get_output(y_or_yz) - y_vec = vec(y) - backend = forward_conditions_backend(implicit) - A_vec, B_vec = conditions_forward_operators(backend, implicit, x, y_or_yz, args; kwargs) - A_vec_presolved = presolve(linear_solver, A_vec, y) + suggested_backend = AutoForwardDiff(; chunksize=chunksize(Chunk(x))) + A = build_A(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + B = build_B(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) dy = ntuple(Val(N)) do k dₖx = partials.(x_and_dx, k) - dₖx_vec = vec(dₖx) - dₖc_vec = similar(y_vec) - mul!(dₖc_vec, B_vec, dₖx_vec) - dₖy_vec = solve(implicit.linear_solver, A_vec_presolved, -dₖc_vec) - reshape(dₖy_vec, size(y)) + dₖc = B * dₖx + dₖy = implicit.linear_solver(A, -dₖc) + return dₖy end - y_and_dy = map(eachindex(IndexCartesian(), y)) do i + y_and_dy = map(eachindex(y)) do i Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N)))) end @@ -58,25 +35,4 @@ function (implicit::ImplicitFunction)( end end -function forward_conditions_backend(::ImplicitFunction{F,C,L,Nothing}) where {F,C,L} - return ForwardDiffBackend() -end - -function forward_conditions_backend( - implicit::ImplicitFunction{F,C,L,<:AbstractBackend} -) where {F,C,L} - return implicit.conditions_backend -end - -@compile_workload begin - forward(x) = sqrt.(identity_break_autodiff(x)) - conditions(x, y) = y .^ 2 .- x - for linear_solver in (DirectLinearSolver(), IterativeLinearSolver()) - implicit = ImplicitFunction(forward, conditions; linear_solver) - x = rand(2) - implicit(x) - jacobian(implicit, x) - end -end - end diff --git a/ext/ImplicitDifferentiationStaticArraysExt.jl b/ext/ImplicitDifferentiationStaticArraysExt.jl deleted file mode 100644 index 90f3998..0000000 --- a/ext/ImplicitDifferentiationStaticArraysExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module ImplicitDifferentiationStaticArraysExt - -@static if isdefined(Base, :get_extension) - using StaticArrays: StaticArray, MMatrix, StaticVector -else - using ..StaticArrays: StaticArray, MMatrix, StaticVector -end - -import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver -using LinearAlgebra: lu, mul! - -function ImplicitDifferentiation.presolve(::DirectLinearSolver, A, y::StaticArray) - T = eltype(A) - m = length(y) - A_static = zero(MMatrix{m,m,T}) - v = vec(similar(y, T)) - for i in axes(A_static, 2) - v .= zero(T) - v[i] = one(T) - mul!(@view(A_static[:, i]), A, v) - end - return lu(A_static) -end - -end diff --git a/ext/ImplicitDifferentiationZygoteExt.jl b/ext/ImplicitDifferentiationZygoteExt.jl deleted file mode 100644 index fb64e4e..0000000 --- a/ext/ImplicitDifferentiationZygoteExt.jl +++ /dev/null @@ -1,24 +0,0 @@ -module ImplicitDifferentiationZygoteExt - -@static if isdefined(Base, :get_extension) - using Zygote: jacobian -else - using ..Zygote: jacobian -end - -using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff -using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver -using PrecompileTools: @compile_workload - -@compile_workload begin - forward(x) = sqrt.(identity_break_autodiff(x)) - conditions(x, y) = y .^ 2 .- x - for linear_solver in (DirectLinearSolver(), IterativeLinearSolver()) - implicit = ImplicitFunction(forward, conditions; linear_solver) - x = rand(2) - implicit(x) - jacobian(implicit, x) - end -end - -end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 0fc55d8..88cd0e8 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -7,38 +7,20 @@ Its main export is the type [`ImplicitFunction`](@ref). """ module ImplicitDifferentiation -using AbstractDifferentiation: AbstractBackend -using AbstractDifferentiation: pushforward_function, pullback_function, jacobian +using ADTypes: AbstractADType +using DifferentiationInterface: + jacobian, + prepare_pushforward, + prepare_pullback, + pushforward!!, + value_and_pullback!!_split using Krylov: gmres -using LinearOperators: LinearOperators, LinearOperator -using LinearAlgebra: issuccess, lu -using PrecompileTools: @compile_workload -using Requires: @require -using SimpleUnPack: @unpack +using LinearOperators: LinearOperator +using LinearAlgebra: factorize -include("utils.jl") -include("linear_solver.jl") include("implicit_function.jl") include("operators.jl") export ImplicitFunction -export AbstractLinearSolver, IterativeLinearSolver, DirectLinearSolver - -@static if !isdefined(Base, :get_extension) - # Loaded unconditionally on Julia < 1.9 - include("../ext/ImplicitDifferentiationChainRulesCoreExt.jl") - function __init__() - # Loaded conditionally on Julia < 1.9 - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - include("../ext/ImplicitDifferentiationForwardDiffExt.jl") - end - @require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin - include("../ext/ImplicitDifferentiationStaticArraysExt.jl") - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/ImplicitDifferentiationZygoteExt.jl") - end - end -end end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index fd73f0a..3582a89 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,5 +1,5 @@ """ - ImplicitFunction{F,C,L,B} + ImplicitFunction Wrapper for an implicit function defined by a forward mapping `y` and a set of conditions `c`. @@ -12,10 +12,11 @@ This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = # Fields -- `forward::F`: a callable, does not need to be compatible with automatic differentiation -- `conditions::C`: a callable, must be compatible with automatic differentiation -- `linear_solver::L`: a subtype of `AbstractLinearSolver`, defines how the linear system will be solved -- `conditions_backend::B`: either `nothing` or a subtype of `AbstractDifferentiation.AbstractBackend`, defines how the conditions will be differentiated within the implicit function theorem +- `forward`: a callable, does not need to be compatible with automatic differentiation +- `conditions`: a callable, must be compatible with automatic differentiation +- `linear_solver`: a subtype of `AbstractLinearSolver`, defines how the linear system will be solved +- `conditions_x_backend`: either `nothing` or a subtype of `ADTypes.AbstractADType`, defines how the conditions will be differentiated with respect to the first argument `x` +- `conditions_y_backend`: same for the second argument `y` There are two possible signatures for `forward` and `conditions`, which must be consistent with one another: @@ -29,33 +30,26 @@ The positional arguments `args...` and keyword arguments `kwargs...` must be the !!! warning "Warning" The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes. """ -struct ImplicitFunction{F,C,L<:AbstractLinearSolver,B<:Union{Nothing,AbstractBackend}} +@kwdef struct ImplicitFunction{ + F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType} +} forward::F conditions::C - linear_solver::L - conditions_backend::B + linear_solver::L = first ∘ gmres + conditions_x_backend::B1 = nothing + conditions_y_backend::B2 = nothing end -""" - ImplicitFunction( - forward, - conditions; - linear_solver=IterativeLinearSolver(), - conditions_backend=nothing, - ) - -Construct an `ImplicitFunction` with default parameters. -""" -function ImplicitFunction( - forward, conditions; linear_solver=IterativeLinearSolver(), conditions_backend=nothing -) - return ImplicitFunction(forward, conditions, linear_solver, conditions_backend) +function ImplicitFunction(forward, conditions; kwargs...) + return ImplicitFunction(; forward, conditions, kwargs...) end function Base.show(io::IO, implicit::ImplicitFunction) - @unpack forward, conditions, linear_solver, conditions_backend = implicit + (; forward, conditions, linear_solver, conditions_x_backend, conditions_y_backend) = + implicit return print( - io, "ImplicitFunction($forward, $conditions, $linear_solver, $conditions_backend)" + io, + "ImplicitFunction($forward, $conditions, $linear_solver, $conditions_x_backend, $conditions_y_backend)", ) end @@ -64,24 +58,19 @@ end Return `implicit.forward(x, args...; kwargs...)`, which can be either an array `y` or a tuple `(y, z)`. -This call is differentiable. +This call is differentiable (except for `z`). """ -function (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...) +function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) y_or_yz = implicit.forward(x, args...; kwargs...) - valid = ( - y_or_yz isa AbstractArray || # - (y_or_yz isa Tuple && length(y_or_yz) == 2 && y_or_yz[1] isa AbstractArray) - ) - if !valid - throw( - DimensionMismatch( - "The forward mapping must return an array `y` or a tuple `(y, z)` where `y` is an array", - ), + if !(y_or_yz isa Union{AbstractArray,Tuple{<:AbstractVector,<:Any}}) + error( + "The forward mapping must return a vector `y` or a tuple `(y, z)` where `y` is a vector", ) end return y_or_yz end -get_output(y::AbstractArray) = y -get_output(yz::Tuple) = yz[1] -get_byproduct(yz::Tuple) = yz[2] +get_output(y::AbstractVector) = y +get_byproduct(y::AbstractVector) = error("No byproduct") +get_output(yz::Tuple{<:AbstractVector,<:Any}) = yz[1] +get_byproduct(yz::Tuple{<:AbstractVector,<:Any}) = yz[2] diff --git a/src/linear_solver.jl b/src/linear_solver.jl deleted file mode 100644 index fd39398..0000000 --- a/src/linear_solver.jl +++ /dev/null @@ -1,78 +0,0 @@ -""" - AbstractLinearSolver - -All linear solvers used within an `ImplicitFunction` must satisfy this interface. - -It can be useful to roll out your own solver if you need more fine-grained control on convergence / speed / behavior in case of singularity. -Check out the source code of `IterativeLinearSolver` and `DirectLinearSolver` for implementation examples. - -# Required methods - -- `presolve(linear_solver, A, y)`: Returns a matrix-like object `A` for which it is cheaper to solve several linear systems with different vectors `b` of type similar to `y` (a typical example would be to perform LU factorization). -- `solve(linear_solver, A, b)`: Returns a vector `x` satisfying `Ax = b`. If the linear system has not been solved to satisfaction, every element of `x` should be a `NaN` of the appropriate floating point type. -""" -abstract type AbstractLinearSolver end - -""" - IterativeLinearSolver - -An implementation of `AbstractLinearSolver` using `Krylov.gmres`, set as the default for `ImplicitFunction`. - -# Fields - -- `verbose::Bool`: Whether to display a warning when the solver fails and returns `NaN`s (defaults to `true`) -- `accept_inconsistent::Bool`: Whether to accept approximate least squares solutions for inconsistent systems, or fail and return `NaN`s (defaults to `false`) - -!!! note - If you find that your implicit gradients contains `NaN`s, try using this solver with `accept_inconsistent=true`. - However, beware that the implicit function theorem does not cover the case of inconsistent linear systems `AJ = B`, so it is unclear what the result will mean. -""" -Base.@kwdef struct IterativeLinearSolver <: AbstractLinearSolver - verbose::Bool = true - accept_inconsistent::Bool = false -end - -presolve(::IterativeLinearSolver, A, y) = A - -function solve(sol::IterativeLinearSolver, A, b) - x, stats = gmres(A, b) - if sol.accept_inconsistent - success = stats.solved || stats.inconsistent - else - success = stats.solved && !stats.inconsistent - end - if !success - if sol.verbose - @warn "IterativeLinearSolver failed, result contains NaNs" - @show stats - end - x .= NaN - end - return x -end - -""" - DirectLinearSolver - -An implementation of `AbstractLinearSolver` using the built-in backslash operator. - -# Fields - -- `verbose::Bool`: Whether to throw a warning when the solver fails (defaults to `true`) -""" -Base.@kwdef struct DirectLinearSolver <: AbstractLinearSolver - verbose::Bool = true -end - -function presolve(::DirectLinearSolver, A, y) - return lu(Matrix(A); check=false) -end - -function solve(sol::DirectLinearSolver, A_lu, b) - x = A_lu \ b - if !issuccess(A_lu) - sol.verbose && @warn "DirectLinearSolver failed, result contains NaNs" - x .= NaN - end - return x -end diff --git a/src/operators.jl b/src/operators.jl index 4b2fee7..9335f09 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1,115 +1,194 @@ -## Forward +## Partial conditions -function conditions_pushforwards( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - y::AbstractArray, - args; - kwargs, -) - conditions = implicit.conditions - pfA = only ∘ pushforward_function(ba, _y -> conditions(x, _y, args...; kwargs...), y) - pfB = only ∘ pushforward_function(ba, _x -> conditions(_x, y, args...; kwargs...), x) - return pfA, pfB +struct ConditionsXNoByproduct{C,Y,A,K} + conditions::C + y::Y + args::A + kwargs::K end -function conditions_pushforwards( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - yz::Tuple, - args; - kwargs, -) - conditions = implicit.conditions - y, z = yz - pfA = only ∘ pushforward_function(ba, _y -> conditions(x, _y, z, args...; kwargs...), y) - pfB = only ∘ pushforward_function(ba, _x -> conditions(_x, y, z, args...; kwargs...), x) - return pfA, pfB +function (conditions_x_nobyproduct::ConditionsXNoByproduct)(x::AbstractVector) + (; conditions, y, args, kwargs) = conditions_x_nobyproduct + return conditions(x, y, args...; kwargs...) end -struct PushforwardProd!{F,N} - pushforward::F - size::NTuple{N,Int} +struct ConditionsYNoByproduct{C,X,A,K} + conditions::C + x::X + args::A + kwargs::K end -function (pfp::PushforwardProd!)(dc_vec::AbstractVector, dy_vec::AbstractVector) - dy = reshape(dy_vec, pfp.size) - dc = pfp.pushforward(dy) - return dc_vec .= vec(dc) +function (conditions_y_nobyproduct::ConditionsYNoByproduct)(y::AbstractVector) + (; conditions, x, args, kwargs) = conditions_y_nobyproduct + return conditions(x, y, args...; kwargs...) end -function pushforwards_to_operators(x::AbstractArray, y::AbstractArray, pfA, pfB) - n, m = length(x), length(y) - A_vec = LinearOperator(eltype(y), m, m, false, false, PushforwardProd!(pfA, size(y))) - B_vec = LinearOperator(eltype(x), m, n, false, false, PushforwardProd!(pfB, size(x))) - return A_vec, B_vec +struct ConditionsXByproduct{C,Y,Z,A,K} + conditions::C + y::Y + z::Z + args::A + kwargs::K end -function conditions_forward_operators( - backend::AbstractBackend, implicit::ImplicitFunction, x, y_or_yz, args; kwargs -) +function (conditions_x_byproduct::ConditionsXByproduct)(x::AbstractVector) + (; conditions, y, z, args, kwargs) = conditions_x_byproduct + return conditions(x, y, z, args...; kwargs...) +end + +struct ConditionsYByproduct{C,X,Z,A,K} + conditions::C + x::X + z::Z + args::A + kwargs::K +end + +function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector) + (; conditions, x, z, args, kwargs) = conditions_y_byproduct + return conditions(x, y, z, args...; kwargs...) +end + +function ConditionsX(conditions, x, y_or_yz, args...; kwargs...) y = get_output(y_or_yz) - pfA, pfB = conditions_pushforwards(backend, implicit, x, y_or_yz, args; kwargs) - A_vec, B_vec = pushforwards_to_operators(x, y, pfA, pfB) - return A_vec, B_vec + if y_or_yz isa Tuple + z = get_byproduct(y_or_yz) + return ConditionsXByproduct(conditions, y, z, args, kwargs) + else + return ConditionsXNoByproduct(conditions, y, args, kwargs) + end +end + +function ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if y_or_yz isa Tuple + z = get_byproduct(y_or_yz) + return ConditionsYByproduct(conditions, x, z, args, kwargs) + else + return ConditionsYNoByproduct(conditions, x, args, kwargs) + end end -## Reverse +## Lazy operators -function conditions_pullbacks( - ba::AbstractBackend, +function build_A( implicit::ImplicitFunction, - x::AbstractArray, - y::AbstractArray, - args; - kwargs, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., ) - conditions = implicit.conditions - pbAᵀ = only ∘ pullback_function(ba, _y -> conditions(x, _y, args...; kwargs...), y) - pbBᵀ = only ∘ pullback_function(ba, _x -> conditions(_x, y, args...; kwargs...), x) - return pbAᵀ, pbBᵀ + (; conditions, linear_solver, conditions_y_backend) = implicit + y = get_output(y_or_yz) + n, m = length(x), length(y) + back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend + cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + A = factorize(jacobian(cond_y, back_y, y)) + else + extras = prepare_pushforward(cond_y, back_y, y) + A = LinearOperator( + eltype(y), + m, + m, + false, + false, + (res, v) -> res .= pushforward!!(cond_y, res, back_y, y, v, extras), + typeof(y), + ) + end + return A end -function conditions_pullbacks( - ba::AbstractBackend, +function build_Aᵀ( implicit::ImplicitFunction, - x::AbstractArray, - yz::Tuple, - args; - kwargs, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., ) - conditions = implicit.conditions - y, z = yz - pbAᵀ = only ∘ pullback_function(ba, _y -> conditions(x, _y, z, args...; kwargs...), y) - pbBᵀ = only ∘ pullback_function(ba, _x -> conditions(_x, y, z, args...; kwargs...), x) - return pbAᵀ, pbBᵀ -end - -struct PullbackProd!{F,N} - pullback::F - size::NTuple{N,Int} -end - -function (pbp::PullbackProd!)(dy_vec::AbstractVector, dc_vec::AbstractVector) - dc = reshape(dc_vec, pbp.size) - dy = pbp.pullback(dc) - return dy_vec .= vec(dy) + (; conditions, linear_solver, conditions_y_backend) = implicit + y = get_output(y_or_yz) + n, m = length(x), length(y) + back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend + cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + Aᵀ = factorize(transpose(jacobian(cond_y, back_y, y))) + else + extras = prepare_pullback(cond_y, back_y, y) + _, pullbackfunc!! = value_and_pullback!!_split(cond_y, back_y, y, extras) + Aᵀ = LinearOperator( + eltype(y), + m, + m, + false, + false, + (res, v) -> res .= pullbackfunc!!(res, v), + typeof(y), + ) + end + return Aᵀ end -function pullbacks_to_operators(x::AbstractArray, y::AbstractArray, pbAᵀ, pbBᵀ) +function build_B( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., +) + (; conditions, linear_solver, conditions_x_backend) = implicit + y = get_output(y_or_yz) n, m = length(x), length(y) - Aᵀ_vec = LinearOperator(eltype(y), m, m, false, false, PullbackProd!(pbAᵀ, size(y))) - Bᵀ_vec = LinearOperator(eltype(y), n, m, false, false, PullbackProd!(pbBᵀ, size(y))) - return Aᵀ_vec, Bᵀ_vec + back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend + cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + B = factorize(transpose(jacobian(cond_x, back_x, x))) + else + extras = prepare_pushforward(cond_x, back_x, x) + B = LinearOperator( + eltype(y), + m, + n, + false, + false, + (res, v) -> res .= pushforward!!(cond_x, res, back_x, x, v, extras), + typeof(x), + ) + end + return B end -function conditions_reverse_operators( - backend::AbstractBackend, implicit::ImplicitFunction, x, y_or_yz, args; kwargs +function build_Bᵀ( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., ) + (; conditions, linear_solver, conditions_x_backend) = implicit y = get_output(y_or_yz) - pbAᵀ, pbBᵀ = conditions_pullbacks(backend, implicit, x, y_or_yz, args; kwargs) - Aᵀ_vec, Bᵀ_vec = pullbacks_to_operators(x, y, pbAᵀ, pbBᵀ) - return Aᵀ_vec, Bᵀ_vec + n, m = length(x), length(y) + back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend + cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + Bᵀ = factorize(transpose(jacobian(cond_x, back_x, x))) + else + extras = prepare_pullback(cond_x, back_x, x) + _, pullbackfunc!! = value_and_pullback!!_split(cond_x, back_x, x, extras) + Bᵀ = LinearOperator( + eltype(y), + n, + m, + false, + false, + (res, v) -> res .= pullbackfunc!!(res, v), + typeof(x), + ) + end + return Bᵀ end diff --git a/src/utils.jl b/src/utils.jl deleted file mode 100644 index 3c32c2a..0000000 --- a/src/utils.jl +++ /dev/null @@ -1,5 +0,0 @@ -function identity_break_autodiff(x) - a = [0.0] - a[1] = float(first(x)) - return x -end diff --git a/test/systematic.jl b/test/systematic.jl index 2f022e2..9dfe954 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -13,15 +13,6 @@ using StaticArrays using Test using Zygote: Zygote, ZygoteRuleConfig -@static if VERSION < v"1.9" - macro test_opt(x...) - return :() - end - macro test_call(x...) - return :() - end -end - Random.seed!(63); ## Utils From a3b89f636115dfd0fa4f577795b164cc638e55ee Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 9 Apr 2024 12:46:53 +0200 Subject: [PATCH 2/8] Fix more stuff --- Project.toml | 5 +- examples/1_basic.jl | 2 +- ext/ImplicitDifferentiationForwardDiffExt.jl | 2 +- src/ImplicitDifferentiation.jl | 2 +- src/implicit_function.jl | 7 +- src/operators.jl | 47 +++--- test/runtests.jl | 16 +-- test/systematic.jl | 143 +++++-------------- 8 files changed, 80 insertions(+), 144 deletions(-) diff --git a/Project.toml b/Project.toml index f00f7ae..bafe5b4 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -31,6 +30,7 @@ LinearOperators = "2.7.0" julia = "1.10" [extras] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" @@ -43,7 +43,6 @@ JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -52,6 +51,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = [ + "ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", @@ -63,7 +63,6 @@ test = [ "JuliaFormatter", "NLsolve", "Optim", - "Pkg", "Random", "SparseArrays", "StaticArrays", diff --git a/examples/1_basic.jl b/examples/1_basic.jl index 85bfb66..6928660 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -74,7 +74,7 @@ end We now have all the ingredients to construct our implicit function. =# -implicit_optim = ImplicitFunction(forward=forward_optim, conditions=conditions_optim) +implicit_optim = ImplicitFunction(; forward=forward_optim, conditions=conditions_optim) # And indeed, it behaves as it should when we call it: diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 3169638..6a899f0 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -13,7 +13,7 @@ function (implicit::ImplicitFunction)( y_or_yz = implicit(x, args...; kwargs...) y = get_output(y_or_yz) - suggested_backend = AutoForwardDiff(; chunksize=chunksize(Chunk(x))) + suggested_backend = AutoForwardDiff{1,Nothing}(nothing) A = build_A(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) B = build_B(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 88cd0e8..ddc6a75 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -16,7 +16,7 @@ using DifferentiationInterface: value_and_pullback!!_split using Krylov: gmres using LinearOperators: LinearOperator -using LinearAlgebra: factorize +using LinearAlgebra: factorize, lu include("implicit_function.jl") include("operators.jl") diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 3582a89..872b3aa 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -71,6 +71,7 @@ function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) end get_output(y::AbstractVector) = y -get_byproduct(y::AbstractVector) = error("No byproduct") -get_output(yz::Tuple{<:AbstractVector,<:Any}) = yz[1] -get_byproduct(yz::Tuple{<:AbstractVector,<:Any}) = yz[2] +get_byproduct(::AbstractVector) = error("No byproduct") + +get_output((y, z)) = y +get_byproduct((y, z)) = z diff --git a/src/operators.jl b/src/operators.jl index 9335f09..ed31dc7 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -71,6 +71,25 @@ end ## Lazy operators +struct PushforwardOperator!{F,B,X,E} + f::F + backend::B + x::X + extras::E +end + +function (po::PushforwardOperator!)(res, v) + return res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) +end + +struct PullbackOperator!{PB} + pullbackfunc!!::PB +end + +function (po::PullbackOperator!)(res, v) + return res .= po.pullbackfunc!!(res, v) +end + function build_A( implicit::ImplicitFunction, x::AbstractVector, @@ -85,7 +104,8 @@ function build_A( back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) - A = factorize(jacobian(cond_y, back_y, y)) + J = jacobian(cond_y, back_y, y) + A = lu(J) else extras = prepare_pushforward(cond_y, back_y, y) A = LinearOperator( @@ -94,7 +114,7 @@ function build_A( m, false, false, - (res, v) -> res .= pushforward!!(cond_y, res, back_y, y, v, extras), + PushforwardOperator!(cond_y, back_y, y, extras), typeof(y), ) end @@ -115,18 +135,13 @@ function build_Aᵀ( back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) - Aᵀ = factorize(transpose(jacobian(cond_y, back_y, y))) + Jᵀ = transpose(jacobian(cond_y, back_y, y)) + Aᵀ = lu(Jᵀ) else extras = prepare_pullback(cond_y, back_y, y) _, pullbackfunc!! = value_and_pullback!!_split(cond_y, back_y, y, extras) Aᵀ = LinearOperator( - eltype(y), - m, - m, - false, - false, - (res, v) -> res .= pullbackfunc!!(res, v), - typeof(y), + eltype(y), m, m, false, false, PullbackOperator!(pullbackfunc!!), typeof(y) ) end return Aᵀ @@ -146,7 +161,7 @@ function build_B( back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) - B = factorize(transpose(jacobian(cond_x, back_x, x))) + B = transpose(jacobian(cond_x, back_x, x)) else extras = prepare_pushforward(cond_x, back_x, x) B = LinearOperator( @@ -176,18 +191,12 @@ function build_Bᵀ( back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) - Bᵀ = factorize(transpose(jacobian(cond_x, back_x, x))) + Bᵀ = transpose(jacobian(cond_x, back_x, x)) else extras = prepare_pullback(cond_x, back_x, x) _, pullbackfunc!! = value_and_pullback!!_split(cond_x, back_x, x, extras) Bᵀ = LinearOperator( - eltype(y), - n, - m, - false, - false, - (res, v) -> res .= pullbackfunc!!(res, v), - typeof(x), + eltype(y), n, m, false, false, PullbackOperator!(pullbackfunc!!), typeof(x) ) end return Bᵀ diff --git a/test/runtests.jl b/test/runtests.jl index 37bc32d..bb662fa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,6 @@ using ForwardDiff: ForwardDiff using ImplicitDifferentiation using JET using JuliaFormatter -using Pkg using Random using Test using Zygote: Zygote @@ -34,22 +33,15 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") @testset verbose = true "ImplicitDifferentiation.jl" begin @testset verbose = false "Code quality (Aqua.jl)" begin - if VERSION >= v"1.9" - Aqua.test_all(ImplicitDifferentiation; ambiguities=false, deps_compat=false) - Aqua.test_deps_compat(ImplicitDifferentiation; check_extras=false) - end + Aqua.test_all( + ImplicitDifferentiation; ambiguities=false, deps_compat=(check_extras = false) + ) end @testset verbose = true "Formatting (JuliaFormatter.jl)" begin @test format(ImplicitDifferentiation; verbose=false, overwrite=false) end @testset verbose = true "Static checking (JET.jl)" begin - if VERSION >= v"1.9" - JET.test_package( - ImplicitDifferentiation; - target_defined_modules=true, - toplevel_logger=nothing, - ) - end + JET.test_package(ImplicitDifferentiation; target_defined_modules=true) end @testset verbose = false "Doctests (Documenter.jl)" begin doctest(ImplicitDifferentiation) diff --git a/test/systematic.jl b/test/systematic.jl index 9dfe954..271054f 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,11 +1,11 @@ -import AbstractDifferentiation as AD +using ADTypes using ChainRulesCore using ChainRulesTestUtils using ForwardDiff: ForwardDiff import ImplicitDifferentiation as ID -using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff -using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver +using ImplicitDifferentiation: ImplicitFunction using JET +using Krylov using LinearAlgebra using Random using SparseArrays @@ -15,41 +15,42 @@ using Zygote: Zygote, ZygoteRuleConfig Random.seed!(63); -## Utils - -change_shape(x::AbstractArray{T,3}) where {T} = x[:, :, 1] -change_shape(x::AbstractSparseArray) = x +function identity_break_autodiff(x) + a = [0.0] + a[1] = float(first(x)) + return x +end -function mysqrt(x::AbstractArray) +function mysqrt(x::AbstractVector) return identity_break_autodiff(sqrt.(abs.(x))) end ## Various signatures function make_implicit_sqrt(; kwargs...) - forward(x) = mysqrt(change_shape(x)) - conditions(x, y) = abs2.(y) .- abs.(change_shape(x)) + forward(x) = mysqrt(x) + conditions(x, y) = abs2.(y) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_byproduct(; kwargs...) - forward(x) = 1 * mysqrt(change_shape(x)), 1 - conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(change_shape(x)) + forward(x) = 1 * mysqrt(x), 1 + conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_args(; kwargs...) - forward(x, p::Integer) = p * mysqrt(change_shape(x)) - conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) + forward(x, p::Integer) = p * mysqrt(x) + conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_kwargs(; kwargs...) - forward(x; p::Integer) = p .* mysqrt(change_shape(x)) - conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) + forward(x; p::Integer) = p .* mysqrt(x) + conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end @@ -76,7 +77,7 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} imf3 = make_implicit_sqrt_args(; kwargs...) imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(change_shape(x)) + y_true = mysqrt(x) y1 = @inferred imf1(x) y2, z2 = @inferred imf2(x) y3 = @inferred imf3(x, 1) @@ -116,24 +117,15 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} imf3 = make_implicit_sqrt_args(; kwargs...) imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(change_shape(x)) + y_true = mysqrt(x) dx = similar(x) dx .= one(T) x_and_dx = ForwardDiff.Dual.(x, dx) - #= - TODO: fix AbstractDifferentiation.jl 0.6 - y_and_dy1 = @inferred imf1(x_and_dx) y_and_dy2, z2 = @inferred imf2(x_and_dx) y_and_dy3 = @inferred imf3(x_and_dx, 1) y_and_dy4 = @inferred imf4(x_and_dx; p=1) - =# - - y_and_dy1 = imf1(x_and_dx) - y_and_dy2, z2 = imf2(x_and_dx) - y_and_dy3 = imf3(x_and_dx, 1) - y_and_dy4 = imf4(x_and_dx; p=1) @testset "Dual numbers" begin @test ForwardDiff.value.(y_and_dy1) ≈ y_true @@ -169,14 +161,11 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} imf3 = make_implicit_sqrt_args(; kwargs...) imf4 = make_implicit_sqrt_kwargs(; kwargs...) - y_true = mysqrt(change_shape(x)) + y_true = mysqrt(x) dy = similar(y_true) dy .= one(eltype(y_true)) dz = nothing - #= - # TODO: fix AbstractDifferentiation.jl 0.6 - y1, pb1 = @inferred rrule(rc, imf1, x) (y2, z2), pb2 = @inferred rrule(rc, imf2, x) y3, pb3 = @inferred rrule(rc, imf3, x, 1) @@ -186,17 +175,6 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} dimf2, dx2 = @inferred pb2((dy, dz)) dimf3, dx3, dp3 = @inferred pb3(dy) dimf4, dx4 = @inferred pb4(dy) - =# - - y1, pb1 = rrule(rc, imf1, x) - (y2, z2), pb2 = rrule(rc, imf2, x) - y3, pb3 = rrule(rc, imf3, x, 1) - y4, pb4 = rrule(rc, imf4, x; p=1) - - dimf1, dx1 = pb1(dy) - dimf2, dx2 = pb2((dy, dz)) - dimf3, dx3, dp3 = pb3(dy) - dimf4, dx4 = pb4(dy) @testset "Pullbacks" begin @test y1 ≈ y_true @@ -272,7 +250,7 @@ function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T} J2 = ForwardDiff.jacobian(first ∘ imf2, x) J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x) J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x) - J_true = ForwardDiff.jacobian(_x -> sqrt.(change_shape(_x)), x) + J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) @testset "Exact Jacobian" begin @test J1 ≈ J_true @@ -298,7 +276,7 @@ function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T} J2 = Zygote.jacobian(first ∘ imf2, x)[1] J3 = Zygote.jacobian(imf3, x, 1)[1] J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1] - J_true = Zygote.jacobian(_x -> sqrt.(change_shape(_x)), x)[1] + J_true = Zygote.jacobian(_x -> sqrt.(_x), x)[1] @testset "Exact Jacobian" begin @test J1 ≈ J_true @@ -324,72 +302,29 @@ function test_implicit(x; kwargs...) test_implicit_duals(x; kwargs...) end end - @testset verbose = true "Zygote.jl" begin - rc = Zygote.ZygoteRuleConfig() - test_implicit_zygote(x; kwargs...) - test_implicit_rrule(rc, x; kwargs...) - end + # @testset verbose = true "Zygote.jl" begin + # rc = Zygote.ZygoteRuleConfig() + # test_implicit_zygote(x; kwargs...) + # test_implicit_rrule(rc, x; kwargs...) + # end return nothing end ## Parameter combinations -linear_solver_candidates = ( - IterativeLinearSolver(), # - DirectLinearSolver(), # -) - -conditions_backend_candidates = ( - nothing, # - AD.ForwardDiffBackend(), # - # AD.ZygoteBackend(), # TODO: failing - # AD.ReverseDiffBackend() # TODO: failing - # AD.FiniteDifferencesBackend() # TODO: failing -); - -x_candidates = ( - rand(Float32, 2, 3, 2), # - SArray{Tuple{2,3,2}}(rand(Float32, 2, 3, 2)), # - sparse(rand(Float32, 2)), # - sparse(rand(Float32, 2, 3)), # -); - -params_candidates = [] - -for linear_solver in linear_solver_candidates, x in x_candidates - push!( - params_candidates, (; - linear_solver=linear_solver, # - conditions_backend=nothing, # - x=x, # - ) - ) -end - -for conditions_backend in conditions_backend_candidates - push!( - params_candidates, - (; - linear_solver=linear_solver_candidates[1], # - conditions_backend=conditions_backend, # - x=x_candidates[1], # - ), - ) -end +linear_solver_candidates = (\, first ∘ Krylov.gmres) +backend_candidates = (nothing, AutoForwardDiff(; chunksize=1)); +x_candidates = (rand(Float32, 2), rand(Float64, 2)); ## Test loop -for (linear_solver, conditions_backend, x) in params_candidates - testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))" - if ( - linear_solver isa DirectLinearSolver && - x isa AbstractSparseArray && - VERSION < v"1.9" - ) # missing linalg function for sparse arrays in 1.6 - continue - end - @info "$testsetname" - @testset "$testsetname" begin - test_implicit(x; linear_solver, conditions_backend) - end +@testset "$linear_solver - $(typeof(backend)) - $(typeof(x))" for ( + linear_solver, backend, x +) in Iterators.product( + linear_solver_candidates, backend_candidates, x_candidates +) + @info "$linear_solver - $(typeof(backend)) - $(typeof(x))" + test_implicit( + x; linear_solver, conditions_x_backend=backend, conditions_y_backend=backend + ) end From 9695145df2fd46edb7fac0f725c5ba69283d4293 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 9 Apr 2024 16:50:11 +0200 Subject: [PATCH 3/8] Fix tests and start Enzyme --- Project.toml | 12 +- docs/Project.toml | 2 + docs/make.jl | 81 ++------- ...mplicitDifferentiationChainRulesCoreExt.jl | 39 +++-- ext/ImplicitDifferentiationEnzymeCoreExt.jl | 5 - ext/ImplicitDifferentiationEnzymeExt.jl | 61 +++++++ ext/ImplicitDifferentiationForwardDiffExt.jl | 19 +- src/ImplicitDifferentiation.jl | 2 +- src/implicit_function.jl | 38 ++-- src/operators.jl | 20 ++- test/errors.jl | 80 --------- test/runtests.jl | 4 - test/systematic.jl | 164 +++++++----------- 13 files changed, 221 insertions(+), 306 deletions(-) delete mode 100644 ext/ImplicitDifferentiationEnzymeCoreExt.jl create mode 100644 ext/ImplicitDifferentiationEnzymeExt.jl delete mode 100644 test/errors.jl diff --git a/Project.toml b/Project.toml index bafe5b4..8a20f6b 100644 --- a/Project.toml +++ b/Project.toml @@ -12,17 +12,17 @@ LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" +ImplicitDifferentiationEnzymeExt = "Enzyme" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" -ImplicitDifferentiationEnzymeCoreExt = "EnzymeCore" [compat] ChainRulesCore = "1.23.0" -EnzymeCore = "0.6.5" +Enzyme = "0.11.20" ForwardDiff = "0.10.36" Krylov = "0.9.5" LinearAlgebra = "1.10" @@ -35,8 +35,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" @@ -56,8 +57,9 @@ test = [ "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", + "DifferentiationInterface", "Documenter", - "FiniteDifferences", + "Enzyme", "ForwardDiff", "JET", "JuliaFormatter", diff --git a/docs/Project.toml b/docs/Project.toml index dbd70ee..b310c93 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,6 +1,8 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" diff --git a/docs/make.jl b/docs/make.jl index 5f8b71e..5c53f61 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,3 @@ -using ChainRulesCore: ChainRulesCore using Documenter using ForwardDiff: ForwardDiff using ImplicitDifferentiation @@ -11,36 +10,11 @@ DocMeta.setdocmeta!( ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true ) -base_url = "https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/" - -open(joinpath(@__DIR__, "src", "index.md"), "w") do io - # Point to source license file - println( - io, - """ - ```@meta - EditURL = "$(base_url)README.md" - ``` - """, - ) - # Write the contents out below the meta block - for line in eachline(joinpath(dirname(@__DIR__), "README.md")) - println(io, line) - end -end - -function markdown_title(path) - title = "?" - open(path, "r") do file - for line in eachline(file) - if startswith(line, '#') - title = strip(line, [' ', '#']) - break - end - end - end - return title -end +cp( + joinpath(dirname(@__DIR__), "README.md"), + joinpath(@__DIR__, "src", "index.md"); + force=true, +) EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") EXAMPLES_DIR_MD = joinpath(@__DIR__, "src", "examples") @@ -60,53 +34,22 @@ for file in readdir(EXAMPLES_DIR_JL) ) end -example_pages = Pair{String,String}[] -for file in sort(readdir(EXAMPLES_DIR_MD)) - if endswith(file, ".md") - title = markdown_title(joinpath(EXAMPLES_DIR_MD, file)) - path = joinpath("examples", file) - push!(example_pages, title => path) - end -end - pages = [ "Home" => "index.md", - "FAQ" => "faq.md", - "Examples" => example_pages, - "API reference" => "api.md", + "faq.md", + "Examples" => [joinpath("examples", file) for file in sort(readdir(EXAMPLES_DIR_MD))], + "api.md", ] -fmt = Documenter.HTML(; - prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://gdalle.github.io/ImplicitDifferentiation.jl", - assets=String[], - edit_link=:commit, -) - -if isdefined(Base, :get_extension) - extension_modules = [ - Base.get_extension(ID, :ImplicitDifferentiationChainRulesCoreExt), - Base.get_extension(ID, :ImplicitDifferentiationForwardDiffExt), - ] -else - extension_modules = [ - ID.ImplicitDifferentiationChainRulesCoreExt, - ID.ImplicitDifferentiationForwardDiffExt, - ] -end - makedocs(; modules=[ImplicitDifferentiation], authors="Guillaume Dalle, Mohamed Tarek and contributors", repo=Documenter.Remotes.GitHub("gdalle", "ImplicitDifferentiation.jl"), sitename="ImplicitDifferentiation.jl", - format=fmt, + format=Documenter.HTML(; + canonical="https://gdalle.github.io/ImplicitDifferentiation.jl" + ), pages=pages, - linkcheck=true, -) - -deploydocs(; - repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main", push_preview=true ) -rm(joinpath(@__DIR__, "src", "index.md")) +deploydocs(; repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main") diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index 26e2d66..cea6a32 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -3,12 +3,15 @@ module ImplicitDifferentiationChainRulesCoreExt using ADTypes: AbstractADType, AutoChainRules using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig using ChainRulesCore: rrule, rrule_via_ad, unthunk, @not_implemented -using ImplicitDifferentiation: ImplicitFunction, build_Aᵀ, build_Bᵀ, get_output -using LinearAlgebra: mul! +using ImplicitDifferentiation: ImplicitFunction, build_Aᵀ, build_Bᵀ, output function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractVector, args...; kwargs... -) + rc::RuleConfig, + implicit::ImplicitFunction, + x::AbstractVector, + args::Vararg{T,N}; + kwargs..., +) where {T,N} y_or_yz = implicit(x, args...; kwargs...) suggested_backend = AutoChainRules(rc) @@ -16,16 +19,30 @@ function ChainRulesCore.rrule( Bᵀ = build_Bᵀ(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) project_x = ProjectTo(x) - function implicit_pullback(dy_or_dydz) - dy = get_output(unthunk(dy_or_dydz)) - dc = implicit.linear_solver(Aᵀ, -dy) - dx = Bᵀ * dc - return (NoTangent(), project_x(dx), ntuple(unimplemented_tangent, length(args))...) - end - + implicit_pullback = ImplicitPullback( + Aᵀ, Bᵀ, implicit.linear_solver, project_x, Val{N}() + ) return y_or_yz, implicit_pullback end +struct ImplicitPullback{N,M1,M2,L,P} + Aᵀ::M1 + Bᵀ::M2 + linear_solver::L + project_x::P + nargs::Val{N} +end + +function (ip::ImplicitPullback{N})(dy_or_dydz) where {N} + (; Aᵀ, Bᵀ, linear_solver, project_x) = ip + dy = output(unthunk(dy_or_dydz)) + dc = linear_solver(Aᵀ, -dy) + dx = Bᵀ * dc + df = NoTangent() + dargs = ntuple(unimplemented_tangent, N) + return (df, project_x(dx), dargs...) +end + function unimplemented_tangent(_) return @not_implemented( "Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented" diff --git a/ext/ImplicitDifferentiationEnzymeCoreExt.jl b/ext/ImplicitDifferentiationEnzymeCoreExt.jl deleted file mode 100644 index 96580e5..0000000 --- a/ext/ImplicitDifferentiationEnzymeCoreExt.jl +++ /dev/null @@ -1,5 +0,0 @@ -module ImplicitDifferentiationEnzymeCoreExt - -using EnzymeCore - -end diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl new file mode 100644 index 0000000..1efb87f --- /dev/null +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -0,0 +1,61 @@ +module ImplicitDifferentiationEnzymeExt + +using ADTypes +using Enzyme +using Enzyme.EnzymeCore +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, output + +# https://discourse.julialang.org/t/can-i-define-a-type-unstable-enzymerule/112732 + +function EnzymeRules.forward( + func::Const{<:ImplicitFunction}, + RT::Type{<:Union{Duplicated,DuplicatedNoNeed}}, + x::Union{Duplicated,DuplicatedNoNeed}, +) + implicit = func.val + @info "My Duplicated rule is used" + y_or_yz = implicit(x.val) + y = output(y_or_yz) + + suggested_backend = AutoEnzyme(Enzyme.Forward) + A = build_A(implicit, x.val, y_or_yz; suggested_backend) + B = build_B(implicit, x.val, y_or_yz; suggested_backend) + + dc = B * x.dval + dy = implicit.linear_solver(A, -dc) + if RT <: Duplicated + return Duplicated(y, dy) + elseif RT <: DuplicatedNoNeed + return dy + end +end + +function EnzymeRules.forward( + func::Const{<:ImplicitFunction}, + RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}}, + x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, +) where {T,N} + implicit = func.val + @info "My BatchDuplicated rule is used" + y_or_yz = implicit(x.val) + y = output(y_or_yz) + + suggested_backend = AutoEnzyme(Enzyme.Forward) + A = build_A(implicit, x.val, y_or_yz; suggested_backend) + B = build_B(implicit, x.val, y_or_yz; suggested_backend) + + dX = reduce(hcat, x.dval) + dC = mareduce(hcat, eachcol(dX)) do dₖx + B * dₖx + end + dY = implicit.linear_solver(A, -dC) + + dy = ntuple(k -> dY[:, k], Val(N)) + if RT <: BatchDuplicated + return BatchDuplicated(y, dy) + elseif RT <: BatchDuplicatedNoNeed + return dy + end +end + +end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 6a899f0..376f64d 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -2,7 +2,7 @@ module ImplicitDifferentiationForwardDiffExt using ADTypes: AutoForwardDiff using ForwardDiff: Chunk, Dual, Partials, jacobian, partials, value -using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, get_byproduct, get_output +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output chunksize(::Chunk{N}) where {N} = N @@ -11,25 +11,26 @@ function (implicit::ImplicitFunction)( ) where {T,R,N} x = value.(x_and_dx) y_or_yz = implicit(x, args...; kwargs...) - y = get_output(y_or_yz) + y = output(y_or_yz) suggested_backend = AutoForwardDiff{1,Nothing}(nothing) A = build_A(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) B = build_B(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) - dy = ntuple(Val(N)) do k - dₖx = partials.(x_and_dx, k) - dₖc = B * dₖx - dₖy = implicit.linear_solver(A, -dₖc) - return dₖy + dX = mapreduce(hcat, 1:N) do k + partials.(x_and_dx, k) end + dC = mapreduce(hcat, eachcol(dX)) do dₖx + B * dₖx + end + dY = implicit.linear_solver(A, -dC) y_and_dy = map(eachindex(y)) do i - Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N)))) + Dual{T}(y[i], Partials(ntuple(k -> dY[i, k], Val(N)))) end if y_or_yz isa Tuple - return y_and_dy, get_byproduct(y_or_yz) + return y_and_dy, byproduct(y_or_yz) else return y_and_dy end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index ddc6a75..07c6d7b 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -14,7 +14,7 @@ using DifferentiationInterface: prepare_pullback, pushforward!!, value_and_pullback!!_split -using Krylov: gmres +using Krylov: block_gmres, gmres using LinearOperators: LinearOperator using LinearAlgebra: factorize, lu diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 872b3aa..a4124bf 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,3 +1,18 @@ +struct DefaultLinearSolver end + +function (::DefaultLinearSolver)(A, b::AbstractVector) + x, stats = gmres(A, b) + return x +end + +function (::DefaultLinearSolver)(A, B::AbstractMatrix) + # X, stats = block_gmres(A, B) # https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/854 + X = mapreduce(hcat, eachcol(B)) do b + first(gmres(A, b)) + end + return X +end + """ ImplicitFunction @@ -14,7 +29,9 @@ This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = - `forward`: a callable, does not need to be compatible with automatic differentiation - `conditions`: a callable, must be compatible with automatic differentiation -- `linear_solver`: a subtype of `AbstractLinearSolver`, defines how the linear system will be solved +- `linear_solver`: a callable with two methods: + - `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b` + - `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B` - `conditions_x_backend`: either `nothing` or a subtype of `ADTypes.AbstractADType`, defines how the conditions will be differentiated with respect to the first argument `x` - `conditions_y_backend`: same for the second argument `y` @@ -23,7 +40,7 @@ There are two possible signatures for `forward` and `conditions`, which must be 1. Standard: `forward(x, args...; kwargs...) = y` and `conditions(x, y, args...; kwargs...) = c` 2. Byproduct: `forward(x, args...; kwargs...) = (y, z)` and `conditions(x, y, z, args...; kwargs...) = c`. -In both cases, `x`, `y` and `c` must be arrays, with `size(y) = size(c)`. +In both cases, `x`, `y` and `c` must be vectors, with `length(y) = length(c)`. In the second case, the byproduct `z` can be an arbitrary object generated by `forward`. The positional arguments `args...` and keyword arguments `kwargs...` must be the same for both `forward` and `conditions`. @@ -35,7 +52,7 @@ The positional arguments `args...` and keyword arguments `kwargs...` must be the } forward::F conditions::C - linear_solver::L = first ∘ gmres + linear_solver::L = DefaultLinearSolver() conditions_x_backend::B1 = nothing conditions_y_backend::B2 = nothing end @@ -56,22 +73,17 @@ end """ (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...) -Return `implicit.forward(x, args...; kwargs...)`, which can be either an array `y` or a tuple `(y, z)`. +Return `implicit.forward(x, args...; kwargs...)`, which can be either a vector `y` or a tuple `(y, z)`. This call is differentiable (except for `z`). """ function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) y_or_yz = implicit.forward(x, args...; kwargs...) - if !(y_or_yz isa Union{AbstractArray,Tuple{<:AbstractVector,<:Any}}) - error( - "The forward mapping must return a vector `y` or a tuple `(y, z)` where `y` is a vector", - ) - end return y_or_yz end -get_output(y::AbstractVector) = y -get_byproduct(::AbstractVector) = error("No byproduct") +output(y::AbstractVector) = y +byproduct(::AbstractVector) = error("No byproduct") -get_output((y, z)) = y -get_byproduct((y, z)) = z +output((y, z)) = y +byproduct((y, z)) = z diff --git a/src/operators.jl b/src/operators.jl index ed31dc7..ecbe375 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -51,9 +51,9 @@ function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector) end function ConditionsX(conditions, x, y_or_yz, args...; kwargs...) - y = get_output(y_or_yz) + y = output(y_or_yz) if y_or_yz isa Tuple - z = get_byproduct(y_or_yz) + z = byproduct(y_or_yz) return ConditionsXByproduct(conditions, y, z, args, kwargs) else return ConditionsXNoByproduct(conditions, y, args, kwargs) @@ -62,7 +62,7 @@ end function ConditionsY(conditions, x, y_or_yz, args...; kwargs...) if y_or_yz isa Tuple - z = get_byproduct(y_or_yz) + z = byproduct(y_or_yz) return ConditionsYByproduct(conditions, x, z, args, kwargs) else return ConditionsYNoByproduct(conditions, x, args, kwargs) @@ -79,7 +79,8 @@ struct PushforwardOperator!{F,B,X,E} end function (po::PushforwardOperator!)(res, v) - return res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + return res end struct PullbackOperator!{PB} @@ -87,7 +88,8 @@ struct PullbackOperator!{PB} end function (po::PullbackOperator!)(res, v) - return res .= po.pullbackfunc!!(res, v) + res .= po.pullbackfunc!!(res, v) + return res end function build_A( @@ -99,7 +101,7 @@ function build_A( kwargs..., ) (; conditions, linear_solver, conditions_y_backend) = implicit - y = get_output(y_or_yz) + y = output(y_or_yz) n, m = length(x), length(y) back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) @@ -130,7 +132,7 @@ function build_Aᵀ( kwargs..., ) (; conditions, linear_solver, conditions_y_backend) = implicit - y = get_output(y_or_yz) + y = output(y_or_yz) n, m = length(x), length(y) back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) @@ -156,7 +158,7 @@ function build_B( kwargs..., ) (; conditions, linear_solver, conditions_x_backend) = implicit - y = get_output(y_or_yz) + y = output(y_or_yz) n, m = length(x), length(y) back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) @@ -186,7 +188,7 @@ function build_Bᵀ( kwargs..., ) (; conditions, linear_solver, conditions_x_backend) = implicit - y = get_output(y_or_yz) + y = output(y_or_yz) n, m = length(x), length(y) back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) diff --git a/test/errors.jl b/test/errors.jl deleted file mode 100644 index 0a729f4..0000000 --- a/test/errors.jl +++ /dev/null @@ -1,80 +0,0 @@ -using ChainRulesCore -using ChainRulesTestUtils -using ForwardDiff -using ImplicitDifferentiation -using Test -using Zygote - -@testset "Byproduct handling" begin - f1 = (_) -> (1, 2) - f2 = (_) -> ([1.0], 2, 3) - c = nothing - imf1 = ImplicitFunction(f1, c) - imf2 = ImplicitFunction(f2, c) - @test_throws DimensionMismatch imf1(zeros(1)) - @test_throws DimensionMismatch imf2(zeros(1)) -end - -@testset "Only accept one array" begin - f = identity - c = nothing - imf = ImplicitFunction(f, c) - @test_throws MethodError imf((1.0,)) - @test_throws MethodError imf([1.0], [1.0]) -end - -@testset verbose = true "Derivative NaNs" begin - x = zeros(Float32, 2) - linear_solvers = ( - IterativeLinearSolver(; verbose=false), # - IterativeLinearSolver(; verbose=false, accept_inconsistent=true), # - DirectLinearSolver(; verbose=false), # - ) - function should_give_nan(linear_solver) - return linear_solver isa DirectLinearSolver || !linear_solver.accept_inconsistent - end - - @testset "Infinite derivative" begin - f = x -> sqrt.(x) # nondifferentiable at 0 - c = (x, y) -> y .^ 2 .- x - for linear_solver in linear_solvers - @testset "$(typeof(linear_solver))" begin - implicit = ImplicitFunction(f, c; linear_solver) - J1 = ForwardDiff.jacobian(implicit, x) - J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) == should_give_nan(linear_solver) - @test all(isnan, J2) == should_give_nan(linear_solver) - @test eltype(J1) == Float32 - @test eltype(J2) == Float32 - end - end - end - - @testset "Singular linear system" begin - f = x -> x # wrong solver - c = (x, y) -> (x .+ 1) .^ 2 .- y .^ 2 - for linear_solver in linear_solvers - @testset "$(typeof(linear_solver))" begin - implicit = ImplicitFunction(f, c; linear_solver) - J1 = ForwardDiff.jacobian(implicit, x) - J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) == should_give_nan(linear_solver) - @test all(isnan, J2) == should_give_nan(linear_solver) - @test eltype(J1) == Float32 - @test eltype(J2) == Float32 - end - end - end -end - -@testset "Weird ChainRulesTestUtils behavior" begin - x = rand(3) - forward(x) = sqrt.(abs.(x)), 1 - conditions(x, y, z) = abs.(y ./ z) .- abs.(x) - implicit = ImplicitFunction(forward, conditions) - y, z = implicit(x) - dy = similar(y) - rc = Zygote.ZygoteRuleConfig() - test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0)) - @test_skip test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent())) -end diff --git a/test/runtests.jl b/test/runtests.jl index bb662fa..3a6b739 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,10 +57,6 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") end end end - @testset verbose = true "Errors" begin - @info "Error tests" - include("errors.jl") - end @testset verbose = true "Systematic" begin @info "Systematic tests" include("systematic.jl") diff --git a/test/systematic.jl b/test/systematic.jl index 271054f..aef236b 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,6 +1,8 @@ using ADTypes using ChainRulesCore using ChainRulesTestUtils +using DifferentiationInterface: DifferentiationInterface +using Enzyme: Enzyme using ForwardDiff: ForwardDiff import ImplicitDifferentiation as ID using ImplicitDifferentiation: ImplicitFunction @@ -13,12 +15,14 @@ using StaticArrays using Test using Zygote: Zygote, ZygoteRuleConfig +## + Random.seed!(63); -function identity_break_autodiff(x) - a = [0.0] +function identity_break_autodiff(x::AbstractVector{R}) where {R} + a = [zero(R)] a[1] = float(first(x)) - return x + return copy(x) end function mysqrt(x::AbstractVector) @@ -71,17 +75,17 @@ function coherent_array_type(a, b) end end -function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} +function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) imf4 = make_implicit_sqrt_kwargs(; kwargs...) y_true = mysqrt(x) - y1 = @inferred imf1(x) - y2, z2 = @inferred imf2(x) - y3 = @inferred imf3(x, 1) - y4 = @inferred imf4(x; p=1) + y1 = imf1(x) + y2, z2 = imf2(x) + y3 = imf3(x, 1) + y4 = imf4(x; p=1) @testset "Exact value" begin @test y1 ≈ y_true @@ -98,20 +102,15 @@ function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} @test coherent_array_type(x, y4) end - @testset "JET" begin + @testset "Type stability" begin @test_opt target_modules = (ID,) imf1(x) @test_opt target_modules = (ID,) imf2(x) @test_opt target_modules = (ID,) imf3(x, 1) @test_opt target_modules = (ID,) imf4(x; p=1) - - @test_call target_modules = (ID,) imf1(x) - @test_call target_modules = (ID,) imf2(x) - @test_call target_modules = (ID,) imf3(x, 1) - @test_call target_modules = (ID,) imf4(x; p=1) end end -function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} +function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) @@ -122,10 +121,10 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} dx .= one(T) x_and_dx = ForwardDiff.Dual.(x, dx) - y_and_dy1 = @inferred imf1(x_and_dx) - y_and_dy2, z2 = @inferred imf2(x_and_dx) - y_and_dy3 = @inferred imf3(x_and_dx, 1) - y_and_dy4 = @inferred imf4(x_and_dx; p=1) + y_and_dy1 = imf1(x_and_dx) + y_and_dy2, z2 = imf2(x_and_dx) + y_and_dy3 = imf3(x_and_dx, 1) + y_and_dy4 = imf4(x_and_dx; p=1) @testset "Dual numbers" begin @test ForwardDiff.value.(y_and_dy1) ≈ y_true @@ -135,27 +134,22 @@ function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} @test z2 ≈ 1 end - @testset "Static arrays" begin + @testset "Array types" begin @test coherent_array_type(x, y_and_dy1) @test coherent_array_type(x, y_and_dy2) @test coherent_array_type(x, y_and_dy3) @test coherent_array_type(x, y_and_dy4) end - @testset "JET" begin - @test_opt target_modules = (ID,) imf1(x_and_dx) - @test_opt target_modules = (ID,) imf2(x_and_dx) - @test_opt target_modules = (ID,) imf3(x_and_dx, 1) - @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) - - @test_call target_modules = (ID,) imf1(x_and_dx) - @test_call target_modules = (ID,) imf2(x_and_dx) - @test_call target_modules = (ID,) imf3(x_and_dx, 1) - @test_call target_modules = (ID,) imf4(x_and_dx; p=1) + @testset "Type stability" begin + @test_skip @test_opt target_modules = (ID,) imf1(x_and_dx) + @test_skip @test_opt target_modules = (ID,) imf2(x_and_dx) + @test_skip @test_opt target_modules = (ID,) imf3(x_and_dx, 1) + @test_skip @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) end end -function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} +function test_implicit_rrule(rc, x::AbstractVector{T}; kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) @@ -166,15 +160,15 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} dy .= one(eltype(y_true)) dz = nothing - y1, pb1 = @inferred rrule(rc, imf1, x) - (y2, z2), pb2 = @inferred rrule(rc, imf2, x) - y3, pb3 = @inferred rrule(rc, imf3, x, 1) - y4, pb4 = @inferred rrule(rc, imf4, x; p=1) + y1, pb1 = rrule(rc, imf1, x) + (y2, z2), pb2 = rrule(rc, imf2, x) + y3, pb3 = rrule(rc, imf3, x, 1) + y4, pb4 = rrule(rc, imf4, x; p=1) - dimf1, dx1 = @inferred pb1(dy) - dimf2, dx2 = @inferred pb2((dy, dz)) - dimf3, dx3, dp3 = @inferred pb3(dy) - dimf4, dx4 = @inferred pb4(dy) + dimf1, dx1 = pb1(dy) + dimf2, dx2 = pb2((dy, dz)) + dimf3, dx3, dp3 = pb3(dy) + dimf4, dx4 = pb4(dy) @testset "Pullbacks" begin @test y1 ≈ y_true @@ -208,7 +202,7 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @test coherent_array_type(x, dx4) end - @testset "JET" begin + @testset "Type stability" begin @test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x) @test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x) @test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) @@ -218,16 +212,6 @@ function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} @test_skip @test_opt target_modules = (ID,) pb2((dy, dz)) @test_skip @test_opt target_modules = (ID,) pb3(dy) @test_skip @test_opt target_modules = (ID,) pb4(dy) - - @test_call target_modules = (ID,) rrule(rc, imf1, x) - @test_call target_modules = (ID,) rrule(rc, imf2, x) - @test_call target_modules = (ID,) rrule(rc, imf3, x, 1) - @test_call target_modules = (ID,) rrule(rc, imf4, x; p=1) - - @test_call target_modules = (ID,) pb1(dy) - @test_call target_modules = (ID,) pb2((dy, dz)) - @test_call target_modules = (ID,) pb3(dy) - @test_call target_modules = (ID,) pb4(dy) end @testset "ChainRulesTestUtils" begin @@ -240,16 +224,18 @@ end ## High-level tests per backend -function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T} +function test_implicit_backend( + backend::ADTypes.AbstractADType, x::AbstractVector{T}; kwargs... +) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) imf4 = make_implicit_sqrt_kwargs(; kwargs...) - J1 = ForwardDiff.jacobian(imf1, x) - J2 = ForwardDiff.jacobian(first ∘ imf2, x) - J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x) - J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x) + J1 = DifferentiationInterface.jacobian(imf1, backend, x) + J2 = DifferentiationInterface.jacobian(first ∘ imf2, backend, x) + J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, 1), backend, x) + J4 = DifferentiationInterface.jacobian(_x -> imf4(_x; p=1), backend, x) J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) @testset "Exact Jacobian" begin @@ -266,65 +252,43 @@ function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T} return nothing end -function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - J1 = Zygote.jacobian(imf1, x)[1] - J2 = Zygote.jacobian(first ∘ imf2, x)[1] - J3 = Zygote.jacobian(imf3, x, 1)[1] - J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1] - J_true = Zygote.jacobian(_x -> sqrt.(_x), x)[1] - - @testset "Exact Jacobian" begin - @test J1 ≈ J_true - @test J2 ≈ J_true - @test J3 ≈ J_true - @test J4 ≈ J_true - - @test eltype(J1) == eltype(x) - @test eltype(J2) == eltype(x) - @test eltype(J3) == eltype(x) - @test eltype(J4) == eltype(x) - end - return nothing -end - -function test_implicit(x; kwargs...) +function test_implicit(backends, x; kwargs...) @testset verbose = true "Call" begin test_implicit_call(x; kwargs...) end - @testset verbose = true "ForwardDiff.jl" begin - if !(x isa AbstractSparseArray) - test_implicit_forwarddiff(x; kwargs...) - test_implicit_duals(x; kwargs...) - end + @testset verbose = true "Duals" begin + test_implicit_duals(x; kwargs...) + end + @testset verbose = true "ChainRule" begin + test_implicit_rrule(ZygoteRuleConfig(), x; kwargs...) + end + @testset "$backend" for backend in backends + test_implicit_backend(backend, x; kwargs...) end - # @testset verbose = true "Zygote.jl" begin - # rc = Zygote.ZygoteRuleConfig() - # test_implicit_zygote(x; kwargs...) - # test_implicit_rrule(rc, x; kwargs...) - # end return nothing end ## Parameter combinations -linear_solver_candidates = (\, first ∘ Krylov.gmres) -backend_candidates = (nothing, AutoForwardDiff(; chunksize=1)); +backends = [AutoForwardDiff(; chunksize=1), AutoZygote()] + +linear_solver_candidates = (\, ID.DefaultLinearSolver()) +conditions_backend_candidates = (nothing, AutoForwardDiff(; chunksize=1)); x_candidates = (rand(Float32, 2), rand(Float64, 2)); ## Test loop -@testset "$linear_solver - $(typeof(backend)) - $(typeof(x))" for ( - linear_solver, backend, x +@testset "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" for ( + linear_solver, conditions_backend, x ) in Iterators.product( - linear_solver_candidates, backend_candidates, x_candidates + linear_solver_candidates, conditions_backend_candidates, x_candidates ) - @info "$linear_solver - $(typeof(backend)) - $(typeof(x))" + @info "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" test_implicit( - x; linear_solver, conditions_x_backend=backend, conditions_y_backend=backend + backends, + x; + linear_solver, + conditions_x_backend=conditions_backend, + conditions_y_backend=conditions_backend, ) -end +end; From 6d11e7ca37323a2e727d94569b3e5d6202b19af6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 10 Apr 2024 07:51:39 +0200 Subject: [PATCH 4/8] More tests --- ext/ImplicitDifferentiationEnzymeExt.jl | 6 +- test/systematic.jl | 159 +++++++++++++++--------- 2 files changed, 101 insertions(+), 64 deletions(-) diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index 1efb87f..f058d2a 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -5,8 +5,6 @@ using Enzyme using Enzyme.EnzymeCore using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, output -# https://discourse.julialang.org/t/can-i-define-a-type-unstable-enzymerule/112732 - function EnzymeRules.forward( func::Const{<:ImplicitFunction}, RT::Type{<:Union{Duplicated,DuplicatedNoNeed}}, @@ -22,7 +20,7 @@ function EnzymeRules.forward( B = build_B(implicit, x.val, y_or_yz; suggested_backend) dc = B * x.dval - dy = implicit.linear_solver(A, -dc) + dy = convert(typeof(y), implicit.linear_solver(A, -dc)) if RT <: Duplicated return Duplicated(y, dy) elseif RT <: DuplicatedNoNeed @@ -50,7 +48,7 @@ function EnzymeRules.forward( end dY = implicit.linear_solver(A, -dC) - dy = ntuple(k -> dY[:, k], Val(N)) + dy = convert(NTuple{N,typeof(y)}, ntuple(k -> dY[:, k], Val(N))) if RT <: BatchDuplicated return BatchDuplicated(y, dy) elseif RT <: BatchDuplicatedNoNeed diff --git a/test/systematic.jl b/test/systematic.jl index aef236b..d2d5ba8 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -20,9 +20,13 @@ using Zygote: Zygote, ZygoteRuleConfig Random.seed!(63); function identity_break_autodiff(x::AbstractVector{R}) where {R} - a = [zero(R)] - a[1] = float(first(x)) - return copy(x) + float(first(x)) # break ForwardDiff + (Vector{R}(undef, 1))[1] = first(x) # break Zygote + try + throw(copy(x)) # break Enzyme + catch y + return y + end end function mysqrt(x::AbstractVector) @@ -61,21 +65,21 @@ end ## Low level tests -function coherent_array_type(a, b) +function test_coherent_array_type(a, b) + @test eltype(a) == eltype(b) if a isa Array - return b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) + @test b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) elseif a isa StaticArray - return b isa StaticArray || - b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) + @test b isa StaticArray || b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) elseif a isa AbstractSparseArray - return b isa AbstractSparseArray || - b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) + @test b isa AbstractSparseArray || + b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) else error("New array type") end end -function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} +function test_implicit_call(x::AbstractVector{T}; type_stability=false, kwargs...) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) @@ -87,7 +91,7 @@ function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} y3 = imf3(x, 1) y4 = imf4(x; p=1) - @testset "Exact value" begin + @testset "Primal value" begin @test y1 ≈ y_true @test y2 ≈ y_true @test y3 ≈ y_true @@ -96,21 +100,27 @@ function test_implicit_call(x::AbstractVector{T}; kwargs...) where {T} end @testset "Array type" begin - @test coherent_array_type(x, y1) - @test coherent_array_type(x, y2) - @test coherent_array_type(x, y3) - @test coherent_array_type(x, y4) + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) end - @testset "Type stability" begin - @test_opt target_modules = (ID,) imf1(x) - @test_opt target_modules = (ID,) imf2(x) - @test_opt target_modules = (ID,) imf3(x, 1) - @test_opt target_modules = (ID,) imf4(x; p=1) + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x) + @test_opt target_modules = (ID,) imf2(x) + @test_opt target_modules = (ID,) imf3(x, 1) + @test_opt target_modules = (ID,) imf4(x; p=1) + end end end -function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} +tag(::AbstractVector{<:ForwardDiff.Dual{T}}) where {T} = T + +function test_implicit_duals( + x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) @@ -118,7 +128,7 @@ function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} y_true = mysqrt(x) dx = similar(x) - dx .= one(T) + dx .= 2 * one(T) x_and_dx = ForwardDiff.Dual.(x, dx) y_and_dy1 = imf1(x_and_dx) @@ -131,25 +141,37 @@ function test_implicit_duals(x::AbstractVector{T}; kwargs...) where {T} @test ForwardDiff.value.(y_and_dy2) ≈ y_true @test ForwardDiff.value.(y_and_dy3) ≈ y_true @test ForwardDiff.value.(y_and_dy4) ≈ y_true + @test ForwardDiff.extract_derivative(tag(y_and_dy1), y_and_dy1) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy2), y_and_dy2) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy3), y_and_dy3) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy4), y_and_dy4) ≈ + 2 .* inv.(2 .* sqrt.(x)) @test z2 ≈ 1 end @testset "Array types" begin - @test coherent_array_type(x, y_and_dy1) - @test coherent_array_type(x, y_and_dy2) - @test coherent_array_type(x, y_and_dy3) - @test coherent_array_type(x, y_and_dy4) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy1)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy2)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy3)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy4)) end - @testset "Type stability" begin - @test_skip @test_opt target_modules = (ID,) imf1(x_and_dx) - @test_skip @test_opt target_modules = (ID,) imf2(x_and_dx) - @test_skip @test_opt target_modules = (ID,) imf3(x_and_dx, 1) - @test_skip @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x_and_dx) + @test_opt target_modules = (ID,) imf2(x_and_dx) + @test_opt target_modules = (ID,) imf3(x_and_dx, 1) + @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) + end end end -function test_implicit_rrule(rc, x::AbstractVector{T}; kwargs...) where {T} +function test_implicit_rrule( + rc, x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) imf3 = make_implicit_sqrt_args(; kwargs...) @@ -191,27 +213,15 @@ function test_implicit_rrule(rc, x::AbstractVector{T}; kwargs...) where {T} end @testset "Array type" begin - @test coherent_array_type(x, y1) - @test coherent_array_type(x, y2) - @test coherent_array_type(x, y3) - @test coherent_array_type(x, y4) - - @test coherent_array_type(x, dx1) - @test coherent_array_type(x, dx2) - @test coherent_array_type(x, dx3) - @test coherent_array_type(x, dx4) - end - - @testset "Type stability" begin - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) - - @test_skip @test_opt target_modules = (ID,) pb1(dy) - @test_skip @test_opt target_modules = (ID,) pb2((dy, dz)) - @test_skip @test_opt target_modules = (ID,) pb3(dy) - @test_skip @test_opt target_modules = (ID,) pb4(dy) + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) + + test_coherent_array_type(x, dx1) + test_coherent_array_type(x, dx2) + test_coherent_array_type(x, dx3) + test_coherent_array_type(x, dx4) end @testset "ChainRulesTestUtils" begin @@ -220,12 +230,26 @@ function test_implicit_rrule(rc, x::AbstractVector{T}; kwargs...) where {T} test_rrule(rc, imf3, x, 1; atol=1e-2, check_inferred=false) test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,), check_inferred=false) end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) rrule(rc, imf1, x) + @test_opt target_modules = (ID,) rrule(rc, imf2, x) + @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) + @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) + + @test_opt target_modules = (ID,) pb1(dy) + @test_opt target_modules = (ID,) pb2((dy, dz)) + @test_opt target_modules = (ID,) pb3(dy) + @test_opt target_modules = (ID,) pb4(dy) + end + end end ## High-level tests per backend function test_implicit_backend( - backend::ADTypes.AbstractADType, x::AbstractVector{T}; kwargs... + backend::ADTypes.AbstractADType, x::AbstractVector{T}; type_stability=false, kwargs... ) where {T} imf1 = make_implicit_sqrt(; kwargs...) imf2 = make_implicit_sqrt_byproduct(; kwargs...) @@ -252,7 +276,7 @@ function test_implicit_backend( return nothing end -function test_implicit(backends, x; kwargs...) +function test_implicit(backends, x; type_stability=false, kwargs...) @testset verbose = true "Call" begin test_implicit_call(x; kwargs...) end @@ -270,15 +294,30 @@ end ## Parameter combinations -backends = [AutoForwardDiff(; chunksize=1), AutoZygote()] +backends = [ + AutoForwardDiff(; chunksize=1), # + # AutoEnzyme(Enzyme.Forward), + AutoZygote(), +] + +linear_solver_candidates = ( + \, # + ID.DefaultLinearSolver(), +) + +conditions_backend_candidates = ( + nothing, # + AutoForwardDiff(; chunksize=1), + # AutoEnzyme(Enzyme.Forward), +); -linear_solver_candidates = (\, ID.DefaultLinearSolver()) -conditions_backend_candidates = (nothing, AutoForwardDiff(; chunksize=1)); -x_candidates = (rand(Float32, 2), rand(Float64, 2)); +x_candidates = ( + rand(Float32, 2), # +); ## Test loop -@testset "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" for ( +@testset verbose = true "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" for ( linear_solver, conditions_backend, x ) in Iterators.product( linear_solver_candidates, conditions_backend_candidates, x_candidates From 2819cd24414738f4ee890a3a3697ebba4836c254 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 10 Apr 2024 10:55:03 +0200 Subject: [PATCH 5/8] Fix ENzyme --- README.md | 1 + docs/make.jl | 5 +- ext/ImplicitDifferentiationEnzymeExt.jl | 74 ++++++++++++------------- src/implicit_function.jl | 3 + test/systematic.jl | 51 ++++++++++------- 5 files changed, 73 insertions(+), 61 deletions(-) diff --git a/README.md b/README.md index fc10902..b513215 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ Please read the [documentation](https://gdalle.github.io/ImplicitDifferentiation In Julia: +- [SciML](https://sciml.ai/) ecosystem, especially [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl), [NonlinearSolve.jl](https://github.com/SciML/NonlinearSolve.jl) and [Optimization.jl](https://github.com/SciML/Optimization.jl) - [jump-dev/DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiation of convex optimization problems - [axelparmentier/InferOpt.jl](https://github.com/axelparmentier/InferOpt.jl): approximate differentiation of combinatorial optimization problems - [JuliaNonconvex/NonconvexUtils.jl](https://github.com/JuliaNonconvex/NonconvexUtils.jl): contains the original implementation from which this package drew inspiration diff --git a/docs/make.jl b/docs/make.jl index 5c53f61..de7c136 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -37,7 +37,10 @@ end pages = [ "Home" => "index.md", "faq.md", - "Examples" => [joinpath("examples", file) for file in sort(readdir(EXAMPLES_DIR_MD))], + "Examples" => [ + joinpath("examples", file) for + file in sort(readdir(EXAMPLES_DIR_MD)) if endswith(file, ".md") + ], "api.md", ] diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index f058d2a..2db9426 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -3,56 +3,52 @@ module ImplicitDifferentiationEnzymeExt using ADTypes using Enzyme using Enzyme.EnzymeCore -using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, output - -function EnzymeRules.forward( - func::Const{<:ImplicitFunction}, - RT::Type{<:Union{Duplicated,DuplicatedNoNeed}}, - x::Union{Duplicated,DuplicatedNoNeed}, -) - implicit = func.val - @info "My Duplicated rule is used" - y_or_yz = implicit(x.val) - y = output(y_or_yz) - - suggested_backend = AutoEnzyme(Enzyme.Forward) - A = build_A(implicit, x.val, y_or_yz; suggested_backend) - B = build_B(implicit, x.val, y_or_yz; suggested_backend) - - dc = B * x.dval - dy = convert(typeof(y), implicit.linear_solver(A, -dc)) - if RT <: Duplicated - return Duplicated(y, dy) - elseif RT <: DuplicatedNoNeed - return dy - end -end +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output function EnzymeRules.forward( func::Const{<:ImplicitFunction}, RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}}, - x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, -) where {T,N} + func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, + func_args::Vararg{Const,P}, +) where {T,N,P} + @info "My BatchDuplicated rule is used" RT typeof(func_x) typeof(func_args) implicit = func.val - @info "My BatchDuplicated rule is used" - y_or_yz = implicit(x.val) + args = map(a -> a.val, func_args) + x = func_x.val + dx = func_x.dval + + y_or_yz = implicit(x, args...) y = output(y_or_yz) + Y = typeof(y) suggested_backend = AutoEnzyme(Enzyme.Forward) - A = build_A(implicit, x.val, y_or_yz; suggested_backend) - B = build_B(implicit, x.val, y_or_yz; suggested_backend) + A = build_A(implicit, x, y_or_yz, args...; suggested_backend) + B = build_B(implicit, x, y_or_yz, args...; suggested_backend) - dX = reduce(hcat, x.dval) - dC = mareduce(hcat, eachcol(dX)) do dₖx + dx_batch = reduce(hcat, dx) + dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx B * dₖx end - dY = implicit.linear_solver(A, -dC) - - dy = convert(NTuple{N,typeof(y)}, ntuple(k -> dY[:, k], Val(N))) - if RT <: BatchDuplicated - return BatchDuplicated(y, dy) - elseif RT <: BatchDuplicatedNoNeed - return dy + dy_batch = implicit.linear_solver(A, -dc_batch) + + dy::NTuple{N,Y} = ntuple(k -> convert(Y, dy_batch[:, k]), Val(N)) + + if y_or_yz isa AbstractArray + if RT <: BatchDuplicated + return BatchDuplicated(y, dy) + elseif RT <: BatchDuplicatedNoNeed + return dy + end + elseif y_or_yz isa Tuple + yz = y_or_yz + z = byproduct(yz) + Z = typeof(z) + dyz::NTuple{N,Tuple{Y,Z}} = ntuple(k -> (dy[k], make_zero(z)), Val(N)) + if RT <: BatchDuplicated + return BatchDuplicated(yz, dyz) + elseif RT <: BatchDuplicatedNoNeed + return dyz + end end end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index a4124bf..279f768 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -85,5 +85,8 @@ end output(y::AbstractVector) = y byproduct(::AbstractVector) = error("No byproduct") +output(yz::Tuple{<:Any,<:Any}) = yz[1] +byproduct(yz::Tuple{<:Any,<:Any}) = yz[2] + output((y, z)) = y byproduct((y, z)) = z diff --git a/test/systematic.jl b/test/systematic.jl index d2d5ba8..5c4fe21 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -19,14 +19,15 @@ using Zygote: Zygote, ZygoteRuleConfig Random.seed!(63); -function identity_break_autodiff(x::AbstractVector{R}) where {R} +function identity_break_autodiff(x::X)::X where {R,X<:AbstractVector{R}} float(first(x)) # break ForwardDiff (Vector{R}(undef, 1))[1] = first(x) # break Zygote - try - throw(copy(x)) # break Enzyme + result = try + throw(copy(x)) catch y - return y + y end + return result end function mysqrt(x::AbstractVector) @@ -43,22 +44,22 @@ function make_implicit_sqrt(; kwargs...) end function make_implicit_sqrt_byproduct(; kwargs...) - forward(x) = 1 * mysqrt(x), 1 - conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(x) + forward(x) = one(eltype(x)) .* mysqrt(x), one(eltype(x)) + conditions(x, y, z) = abs2.(y ./ z) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_args(; kwargs...) - forward(x, p::Integer) = p * mysqrt(x) - conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(x) + forward(x, p) = p .* mysqrt(x) + conditions(x, y, p) = abs2.(y ./ p) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end function make_implicit_sqrt_kwargs(; kwargs...) - forward(x; p::Integer) = p .* mysqrt(x) - conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(x) + forward(x; p) = p .* mysqrt(x) + conditions(x, y; p) = abs2.(y ./ p) .- abs.(x) implicit = ImplicitFunction(forward, conditions; kwargs...) return implicit end @@ -258,20 +259,29 @@ function test_implicit_backend( J1 = DifferentiationInterface.jacobian(imf1, backend, x) J2 = DifferentiationInterface.jacobian(first ∘ imf2, backend, x) - J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, 1), backend, x) - J4 = DifferentiationInterface.jacobian(_x -> imf4(_x; p=1), backend, x) + J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, one(eltype(x))), backend, x) + + J4 = if !(backend isa AutoEnzyme) + DifferentiationInterface.jacobian(_x -> imf4(_x; p=one(eltype(x))), backend, x) + else + nothing + end + J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) @testset "Exact Jacobian" begin @test J1 ≈ J_true @test J2 ≈ J_true @test J3 ≈ J_true - @test J4 ≈ J_true @test eltype(J1) == eltype(x) @test eltype(J2) == eltype(x) @test eltype(J3) == eltype(x) - @test eltype(J4) == eltype(x) + + if !(backend isa AutoEnzyme) + @test J4 ≈ J_true + @test eltype(J4) == eltype(x) + end end return nothing end @@ -296,18 +306,18 @@ end backends = [ AutoForwardDiff(; chunksize=1), # - # AutoEnzyme(Enzyme.Forward), + AutoEnzyme(Enzyme.Forward), AutoZygote(), ] linear_solver_candidates = ( \, # - ID.DefaultLinearSolver(), + # ID.DefaultLinearSolver(), ) conditions_backend_candidates = ( nothing, # - AutoForwardDiff(; chunksize=1), + # AutoForwardDiff(; chunksize=1), # AutoEnzyme(Enzyme.Forward), ); @@ -317,12 +327,11 @@ x_candidates = ( ## Test loop -@testset verbose = true "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" for ( - linear_solver, conditions_backend, x +@testset verbose = true "$(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" for ( + x, linear_solver, conditions_backend ) in Iterators.product( - linear_solver_candidates, conditions_backend_candidates, x_candidates + x_candidates, linear_solver_candidates, conditions_backend_candidates ) - @info "$linear_solver - $(typeof(conditions_backend)) - $(typeof(x))" test_implicit( backends, x; From 98d336e5bdec939e50a2be14e77ec35f96f7fa46 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:32:59 +0200 Subject: [PATCH 6/8] Update docs --- Project.toml | 2 +- docs/make.jl | 2 +- docs/src/faq.md | 25 +- examples/0_intro.jl | 23 +- examples/1_basic.jl | 39 ++- examples/2_advanced.jl | 13 +- examples/3_tricks.jl | 9 +- ext/ImplicitDifferentiationEnzymeExt.jl | 3 +- src/implicit_function.jl | 17 +- src/operators.jl | 52 +++- test/runtests.jl | 21 +- test/systematic.jl | 309 +----------------------- test/utils.jl | 298 +++++++++++++++++++++++ 13 files changed, 410 insertions(+), 403 deletions(-) create mode 100644 test/utils.jl diff --git a/Project.toml b/Project.toml index 8a20f6b..d1107d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.5.2" +version = "0.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/make.jl b/docs/make.jl index de7c136..b08b857 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -36,12 +36,12 @@ end pages = [ "Home" => "index.md", - "faq.md", "Examples" => [ joinpath("examples", file) for file in sort(readdir(EXAMPLES_DIR_MD)) if endswith(file, ".md") ], "api.md", + "faq.md", ] makedocs(; diff --git a/docs/src/faq.md b/docs/src/faq.md index 4451f57..6f52441 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -1,21 +1,18 @@ -# Frequently Asked Questions +# FAQ ## Supported autodiff backends To differentiate an `ImplicitFunction`, the following backends are supported. | Backend | Forward mode | Reverse mode | -| ---------------------------------------------------------------------- | ------------ | ------------ | +| :--------------------------------------------------------------------- | :----------- | :----------- | | [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | yes | - | -| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | soon | yes | -| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | someday | someday | +| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | no | yes | +| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | yes | soon | -By default, the conditions are differentiated with the same backend as the `ImplicitFunction` that contains them. -However, this can be switched to any backend compatible with [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) (i.e. a subtype of `AD.AbstractBackend`). -You can specify it with the `conditions_backend` keyword argument when constructing an `ImplicitFunction`. - -!!! warning "Warning" - At the moment, `conditions_backend` can only be `nothing` or `AD.ForwardDiffBackend()`. We are investigating why the other backends fail. +By default, the conditions are differentiated using the same "outer" backend that is trying to differentiate the `ImplicitFunction`. +However, this can be switched to any other "inner" backend compatible with [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) (i.e. a subtype of `ADTypes.AbstractADType`). +You can override the default with the `conditions_x_backend` and `conditions_y_backend` keyword arguments when constructing an `ImplicitFunction`. ## Input and output types @@ -37,7 +34,7 @@ Or better yet, wrap it in a static vector: `SVector(val)`. Sparse arrays are not officially supported and might give incorrect values or `NaN`s! With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658). -With Zygote.jl it appears to work, but this functionality is considered experimental and might evolve. +That is why we do not test behavior for sparse inputs. ## Number of inputs and outputs @@ -46,9 +43,9 @@ What can you do to handle multiple inputs or outputs? Well, it depends whether you want their derivatives or not. | | Derivatives needed | Derivatives not needed | -| -------------------- | --------------------------------------- | --------------------------------------- | +| :------------------- | :-------------------------------------- | :-------------------------------------- | | **Multiple inputs** | Make `x` a `ComponentVector` | Supply `args` and `kwargs` to `forward` | -| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct | +| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct `z` | We now detail each of these options. @@ -100,7 +97,7 @@ A more advanced application is given by [DifferentiableFrankWolfe.jl](https://gi ### Writing conditions We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient. -Otherwise, you will need to make sure that nested autodiff works well in your case. +Otherwise, you will need to make sure that nested autodiff works well in your case (i.e. that the "outer" backend can differentiate through the "inner" backend). For instance, if you're differentiating your implicit function (and your conditions) in reverse mode with Zygote.jl, you may want to use ForwardDiff.jl mode to compute gradients inside the conditions. ### Dealing with constraints diff --git a/examples/0_intro.jl b/examples/0_intro.jl index 262738d..051f381 100644 --- a/examples/0_intro.jl +++ b/examples/0_intro.jl @@ -8,12 +8,9 @@ using ForwardDiff using ImplicitDifferentiation using JET #src using LinearAlgebra -using Random using Test #src using Zygote -Random.seed!(63); - # ## Why do we bother? #= @@ -23,17 +20,17 @@ While they are very generic, there are simple language constructs that they cann function badsqrt(x::AbstractArray) a = [0.0] - a[1] = first(x) + a[1] = x[1] return sqrt.(x) -end +end; #= This is essentially the componentwise square root function but with an additional twist: `a::Vector{Float64}` is created internally, and its only element is replaced with the first element of `x`. We can check that it does what it's supposed to do. =# -x = rand(2) -badsqrt(x) ≈ sqrt.(x) +x = [4.0, 9.0] +badsqrt(x) @test badsqrt(x) ≈ sqrt.(x) #src #= @@ -78,9 +75,9 @@ x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m ``` whose output is defined by conditions ```math -F(x,y(x)) = 0 \in \mathbb{R}^m +c(x,y(x)) = 0 \in \mathbb{R}^m ``` -We represent it using a type called `ImplicitFunction`, which you will see in action shortly. +We represent it using a type called [`ImplicitFunction`](@ref), which you will see in action shortly. =# #= @@ -89,7 +86,7 @@ It returns the actual output $y(x)$ of the function, and can be thought of as a Importantly, this Julia callable _doesn't need to be differentiable by automatic differentiation packages but the underlying function still needs to be mathematically differentiable_. =# -forward(x) = badsqrt(x) +forward(x) = badsqrt(x); #= Then we define `conditions` $c(x, y) = 0$ that the output $y(x)$ is supposed to satisfy. @@ -101,7 +98,7 @@ Here the conditions are very obvious: the square of the square root should be eq function conditions(x, y) c = y .^ 2 .- x return c -end +end; #= Finally, we construct a wrapper `implicit` around the previous objects. @@ -113,10 +110,10 @@ implicit = ImplicitFunction(forward, conditions) #= What does this wrapper do? -When we call it as a function, it just falls back on `first ∘ implicit.forward`, so unsurprisingly we get the first output $y(x)$. +When we call it as a function, it just falls back on `implicit.forward`, so unsurprisingly we get the output $y(x)$. =# -implicit(x) ≈ sqrt.(x) +implicit(x) @test implicit(x) ≈ sqrt.(x) #src #= diff --git a/examples/1_basic.jl b/examples/1_basic.jl index 6928660..c713d17 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -5,6 +5,8 @@ We show how to differentiate through very common routines: - an unconstrained optimization problem - a nonlinear system of equations - a fixed point iteration + +Note that some packages from the [SciML](https://sciml.ai/) ecosystem provide a similar implicit differentiation mechanism. =# using ForwardDiff @@ -12,18 +14,15 @@ using ImplicitDifferentiation using LinearAlgebra using NLsolve using Optim -using Random using Test #src using Zygote -Random.seed!(63); - #= In all three cases, we will use the square root as our forward mapping, but expressed in three different ways. Here's our heroic test vector: =# -x = rand(2); +x = [4.0, 9.0]; #= Since we already know the mathematical expression of the Jacobian, we will be able to compare it with our numerical results. @@ -40,7 +39,7 @@ y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) ``` The optimality conditions are given by gradient stationarity: ```math -\nabla_2 f(x, y) = 0 +c(x, y) = \nabla_2 f(x, y) = 0 ``` =# @@ -58,7 +57,7 @@ function forward_optim(x; method) y0 = ones(eltype(x), size(x)) result = optimize(f, y0, method) return Optim.minimizer(result) -end +end; #= Even though they are defined as a gradient, it is better to provide optimality conditions explicitly: that way we avoid nesting autodiff calls. By default, the conditions should accept two arguments as input. @@ -68,7 +67,7 @@ The forward mapping and the conditions should accept the same set of keyword arg function conditions_optim(x, y; method) ∇₂f = @. 4 * (y^2 - x) * y return ∇₂f -end +end; #= We now have all the ingredients to construct our implicit function. @@ -112,18 +111,18 @@ end #= Next, we show how to differentiate through the solution of a nonlinear system of equations: ```math -\text{find} \quad y(x) \quad \text{such that} \quad F(x, y(x)) = 0 +\text{find} \quad y(x) \quad \text{such that} \quad c(x, y(x)) = 0 ``` The optimality conditions are pretty obvious: ```math -F(x, y) = 0 +c(x, y) = 0 ``` =# #= To make verification easy, we solve the following system: ```math -F(x, y) = y \odot y - x = 0 +c(x, y) = y \odot y - x = 0 ``` In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). =# @@ -134,14 +133,14 @@ function forward_nlsolve(x; method) initial_y .= 1 result = nlsolve(F!, initial_y; method) return result.zero -end +end; #- function conditions_nlsolve(x, y; method) c = y .^ 2 .- x return c -end +end; #- @@ -179,18 +178,18 @@ end #= Finally, we show how to differentiate through the limit of a fixed point iteration: ```math -y \longmapsto T(x, y) +y \longmapsto g(x, y) ``` The optimality conditions are pretty obvious: ```math -y = T(x, y) +c(x, y) = g(x, y) - y = 0 ``` =# #= To make verification easy, we consider [Heron's method](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method): ```math -T(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) +g(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) ``` In this case, the fixed point algorithm boils down to the componentwise square root function, but we implement it manually. =# @@ -198,17 +197,17 @@ In this case, the fixed point algorithm boils down to the componentwise square r function forward_fixedpoint(x; iterations) y = ones(eltype(x), size(x)) for _ in 1:iterations - y .= 0.5 .* (y .+ x ./ y) + y .= (y .+ x ./ y) ./ 2 end return y -end +end; #- function conditions_fixedpoint(x, y; iterations) - T = 0.5 .* (y .+ x ./ y) - return T .- y -end + g = (y .+ x ./ y) ./ 2 + return g .- y +end; #- diff --git a/examples/2_advanced.jl b/examples/2_advanced.jl index d1d48ff..2309c27 100644 --- a/examples/2_advanced.jl +++ b/examples/2_advanced.jl @@ -9,12 +9,9 @@ using ForwardDiff using ImplicitDifferentiation using LinearAlgebra using Optim -using Random using Test #src using Zygote -Random.seed!(63); - # ## Constrained optimization #= @@ -46,19 +43,17 @@ function forward_cstr_optim(x) res = optimize(f, lower, upper, y0, Fminbox(GradientDescent())) y = Optim.minimizer(res) return y -end +end; #- -function proj_hypercube(p) - return max.(0, min.(1, p)) -end +proj_hypercube(p) = max.(0, min.(1, p)) function conditions_cstr_optim(x, y) ∇₂f = @. 4 * (y^2 - x) * y η = 0.1 return y .- proj_hypercube(y .- η .* ∇₂f) -end +end; # We now have all the ingredients to construct our implicit function. @@ -66,7 +61,7 @@ implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim # And indeed, it behaves as it should when we call it: -x = rand(2) .+ [0, 1] +x = [0.3, 1.4] #= The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$. diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index 13c1026..cc1c524 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -9,12 +9,9 @@ using ForwardDiff using ImplicitDifferentiation using Krylov using LinearAlgebra -using Random using Test #src using Zygote -Random.seed!(63); - # ## ComponentArrays # For when you need derivatives with respect to multiple inputs or outputs. @@ -55,7 +52,7 @@ Krylov.ktypeof(::ComponentVector{T,V}) where {T,V} = V # Now we're good to go. -a, b, m = rand(2), rand(3), 7.0 +a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0 x = ComponentVector(; a=a, b=b, m=m) implicit_components(x) @@ -83,7 +80,7 @@ end; # For when you need an additional output but don't care about its derivative. function forward_byproduct(x) - z = rand((2, 2)) # "randomized" choice + z = 2 # "randomized" choice y = x .^ (1 / z) return y, z end @@ -99,7 +96,7 @@ implicit_byproduct = ImplicitFunction(forward_byproduct, conditions_byproduct); # But this time the return value is a tuple `(y, z)` -x = rand(3) +x = [4.0, 9.0] implicit_byproduct(x) # And it works with both ForwardDiff.jl and Zygote.jl diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index 2db9426..eaeec45 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -11,11 +11,10 @@ function EnzymeRules.forward( func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, func_args::Vararg{Const,P}, ) where {T,N,P} - @info "My BatchDuplicated rule is used" RT typeof(func_x) typeof(func_args) implicit = func.val - args = map(a -> a.val, func_args) x = func_x.val dx = func_x.dval + args = map(a -> a.val, func_args) y_or_yz = implicit(x, args...) y = output(y_or_yz) diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 279f768..6a91012 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -32,20 +32,21 @@ This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = - `linear_solver`: a callable with two methods: - `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b` - `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B` -- `conditions_x_backend`: either `nothing` or a subtype of `ADTypes.AbstractADType`, defines how the conditions will be differentiated with respect to the first argument `x` +- `conditions_x_backend`: either `nothing` or an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl), defines how the conditions will be differentiated with respect to the first argument `x` - `conditions_y_backend`: same for the second argument `y` There are two possible signatures for `forward` and `conditions`, which must be consistent with one another: -1. Standard: `forward(x, args...; kwargs...) = y` and `conditions(x, y, args...; kwargs...) = c` -2. Byproduct: `forward(x, args...; kwargs...) = (y, z)` and `conditions(x, y, z, args...; kwargs...) = c`. +| standard | byproduct | +|:---|:---| +| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` | +| `forward(x, args...; kwargs...) = (y, z)` | `conditions(x, y, z, args...; kwargs...) = c` | -In both cases, `x`, `y` and `c` must be vectors, with `length(y) = length(c)`. +In both cases, `x`, `y` and `c` must be `AbstractVector`s, with `length(y) = length(c)`. In the second case, the byproduct `z` can be an arbitrary object generated by `forward`. The positional arguments `args...` and keyword arguments `kwargs...` must be the same for both `forward` and `conditions`. -!!! warning "Warning" - The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes. +The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes. """ @kwdef struct ImplicitFunction{ F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType} @@ -73,9 +74,9 @@ end """ (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...) -Return `implicit.forward(x, args...; kwargs...)`, which can be either a vector `y` or a tuple `(y, z)`. +Return `implicit.forward(x, args...; kwargs...)`, which can be either an `AbstractVector` `y` or a tuple `(y, z)`. -This call is differentiable (except for `z`). +This call makes `y` differentiable with respect to `x`. """ function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) y_or_yz = implicit.forward(x, args...; kwargs...) diff --git a/src/operators.jl b/src/operators.jl index ecbe375..bb3f617 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -71,24 +71,40 @@ end ## Lazy operators -struct PushforwardOperator!{F,B,X,E} +struct PushforwardOperator!{F,B,X,E,R} f::F backend::B x::X extras::E + res_backup::R end -function (po::PushforwardOperator!)(res, v) - res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) +function (po::PushforwardOperator!)(res, v, α, β) + if iszero(β) + res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + res .= α .* res + else + po.res_backup .= res + res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + res .= α .* res .+ β .* po.res_backup + end return res end -struct PullbackOperator!{PB} +struct PullbackOperator!{PB,R} pullbackfunc!!::PB + res_backup::R end -function (po::PullbackOperator!)(res, v) - res .= po.pullbackfunc!!(res, v) +function (po::PullbackOperator!)(res, v, α, β) + if iszero(β) + res .= po.pullbackfunc!!(res, v) + res .= α .* res + else + po.res_backup .= res + res .= po.pullbackfunc!!(res, v) + res .= α .* res .+ β .+ po.res_backup + end return res end @@ -107,7 +123,7 @@ function build_A( cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) J = jacobian(cond_y, back_y, y) - A = lu(J) + A = factorize(J) else extras = prepare_pushforward(cond_y, back_y, y) A = LinearOperator( @@ -116,7 +132,7 @@ function build_A( m, false, false, - PushforwardOperator!(cond_y, back_y, y, extras), + PushforwardOperator!(cond_y, back_y, y, extras, similar(y)), typeof(y), ) end @@ -138,12 +154,18 @@ function build_Aᵀ( cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) if linear_solver isa typeof(\) Jᵀ = transpose(jacobian(cond_y, back_y, y)) - Aᵀ = lu(Jᵀ) + Aᵀ = factorize(Jᵀ) else extras = prepare_pullback(cond_y, back_y, y) _, pullbackfunc!! = value_and_pullback!!_split(cond_y, back_y, y, extras) Aᵀ = LinearOperator( - eltype(y), m, m, false, false, PullbackOperator!(pullbackfunc!!), typeof(y) + eltype(y), + m, + m, + false, + false, + PullbackOperator!(pullbackfunc!!, similar(y)), + typeof(y), ) end return Aᵀ @@ -172,7 +194,7 @@ function build_B( n, false, false, - (res, v) -> res .= pushforward!!(cond_x, res, back_x, x, v, extras), + PushforwardOperator!(cond_x, back_x, x, extras, similar(y)), typeof(x), ) end @@ -198,7 +220,13 @@ function build_Bᵀ( extras = prepare_pullback(cond_x, back_x, x) _, pullbackfunc!! = value_and_pullback!!_split(cond_x, back_x, x, extras) Bᵀ = LinearOperator( - eltype(y), n, m, false, false, PullbackOperator!(pullbackfunc!!), typeof(x) + eltype(y), + n, + m, + false, + false, + PullbackOperator!(pullbackfunc!!, similar(y)), + typeof(x), ) end return Bᵀ diff --git a/test/runtests.jl b/test/runtests.jl index 3a6b739..338eb57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,19 +14,6 @@ DocMeta.setdocmeta!( ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true ) -function markdown_title(path) - title = "?" - open(path, "r") do file - for line in eachline(file) - if startswith(line, '#') - title = strip(line, [' ', '#']) - break - end - end - end - return title -end - EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") ## Test sets @@ -49,11 +36,9 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") @testset verbose = true "Examples" begin @info "Example tests" for file in readdir(EXAMPLES_DIR_JL) - path = joinpath(EXAMPLES_DIR_JL, file) - title = markdown_title(path) - @info "$title" - @testset verbose = true "$title" begin - include(path) + @info "$file" + @testset "$file" begin + include(joinpath(EXAMPLES_DIR_JL, file)) end end end diff --git a/test/systematic.jl b/test/systematic.jl index 5c4fe21..29ad82e 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,306 +1,12 @@ using ADTypes -using ChainRulesCore -using ChainRulesTestUtils -using DifferentiationInterface: DifferentiationInterface using Enzyme: Enzyme using ForwardDiff: ForwardDiff -import ImplicitDifferentiation as ID -using ImplicitDifferentiation: ImplicitFunction -using JET -using Krylov -using LinearAlgebra -using Random using SparseArrays using StaticArrays using Test using Zygote: Zygote, ZygoteRuleConfig -## - -Random.seed!(63); - -function identity_break_autodiff(x::X)::X where {R,X<:AbstractVector{R}} - float(first(x)) # break ForwardDiff - (Vector{R}(undef, 1))[1] = first(x) # break Zygote - result = try - throw(copy(x)) - catch y - y - end - return result -end - -function mysqrt(x::AbstractVector) - return identity_break_autodiff(sqrt.(abs.(x))) -end - -## Various signatures - -function make_implicit_sqrt(; kwargs...) - forward(x) = mysqrt(x) - conditions(x, y) = abs2.(y) .- abs.(x) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_byproduct(; kwargs...) - forward(x) = one(eltype(x)) .* mysqrt(x), one(eltype(x)) - conditions(x, y, z) = abs2.(y ./ z) .- abs.(x) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_args(; kwargs...) - forward(x, p) = p .* mysqrt(x) - conditions(x, y, p) = abs2.(y ./ p) .- abs.(x) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_kwargs(; kwargs...) - forward(x; p) = p .* mysqrt(x) - conditions(x, y; p) = abs2.(y ./ p) .- abs.(x) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -## Low level tests - -function test_coherent_array_type(a, b) - @test eltype(a) == eltype(b) - if a isa Array - @test b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) - elseif a isa StaticArray - @test b isa StaticArray || b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) - elseif a isa AbstractSparseArray - @test b isa AbstractSparseArray || - b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) - else - error("New array type") - end -end - -function test_implicit_call(x::AbstractVector{T}; type_stability=false, kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(x) - y1 = imf1(x) - y2, z2 = imf2(x) - y3 = imf3(x, 1) - y4 = imf4(x; p=1) - - @testset "Primal value" begin - @test y1 ≈ y_true - @test y2 ≈ y_true - @test y3 ≈ y_true - @test y4 ≈ y_true - @test z2 ≈ 1 - end - - @testset "Array type" begin - test_coherent_array_type(x, y1) - test_coherent_array_type(x, y2) - test_coherent_array_type(x, y3) - test_coherent_array_type(x, y4) - end - - if type_stability - @testset "Type stability" begin - @test_opt target_modules = (ID,) imf1(x) - @test_opt target_modules = (ID,) imf2(x) - @test_opt target_modules = (ID,) imf3(x, 1) - @test_opt target_modules = (ID,) imf4(x; p=1) - end - end -end - -tag(::AbstractVector{<:ForwardDiff.Dual{T}}) where {T} = T - -function test_implicit_duals( - x::AbstractVector{T}; type_stability=false, kwargs... -) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(x) - dx = similar(x) - dx .= 2 * one(T) - x_and_dx = ForwardDiff.Dual.(x, dx) - - y_and_dy1 = imf1(x_and_dx) - y_and_dy2, z2 = imf2(x_and_dx) - y_and_dy3 = imf3(x_and_dx, 1) - y_and_dy4 = imf4(x_and_dx; p=1) - - @testset "Dual numbers" begin - @test ForwardDiff.value.(y_and_dy1) ≈ y_true - @test ForwardDiff.value.(y_and_dy2) ≈ y_true - @test ForwardDiff.value.(y_and_dy3) ≈ y_true - @test ForwardDiff.value.(y_and_dy4) ≈ y_true - @test ForwardDiff.extract_derivative(tag(y_and_dy1), y_and_dy1) ≈ - 2 .* inv.(2 .* sqrt.(x)) - @test ForwardDiff.extract_derivative(tag(y_and_dy2), y_and_dy2) ≈ - 2 .* inv.(2 .* sqrt.(x)) - @test ForwardDiff.extract_derivative(tag(y_and_dy3), y_and_dy3) ≈ - 2 .* inv.(2 .* sqrt.(x)) - @test ForwardDiff.extract_derivative(tag(y_and_dy4), y_and_dy4) ≈ - 2 .* inv.(2 .* sqrt.(x)) - @test z2 ≈ 1 - end - - @testset "Array types" begin - test_coherent_array_type(x, ForwardDiff.value.(y_and_dy1)) - test_coherent_array_type(x, ForwardDiff.value.(y_and_dy2)) - test_coherent_array_type(x, ForwardDiff.value.(y_and_dy3)) - test_coherent_array_type(x, ForwardDiff.value.(y_and_dy4)) - end - - if type_stability - @testset "Type stability" begin - @test_opt target_modules = (ID,) imf1(x_and_dx) - @test_opt target_modules = (ID,) imf2(x_and_dx) - @test_opt target_modules = (ID,) imf3(x_and_dx, 1) - @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) - end - end -end - -function test_implicit_rrule( - rc, x::AbstractVector{T}; type_stability=false, kwargs... -) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(x) - dy = similar(y_true) - dy .= one(eltype(y_true)) - dz = nothing - - y1, pb1 = rrule(rc, imf1, x) - (y2, z2), pb2 = rrule(rc, imf2, x) - y3, pb3 = rrule(rc, imf3, x, 1) - y4, pb4 = rrule(rc, imf4, x; p=1) - - dimf1, dx1 = pb1(dy) - dimf2, dx2 = pb2((dy, dz)) - dimf3, dx3, dp3 = pb3(dy) - dimf4, dx4 = pb4(dy) - - @testset "Pullbacks" begin - @test y1 ≈ y_true - @test y2 ≈ y_true - @test y3 ≈ y_true - @test y4 ≈ y_true - @test z2 ≈ 1 - - @test dimf1 isa NoTangent - @test dimf2 isa NoTangent - @test dimf3 isa NoTangent - @test dimf4 isa NoTangent - - @test size(dx1) == size(x) - @test size(dx2) == size(x) - @test size(dx3) == size(x) - @test size(dx4) == size(x) - - @test dp3 isa ChainRulesCore.NotImplemented - end - - @testset "Array type" begin - test_coherent_array_type(x, y1) - test_coherent_array_type(x, y2) - test_coherent_array_type(x, y3) - test_coherent_array_type(x, y4) - - test_coherent_array_type(x, dx1) - test_coherent_array_type(x, dx2) - test_coherent_array_type(x, dx3) - test_coherent_array_type(x, dx4) - end - - @testset "ChainRulesTestUtils" begin - test_rrule(rc, imf1, x; atol=1e-2, check_inferred=false) - test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0), check_inferred=false) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112 - test_rrule(rc, imf3, x, 1; atol=1e-2, check_inferred=false) - test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,), check_inferred=false) - end - - if type_stability - @testset "Type stability" begin - @test_opt target_modules = (ID,) rrule(rc, imf1, x) - @test_opt target_modules = (ID,) rrule(rc, imf2, x) - @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) - @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) - - @test_opt target_modules = (ID,) pb1(dy) - @test_opt target_modules = (ID,) pb2((dy, dz)) - @test_opt target_modules = (ID,) pb3(dy) - @test_opt target_modules = (ID,) pb4(dy) - end - end -end - -## High-level tests per backend - -function test_implicit_backend( - backend::ADTypes.AbstractADType, x::AbstractVector{T}; type_stability=false, kwargs... -) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - J1 = DifferentiationInterface.jacobian(imf1, backend, x) - J2 = DifferentiationInterface.jacobian(first ∘ imf2, backend, x) - J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, one(eltype(x))), backend, x) - - J4 = if !(backend isa AutoEnzyme) - DifferentiationInterface.jacobian(_x -> imf4(_x; p=one(eltype(x))), backend, x) - else - nothing - end - - J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) - - @testset "Exact Jacobian" begin - @test J1 ≈ J_true - @test J2 ≈ J_true - @test J3 ≈ J_true - - @test eltype(J1) == eltype(x) - @test eltype(J2) == eltype(x) - @test eltype(J3) == eltype(x) - - if !(backend isa AutoEnzyme) - @test J4 ≈ J_true - @test eltype(J4) == eltype(x) - end - end - return nothing -end - -function test_implicit(backends, x; type_stability=false, kwargs...) - @testset verbose = true "Call" begin - test_implicit_call(x; kwargs...) - end - @testset verbose = true "Duals" begin - test_implicit_duals(x; kwargs...) - end - @testset verbose = true "ChainRule" begin - test_implicit_rrule(ZygoteRuleConfig(), x; kwargs...) - end - @testset "$backend" for backend in backends - test_implicit_backend(backend, x; kwargs...) - end - return nothing -end +include("utils.jl") ## Parameter combinations @@ -312,26 +18,31 @@ backends = [ linear_solver_candidates = ( \, # - # ID.DefaultLinearSolver(), + ID.DefaultLinearSolver(), ) conditions_backend_candidates = ( nothing, # - # AutoForwardDiff(; chunksize=1), + AutoForwardDiff(; chunksize=1), # AutoEnzyme(Enzyme.Forward), ); x_candidates = ( - rand(Float32, 2), # + Float32[3, 4], # + MVector{2}(Float32[3, 4]), # ); ## Test loop -@testset verbose = true "$(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" for ( +@testset verbose = false "$(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" for ( x, linear_solver, conditions_backend ) in Iterators.product( x_candidates, linear_solver_candidates, conditions_backend_candidates ) + if x isa StaticArray && (linear_solver != \) + continue + end + @info "Testing $(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" test_implicit( backends, x; diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..373a503 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,298 @@ +using ADTypes +using ChainRulesCore +using ChainRulesTestUtils +using DifferentiationInterface: DifferentiationInterface +using ForwardDiff: ForwardDiff +import ImplicitDifferentiation as ID +using ImplicitDifferentiation: ImplicitFunction +using JET +using LinearAlgebra +using SparseArrays +using StaticArrays +using Test +using Zygote: Zygote, ZygoteRuleConfig + +## + +function identity_break_autodiff(x::X)::X where {R,X<:AbstractVector{R}} + float(first(x)) # break ForwardDiff + (Vector{R}(undef, 1))[1] = first(x) # break Zygote + result = try + throw(copy(x)) + catch y + y + end + return result +end + +function mysqrt(x::AbstractVector) + return identity_break_autodiff(sqrt.(abs.(x))) +end + +## Various signatures + +function make_implicit_sqrt(; kwargs...) + forward(x) = mysqrt(x) + conditions(x, y) = abs2.(y) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_byproduct(; kwargs...) + forward(x) = one(eltype(x)) .* mysqrt(x), one(eltype(x)) + conditions(x, y, z) = abs2.(y ./ z) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_args(; kwargs...) + forward(x, p) = p .* mysqrt(x) + conditions(x, y, p) = abs2.(y ./ p) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_kwargs(; kwargs...) + forward(x; p) = p .* mysqrt(x) + conditions(x, y; p) = abs2.(y ./ p) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +## Low level tests + +function test_coherent_array_type(a, b) + @test eltype(a) == eltype(b) + if a isa Array + @test b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) + elseif a isa StaticArray + @test b isa StaticArray || b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) + elseif a isa AbstractSparseArray + @test b isa AbstractSparseArray || + b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) + else + error("New array type") + end +end + +function test_implicit_call(x::AbstractVector{T}; type_stability=false, kwargs...) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + y1 = imf1(x) + y2, z2 = imf2(x) + y3 = imf3(x, 1) + y4 = imf4(x; p=1) + + @testset "Primal value" begin + @test y1 ≈ y_true + @test y2 ≈ y_true + @test y3 ≈ y_true + @test y4 ≈ y_true + @test z2 ≈ 1 + end + + @testset "Array type" begin + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x) + @test_opt target_modules = (ID,) imf2(x) + @test_opt target_modules = (ID,) imf3(x, 1) + @test_opt target_modules = (ID,) imf4(x; p=1) + end + end +end + +tag(::AbstractVector{<:ForwardDiff.Dual{T}}) where {T} = T + +function test_implicit_duals( + x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + dx = similar(x) + dx .= 2 * one(T) + x_and_dx = ForwardDiff.Dual.(x, dx) + + y_and_dy1 = imf1(x_and_dx) + y_and_dy2, z2 = imf2(x_and_dx) + y_and_dy3 = imf3(x_and_dx, 1) + y_and_dy4 = imf4(x_and_dx; p=1) + + @testset "Dual numbers" begin + @test ForwardDiff.value.(y_and_dy1) ≈ y_true + @test ForwardDiff.value.(y_and_dy2) ≈ y_true + @test ForwardDiff.value.(y_and_dy3) ≈ y_true + @test ForwardDiff.value.(y_and_dy4) ≈ y_true + @test ForwardDiff.extract_derivative(tag(y_and_dy1), y_and_dy1) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy2), y_and_dy2) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy3), y_and_dy3) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy4), y_and_dy4) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test z2 ≈ 1 + end + + @testset "Array type" begin + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy1)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy2)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy3)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy4)) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x_and_dx) + @test_opt target_modules = (ID,) imf2(x_and_dx) + @test_opt target_modules = (ID,) imf3(x_and_dx, 1) + @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) + end + end +end + +function test_implicit_rrule( + rc, x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + dy = similar(y_true) + dy .= one(eltype(y_true)) + dz = nothing + + y1, pb1 = rrule(rc, imf1, x) + (y2, z2), pb2 = rrule(rc, imf2, x) + y3, pb3 = rrule(rc, imf3, x, 1) + y4, pb4 = rrule(rc, imf4, x; p=1) + + dimf1, dx1 = pb1(dy) + dimf2, dx2 = pb2((dy, dz)) + dimf3, dx3, dp3 = pb3(dy) + dimf4, dx4 = pb4(dy) + + @testset "Pullbacks" begin + @test y1 ≈ y_true + @test y2 ≈ y_true + @test y3 ≈ y_true + @test y4 ≈ y_true + @test z2 ≈ 1 + + @test dimf1 isa NoTangent + @test dimf2 isa NoTangent + @test dimf3 isa NoTangent + @test dimf4 isa NoTangent + + @test size(dx1) == size(x) + @test size(dx2) == size(x) + @test size(dx3) == size(x) + @test size(dx4) == size(x) + + @test dp3 isa ChainRulesCore.NotImplemented + end + + @testset "Array type" begin + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) + + test_coherent_array_type(x, dx1) + test_coherent_array_type(x, dx2) + test_coherent_array_type(x, dx3) + test_coherent_array_type(x, dx4) + end + + @testset "ChainRulesTestUtils" begin + test_rrule(rc, imf1, x; atol=1e-2, check_inferred=false) + test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0), check_inferred=false) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112 + test_rrule(rc, imf3, x, 1; atol=1e-2, check_inferred=false) + test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,), check_inferred=false) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) rrule(rc, imf1, x) + @test_opt target_modules = (ID,) rrule(rc, imf2, x) + @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) + @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) + + @test_opt target_modules = (ID,) pb1(dy) + @test_opt target_modules = (ID,) pb2((dy, dz)) + @test_opt target_modules = (ID,) pb3(dy) + @test_opt target_modules = (ID,) pb4(dy) + end + end +end + +## High-level tests per backend + +function test_implicit_backend( + backend::ADTypes.AbstractADType, x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + J1 = DifferentiationInterface.jacobian(imf1, backend, x) + J2 = DifferentiationInterface.jacobian(first ∘ imf2, backend, x) + J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, one(eltype(x))), backend, x) + + J4 = if !(backend isa AutoEnzyme) + DifferentiationInterface.jacobian(_x -> imf4(_x; p=one(eltype(x))), backend, x) + else + nothing + end + + J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) + + @testset "Exact Jacobian" begin + @test J1 ≈ J_true + @test J2 ≈ J_true + @test J3 ≈ J_true + + @test eltype(J1) == eltype(x) + @test eltype(J2) == eltype(x) + @test eltype(J3) == eltype(x) + + if !(backend isa AutoEnzyme) + @test J4 ≈ J_true + @test eltype(J4) == eltype(x) + end + end + return nothing +end + +function test_implicit(backends, x; type_stability=false, kwargs...) + @testset "Call" begin + test_implicit_call(x; kwargs...) + end + @testset "Duals" begin + test_implicit_duals(x; kwargs...) + end + @testset "ChainRule" begin + test_implicit_rrule(ZygoteRuleConfig(), x; kwargs...) + end + @testset "$backend" for backend in backends + test_implicit_backend(backend, x; kwargs...) + end + return nothing +end From d71e3408d97ebe819b02c0ee49cfad9a560032db Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:52:45 +0200 Subject: [PATCH 7/8] Select tag and chunksize for AutoForwardDiff --- ext/ImplicitDifferentiationForwardDiffExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 376f64d..5537105 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -13,7 +13,7 @@ function (implicit::ImplicitFunction)( y_or_yz = implicit(x, args...; kwargs...) y = output(y_or_yz) - suggested_backend = AutoForwardDiff{1,Nothing}(nothing) + suggested_backend = AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(x))) A = build_A(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) B = build_B(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) From 4b023f3dbd5e5abee9ed2235e8289925e2bd4204 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 10 Apr 2024 13:57:16 +0200 Subject: [PATCH 8/8] Right chunksize --- ext/ImplicitDifferentiationForwardDiffExt.jl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 5537105..3337a92 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -13,9 +13,22 @@ function (implicit::ImplicitFunction)( y_or_yz = implicit(x, args...; kwargs...) y = output(y_or_yz) - suggested_backend = AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(x))) - A = build_A(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) - B = build_B(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + A = build_A( + implicit, + x, + y_or_yz, + args...; + suggested_backend=AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(y))), + kwargs..., + ) + B = build_B( + implicit, + x, + y_or_yz, + args...; + suggested_backend=AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(x))), + kwargs..., + ) dX = mapreduce(hcat, 1:N) do k partials.(x_and_dx, k)