Skip to content

Commit f62c9dd

Browse files
authored
Batched pushforward, pullback and hvp (#320)
* Batched pushforward, pullback and hvp * Fixes, add FromPrimitive for testing * Typo * Typo * Typos * More formatting * Reduce code duplication * Typos * Better display * Typos * Uncomment * Printing * Typos * Type stability * Typo * Typo * Typo * Log Zygote * Forward-over-reverse HVP batched for Zygote * Typo and coverage * Chunksize * Funny chunk size
1 parent 004e934 commit f62c9dd

File tree

22 files changed

+1035
-493
lines changed

22 files changed

+1035
-493
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using DifferentiationInterface:
1313
NoJacobianExtras,
1414
NoPullbackExtras,
1515
NoPushforwardExtras,
16-
pick_chunksize
16+
pick_batchsize
1717
using DocStringExtensions
1818
using Enzyme:
1919
Active,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

+22-22
Original file line numberDiff line numberDiff line change
@@ -57,20 +57,20 @@ end
5757

5858
## Gradient
5959

60-
struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
60+
struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
6161
shadow::O
6262
end
6363

64-
function DI.prepare_gradient(f, ::AutoEnzyme{<:ForwardMode}, x)
65-
C = pick_chunksize(length(x))
66-
shadow = chunkedonehot(x, Val(C))
67-
return EnzymeForwardGradientExtras{C,typeof(shadow)}(shadow)
64+
function DI.prepare_gradient(f, backend::AutoEnzyme{<:ForwardMode}, x)
65+
B = pick_batchsize(backend, length(x))
66+
shadow = chunkedonehot(x, Val(B))
67+
return EnzymeForwardGradientExtras{B,typeof(shadow)}(shadow)
6868
end
6969

7070
function DI.gradient(
71-
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
72-
) where {C}
73-
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
71+
f, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
72+
) where {B}
73+
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
7474
return reshape(collect(grad_tup), size(x))
7575
end
7676

@@ -81,38 +81,38 @@ function DI.value_and_gradient(
8181
end
8282

8383
function DI.gradient!(
84-
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
85-
) where {C}
86-
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
84+
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
85+
) where {B}
86+
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
8787
return copyto!(grad, grad_tup)
8888
end
8989

9090
function DI.value_and_gradient!(
91-
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{C}
92-
) where {C}
93-
grad_tup = gradient(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
91+
f, grad, backend::AutoEnzyme{<:ForwardMode}, x, extras::EnzymeForwardGradientExtras{B}
92+
) where {B}
93+
grad_tup = gradient(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
9494
return f(x), copyto!(grad, grad_tup)
9595
end
9696

9797
## Jacobian
9898

99-
struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
99+
struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
100100
shadow::O
101101
end
102102

103-
function DI.prepare_jacobian(f, ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
104-
C = pick_chunksize(length(x))
105-
shadow = chunkedonehot(x, Val(C))
106-
return EnzymeForwardOneArgJacobianExtras{C,typeof(shadow)}(shadow)
103+
function DI.prepare_jacobian(f, backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, x)
104+
B = pick_batchsize(backend, length(x))
105+
shadow = chunkedonehot(x, Val(B))
106+
return EnzymeForwardOneArgJacobianExtras{B,typeof(shadow)}(shadow)
107107
end
108108

109109
function DI.jacobian(
110110
f,
111111
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
112112
x,
113-
extras::EnzymeForwardOneArgJacobianExtras{C},
114-
) where {C}
115-
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val{C}(); shadow=extras.shadow)
113+
extras::EnzymeForwardOneArgJacobianExtras{B},
114+
) where {B}
115+
jac_wrongshape = jacobian(forward_mode(backend), f, x, Val(B); shadow=extras.shadow)
116116
nx = length(x)
117117
ny = length(jac_wrongshape) ÷ length(x)
118118
return reshape(jac_wrongshape, ny, nx)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,22 @@ end
160160

161161
#=
162162
163-
struct EnzymeReverseOneArgJacobianExtras{C,N} end
163+
struct EnzymeReverseOneArgJacobianExtras{B,N} end
164164
165-
function DI.prepare_jacobian(f, ::AutoReverseEnzyme, x)
166-
C = pick_chunksize(length(x))
165+
function DI.prepare_jacobian(f, backend::AutoReverseEnzyme, x)
166+
B = pick_batchsize(backend, length(x))
167167
y = f(x)
168168
N = length(y)
169-
return EnzymeReverseOneArgJacobianExtras{C,N}()
169+
return EnzymeReverseOneArgJacobianExtras{B,N}()
170170
end
171171
172172
function DI.jacobian(
173173
f,
174174
backend::AutoReverseEnzyme,
175175
x::AbstractArray,
176176
::EnzymeReverseOneArgJacobianExtras{C,N},
177-
) where {C,N}
178-
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val{N}(), Val{C}())
177+
) where {B,N}
178+
jac_wrongshape = jacobian(reverse_mode(backend), f, x, Val(N), Val(B))
179179
nx = length(x)
180180
ny = length(jac_wrongshape) ÷ length(x)
181181
jac_rightshape = reshape(jac_wrongshape, ny, nx)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

