Skip to content

Commit

Permalink
mapreduce instead of map + reduce for Jacobians & Hessians (#565)
Browse files Browse the repository at this point in the history
* Improve type stability tests and benchmarking

* Remove `first_order` and `second_order`

* Docs

* Zero allocs

* Fixes

* Call count

* Fix

* Fix

* Add count calls

* Default count calls

* Fix

* Custom stacking for StaticArrays

* Bump

* Clearer modulo

* Woops

* Undo mo1

* Mapreduce

* Add function filter to type stability checks
  • Loading branch information
gdalle authored Oct 10, 2024
1 parent 45fbdd6 commit 9ee81a4
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 170 deletions.
8 changes: 2 additions & 6 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ function _jacobian_aux(
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
)

jac_blocks = map(eachindex(batched_seeds)) do a
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
dy_batch = pushforward(
f_or_f!y...,
pushforward_prep_same,
Expand All @@ -247,8 +247,6 @@ function _jacobian_aux(
end
block
end

jac = reduce(hcat, jac_blocks)
return jac
end

Expand All @@ -265,7 +263,7 @@ function _jacobian_aux(
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
)

jac_blocks = map(eachindex(batched_seeds)) do a
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
dx_batch = pullback(
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
)
Expand All @@ -275,8 +273,6 @@ function _jacobian_aux(
end
block
end

jac = reduce(vcat, jac_blocks)
return jac
end

Expand Down
4 changes: 1 addition & 3 deletions DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,14 @@ function hessian(
f, hvp_prep, backend, x, batched_seeds[1], contexts...
)

hess_blocks = map(eachindex(batched_seeds)) do a
hess = mapreduce(hcat, eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
block = stack_vec_col(dg_batch)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
block
end

hess = reduce(hcat, hess_blocks)
return hess
end

Expand Down
9 changes: 8 additions & 1 deletion DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ For `type_stability` and `benchmark`, the possible values are `:none`, `:prepare
**Type stability options:**
- `ignored_modules=nothing`: list of modules that JET.jl should ignore
- `function_filter`: filter for functions that JET.jl should ignore (with a reasonable default)
**Benchmark options:**
Expand All @@ -72,6 +73,11 @@ function test_differentiation(
sparsity::Bool=false,
# type stability options
ignored_modules=nothing,
function_filter=if VERSION >= v"1.11"
@nospecialize(f) -> true
else
@nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian
end,
# benchmark options
count_calls::Bool=true,
)
Expand Down Expand Up @@ -136,7 +142,8 @@ function test_differentiation(
adapted_backend,
scen;
subset=type_stability,
ignored_modules=ignored_modules,
ignored_modules,
function_filter,
)
end
yield()
Expand Down
Loading

0 comments on commit 9ee81a4

Please sign in to comment.