Skip to content

Commit

Permalink
Add SparseDiffTools extension (#9)
Browse files Browse the repository at this point in the history
* Add SparseDiffTools extension

* Remove Manifest

---------

Co-authored-by: adrhill <[email protected]>
  • Loading branch information
gdalle and adrhill authored Apr 9, 2024
1 parent 8b11d31 commit 121fef8
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 1,204 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
*.jl.cov
*.jl.mem
/Manifest.toml
/test/Manifest.toml
/docs/Manifest.toml
/docs/build/
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,12 @@ version = "1.0.0-DEV"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"

[extensions]
SparseConnectivityTracerSparseDiffToolsExt = "SparseDiffTools"

[compat]
SparseDiffTools = "2.17"
julia = "1.6"
35 changes: 35 additions & 0 deletions ext/SparseConnectivityTracerSparseDiffToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module SparseConnectivityTracerSparseDiffToolsExt

using SparseConnectivityTracer: connectivity
using SparseDiffTools:
AbstractSparseADType,
AbstractSparsityDetection,
ArrayInterface,
GreedyD1Color,
JacPrototypeSparsityDetection,
SparseDiffTools

Base.@kwdef struct ConnectivityTracerSparsityDetection{
A<:ArrayInterface.ColoringAlgorithm
} <: AbstractSparsityDetection
alg::A = GreedyD1Color()
end

function (alg::ConnectivityTracerSparsityDetection)(
ad::AbstractSparseADType, f, x; fx=nothing, kwargs...
)
fx = fx === nothing ? similar(f(x)) : dx
J = connectivity(f, x)
_alg = JacPrototypeSparsityDetection(J, alg.alg)
return _alg(ad, f, x; fx, kwargs...)
end

function (alg::ConnectivityTracerSparsityDetection)(
ad::AbstractSparseADType, f!, fx, x; kwargs...
)
J = connectivity(f!, fx, x)
_alg = JacPrototypeSparsityDetection(J, alg.alg)
return _alg(ad, f!, fx, x; kwargs...)
end

end
Loading

0 comments on commit 121fef8

Please sign in to comment.