From 7c0aa9971fcae1cc88ad74be8676dbdefea11515 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 26 Sep 2024 09:50:43 +0200 Subject: [PATCH] Adapt to new Enzyme and DifferentiationInterface (#156) --- Project.toml | 14 +- ext/ImplicitDifferentiationEnzymeExt.jl | 3 +- src/ImplicitDifferentiation.jl | 6 +- src/implicit_function.jl | 2 + src/operators.jl | 163 ++++++++++-------------- test/systematic.jl | 2 +- 6 files changed, 81 insertions(+), 109 deletions(-) diff --git a/Project.toml b/Project.toml index 1bae515..8a2f235 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.6.1" +version = "0.6.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -21,14 +21,14 @@ ImplicitDifferentiationEnzymeExt = "Enzyme" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" [compat] -ADTypes = "1.7.1" -ChainRulesCore = "1.23.0" -DifferentiationInterface = "0.5.12" -Enzyme = "0.11.20,0.12" +ADTypes = "1.9.0" +ChainRulesCore = "1.25.0" +DifferentiationInterface = "0.6.1" +Enzyme = "0.13.3" ForwardDiff = "0.10.36" -Krylov = "0.9.5" +Krylov = "0.9.6" LinearAlgebra = "1.10" -LinearOperators = "2.7.0" +LinearOperators = "2.8.0" julia = "1.10" [extras] diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl index ff9e29f..aeb2143 100644 --- a/ext/ImplicitDifferentiationEnzymeExt.jl +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -2,12 +2,13 @@ module ImplicitDifferentiationEnzymeExt using ADTypes using Enzyme -using Enzyme.EnzymeCore +using Enzyme.EnzymeRules using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output const FORWARD_BACKEND = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const) function EnzymeRules.forward( + config::EnzymeRules.FwdConfig, func::Const{<:ImplicitFunction}, RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}}, func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 9c7b426..efe2f34 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -9,14 +9,16 @@ module ImplicitDifferentiation using ADTypes: AbstractADType using DifferentiationInterface: + Constant, jacobian, prepare_pushforward_same_point, prepare_pullback_same_point, pullback!, - pushforward! + pushforward!, + unwrap using Krylov: block_gmres, gmres using LinearOperators: LinearOperator -using LinearAlgebra: factorize, lu +using LinearAlgebra: axpby!, factorize, lu include("implicit_function.jl") include("operators.jl") diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 534176f..6201e2a 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -177,9 +177,11 @@ end output(y::AbstractVector) = y byproduct(::AbstractVector) = error("No byproduct") +rest(::AbstractVector) = () output(yz::Tuple{<:Any,<:Any}) = yz[1] byproduct(yz::Tuple{<:Any,<:Any}) = yz[2] +rest(yz::Tuple) = (byproduct(yz),) output((y, z)) = y byproduct((y, z)) = z diff --git a/src/operators.jl b/src/operators.jl index 9c77d79..6843dec 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1,112 +1,75 @@ -## Partial conditions - -struct ConditionsXNoByproduct{C,Y,A,K} +struct ConditionsX{C,K} conditions::C - y::Y - args::A kwargs::K end -function (conditions_x_nobyproduct::ConditionsXNoByproduct)(x::AbstractVector) - (; conditions, y, args, kwargs) = conditions_x_nobyproduct - return conditions(x, y, args...; kwargs...) -end - -struct ConditionsYNoByproduct{C,X,A,K} +struct ConditionsY{C,K} conditions::C - x::X - args::A kwargs::K end -function (conditions_y_nobyproduct::ConditionsYNoByproduct)(y::AbstractVector) - (; conditions, x, args, kwargs) = conditions_y_nobyproduct - return conditions(x, y, args...; kwargs...) +function (cx::ConditionsX)(x, y, args...) + return cx.conditions(x, y, args...; cx.kwargs...) end -struct ConditionsXByproduct{C,Y,Z,A,K} - conditions::C - y::Y - z::Z - args::A - kwargs::K -end - -function (conditions_x_byproduct::ConditionsXByproduct)(x::AbstractVector) - (; conditions, y, z, args, kwargs) = conditions_x_byproduct - return conditions(x, y, z, args...; kwargs...) +function (cy::ConditionsY)(y, x, args...) # order switch + return cy.conditions(x, y, args...; cy.kwargs...) end -struct ConditionsYByproduct{C,X,Z,A,K} - conditions::C +struct PushforwardOperator!{F,P,B,X,C,R} + f::F + prep::P + backend::B x::X - z::Z - args::A - kwargs::K -end - -function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector) - (; conditions, x, z, args, kwargs) = conditions_y_byproduct - return conditions(x, y, z, args...; kwargs...) -end - -function ConditionsX(conditions, x, y_or_yz, args...; kwargs...) - y = output(y_or_yz) - if y_or_yz isa Tuple - z = byproduct(y_or_yz) - return ConditionsXByproduct(conditions, y, z, args, kwargs) - else - return ConditionsXNoByproduct(conditions, y, args, kwargs) - end -end - -function ConditionsY(conditions, x, y_or_yz, args...; kwargs...) - if y_or_yz isa Tuple - z = byproduct(y_or_yz) - return ConditionsYByproduct(conditions, x, z, args, kwargs) - else - return ConditionsYNoByproduct(conditions, x, args, kwargs) - end + contexts::C + res_backup::R end -## Lazy operators - -struct PushforwardOperator!{F,B,X,E,R} +struct PullbackOperator!{F,P,B,X,C,R} f::F + prep::P backend::B x::X - extras::E + contexts::C res_backup::R end +function PushforwardOperator!(f, prep, backend, x, contexts) + res_backup = similar(f(x, map(unwrap, contexts)...)) + return PushforwardOperator!(f, prep, backend, x, contexts, res_backup) +end + +function PullbackOperator!(f, prep, backend, x, contexts) + res_backup = similar(x) + return PullbackOperator!(f, prep, backend, x, contexts, res_backup) +end + function (po::PushforwardOperator!)(res, v, α, β) + (; f, backend, x, contexts, prep, res_backup) = po if iszero(β) - pushforward!(po.f, res, po.backend, po.x, v, po.extras) - res .= α .* res + pushforward!(f, (res,), prep, backend, x, (v,), contexts...) + if !isone(α) + res .*= α + end else - po.res_backup .= res - pushforward!(po.f, res, po.backend, po.x, v, po.extras) - res .= α .* res .+ β .* po.res_backup + copyto!(res_backup, res) + pushforward!(f, (res,), prep, backend, x, (v,), contexts...) + axpby!(β, res_backup, α, res) end return res end -struct PullbackOperator!{F,B,X,E,R} - f::F - backend::B - x::X - extras::E - res_backup::R -end - function (po::PullbackOperator!)(res, v, α, β) + (; f, backend, x, contexts, prep, res_backup) = po if iszero(β) - pullback!(po.f, res, po.backend, po.x, v, po.extras) - res .= α .* res + pullback!(f, (res,), prep, backend, x, (v,), contexts...) + if !isone(α) + res .*= α + end else - po.res_backup .= res - pullback!(po.f, res, po.backend, po.x, v, po.extras) - res .= α .* res .+ β .+ po.res_backup + copyto!(res_backup, res) + pullback!(f, (res,), prep, backend, x, (v,), contexts...) + axpby!(β, res_backup, α, res) end return res end @@ -119,24 +82,25 @@ function build_A( suggested_backend, kwargs..., ) where {lazy} - (; conditions, linear_solver, conditions_y_backend) = implicit + (; conditions, conditions_y_backend) = implicit y = output(y_or_yz) n, m = length(x), length(y) back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend - cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + cond_y = ConditionsY(conditions, kwargs) + contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...) if lazy - extras = prepare_pushforward_same_point(cond_y, back_y, y, zero(y)) + prep = prepare_pushforward_same_point(cond_y, back_y, y, (zero(y),), contexts...) A = LinearOperator( eltype(y), m, m, false, false, - PushforwardOperator!(cond_y, back_y, y, extras, similar(y)), + PushforwardOperator!(cond_y, prep, back_y, y, contexts), typeof(y), ) else - J = jacobian(cond_y, back_y, y) + J = jacobian(cond_y, back_y, y, contexts...) A = factorize(J) end return A @@ -150,24 +114,25 @@ function build_Aᵀ( suggested_backend, kwargs..., ) where {lazy} - (; conditions, linear_solver, conditions_y_backend) = implicit + (; conditions, conditions_y_backend) = implicit y = output(y_or_yz) n, m = length(x), length(y) back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend - cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + cond_y = ConditionsY(conditions, kwargs) + contexts = (Constant(x), map(Constant, rest(y_or_yz))..., map(Constant, args)...) if lazy - extras = prepare_pullback_same_point(cond_y, back_y, y, zero(y)) + prep = prepare_pullback_same_point(cond_y, back_y, y, (zero(y),), contexts...) Aᵀ = LinearOperator( eltype(y), m, m, false, false, - PullbackOperator!(cond_y, back_y, y, extras, similar(y)), + PullbackOperator!(cond_y, prep, back_y, y, contexts), typeof(y), ) else - Jᵀ = transpose(jacobian(cond_y, back_y, y)) + Jᵀ = transpose(jacobian(cond_y, back_y, y, contexts...)) Aᵀ = factorize(Jᵀ) end return Aᵀ @@ -181,24 +146,25 @@ function build_B( suggested_backend, kwargs..., ) where {lazy} - (; conditions, linear_solver, conditions_x_backend) = implicit + (; conditions, conditions_x_backend) = implicit y = output(y_or_yz) n, m = length(x), length(y) back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend - cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + cond_x = ConditionsX(conditions, kwargs) + contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...) if lazy - extras = prepare_pushforward_same_point(cond_x, back_x, x, zero(x)) + prep = prepare_pushforward_same_point(cond_x, back_x, x, (zero(x),), contexts...) B = LinearOperator( eltype(y), m, n, false, false, - PushforwardOperator!(cond_x, back_x, x, extras, similar(y)), + PushforwardOperator!(cond_x, prep, back_x, x, contexts), typeof(x), ) else - B = transpose(jacobian(cond_x, back_x, x)) + B = transpose(jacobian(cond_x, back_x, x, contexts...)) end return B end @@ -211,24 +177,25 @@ function build_Bᵀ( suggested_backend, kwargs..., ) where {lazy} - (; conditions, linear_solver, conditions_x_backend) = implicit + (; conditions, conditions_x_backend) = implicit y = output(y_or_yz) n, m = length(x), length(y) back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend - cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + cond_x = ConditionsX(conditions, kwargs) + contexts = (Constant(y), map(Constant, rest(y_or_yz))..., map(Constant, args)...) if lazy - extras = prepare_pullback_same_point(cond_x, back_x, x, zero(y)) + prep = prepare_pullback_same_point(cond_x, back_x, x, (zero(y),), contexts...) Bᵀ = LinearOperator( eltype(y), n, m, false, false, - PullbackOperator!(cond_x, back_x, x, extras, similar(x)), + PullbackOperator!(cond_x, prep, back_x, x, contexts), typeof(x), ) else - Bᵀ = transpose(jacobian(cond_x, back_x, x)) + Bᵀ = transpose(jacobian(cond_x, back_x, x, contexts...)) end return Bᵀ end diff --git a/test/systematic.jl b/test/systematic.jl index bff1afd..10b7e80 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -29,7 +29,7 @@ conditions_backend_candidates = ( x_candidates = ( Float32[3, 4], # - MVector{2}(Float32[3, 4]), # + # MVector{2}(Float32[3, 4]), # ); ## Test loop