Skip to content

Commit

Permalink
Custom stacking for StaticArrays (#564)
Browse files Browse the repository at this point in the history
* Improve type stability tests and benchmarking

* Remove `first_order` and `second_order`

* Docs

* Zero allocs

* Fixes

* Call count

* Fix

* Fix

* Add count calls

* Default count calls

* Fix

* Custom stacking for StaticArrays

* Bump

* Clearer modulo

* Woops

* Undo mo1
  • Loading branch information
gdalle authored Oct 10, 2024
1 parent 3698dbe commit 7607ec2
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 4 deletions.
5 changes: 4 additions & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.9"
version = "0.6.10"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -20,6 +20,7 @@ PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -37,6 +38,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
Expand All @@ -56,6 +58,7 @@ PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
StaticArrays = "1.9.7"
SparseMatrixColorings = "0.4.5"
Symbolics = "5.27.1, 6"
Tracker = "0.2.33"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DifferentiationInterfaceStaticArraysExt

import DifferentiationInterface as DI
using StaticArrays: SArray

function DI.stack_vec_col(t::NTuple{B,<:SArray}) where {B}
return hcat(map(vec, t)...)
end

end
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ include("utils/check.jl")
include("utils/exceptions.jl")
include("utils/printing.jl")
include("utils/context.jl")
include("utils/linalg.jl")

include("first_order/pushforward.jl")
include("first_order/pullback.jl")
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function _jacobian_aux(
batched_seeds[a],
contexts...,
)
block = stack(vec, dy_batch; dims=2)
block = stack_vec_col(dy_batch)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
Expand Down Expand Up @@ -269,7 +269,7 @@ function _jacobian_aux(
dx_batch = pullback(
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
)
block = stack(vec, dx_batch; dims=1)
block = stack_vec_row(dx_batch)
if M % B != 0 && a == lastindex(batched_seeds)
block = block[1:(M - (a - 1) * B), :]
end
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function hessian(

hess_blocks = map(eachindex(batched_seeds)) do a
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
block = stack(vec, dg_batch; dims=2)
block = stack_vec_col(dg_batch)
if N % B != 0 && a == lastindex(batched_seeds)
block = block[:, 1:(N - (a - 1) * B)]
end
Expand Down
2 changes: 2 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)

0 comments on commit 7607ec2

Please sign in to comment.