+8
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ using LinearAlgebra: dot, mul!
3838

3939
DI.check_available(::AutoForwardDiff) = true
4040

41+
function DI.pick_batchsize(::AutoForwardDiff{C}, dimension::Integer) where {C}
42+
if isnothing(C)
43+
return ForwardDiff.pickchunksize(dimension)
44+
else
45+
return min(dimension, C)
46+
end
47+
end
48+
4149
include("utils.jl")
4250
include("onearg.jl")
4351
include("twoarg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

+34-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module DifferentiationInterfaceZygoteExt
33
using ADTypes: AutoForwardDiff, AutoZygote
44
import DifferentiationInterface as DI
55
using DifferentiationInterface:
6+
Batch,
67
HVPExtras,
78
NoGradientExtras,
89
NoHessianExtras,
@@ -103,20 +104,47 @@ struct ZygoteHVPExtras{G,PE} <: HVPExtras
103104
pushforward_extras::PE
104105
end
105106

106-
function DI.prepare_hvp(f, ::AutoZygote, x, v)
107+
function DI.prepare_hvp(f, ::AutoZygote, x, dx)
107108
∇f(x) = only(gradient(f, x))
108-
pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, v)
109+
pushforward_extras = DI.prepare_pushforward(∇f, AutoForwardDiff(), x, dx)
109110
return ZygoteHVPExtras(∇f, pushforward_extras)
110111
end
111112

112-
function DI.hvp(f, ::AutoZygote, x, v, extras::ZygoteHVPExtras)
113+
function DI.hvp(f, ::AutoZygote, x, dx, extras::ZygoteHVPExtras)
113114
@compat (; ∇f, pushforward_extras) = extras
114-
return DI.pushforward(∇f, AutoForwardDiff(), x, v, pushforward_extras)
115+
return DI.pushforward(∇f, AutoForwardDiff(), x, dx, pushforward_extras)
115116
end
116117

117-
function DI.hvp!(f, p, ::AutoZygote, x, v, extras::ZygoteHVPExtras)
118+
function DI.hvp!(f, dg, ::AutoZygote, x, dx, extras::ZygoteHVPExtras)
118119
@compat (; ∇f, pushforward_extras) = extras
119-
return DI.pushforward!(∇f, p, AutoForwardDiff(), x, v, pushforward_extras)
120+
return DI.pushforward!(∇f, dg, AutoForwardDiff(), x, dx, pushforward_extras)
121+
end
122+
123+
struct ZygoteHVPBatchedExtras{G,PE} <: HVPExtras
124+
∇f::G
125+
pushforward_batched_extras::PE
126+
end
127+
128+
function DI.prepare_hvp_batched(f, ::AutoZygote, x, dx::Batch)
129+
∇f(x) = only(gradient(f, x))
130+
pushforward_batched_extras = DI.prepare_pushforward_batched(
131+
∇f, AutoForwardDiff(), x, dx
132+
)
133+
return ZygoteHVPBatchedExtras(∇f, pushforward_batched_extras)
134+
end
135+
136+
function DI.hvp_batched(f, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras)
137+
@compat (; ∇f, pushforward_batched_extras) = extras
138+
return DI.pushforward_batched(∇f, AutoForwardDiff(), x, dx, pushforward_batched_extras)
139+
end
140+
141+
function DI.hvp_batched!(
142+
f, dg::Batch, ::AutoZygote, x, dx::Batch, extras::ZygoteHVPBatchedExtras
143+
)
144+
@compat (; ∇f, pushforward_batched_extras) = extras
145+
return DI.pushforward_batched!(
146+
∇f, dg, AutoForwardDiff(), x, dx, pushforward_batched_extras
147+
)
120148
end
121149

122150
## Hessian

DifferentiationInterface/src/DifferentiationInterface.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ include("second_order/second_order.jl")
5050

5151
include("utils/traits.jl")
5252
include("utils/basis.jl")
53-
include("utils/printing.jl")
54-
include("utils/chunk.jl")
53+
include("utils/batch.jl")
5554
include("utils/check.jl")
5655
include("utils/exceptions.jl")
5756
include("utils/maybe.jl")
@@ -73,6 +72,9 @@ include("sparse/hessian.jl")
7372

7473
include("misc/differentiate_with.jl")
7574
include("misc/sparsity_detector.jl")
75+
include("misc/from_primitive.jl")
76+
77+
include("utils/printing.jl")
7678

7779
function __init__()
7880
@require_extensions

0 commit comments

Comments
 (0)