Skip to content

Commit

Permalink
Batched pushforward, pullback and hvp (#320)
Browse files Browse the repository at this point in the history
* Batched pushforward, pullback and hvp

* Fixes, add FromPrimitive for testing

* Typo

* Typo

* Typos

* More formatting

* Reduce code duplication

* Typos

* Better display

* Typos

* Uncomment

* Printing

* Typos

* Type stability

* Typo

* Typo

* Typo

* Log Zygote

* Forward-over-reverse HVP batched for Zygote

* Typo and coverage

* Chunksize

* Funny chunk size
  • Loading branch information
gdalle authored Jun 20, 2024
1 parent 004e934 commit f62c9dd
Show file tree
Hide file tree
Showing 22 changed files with 1,035 additions and 493 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using DifferentiationInterface:
NoJacobianExtras,
NoPullbackExtras,
NoPushforwardExtras,
pick_chunksize
pick_batchsize
using DocStringExtensions
using Enzyme:
Active,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ end

## Gradient

struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
shadow::O
end

function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x)
C = pick_chunksize(length(x))
shadow = chunkedonehot(x, Val(C))
return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow)
function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode}, x)
B = pick_batchsize(backend, length(x))
shadow = chunkedonehot(x, Val(B))
return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow)
end

function DI.gradient(
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
) where {B}
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
return reshape(collect(grad_tup), size(x))
end

Expand All @@ -81,38 +81,38 @@ function DI.value_and_gradient(
end

function DI.gradient!(
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
) where {B}
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
return copyto!(grad, grad_tup)
end

function DI.value_and_gradient!(
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
) where {C}
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
) where {B}
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
return f(x), copyto!(grad, grad_tup)
end

## Jacobian

struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
shadow::O
end

function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
C = pick_chunksize(length(x))
shadow = chunkedonehot(x, Val(C))
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
B = pick_batchsize(backend, length(x))
shadow = chunkedonehot(x, Val(B))
return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow)
end

function DI.jacobian(
f,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
extras::EnzymeForwardOneArgJacobianExtras{C},
) where {C}
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
extras::EnzymeForwardOneArgJacobianExtras{B},
) where {B}
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
nx = length(x)
ny = length(jac_wrongshape) ÷ length(x)
return reshape(jac_wrongshape, ny, nx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,22 @@ end

#=
struct EnzymeReverseOneArgJacobianExtras{C,N} end
struct EnzymeReverseOneArgJacobianExtras{B,N} end
function DI.prepare_jacobian(f, ::AutoReverseEnzyme, x)
C = pick_chunksize(length(x))
function DI.prepare_jacobian(f, backend::AutoReverseEnzyme, x)
B = pick_batchsize(backend, length(x))
y = f(x)
N = length(y)
return EnzymeReverseOneArgJacobianExtras{C,N}()
return EnzymeReverseOneArgJacobianExtras{B,N}()
end
function DI.jacobian(
f,
backend::AutoReverseEnzyme,
x::AbstractArray,
::EnzymeReverseOneArgJacobianExtras{C,N},
) where {C,N}
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val{N}(), Val{C}())
) where {B,N}
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(N), Val(B))
nx = length(x)
ny = length(jac_wrongshape) ÷ length(x)
jac_rightshape = reshape(jac_wrongshape, ny, nx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,14 @@ using LinearAlgebra: dot, mul!

DI.check_available(::AutoForwardDiff) = true

function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C}
if isnothing(C)
return ForwardDiff.pickchunksize(dimension)
else
return min(dimension, C)
end
end

include("utils.jl")
include("onearg.jl")
include("twoarg.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module DifferentiationInterfaceZygoteExt
using ADTypes: AutoForwardDiff, AutoZygote
import DifferentiationInterface as DI
using DifferentiationInterface:
Batch,
HVPExtras,
NoGradientExtras,
NoHessianExtras,
Expand Down Expand Up @@ -103,20 +104,47 @@ struct ZygoteHVPExtras{G,PE} <: HVPExtras
pushforward_extras::PE
end

function DI.prepare_hvp(f, ::AutoZygote, x, v)
function DI.prepare_hvp(f, ::AutoZygote, x, dx)
∇f(x) = only(gradient(f, x))
pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, v)
pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, dx)
return ZygoteHVPExtras(∇f, pushforward_extras)
end

function DI.hvp(f, ::AutoZygote, x, v, extras::ZygoteHVPExtras)
function DI.hvp(f, ::AutoZygote, x, dx, extras::ZygoteHVPExtras)
@compat (; ∇f, pushforward_extras) = extras
return DI.pushforward(∇f, AutoForwardDiff(), x, v, pushforward_extras)
return DI.pushforward(∇f, AutoForwardDiff(), x, dx, pushforward_extras)
end

function DI.hvp!(f, p, ::AutoZygote, x, v, extras::ZygoteHVPExtras)
function DI.hvp!(f, dg, ::AutoZygote, x, dx, extras::ZygoteHVPExtras)
@compat (; ∇f, pushforward_extras) = extras
return DI.pushforward!(∇f, p, AutoForwardDiff(), x, v, pushforward_extras)
return DI.pushforward!(∇f, dg, AutoForwardDiff(), x, dx, pushforward_extras)
end

struct ZygoteHVPBatchedExtras{G,PE} <: HVPExtras
∇f::G
pushforward_batched_extras::PE
end

function DI.prepare_hvp_batched(f, ::AutoZygote, x, dx::Batch)
∇f(x) = only(gradient(f, x))
pushforward_batched_extras = DI.prepare_pushforward_batched(
∇f, AutoForwardDiff(), x, dx
)
return ZygoteHVPBatchedExtras(∇f, pushforward_batched_extras)
end

function DI.hvp_batched(f, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras)
@compat (; ∇f, pushforward_batched_extras) = extras
return DI.pushforward_batched(∇f, AutoForwardDiff(), x, dx, pushforward_batched_extras)
end

function DI.hvp_batched!(
f, dg::Batch, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras
)
@compat (; ∇f, pushforward_batched_extras) = extras
return DI.pushforward_batched!(
∇f, dg, AutoForwardDiff(), x, dx, pushforward_batched_extras
)
end

## Hessian
Expand Down
6 changes: 4 additions & 2 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ include("second_order/second_order.jl")

include("utils/traits.jl")
include("utils/basis.jl")
include("utils/printing.jl")
include("utils/chunk.jl")
include("utils/batch.jl")
include("utils/check.jl")
include("utils/exceptions.jl")
include("utils/maybe.jl")
Expand All @@ -73,6 +72,9 @@ include("sparse/hessian.jl")

include("misc/differentiate_with.jl")
include("misc/sparsity_detector.jl")
include("misc/from_primitive.jl")

include("utils/printing.jl")

function __init__()
@require_extensions
Expand Down
Loading

0 comments on commit f62c9dd

Please sign in to comment.