diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl index 3d19b5fe4..dd2942687 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl @@ -13,7 +13,7 @@ using DifferentiationInterface: NoJacobianExtras, NoPullbackExtras, NoPushforwardExtras, - pick_chunksize + pick_batchsize using DocStringExtensions using Enzyme: Active, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl index 95ccbb5fd..e5d86f9a8 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl @@ -57,20 +57,20 @@ end ## Gradient -struct EnzymeForwardGradientExtras{C,O} <: GradientExtras +struct EnzymeForwardGradientExtras{B,O} <: GradientExtras shadow::O end -function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x) - C = pick_chunksize(length(x)) - shadow = chunkedonehot(x, Val(C)) - return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow) +function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode}, x) + B = pick_batchsize(backend, length(x)) + shadow = chunkedonehot(x, Val(B)) + return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow) end function DI.gradient( - f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} -) where {C} - grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) + f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} +) where {B} + grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return reshape(collect(grad_tup), size(x)) end @@ -81,38 +81,38 @@ function DI.value_and_gradient( end function DI.gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} -) where {C} - grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) + f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} +) where {B} + grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return copyto!(grad, grad_tup) end function DI.value_and_gradient!( - f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C} -) where {C} - grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) + f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B} +) where {B} + grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) return f(x), copyto!(grad, grad_tup) end ## Jacobian -struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras +struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras shadow::O end -function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x) - C = pick_chunksize(length(x)) - shadow = chunkedonehot(x, Val(C)) - return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow) +function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x) + B = pick_batchsize(backend, length(x)) + shadow = chunkedonehot(x, Val(B)) + return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow) end function DI.jacobian( f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x, - extras::EnzymeForwardOneArgJacobianExtras{C}, -) where {C} - jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow) + extras::EnzymeForwardOneArgJacobianExtras{B}, +) where {B} + jac_wrongshape = jacobian(forward_mode(backend), f, x, Val(B); shadow=extras.shadow) nx = length(x) ny = length(jac_wrongshape) ÷ length(x) return reshape(jac_wrongshape, ny, nx) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl index 5c520d435..5a6f03b4a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl @@ -160,13 +160,13 @@ end #= -struct EnzymeReverseOneArgJacobianExtras{C,N} end +struct EnzymeReverseOneArgJacobianExtras{B,N} end -function DI.prepare_jacobian(f, ::AutoReverseEnzyme, x) - C = pick_chunksize(length(x)) +function DI.prepare_jacobian(f, backend::AutoReverseEnzyme, x) + B = pick_batchsize(backend, length(x)) y = f(x) N = length(y) - return EnzymeReverseOneArgJacobianExtras{C,N}() + return EnzymeReverseOneArgJacobianExtras{B,N}() end function DI.jacobian( @@ -174,8 +174,8 @@ function DI.jacobian( backend::AutoReverseEnzyme, x::AbstractArray, ::EnzymeReverseOneArgJacobianExtras{C,N}, -) where {C,N} - jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val{N}(), Val{C}()) +) where {B,N} + jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(N), Val(B)) nx = length(x) ny = length(jac_wrongshape) ÷ length(x) jac_rightshape = reshape(jac_wrongshape, ny, nx) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index a0c545657..3273bfcaa 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -38,6 +38,14 @@ using LinearAlgebra: dot, mul! DI.check_available(::AutoForwardDiff) = true +function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C} + if isnothing(C) + return ForwardDiff.pickchunksize(dimension) + else + return min(dimension, C) + end +end + include("utils.jl") include("onearg.jl") include("twoarg.jl") diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 2ce4ba6cb..81be52b4f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -3,6 +3,7 @@ module DifferentiationInterfaceZygoteExt using ADTypes: AutoForwardDiff, AutoZygote import DifferentiationInterface as DI using DifferentiationInterface: + Batch, HVPExtras, NoGradientExtras, NoHessianExtras, @@ -103,20 +104,47 @@ struct ZygoteHVPExtras{G,PE} <: HVPExtras pushforward_extras::PE end -function DI.prepare_hvp(f, ::AutoZygote, x, v) +function DI.prepare_hvp(f, ::AutoZygote, x, dx) ∇f(x) = only(gradient(f, x)) - pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, v) + pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, dx) return ZygoteHVPExtras(∇f, pushforward_extras) end -function DI.hvp(f, ::AutoZygote, x, v, extras::ZygoteHVPExtras) +function DI.hvp(f, ::AutoZygote, x, dx, extras::ZygoteHVPExtras) @compat (; ∇f, pushforward_extras) = extras - return DI.pushforward(∇f, AutoForwardDiff(), x, v, pushforward_extras) + return DI.pushforward(∇f, AutoForwardDiff(), x, dx, pushforward_extras) end -function DI.hvp!(f, p, ::AutoZygote, x, v, extras::ZygoteHVPExtras) +function DI.hvp!(f, dg, ::AutoZygote, x, dx, extras::ZygoteHVPExtras) @compat (; ∇f, pushforward_extras) = extras - return DI.pushforward!(∇f, p, AutoForwardDiff(), x, v, pushforward_extras) + return DI.pushforward!(∇f, dg, AutoForwardDiff(), x, dx, pushforward_extras) +end + +struct ZygoteHVPBatchedExtras{G,PE} <: HVPExtras + ∇f::G + pushforward_batched_extras::PE +end + +function DI.prepare_hvp_batched(f, ::AutoZygote, x, dx::Batch) + ∇f(x) = only(gradient(f, x)) + pushforward_batched_extras = DI.prepare_pushforward_batched( + ∇f, AutoForwardDiff(), x, dx + ) + return ZygoteHVPBatchedExtras(∇f, pushforward_batched_extras) +end + +function DI.hvp_batched(f, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras) + @compat (; ∇f, pushforward_batched_extras) = extras + return DI.pushforward_batched(∇f, AutoForwardDiff(), x, dx, pushforward_batched_extras) +end + +function DI.hvp_batched!( + f, dg::Batch, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras +) + @compat (; ∇f, pushforward_batched_extras) = extras + return DI.pushforward_batched!( + ∇f, dg, AutoForwardDiff(), x, dx, pushforward_batched_extras + ) end ## Hessian diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 504171bbe..b36940432 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -50,8 +50,7 @@ include("second_order/second_order.jl") include("utils/traits.jl") include("utils/basis.jl") -include("utils/printing.jl") -include("utils/chunk.jl") +include("utils/batch.jl") include("utils/check.jl") include("utils/exceptions.jl") include("utils/maybe.jl") @@ -73,6 +72,9 @@ include("sparse/hessian.jl") include("misc/differentiate_with.jl") include("misc/sparsity_detector.jl") +include("misc/from_primitive.jl") + +include("utils/printing.jl") function __init__() @require_extensions diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 523c19fae..a7fc7666a 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -55,48 +55,62 @@ abstract type JacobianExtras <: Extras end struct NoJacobianExtras <: JacobianExtras end -struct PushforwardJacobianExtras{E<:PushforwardExtras} <: JacobianExtras - pushforward_extras::E +struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras + seeds::D + pushforward_batched_extras::E + y_example::Y end -struct PullbackJacobianExtras{E<:PullbackExtras} <: JacobianExtras - pullback_extras::E +struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras + seeds::D + pullback_batched_extras::E + y_example::Y end function prepare_jacobian(f::F, backend::AbstractADType, x) where {F} - return prepare_jacobian_aux(f, backend, x, pushforward_performance(backend)) + y = f(x) + return prepare_jacobian_aux((f,), backend, x, y, pushforward_performance(backend)) end function prepare_jacobian(f!::F, y, backend::AbstractADType, x) where {F} - return prepare_jacobian_aux(f!, y, backend, x, pushforward_performance(backend)) + return prepare_jacobian_aux((f!, y), backend, x, y, pushforward_performance(backend)) end -function prepare_jacobian_aux(f::F, backend, x, ::PushforwardFast) where {F} - dx = basis(backend, x, first(CartesianIndices(x))) - pushforward_extras = prepare_pushforward(f, backend, x, dx) - return PushforwardJacobianExtras(pushforward_extras) +function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) where {FY} + N = length(x) + B = pick_batchsize(backend, N) + seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)] + pushforward_batched_extras = prepare_pushforward_batched( + f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))) + ) + D = typeof(seeds) + E = typeof(pushforward_batched_extras) + Y = typeof(y) + return PushforwardJacobianExtras{B,D,E,Y}(seeds, pushforward_batched_extras, copy(y)) end -function prepare_jacobian_aux(f!::F, y, backend, x, ::PushforwardFast) where {F} - dx = basis(backend, x, first(CartesianIndices(x))) - pushforward_extras = prepare_pushforward(f!, y, backend, x, dx) - return PushforwardJacobianExtras(pushforward_extras) +function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY} + M = length(y) + B = pick_batchsize(backend, M) + seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)] + pullback_batched_extras = prepare_pullback_batched( + f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))) + ) + D = typeof(seeds) + E = typeof(pullback_batched_extras) + Y = typeof(y) + return PullbackJacobianExtras{B,D,E,Y}(seeds, pullback_batched_extras, copy(y)) end -function prepare_jacobian_aux(f::F, backend, x, ::PushforwardSlow) where {F} - y = f(x) - dy = basis(backend, y, first(CartesianIndices(y))) - pullback_extras = prepare_pullback(f, backend, x, dy) - return PullbackJacobianExtras(pullback_extras) -end +## One argument -function prepare_jacobian_aux(f!::F, y, backend, x, ::PushforwardSlow) where {F} - dy = basis(backend, y, first(CartesianIndices(y))) - pullback_extras = prepare_pullback(f!, y, backend, x, dy) - return PullbackJacobianExtras(pullback_extras) +function jacobian(f::F, backend::AbstractADType, x) where {F} + return jacobian(f, backend, x, prepare_jacobian(f, backend, x)) end -## One argument +function jacobian!(f::F, jac, backend::AbstractADType, x) where {F} + return jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) +end function value_and_jacobian(f::F, backend::AbstractADType, x) where {F} return value_and_jacobian(f, backend, x, prepare_jacobian(f, backend, x)) @@ -106,92 +120,36 @@ function value_and_jacobian!(f::F, jac, backend::AbstractADType, x) where {F} return value_and_jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) end -function jacobian(f::F, backend::AbstractADType, x) where {F} - return jacobian(f, backend, x, prepare_jacobian(f, backend, x)) -end - -function jacobian!(f::F, jac, backend::AbstractADType, x) where {F} - return jacobian!(f, jac, backend, x, prepare_jacobian(f, backend, x)) +function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F} + return jacobian_aux((f,), backend, x, extras) end -function value_and_jacobian( - f::F, backend, x::AbstractArray, extras::PushforwardJacobianExtras -) where {F} - y = f(x) # TODO: remove - pushforward_extras_same = prepare_pushforward_same_point( - f, - backend, - x, - basis(backend, x, first(CartesianIndices(x))), - extras.pushforward_extras, - ) - jac = stack(CartesianIndices(x); dims=2) do j - dx_j = basis(backend, x, j) - jac_col_j = pushforward(f, backend, x, dx_j, pushforward_extras_same) - vec(jac_col_j) - end - return y, jac +function jacobian!(f::F, jac, backend::AbstractADType, x, extras::JacobianExtras) where {F} + return jacobian_aux!((f,), jac, backend, x, extras) end function value_and_jacobian( - f::F, backend, x::AbstractArray, extras::PullbackJacobianExtras + f::F, backend::AbstractADType, x, extras::JacobianExtras ) where {F} - y = f(x) # TODO: remove - pullback_extras_same = prepare_pullback_same_point( - f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras - ) - jac = stack(CartesianIndices(y); dims=1) do i - dy_i = basis(backend, y, i) - jac_row_i = pullback(f, backend, x, dy_i, pullback_extras_same) - vec(jac_row_i) - end - return y, jac + return f(x), jacobian(f, backend, x, extras) end function value_and_jacobian!( - f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PushforwardJacobianExtras + f::F, jac, backend::AbstractADType, x, extras::JacobianExtras ) where {F} - y = f(x) # TODO: remove - pushforward_extras_same = prepare_pushforward_same_point( - f, - backend, - x, - basis(backend, x, first(CartesianIndices(x))), - extras.pushforward_extras, - ) - for (k, j) in enumerate(CartesianIndices(x)) - dx_j = basis(backend, x, j) - jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(f, jac_col_j, backend, x, dx_j, pushforward_extras_same) - end - return y, jac + return f(x), jacobian!(f, jac, backend, x, extras) end -function value_and_jacobian!( - f::F, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras -) where {F} - y = f(x) # TODO: remove - pullback_extras_same = prepare_pullback_same_point( - f, backend, x, basis(backend, y, first(CartesianIndices(y))), extras.pullback_extras - ) - for (k, i) in enumerate(CartesianIndices(y)) - dy_i = basis(backend, y, i) - jac_row_i = reshape(view(jac, k, :), size(x)) - pullback!(f, jac_row_i, backend, x, dy_i, pullback_extras_same) - end - return y, jac -end +## Two arguments -function jacobian(f::F, backend::AbstractADType, x, extras::JacobianExtras) where {F} - return value_and_jacobian(f, backend, x, extras)[2] +function jacobian(f!::F, y, backend::AbstractADType, x) where {F} + return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) end -function jacobian!(f::F, jac, backend::AbstractADType, x, extras::JacobianExtras) where {F} - return value_and_jacobian!(f, jac, backend, x, extras)[2] +function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} + return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) end -## Two arguments - function value_and_jacobian(f!::F, y, backend::AbstractADType, x) where {F} return value_and_jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) end @@ -200,105 +158,176 @@ function value_and_jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F return value_and_jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) end -function jacobian(f!::F, y, backend::AbstractADType, x) where {F} - return jacobian(f!, y, backend, x, prepare_jacobian(f!, y, backend, x)) +function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F} + return jacobian_aux((f!, y), backend, x, extras) end -function jacobian!(f!::F, y, jac, backend::AbstractADType, x) where {F} - return jacobian!(f!, y, jac, backend, x, prepare_jacobian(f!, y, backend, x)) +function jacobian!( + f!::F, y, jac, backend::AbstractADType, x, extras::JacobianExtras +) where {F} + return jacobian_aux!((f!, y), jac, backend, x, extras) end function value_and_jacobian( - f!::F, y, backend, x::AbstractArray, extras::PushforwardJacobianExtras + f!::F, y, backend::AbstractADType, x, extras::JacobianExtras +) where {F} + jac = jacobian(f!, y, backend, x, extras) + f!(y, x) + return y, jac +end + +function value_and_jacobian!( + f!::F, y, jac, backend::AbstractADType, x, extras::JacobianExtras ) where {F} - pushforward_extras_same = prepare_pushforward_same_point( - f!, - y, + jacobian!(f!, y, jac, backend, x, extras) + f!(y, x) + return y, jac +end + +## Common auxiliaries + +function jacobian_aux( + f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B} +) where {FY,B} + @compat (; seeds, pushforward_batched_extras, y_example) = extras + N = length(x) + + pushforward_batched_extras_same = prepare_pushforward_batched_same_point( + f_or_f!y..., backend, x, - basis(backend, x, first(CartesianIndices(x))), - extras.pushforward_extras, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pushforward_batched_extras, ) - jac = stack(CartesianIndices(x); dims=2) do j - dx_j = basis(backend, x, j) - jac_col_j = pushforward(f!, y, backend, x, dx_j, pushforward_extras_same) - vec(jac_col_j) + + jac_blocks = map(1:div(N, B, RoundUp)) do a + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % N] + end + dy_batch = pushforward_batched( + f_or_f!y..., + backend, + x, + Batch(dx_batch_elements), + pushforward_batched_extras_same, + ) + stack(vec, dy_batch.elements; dims=2) end - f!(y, x) # TODO: remove - return y, jac + + jac = reduce(hcat, jac_blocks) + if N < size(jac, 2) + jac = jac[:, 1:N] + end + return jac end -function value_and_jacobian( - f!::F, y, backend, x::AbstractArray, extras::PullbackJacobianExtras -) where {F} - pullback_extras_same = prepare_pullback_same_point( - f!, - y, +function jacobian_aux( + f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B} +) where {FY,B} + @compat (; seeds, pullback_batched_extras, y_example) = extras + M = length(y_example) + + pullback_batched_extras_same = prepare_pullback_batched_same_point( + f_or_f!y..., backend, x, - basis(backend, y, first(CartesianIndices(y))), - extras.pullback_extras, + Batch(ntuple(Returns(seeds[1]), Val(B))), + extras.pullback_batched_extras, ) - jac = stack(CartesianIndices(y); dims=1) do i - dy_i = basis(backend, y, i) - jac_row_i = pullback(f!, y, backend, x, dy_i, pullback_extras_same) - vec(jac_row_i) + + jac_blocks = map(1:div(M, B, RoundUp)) do a + dy_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % M] + end + dx_batch = pullback_batched( + f_or_f!y..., + backend, + x, + Batch(dy_batch_elements), + pullback_batched_extras_same, + ) + stack(vec, dx_batch.elements; dims=1) end - f!(y, x) # TODO: remove - return y, jac + + jac = reduce(vcat, jac_blocks) + if M < size(jac, 1) + jac = jac[1:M, :] + end + return jac end -function value_and_jacobian!( - f!::F, - y, +function jacobian_aux!( + f_or_f!y::FY, jac::AbstractMatrix, backend, x::AbstractArray, - extras::PushforwardJacobianExtras, -) where {F} - pushforward_extras_same = prepare_pushforward_same_point( - f!, - y, + extras::PushforwardJacobianExtras{B}, +) where {FY,B} + @compat (; seeds, pushforward_batched_extras, y_example) = extras + N = length(x) + + pushforward_batched_extras_same = prepare_pushforward_batched_same_point( + f_or_f!y..., backend, x, - basis(backend, x, first(CartesianIndices(x))), - extras.pushforward_extras, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pushforward_batched_extras, ) - for (k, j) in enumerate(CartesianIndices(x)) - dx_j = basis(backend, x, j) - jac_col_j = reshape(view(jac, :, k), size(y)) - pushforward!(f!, y, jac_col_j, backend, x, dx_j, pushforward_extras_same) + + for a in 1:div(N, B, RoundUp) + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % N] + end + dy_batch_elements = ntuple(Val(B)) do b + reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example)) + end + pushforward_batched!( + f_or_f!y..., + Batch(dy_batch_elements), + backend, + x, + Batch(dx_batch_elements), + pushforward_batched_extras_same, + ) end - f!(y, x) # TODO: remove - return y, jac + + return jac end -function value_and_jacobian!( - f!::F, y, jac::AbstractMatrix, backend, x::AbstractArray, extras::PullbackJacobianExtras -) where {F} - pullback_extras_same = prepare_pullback_same_point( - f!, - y, +function jacobian_aux!( + f_or_f!y::FY, + jac::AbstractMatrix, + backend, + x::AbstractArray, + extras::PullbackJacobianExtras{B}, +) where {FY,B} + @compat (; seeds, pullback_batched_extras, y_example) = extras + M = length(y_example) + + pullback_batched_extras_same = prepare_pullback_batched_same_point( + f_or_f!y..., backend, x, - basis(backend, y, first(CartesianIndices(y))), - extras.pullback_extras, + Batch(ntuple(Returns(seeds[1]), Val(B))), + extras.pullback_batched_extras, ) - for (k, i) in enumerate(CartesianIndices(y)) - dy_i = basis(backend, y, i) - jac_row_i = reshape(view(jac, k, :), size(x)) - pullback!(f!, y, jac_row_i, backend, x, dy_i, pullback_extras_same) - end - f!(y, x) # TODO: remove - return y, jac -end -function jacobian(f!::F, y, backend::AbstractADType, x, extras::JacobianExtras) where {F} - return value_and_jacobian(f!, y, backend, x, extras)[2] -end + for a in 1:div(M, B, RoundUp) + dy_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % M] + end + dx_batch_elements = ntuple(Val(B)) do b + reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x)) + end + pullback_batched!( + f_or_f!y..., + Batch(dx_batch_elements), + backend, + x, + Batch(dy_batch_elements), + pullback_batched_extras_same, + ) + end -function jacobian!( - f!::F, y, jac, backend::AbstractADType, x, extras::JacobianExtras -) where {F} - return value_and_jacobian!(f!, y, jac, backend, x, extras)[2] + return jac end diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 9b464a085..ce23a6f78 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -12,6 +12,8 @@ Create an `extras` object that can be given to [`pullback`](@ref) and its varian """ function prepare_pullback end +function prepare_pullback_batched end + """ prepare_pullback_same_point(f, backend, x, dy) -> extras_same prepare_pullback_same_point(f!, y, backend, x, dy) -> extras_same @@ -24,6 +26,8 @@ Create an `extras_same` object that can be given to [`pullback`](@ref) and its v """ function prepare_pullback_same_point end +function prepare_pullback_batched_same_point end + """ value_and_pullback(f, backend, x, dy, [extras]) -> (y, dx) value_and_pullback(f!, y, backend, x, dy, [extras]) -> (y, dx) @@ -51,6 +55,8 @@ Compute the pullback of the function `f` at point `x` with seed `dy`. """ function pullback end +function pullback_batched end + """ pullback!(f, dx, backend, x, dy, [extras]) -> dx pullback!(f!, y, dx, backend, x, dy, [extras]) -> dx @@ -59,8 +65,12 @@ Compute the pullback of the function `f` at point `x` with seed `dy`, overwritin """ function pullback! end +function pullback_batched! end + ## Preparation +### Extras types + """ PullbackExtras @@ -74,6 +84,8 @@ struct PushforwardPullbackExtras{E} <: PullbackExtras pushforward_extras::E end +## Standard + function prepare_pullback(f::F, backend::AbstractADType, x, dy) where {F} return prepare_pullback_aux(f, backend, x, dy, pullback_performance(backend)) end @@ -94,8 +106,6 @@ function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F return PushforwardPullbackExtras(pushforward_extras) end -# Throw error if backend is missing - function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) throw(MissingBackendError(backend)) end @@ -104,7 +114,7 @@ function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) throw(MissingBackendError(backend)) end -## Preparation (same point) +### Standard, same point function prepare_pullback_same_point( f::F, backend::AbstractADType, x, dy, extras::PullbackExtras @@ -128,8 +138,38 @@ function prepare_pullback_same_point(f!::F, y, backend::AbstractADType, x, dy) w return prepare_pullback_same_point(f!, y, backend, x, dy, extras) end +### Batched + +function prepare_pullback_batched( + f::F, backend::AbstractADType, x, dy::Batch{B} +) where {F,B} + return prepare_pullback(f, backend, x, first(dy.elements)) +end + +function prepare_pullback_batched( + f!::F, y, backend::AbstractADType, x, dy::Batch{B} +) where {F,B} + return prepare_pullback(f!, y, backend, x, first(dy.elements)) +end + +### Batched, same point + +function prepare_pullback_batched_same_point( + f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + return prepare_pullback_same_point(f, backend, x, first(dy.elements), extras) +end + +function prepare_pullback_batched_same_point( + f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + return prepare_pullback_same_point(f!, y, backend, x, first(dy.elements), extras) +end + ## One argument +### Standard + function value_and_pullback(f::F, backend::AbstractADType, x, dy) where {F} return value_and_pullback(f, backend, x, dy, prepare_pullback(f, backend, x, dy)) end @@ -184,8 +224,30 @@ function pullback!( return value_and_pullback!(f, dx, backend, x, dy, extras)[2] end +### Batched + +function pullback_batched( + f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + dx_elements = ntuple(Val(B)) do l + pullback(f, backend, x, dy.elements[l], extras) + end + return Batch(dx_elements) +end + +function pullback_batched!( + f::F, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + for l in 1:B + pullback!(f, dx.elements[l], backend, x, dy.elements[l], extras) + end + return dx +end + ## Two arguments +### Standard + function value_and_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} return value_and_pullback( f!, y, backend, x, dy, prepare_pullback(f!, y, backend, x, dy) @@ -239,3 +301,23 @@ function pullback!( ) where {F} return value_and_pullback!(f!, y, dx, backend, x, dy, extras)[2] end + +### Batched + +function pullback_batched( + f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + dx_elements = ntuple(Val(B)) do l + pullback(f!, y, backend, x, dy.elements[l], extras) + end + return Batch(dx_elements) +end + +function pullback_batched!( + f!::F, y, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras +) where {F,B} + for l in 1:B + pullback!(f!, y, dx.elements[l], backend, x, dy.elements[l], extras) + end + return dx +end diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index a80b3c4dd..31bcad38a 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -12,6 +12,8 @@ Create an `extras` object that can be given to [`pushforward`](@ref) and its var """ function prepare_pushforward end +function prepare_pushforward_batched end + """ prepare_pushforward_same_point(f, backend, x, dx) -> extras_same prepare_pushforward_same_point(f!, y, backend, x, dx) -> extras_same @@ -24,6 +26,8 @@ Create an `extras_same` object that can be given to [`pushforward`](@ref) and it """ function prepare_pushforward_same_point end +function prepare_pushforward_same_point_batched end + """ value_and_pushforward(f, backend, x, dx, [extras]) -> (y, dy) value_and_pushforward(f!, y, backend, x, dx, [extras]) -> (y, dy) @@ -51,6 +55,8 @@ Compute the pushforward of the function `f` at point `x` with seed `dx`. """ function pushforward end +function pushforward_batched end + """ pushforward!(f, dy, backend, x, dx, [extras]) -> dy pushforward!(f!, y, dy, backend, x, dx, [extras]) -> dy @@ -59,8 +65,12 @@ Compute the pushforward of the function `f` at point `x` with seed `dx`, overwri """ function pushforward! end +function pushforward_batched! end + ## Preparation +### Extras types + """ PushforwardExtras @@ -74,6 +84,8 @@ struct PullbackPushforwardExtras{E} <: PushforwardExtras pullback_extras::E end +### Standard + function prepare_pushforward(f::F, backend::AbstractADType, x, dx) where {F} return prepare_pushforward_aux(f, backend, x, dx, pushforward_performance(backend)) end @@ -95,8 +107,6 @@ function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) wh return PullbackPushforwardExtras(pullback_extras) end -# Throw error if backend is missing - function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) throw(MissingBackendError(backend)) end @@ -105,7 +115,7 @@ function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) throw(MissingBackendError(backend)) end -## Preparation (same point) +### Standard, same point function prepare_pushforward_same_point( f::F, backend::AbstractADType, x, dx, extras::PushforwardExtras @@ -129,8 +139,38 @@ function prepare_pushforward_same_point(f!::F, y, backend::AbstractADType, x, dx return prepare_pushforward_same_point(f!, y, backend, x, dx, extras) end +### Batched + +function prepare_pushforward_batched( + f::F, backend::AbstractADType, x, dx::Batch{B} +) where {F,B} + return prepare_pushforward(f, backend, x, first(dx.elements)) +end + +function prepare_pushforward_batched( + f!::F, y, backend::AbstractADType, x, dx::Batch{B} +) where {F,B} + return prepare_pushforward(f!, y, backend, x, first(dx.elements)) +end + +### Batched, same point + +function prepare_pushforward_batched_same_point( + f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras +) where {F,B} + return prepare_pushforward_same_point(f, backend, x, first(dx.elements), extras) +end + +function prepare_pushforward_batched_same_point( + f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras +) where {F,B} + return prepare_pushforward_same_point(f!, y, backend, x, first(dx.elements), extras) +end + ## One argument +### Standard + function value_and_pushforward(f::F, backend::AbstractADType, x, dx) where {F} return value_and_pushforward(f, backend, x, dx, prepare_pushforward(f, backend, x, dx)) end @@ -189,8 +229,30 @@ function pushforward!( return value_and_pushforward!(f, dy, backend, x, dx, extras)[2] end +### Batched + +function pushforward_batched( + f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras +) where {F,B} + dy_elements = ntuple(Val(B)) do l + pushforward(f, backend, x, dx.elements[l], extras) + end + return Batch(dy_elements) +end + +function pushforward_batched!( + f::F, dy::Batch{B}, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras +) where {F,B} + for l in 1:B + pushforward!(f, dy.elements[l], backend, x, dx.elements[l], extras) + end + return dy +end + ## Two arguments +### Standard + function value_and_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} return value_and_pushforward( f!, y, backend, x, dx, prepare_pushforward(f!, y, backend, x, dx) @@ -248,3 +310,29 @@ function pushforward!( ) where {F} return value_and_pushforward!(f!, y, dy, backend, x, dx, extras)[2] end + +### Batched + +function pushforward_batched( + f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras +) where {F,B} + dy_elements = ntuple(Val(B)) do l + pushforward(f!, y, backend, x, dx.elements[l], extras) + end + return Batch(dy_elements) +end + +function pushforward_batched!( + f!::F, + y, + dy::Batch{B}, + backend::AbstractADType, + x, + dx::Batch{B}, + extras::PushforwardExtras, +) where {F,B} + for l in 1:B + pushforward!(f!, y, dy.elements[l], backend, x, dx.elements[l], extras) + end + return dy +end diff --git a/DifferentiationInterface/src/misc/from_primitive.jl b/DifferentiationInterface/src/misc/from_primitive.jl new file mode 100644 index 000000000..329a2fca2 --- /dev/null +++ b/DifferentiationInterface/src/misc/from_primitive.jl @@ -0,0 +1,84 @@ +abstract type FromPrimitive <: AbstractADType end + +check_available(fromprim::FromPrimitive) = check_available(fromprim.backend) +twoarg_support(fromprim::FromPrimitive) = twoarg_support(fromprim.backend) + +## Forward + +struct AutoForwardFromPrimitive{B} <: FromPrimitive + backend::B +end + +ADTypes.mode(::AutoForwardFromPrimitive) = ADTypes.ForwardMode() + +function prepare_pushforward(f, fromprim::AutoForwardFromPrimitive, x, dx) + return prepare_pushforward(f, fromprim.backend, x, dx) +end + +function prepare_pushforward(f!, y, fromprim::AutoForwardFromPrimitive, x, dx) + return prepare_pushforward(f!, y, fromprim.backend, x, dx) +end + +function value_and_pushforward( + f, fromprim::AutoForwardFromPrimitive, x, dx, extras::PushforwardExtras +) + return value_and_pushforward(f, fromprim.backend, x, dx, extras) +end + +function value_and_pushforward( + f!, y, fromprim::AutoForwardFromPrimitive, x, dx, extras::PushforwardExtras +) + return value_and_pushforward(f!, y, fromprim.backend, x, dx, extras) +end + +function value_and_pushforward!( + f, dy, fromprim::AutoForwardFromPrimitive, x, dx, extras::PushforwardExtras +) + return value_and_pushforward!(f, dy, fromprim.backend, x, dx, extras) +end + +function value_and_pushforward!( + f!, y, dy, fromprim::AutoForwardFromPrimitive, x, dx, extras::PushforwardExtras +) + return value_and_pushforward!(f!, y, dy, fromprim.backend, x, dx, extras) +end + +## Reverse + +struct AutoReverseFromPrimitive{B} <: FromPrimitive + backend::B +end + +ADTypes.mode(::AutoReverseFromPrimitive) = ADTypes.ReverseMode() + +function prepare_pullback(f, fromprim::AutoReverseFromPrimitive, x, dy) + return prepare_pullback(f, fromprim.backend, x, dy) +end + +function prepare_pullback(f!, y, fromprim::AutoReverseFromPrimitive, x, dy) + return prepare_pullback(f!, y, fromprim.backend, x, dy) +end + +function value_and_pullback( + f, fromprim::AutoReverseFromPrimitive, x, dy, extras::PullbackExtras +) + return value_and_pullback(f, fromprim.backend, x, dy, extras) +end + +function value_and_pullback( + f!, y, fromprim::AutoReverseFromPrimitive, x, dy, extras::PullbackExtras +) + return value_and_pullback(f!, y, fromprim.backend, x, dy, extras) +end + +function value_and_pullback!( + f, dx, fromprim::AutoReverseFromPrimitive, x, dy, extras::PullbackExtras +) + return value_and_pullback!(f, dx, fromprim.backend, x, dy, extras) +end + +function value_and_pullback!( + f!, y, dx, fromprim::AutoReverseFromPrimitive, x, dy, extras::PullbackExtras +) + return value_and_pullback!(f!, y, dx, fromprim.backend, x, dy, extras) +end diff --git a/DifferentiationInterface/src/second_order/hessian.jl b/DifferentiationInterface/src/second_order/hessian.jl index 4501dd817..1ab13e219 100644 --- a/DifferentiationInterface/src/second_order/hessian.jl +++ b/DifferentiationInterface/src/second_order/hessian.jl @@ -49,16 +49,23 @@ abstract type HessianExtras <: Extras end struct NoHessianExtras <: HessianExtras end -struct HVPGradientHessianExtras{E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras - hvp_extras::E2 +struct HVPGradientHessianExtras{B,D,E2<:HVPExtras,E1<:GradientExtras} <: HessianExtras + seeds::D + hvp_batched_extras::E2 gradient_extras::E1 end function prepare_hessian(f::F, backend::AbstractADType, x) where {F} - v = basis(backend, x, first(CartesianIndices(x))) - hvp_extras = prepare_hvp(f, backend, x, v) + N = length(x) + B = pick_batchsize(maybe_outer(backend), N) + seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)] + hvp_batched_extras = prepare_hvp_batched( + f, backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))) + ) gradient_extras = prepare_gradient(f, maybe_inner(backend), x) - return HVPGradientHessianExtras(hvp_extras, gradient_extras) + D = typeof(seeds) + E2, E1 = typeof(hvp_batched_extras), typeof(gradient_extras) + return HVPGradientHessianExtras{B,D,E2,E1}(seeds, hvp_batched_extras, gradient_extras) end ## One argument @@ -82,28 +89,60 @@ function hessian!(f::F, hess, backend::AbstractADType, x) where {F} end function hessian( - f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras -) where {F} - hvp_extras_same = prepare_hvp_same_point( - f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras + f::F, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B} +) where {F,B} + @compat (; seeds, hvp_batched_extras) = extras + N = length(x) + + hvp_batched_extras_same = prepare_hvp_batched_same_point( + f, backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))), hvp_batched_extras ) - hess = stack(vec(CartesianIndices(x))) do j - hess_col_j = hvp(f, backend, x, basis(backend, x, j), hvp_extras_same) - vec(hess_col_j) + + hess_blocks = map(1:div(N, B, RoundUp)) do a + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % N] + end + dg_batch = hvp_batched( + f, backend, x, Batch(dx_batch_elements), hvp_batched_extras_same + ) + stack(vec, dg_batch.elements; dims=2) + end + + hess = reduce(hcat, hess_blocks) + if N < size(hess, 2) + hess = hess[:, 1:N] end return hess end function hessian!( - f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras -) where {F} - hvp_extras_same = prepare_hvp_same_point( - f, backend, x, basis(backend, x, first(CartesianIndices(x))), extras.hvp_extras + f::F, hess, backend::AbstractADType, x, extras::HVPGradientHessianExtras{B} +) where {F,B} + xinds = CartesianIndices(x) + N = length(x) + + dx_batch_elements = ntuple(Returns(basis(backend, x, xinds[1])), Val(B)) + hvp_batched_extras_same = prepare_hvp_batched_same_point( + f, backend, x, Batch(dx_batch_elements), extras.hvp_batched_extras ) - for (k, j) in enumerate(CartesianIndices(x)) - hess_col_j = reshape(view(hess, :, k), size(x)) - hvp!(f, hess_col_j, backend, x, basis(backend, x, j), hvp_extras_same) + + for a in 1:div(N, B, RoundUp) + dx_batch_elements = ntuple(Val(B)) do b + basis(backend, x, xinds[1 + ((a - 1) * B + (b - 1)) % N]) + end + dg_batch_elements = ntuple(Val(B)) do b + reshape(view(hess, :, 1 + ((a - 1) * B + (b - 1)) % N), size(x)) + end + hvp_batched!( + f, + Batch(dg_batch_elements), + backend, + x, + Batch(dx_batch_elements), + hvp_batched_extras_same, + ) end + return hess end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index a7745ece8..03fc47ba3 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -10,6 +10,8 @@ Create an `extras` object that can be given to [`hvp`](@ref) and its variants. """ function prepare_hvp end +function prepare_hvp_batched end + """ prepare_hvp_same_point(f, backend, x, dx) -> extras_same @@ -20,6 +22,8 @@ Create an `extras_same` object that can be given to [`hvp`](@ref) and its varian """ function prepare_hvp_same_point end +function prepare_hvp_batched_same_point end + """ hvp(f, backend, x, dx, [extras]) -> dg @@ -27,6 +31,8 @@ Compute the Hessian-vector product of `f` at point `x` with seed `dx`. """ function hvp end +function hvp_batched end + """ hvp!(f, dg, backend, x, dx, [extras]) -> dg @@ -34,8 +40,12 @@ Compute the Hessian-vector product of `f` at point `x` with seed `dx`, overwriti """ function hvp! end +function hvp_batched! end + ## Preparation +### Extras types + """ HVPExtras @@ -85,6 +95,8 @@ struct ReverseOverReverseHVPExtras{IG<:InnerGradient,E<:PullbackExtras} <: HVPEx outer_pullback_extras::E end +### Standard + function prepare_hvp(f::F, backend::AbstractADType, x, dx) where {F} return prepare_hvp(f, SecondOrder(backend, backend), x, dx) end @@ -122,7 +134,7 @@ function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) wh return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras) end -## Preparation (same point) +### Standard, same point function prepare_hvp_same_point( f::F, backend::AbstractADType, x, dx, extras::HVPExtras @@ -135,16 +147,28 @@ function prepare_hvp_same_point(f::F, backend::AbstractADType, x, dx) where {F} return prepare_hvp_same_point(f, backend, x, dx, extras) end +### Batched + +function prepare_hvp_batched(f::F, backend::AbstractADType, x, dx::Batch{B}) where {F,B} + return prepare_hvp(f, backend, x, first(dx.elements)) +end + +### Batched, same point + +function prepare_hvp_batched_same_point( + f::F, backend::AbstractADType, x, dx::Batch{B}, extras::HVPExtras +) where {F,B} + return prepare_hvp_same_point(f, backend, x, first(dx.elements), extras) +end + ## One argument +### Standard + function hvp(f::F, backend::AbstractADType, x, dx) where {F} return hvp(f, backend, x, dx, prepare_hvp(f, backend, x, dx)) end -function hvp!(f::F, dg, backend::AbstractADType, x, dx) where {F} - return hvp!(f, dg, backend, x, dx, prepare_hvp(f, backend, x, dx)) -end - function hvp(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F} return hvp(f, SecondOrder(backend, backend), x, dx, extras) end @@ -178,6 +202,10 @@ function hvp( return pullback(inner_gradient, outer(backend), x, dx, outer_pullback_extras) end +function hvp!(f::F, dg, backend::AbstractADType, x, dx) where {F} + return hvp!(f, dg, backend, x, dx, prepare_hvp(f, backend, x, dx)) +end + function hvp!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F} return hvp!(f, dg, SecondOrder(backend, backend), x, dx, extras) end @@ -210,3 +238,30 @@ function hvp!( @compat (; inner_gradient, outer_pullback_extras) = extras return pullback!(inner_gradient, dg, outer(backend), x, dx, outer_pullback_extras) end + +### Batched + +function hvp_batched(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F} + return hvp_batched(f, SecondOrder(backend, backend), x, dx, extras) +end + +function hvp_batched( + f::F, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras +) where {F,B} + dg_elements = ntuple(Val(B)) do l + hvp(f, backend, x, dx.elements[l], extras) + end + return Batch(dg_elements) +end + +function hvp_batched!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F} + return hvp_batched!(f, dg, SecondOrder(backend, backend), x, dx, extras) +end + +function hvp_batched!( + f::F, dg::Batch{B}, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras +) where {F,B} + for l in 1:B + hvp!(f, dg.elements[l], backend, x, dx.elements[l], extras) + end +end diff --git a/DifferentiationInterface/src/sparse/hessian.jl b/DifferentiationInterface/src/sparse/hessian.jl index 074c6fdf0..b7673ee0b 100644 --- a/DifferentiationInterface/src/sparse/hessian.jl +++ b/DifferentiationInterface/src/sparse/hessian.jl @@ -1,4 +1,5 @@ -Base.@kwdef struct SparseHessianExtras{ +struct SparseHessianExtras{ + B, S<:AbstractMatrix{Bool}, C<:AbstractMatrix{<:Real}, K<:AbstractVector{<:Integer}, @@ -12,10 +13,26 @@ Base.@kwdef struct SparseHessianExtras{ colors::K seeds::D products::P - hvp_extras::E2 + hvp_batched_extras::E2 gradient_extras::E1 end +function SparseHessianExtras{B}(; + sparsity::S, + compressed::C, + colors::K, + seeds::D, + products::P, + hvp_batched_extras::E2, + gradient_extras::E1, +) where {B,S,C,K,D,P,E2,E1} + @assert length(seeds) == length(products) == size(compressed, 2) + @assert size(sparsity, 1) == size(sparsity, 2) == size(compressed, 1) == length(colors) + return SparseHessianExtras{B,S,C,K,D,P,E2,E1}( + sparsity, compressed, colors, seeds, products, hvp_batched_extras, gradient_extras + ) +end + ## Hessian, one argument function prepare_hessian(f::F, backend::AutoSparse, x) where {F} @@ -24,42 +41,79 @@ function prepare_hessian(f::F, backend::AutoSparse, x) where {F} sparsity = col_major(initial_sparsity) colors = symmetric_coloring(sparsity, coloring_algorithm(backend)) groups = color_groups(colors) - seeds = map(groups) do group - seed = zero(x) - seed[group] .= one(eltype(x)) - seed - end - hvp_extras = prepare_hvp(f, dense_backend, x, first(seeds)) - products = map(seeds) do _ - similar(x) - end + seeds = map(group -> make_seed(x, group), groups) + G = length(seeds) + B = pick_batchsize(maybe_outer(dense_backend), G) + dx_batch = Batch(ntuple(Returns(seeds[1]), Val(B))) + hvp_batched_extras = prepare_hvp_batched(f, dense_backend, x, dx_batch) + products = map(_ -> similar(x), seeds) compressed = stack(vec, products; dims=2) gradient_extras = prepare_gradient(f, maybe_inner(dense_backend), x) - return SparseHessianExtras(; - sparsity, compressed, colors, seeds, products, hvp_extras, gradient_extras + return SparseHessianExtras{B}(; + sparsity, compressed, colors, seeds, products, hvp_batched_extras, gradient_extras ) end -function hessian!(f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras) where {F} - @compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras +function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras{B}) where {F,B} + @compat (; sparsity, compressed, colors, seeds, products, hvp_batched_extras) = extras + G = length(seeds) dense_backend = dense_ad(backend) - hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras) - for k in eachindex(seeds, products) - hvp!(f, products[k], dense_backend, x, seeds[k], hvp_extras_same) - copyto!(view(compressed, :, k), vec(products[k])) + + hvp_batched_extras_same = prepare_hvp_batched_same_point( + f, dense_backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))), hvp_batched_extras + ) + + compressed_blocks = map(1:div(G, B, RoundUp)) do a + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] + end + dg_batch = hvp_batched( + f, dense_backend, x, Batch(dx_batch_elements), hvp_batched_extras_same + ) + stack(vec, dg_batch.elements; dims=2) end - decompress_symmetric!(hess, sparsity, compressed, colors) - return hess + + compressed = reduce(hcat, compressed_blocks) + if G < size(compressed, 2) + compressed = compressed[:, 1:G] + end + return decompress_symmetric(sparsity, compressed, colors) end -function hessian(f::F, backend::AutoSparse, x, extras::SparseHessianExtras) where {F} - @compat (; sparsity, compressed, colors, seeds, products, hvp_extras) = extras +function hessian!( + f::F, hess, backend::AutoSparse, x, extras::SparseHessianExtras{B} +) where {F,B} + @compat (; sparsity, compressed, colors, seeds, products, hvp_batched_extras) = extras dense_backend = dense_ad(backend) - hvp_extras_same = prepare_hvp_same_point(f, dense_backend, x, seeds[1], hvp_extras) - compressed = stack(eachindex(seeds, products); dims=2) do k - vec(hvp(f, dense_backend, x, seeds[k], hvp_extras_same)) + G = length(seeds) + + hvp_batched_extras_same = prepare_hvp_batched_same_point( + f, dense_backend, x, Batch(ntuple(Returns(seeds[1]), Val(B))), hvp_batched_extras + ) + + for a in 1:div(G, B, RoundUp) + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] + end + dg_batch_elements = ntuple(Val(B)) do b + products[1 + ((a - 1) * B + (b - 1)) % G] + end + hvp_batched!( + f, + Batch(dg_batch_elements), + dense_backend, + x, + Batch(dx_batch_elements), + hvp_batched_extras_same, + ) end - return decompress_symmetric(sparsity, compressed, colors) + + for k in eachindex(products) + copyto!(view(compressed, :, k), vec(products[k])) + end + + decompress_symmetric!(hess, sparsity, compressed, colors) + return hess end function value_gradient_and_hessian!( diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/src/sparse/jacobian.jl index 2937723d4..88b024f6c 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/src/sparse/jacobian.jl @@ -1,265 +1,331 @@ -struct SparseJacobianExtras{ - args, - dir, +## Preparation + +abstract type SparseJacobianExtras <: JacobianExtras end + +struct PushforwardSparseJacobianExtras{ + B, S<:AbstractMatrix{Bool}, C<:AbstractMatrix{<:Real}, K<:AbstractVector{<:Integer}, D<:AbstractVector, P<:AbstractVector, - E<:Extras, -} <: JacobianExtras + E<:PushforwardExtras, +} <: SparseJacobianExtras sparsity::S compressed::C colors::K seeds::D products::P - jp_extras::E + pushforward_batched_extras::E end -function SparseJacobianExtras{args,dir}(; - sparsity::S, compressed::C, colors::K, seeds::D, products::P, jp_extras::E -) where {args,dir,S,C,K,D,P,E} - if dir == :col - @assert jp_extras isa PushforwardExtras - elseif dir == :row - @assert jp_extras isa PullbackExtras - end - return SparseJacobianExtras{args,dir,S,C,K,D,P,E}( - sparsity, compressed, colors, seeds, products, jp_extras +struct PullbackSparseJacobianExtras{ + B, + S<:AbstractMatrix{Bool}, + C<:AbstractMatrix{<:Real}, + K<:AbstractVector{<:Integer}, + D<:AbstractVector, + P<:AbstractVector, + E<:PullbackExtras, +} <: SparseJacobianExtras + sparsity::S + compressed::C + colors::K + seeds::D + products::P + pullback_batched_extras::E +end + +function PushforwardSparseJacobianExtras{B}(; + sparsity::S, + compressed::C, + colors::K, + seeds::D, + products::P, + pushforward_batched_extras::E, +) where {B,S,C,K,D,P,E} + @assert length(seeds) == length(products) == size(compressed, 2) + @assert size(sparsity, 1) == size(compressed, 1) + @assert size(sparsity, 2) == length(colors) + return PushforwardSparseJacobianExtras{B,S,C,K,D,P,E}( + sparsity, compressed, colors, seeds, products, pushforward_batched_extras ) end -## Jacobian, one argument +function PullbackSparseJacobianExtras{B}(; + sparsity::S, compressed::C, colors::K, seeds::D, products::P, pullback_batched_extras::E +) where {B,S,C,K,D,P,E} + @assert length(seeds) == length(products) == size(compressed, 1) + @assert size(sparsity, 2) == size(compressed, 2) + @assert size(sparsity, 1) == length(colors) + return PullbackSparseJacobianExtras{B,S,C,K,D,P,E}( + sparsity, compressed, colors, seeds, products, pullback_batched_extras + ) +end function prepare_jacobian(f::F, backend::AutoSparse, x) where {F} - dense_backend = dense_ad(backend) y = f(x) - initial_sparsity = jacobian_sparsity(f, x, sparsity_detector(backend)) - if Bool(pushforward_performance(backend)) - sparsity = col_major(initial_sparsity) - colors = column_coloring(sparsity, coloring_algorithm(backend)) - groups = color_groups(colors) - seeds = map(groups) do group - seed = zero(x) - seed[group] .= one(eltype(x)) - seed - end - jp_extras = prepare_pushforward(f, dense_backend, x, first(seeds)) - products = map(seeds) do _ - similar(y) - end - compressed = stack(vec, products; dims=2) - return SparseJacobianExtras{1,:col}(; - sparsity, compressed, colors, seeds, products, jp_extras - ) - else - sparsity = row_major(initial_sparsity) - colors = row_coloring(sparsity, coloring_algorithm(backend)) - groups = color_groups(colors) - seeds = map(groups) do group - seed = zero(y) - seed[group] .= one(eltype(y)) - seed - end - jp_extras = prepare_pullback(f, dense_backend, x, first(seeds)) - products = map(seeds) do _ - similar(x) - end - compressed = stack(vec, products; dims=1) - return SparseJacobianExtras{1,:row}(; - sparsity, compressed, colors, seeds, products, jp_extras - ) - end + return prepare_sparse_jacobian_aux( + (f,), backend, x, y, pushforward_performance(backend) + ) end -function jacobian!( - f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:col} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras - dense_backend = dense_ad(backend) - pushforward_extras_same = prepare_pushforward_same_point( - f, dense_backend, x, seeds[1], jp_extras +function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} + return prepare_sparse_jacobian_aux( + (f!, y), backend, x, y, pushforward_performance(backend) ) - for k in eachindex(seeds, products) - pushforward!(f, products[k], dense_backend, x, seeds[k], pushforward_extras_same) - copyto!(view(compressed, :, k), vec(products[k])) - end - decompress_columns!(jac, sparsity, compressed, colors) - return jac end -function jacobian!( - f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:row} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras +function prepare_sparse_jacobian_aux( + f_or_f!y::FY, backend, x, y, ::PushforwardFast +) where {FY} dense_backend = dense_ad(backend) - pullback_extras_same = prepare_pullback_same_point( - f, dense_backend, x, seeds[1], jp_extras + initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend)) + sparsity = col_major(initial_sparsity) + colors = column_coloring(sparsity, coloring_algorithm(backend)) + groups = color_groups(colors) + seeds = map(group -> make_seed(x, group), groups) + G = length(seeds) + B = pick_batchsize(dense_backend, G) + dx_batch = Batch(ntuple(Returns(seeds[1]), Val(B))) + pushforward_batched_extras = prepare_pushforward_batched( + f_or_f!y..., dense_backend, x, dx_batch + ) + products = map(_ -> similar(y), seeds) + compressed = stack(vec, products; dims=2) + return PushforwardSparseJacobianExtras{B}(; + sparsity, compressed, colors, seeds, products, pushforward_batched_extras ) - for k in eachindex(seeds, products) - pullback!(f, products[k], dense_backend, x, seeds[k], pullback_extras_same) - copyto!(view(compressed, k, :), vec(products[k])) - end - decompress_rows!(jac, sparsity, compressed, colors) - return jac end -function jacobian( - f::F, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:col} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras +function prepare_sparse_jacobian_aux( + f_or_f!y::FY, backend, x, y, ::PushforwardSlow +) where {FY} dense_backend = dense_ad(backend) - pushforward_extras_same = prepare_pushforward_same_point( - f, dense_backend, x, seeds[1], jp_extras + initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend)) + sparsity = row_major(initial_sparsity) + colors = row_coloring(sparsity, coloring_algorithm(backend)) + groups = color_groups(colors) + seeds = map(group -> make_seed(y, group), groups) + G = length(seeds) + B = pick_batchsize(dense_backend, G) + dx_batch = Batch(ntuple(Returns(seeds[1]), Val(B))) + pullback_batched_extras = prepare_pullback_batched( + f_or_f!y..., dense_backend, x, dx_batch + ) + products = map(_ -> similar(x), seeds) + compressed = stack(vec, products; dims=1) + return PullbackSparseJacobianExtras{B}(; + sparsity, compressed, colors, seeds, products, pullback_batched_extras ) - compressed = stack(eachindex(seeds, products); dims=2) do k - vec(pushforward(f, dense_backend, x, seeds[k], pushforward_extras_same)) - end - return decompress_columns(sparsity, compressed, colors) end -function jacobian( - f::F, backend::AutoSparse, x, extras::SparseJacobianExtras{1,:row} +## One argument + +function jacobian(f::F, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F} + return sparse_jacobian_aux((f,), backend, x, extras) +end + +function jacobian!( + f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras - dense_backend = dense_ad(backend) - pullback_extras_same = prepare_pullback_same_point( - f, dense_backend, x, seeds[1], jp_extras - ) - compressed = stack(eachindex(seeds, products); dims=1) do k - vec(pullback(f, dense_backend, x, seeds[k], pullback_extras_same)) - end - return decompress_rows(sparsity, compressed, colors) + return sparse_jacobian_aux!((f,), jac, backend, x, extras) +end + +function value_and_jacobian( + f::F, backend::AutoSparse, x, extras::SparseJacobianExtras +) where {F} + return f(x), jacobian(f, backend, x, extras) end function value_and_jacobian!( - f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{1} + f::F, jac, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} return f(x), jacobian!(f, jac, backend, x, extras) end +## Two arguments + +function jacobian(f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras) where {F} + return sparse_jacobian_aux((f!, y), backend, x, extras) +end + +function jacobian!( + f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras +) where {F} + return sparse_jacobian_aux!((f!, y), jac, backend, x, extras) +end + function value_and_jacobian( - f::F, backend::AutoSparse, x, extras::SparseJacobianExtras{1} + f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras ) where {F} - return f(x), jacobian(f, backend, x, extras) + jac = jacobian(f!, y, backend, x, extras) + f!(y, x) + return y, jac +end + +function value_and_jacobian!( + f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras +) where {F} + jacobian!(f!, y, jac, backend, x, extras) + f!(y, x) + return y, jac end -## Jacobian, two arguments +## Common auxiliaries -function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} +function sparse_jacobian_aux( + f_or_f!y::FY, backend::AutoSparse, x, extras::PushforwardSparseJacobianExtras{B} +) where {FY,B} + @compat (; sparsity, compressed, colors, seeds, products, pushforward_batched_extras) = + extras dense_backend = dense_ad(backend) - initial_sparsity = jacobian_sparsity(f!, y, x, sparsity_detector(backend)) - if Bool(pushforward_performance(backend)) - sparsity = col_major(initial_sparsity) - colors = column_coloring(sparsity, coloring_algorithm(backend)) - groups = color_groups(colors) - seeds = map(groups) do group - seed = zero(x) - seed[group] .= one(eltype(x)) - seed - end - jp_extras = prepare_pushforward(f!, y, dense_backend, x, first(seeds)) - products = map(seeds) do _ - similar(y) - end - compressed = stack(vec, products; dims=2) - return SparseJacobianExtras{2,:col}(; - sparsity, compressed, colors, seeds, products, jp_extras - ) - else - sparsity = row_major(initial_sparsity) - colors = row_coloring(sparsity, coloring_algorithm(backend)) - groups = color_groups(colors) - seeds = map(groups) do group - seed = zero(y) - seed[group] .= one(eltype(y)) - seed - end - jp_extras = prepare_pullback(f!, y, dense_backend, x, first(seeds)) - products = map(seeds) do _ - similar(x) + G = length(seeds) + + pushforward_batched_extras_same = prepare_pushforward_batched_same_point( + f_or_f!y..., + dense_backend, + x, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pushforward_batched_extras, + ) + + compressed_blocks = map(1:div(G, B, RoundUp)) do a + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] end - compressed = stack(vec, products; dims=1) - return SparseJacobianExtras{2,:row}(; - sparsity, compressed, colors, seeds, products, jp_extras + dy_batch = pushforward_batched( + f_or_f!y..., + dense_backend, + x, + Batch(dx_batch_elements), + pushforward_batched_extras_same, ) + stack(vec, dy_batch.elements; dims=2) end + + compressed = reduce(hcat, compressed_blocks) + if G < size(compressed, 2) + compressed = compressed[:, 1:G] + end + return decompress_columns(sparsity, compressed, colors) end -function jacobian!( - f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:col} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras +function sparse_jacobian_aux( + f_or_f!y::FY, backend::AutoSparse, x, extras::PullbackSparseJacobianExtras{B} +) where {FY,B} + @compat (; sparsity, compressed, colors, seeds, products, pullback_batched_extras) = + extras dense_backend = dense_ad(backend) - pushforward_extras_same = prepare_pushforward_same_point( - f!, y, dense_backend, x, seeds[1], jp_extras + G = length(seeds) + + pullback_batched_extras_same = prepare_pullback_batched_same_point( + f_or_f!y..., + dense_backend, + x, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pullback_batched_extras, ) - for k in eachindex(seeds, products) - pushforward!( - f!, y, products[k], dense_backend, x, seeds[k], pushforward_extras_same + + compressed_blocks = map(1:div(G, B, RoundUp)) do a + dy_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] + end + dx_batch = pullback_batched( + f_or_f!y..., + dense_backend, + x, + Batch(dy_batch_elements), + pullback_batched_extras_same, ) - copyto!(view(compressed, :, k), vec(products[k])) + stack(vec, dx_batch.elements; dims=1) end - decompress_columns!(jac, sparsity, compressed, colors) - return jac -end -function jacobian!( - f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:row} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras - dense_backend = dense_ad(backend) - pullback_extras_same = prepare_pullback_same_point( - f!, y, dense_backend, x, seeds[1], jp_extras - ) - for k in eachindex(seeds, products) - pullback!(f!, y, products[k], dense_backend, x, seeds[k], pullback_extras_same) - copyto!(view(compressed, k, :), vec(products[k])) + compressed = reduce(vcat, compressed_blocks) + if G < size(compressed, 1) + compressed = compressed[1:G, :] end - decompress_rows!(jac, sparsity, compressed, colors) - return jac + return decompress_rows(sparsity, compressed, colors) end -function jacobian( - f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:col} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras +function sparse_jacobian_aux!( + f_or_f!y::FY, jac, backend::AutoSparse, x, extras::PushforwardSparseJacobianExtras{B} +) where {FY,B} + @compat (; sparsity, compressed, colors, seeds, products, pushforward_batched_extras) = + extras dense_backend = dense_ad(backend) - pushforward_extras_same = prepare_pushforward_same_point( - f!, y, dense_backend, x, seeds[1], jp_extras + G = length(seeds) + + pushforward_batched_extras_same = prepare_pushforward_batched_same_point( + f_or_f!y..., + dense_backend, + x, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pushforward_batched_extras, ) - compressed = stack(eachindex(seeds, products); dims=2) do k - vec(pushforward(f!, y, dense_backend, x, seeds[k], pushforward_extras_same)) + + for a in 1:div(G, B, RoundUp) + dx_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] + end + dy_batch_elements = ntuple(Val(B)) do b + products[1 + ((a - 1) * B + (b - 1)) % G] + end + pushforward_batched!( + f_or_f!y..., + Batch(dy_batch_elements), + dense_backend, + x, + Batch(dx_batch_elements), + pushforward_batched_extras_same, + ) end - return decompress_columns(sparsity, compressed, colors) + + for k in eachindex(products) + copyto!(view(compressed, :, k), vec(products[k])) + end + + decompress_columns!(jac, sparsity, compressed, colors) + return jac end -function jacobian( - f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras{2,:row} -) where {F} - @compat (; sparsity, compressed, colors, seeds, products, jp_extras) = extras +function sparse_jacobian_aux!( + f_or_f!y::FY, jac, backend::AutoSparse, x, extras::PullbackSparseJacobianExtras{B} +) where {FY,B} + @compat (; sparsity, compressed, colors, seeds, products, pullback_batched_extras) = + extras dense_backend = dense_ad(backend) - pullback_extras_same = prepare_pullback_same_point( - f!, y, dense_backend, x, seeds[1], jp_extras + G = length(seeds) + + pullback_batched_extras_same = prepare_pullback_batched_same_point( + f_or_f!y..., + dense_backend, + x, + Batch(ntuple(Returns(seeds[1]), Val(B))), + pullback_batched_extras, ) - compressed = stack(eachindex(seeds, products); dims=1) do k - vec(pullback(f!, y, dense_backend, x, seeds[k], pullback_extras_same)) + + for a in 1:div(G, B, RoundUp) + dy_batch_elements = ntuple(Val(B)) do b + seeds[1 + ((a - 1) * B + (b - 1)) % G] + end + dx_batch_elements = ntuple(Val(B)) do b + products[1 + ((a - 1) * B + (b - 1)) % G] + end + pullback_batched!( + f_or_f!y..., + Batch(dx_batch_elements), + dense_backend, + x, + Batch(dy_batch_elements), + pullback_batched_extras_same, + ) end - return decompress_rows(sparsity, compressed, colors) -end -function value_and_jacobian!( - f!::F, y, jac, backend::AutoSparse, x, extras::SparseJacobianExtras{2} -) where {F} - jacobian!(f!, y, jac, backend, x, extras) - f!(y, x) - return y, jac -end + for k in eachindex(products) + copyto!(view(compressed, k, :), vec(products[k])) + end -function value_and_jacobian( - f!::F, y, backend::AutoSparse, x, extras::SparseJacobianExtras{2} -) where {F} - jac = jacobian(f!, y, backend, x, extras) - f!(y, x) - return y, jac + decompress_rows!(jac, sparsity, compressed, colors) + return jac end diff --git a/DifferentiationInterface/src/utils/basis.jl b/DifferentiationInterface/src/utils/basis.jl index 95b83fa91..3f7063f31 100644 --- a/DifferentiationInterface/src/utils/basis.jl +++ b/DifferentiationInterface/src/utils/basis.jl @@ -13,3 +13,9 @@ basis(::AbstractADType, a::AbstractArray, i) = basis(a, i) function basis(a::AbstractArray{T,N}, i::CartesianIndex{N}) where {T,N} return OneElement(one(T), Tuple(i), axes(a)) end + +function make_seed(x::AbstractArray, group::AbstractVector{<:Integer}) + seed = zero(x) + seed[group] .= one(eltype(x)) + return seed +end diff --git a/DifferentiationInterface/src/utils/batch.jl b/DifferentiationInterface/src/utils/batch.jl new file mode 100644 index 000000000..641c00573 --- /dev/null +++ b/DifferentiationInterface/src/utils/batch.jl @@ -0,0 +1,24 @@ +""" + pick_batchsize(backend::AbstractADType, dimension::Integer) + +Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`. +""" +function pick_batchsize(::AbstractADType, dimension::Integer) + return min(dimension, 8) +end + +""" + Batch{B,T} + +Efficient storage for `B` elements of type `T` (`NTuple` wrapper). + +A `Batch` can be used as seed to trigger batched-mode `pushforward`, `pullback` and `hvp`. + +# Fields + +- `elements::NTuple{B,T}` +""" +struct Batch{B,T} + elements::NTuple{B,T} + Batch(elements::NTuple) = new{length(elements),eltype(elements)}(elements) +end diff --git a/DifferentiationInterface/src/utils/chunk.jl b/DifferentiationInterface/src/utils/chunk.jl deleted file mode 100644 index b3fbca6ab..000000000 --- a/DifferentiationInterface/src/utils/chunk.jl +++ /dev/null @@ -1,22 +0,0 @@ -#= -This heuristic is taken from ForwardDiff.jl. -Source file: https://github.com/JuliaDiff/ForwardDiff.jl/blob/master/src/prelude.jl -=# - -const DEFAULT_CHUNKSIZE = 8 - -""" - pick_chunksize(input_length) - -Pick a reasonable chunk size for chunked derivative evaluation with an input of length `input_length`. - -The result cannot be larger than `DEFAULT_CHUNKSIZE=$DEFAULT_CHUNKSIZE`. -""" -function pick_chunksize(input_length::Integer; threshold::Integer=DEFAULT_CHUNKSIZE) - if input_length <= threshold - return input_length - else - nchunks = round(Int, input_length / threshold, RoundUp) - return round(Int, input_length / nchunks, RoundUp) - end -end diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index 181793cf1..4c9512c05 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -15,6 +15,9 @@ backend_package_name(::AutoTracker) = "Tracker" backend_package_name(::AutoZygote) = "Zygote" backend_package_name(::AutoReverseDiff) = "ReverseDiff" +backend_package_name(::AF) where {AF<:AutoForwardFromPrimitive} = string(AF) +backend_package_name(::AR) where {AR<:AutoReverseFromPrimitive} = string(AR) + function backend_str(backend::AbstractADType) bs = backend_package_name(backend) if mode(backend) isa ForwardMode diff --git a/DifferentiationInterface/test/Internals/batch.jl b/DifferentiationInterface/test/Internals/batch.jl new file mode 100644 index 000000000..9895acb97 --- /dev/null +++ b/DifferentiationInterface/test/Internals/batch.jl @@ -0,0 +1,2 @@ +import DifferentiationInterface as DI +using Test diff --git a/DifferentiationInterface/test/Internals/chunk.jl b/DifferentiationInterface/test/Internals/chunk.jl deleted file mode 100644 index 0418ed9a8..000000000 --- a/DifferentiationInterface/test/Internals/chunk.jl +++ /dev/null @@ -1,12 +0,0 @@ -import DifferentiationInterface as DI -using Test - -@test DI.pick_chunksize.(1:(DI.DEFAULT_CHUNKSIZE)) == 1:(DI.DEFAULT_CHUNKSIZE) -@test all( - DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .<= - DI.DEFAULT_CHUNKSIZE, -) -@test all( - DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .>= - DI.DEFAULT_CHUNKSIZE / 2, -) diff --git a/DifferentiationInterface/test/Single/ForwardDiff/test.jl b/DifferentiationInterface/test/Single/ForwardDiff/test.jl index 472f449d1..793ae8676 100644 --- a/DifferentiationInterface/test/Single/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Single/ForwardDiff/test.jl @@ -1,10 +1,13 @@ using DifferentiationInterface, DifferentiationInterfaceTest +using DifferentiationInterface: AutoForwardFromPrimitive using ForwardDiff: ForwardDiff using SparseConnectivityTracer, SparseMatrixColorings using Test dense_backends = [AutoForwardDiff(), AutoForwardDiff(; chunksize=2, tag=:hello)] +fromprimitive_backends = [AutoForwardFromPrimitive(AutoForwardDiff(; chunksize=5))] + sparse_backends = [ AutoSparse( AutoForwardDiff(); @@ -13,7 +16,7 @@ sparse_backends = [ ), ] -for backend in vcat(dense_backends, sparse_backends) +for backend in vcat(dense_backends, fromprimitive_backends, sparse_backends) @test check_available(backend) @test check_twoarg(backend) @test check_hessian(backend) @@ -21,10 +24,10 @@ end ## Dense backends -test_differentiation(dense_backends; logging=LOGGING); +test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING); test_differentiation( - dense_backends; + vcat(dense_backends, fromprimitive_backends); correctness=false, type_stability=true, second_order=false, diff --git a/DifferentiationInterface/test/Single/ReverseDiff/test.jl b/DifferentiationInterface/test/Single/ReverseDiff/test.jl index 5277a3e93..c293b6ae7 100644 --- a/DifferentiationInterface/test/Single/ReverseDiff/test.jl +++ b/DifferentiationInterface/test/Single/ReverseDiff/test.jl @@ -1,13 +1,16 @@ using DifferentiationInterface, DifferentiationInterfaceTest +using DifferentiationInterface: AutoReverseFromPrimitive using ReverseDiff: ReverseDiff using Test -backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] +dense_backends = [AutoReverseDiff(; compile=false), AutoReverseDiff(; compile=true)] -for backend in backends +fromprimitive_backends = [AutoReverseFromPrimitive(AutoReverseDiff())] + +for backend in vcat(dense_backends, fromprimitive_backends) @test check_available(backend) @test check_twoarg(backend) @test check_hessian(backend) end -test_differentiation(backends; logging=LOGGING); +test_differentiation(vcat(dense_backends, fromprimitive_backends); logging=LOGGING);