diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index bf9fdad7..61cc2884 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -57,12 +57,13 @@ jobs: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 with: - version: '1' + version: '1.10' - uses: julia-actions/cache@v1 - name: Configure doc environment shell: julia --project=docs --color=yes {0} run: | using Pkg + Pkg.Registry.update() Pkg.develop(PackageSpec(path=pwd())) Pkg.instantiate() - uses: julia-actions/julia-buildpkg@v1 diff --git a/Project.toml b/Project.toml index 882c8dff..0cc7316e 100644 --- a/Project.toml +++ b/Project.toml @@ -4,15 +4,11 @@ authors = ["Adrian Hill "] version = "0.1.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -[weakdeps] -SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804" - -[extensions] -SparseConnectivityTracerSparseDiffToolsExt = "SparseDiffTools" - [compat] -SparseDiffTools = "2.17" +ADTypes = "1" +SparseArrays = "1" julia = "1.6" diff --git a/docs/Project.toml b/docs/Project.toml index 38e22acd..cf62baaa 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,7 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" + +[compat] +ADTypes = "1" \ No newline at end of file diff --git a/docs/src/api.md b/docs/src/api.md index 7778f00e..baac9133 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -11,6 +11,7 @@ CollapsedDocStrings = true ## Interface ```@docs connectivity +TracerSparsityDetector ``` ## Internals diff --git a/docs/src/index.md b/docs/src/index.md deleted file mode 100644 index 47a233ba..00000000 --- a/docs/src/index.md +++ /dev/null @@ -1,29 +0,0 @@ -```@meta -CurrentModule = SparseConnectivityTracer -``` - -# SparseConnectivityTracer - -Documentation for [SparseConnectivityTracer](https://github.com/adrhill/SparseConnectivityTracer.jl). - -## API reference -```@index -``` - -### Interface -```@docs -connectivity -``` - -### Internals -SparseConnectivityTracer works by pushing a `Number` type called [`Tracer`](@ref) through generic functions: -```@docs -Tracer -tracer -trace_input -``` - -The following utilities can be used to extract input indices from [`Tracer`](@ref)s: -```@docs -inputs -``` diff --git a/ext/SparseConnectivityTracerSparseDiffToolsExt.jl b/ext/SparseConnectivityTracerSparseDiffToolsExt.jl deleted file mode 100644 index e646414c..00000000 --- a/ext/SparseConnectivityTracerSparseDiffToolsExt.jl +++ /dev/null @@ -1,35 +0,0 @@ -module SparseConnectivityTracerSparseDiffToolsExt - -using SparseConnectivityTracer: connectivity -using SparseDiffTools: - AbstractSparseADType, - AbstractSparsityDetection, - ArrayInterface, - GreedyD1Color, - JacPrototypeSparsityDetection, - SparseDiffTools - -Base.@kwdef struct ConnectivityTracerSparsityDetection{ - A<:ArrayInterface.ColoringAlgorithm -} <: AbstractSparsityDetection - alg::A = GreedyD1Color() -end - -function (alg::ConnectivityTracerSparsityDetection)( - ad::AbstractSparseADType, f, x; fx=nothing, kwargs... -) - fx = fx === nothing ? similar(f(x)) : dx - J = connectivity(f, x) - _alg = JacPrototypeSparsityDetection(J, alg.alg) - return _alg(ad, f, x; fx, kwargs...) -end - -function (alg::ConnectivityTracerSparsityDetection)( - ad::AbstractSparseADType, f!, fx, x; kwargs... -) - J = connectivity(f!, fx, x) - _alg = JacPrototypeSparsityDetection(J, alg.alg) - return _alg(ad, f!, fx, x; kwargs...) -end - -end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index f837adc6..e84fac90 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,4 +1,6 @@ module SparseConnectivityTracer + +using ADTypes: ADTypes import Random: rand, AbstractRNG, SamplerType import SparseArrays: sparse @@ -7,10 +9,12 @@ include("conversion.jl") include("operators.jl") include("overload_tracer.jl") include("connectivity.jl") +include("adtypes.jl") export Tracer export tracer, trace_input export inputs export connectivity +export TracerSparsityDetector end # module diff --git a/src/adtypes.jl b/src/adtypes.jl new file mode 100644 index 00000000..eabea79a --- /dev/null +++ b/src/adtypes.jl @@ -0,0 +1,30 @@ +""" + TracerSparsityDetector <: ADTypes.AbstractSparsityDetector + +Singleton struct for integration with the sparsity detection framework of ADTypes.jl. + +# Example + +```jldoctest +julia> using ADTypes, SparseConnectivityTracer + +julia> ADTypes.jacobian_sparsity(diff, rand(4), TracerSparsityDetector()) +3×4 SparseArrays.SparseMatrixCSC{Bool, UInt64} with 6 stored entries: + 1 1 ⋅ ⋅ + ⋅ 1 1 ⋅ + ⋅ ⋅ 1 1 +``` +""" +struct TracerSparsityDetector <: ADTypes.AbstractSparsityDetector end + +function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector) + return connectivity(f, x) +end + +function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector) + return connectivity(f!, y, x) +end + +function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector) + return error("Hessian sparsity is not yet implemented for `TracerSparsityDetector`.") +end diff --git a/test/Project.toml b/test/Project.toml index 542c8868..70887b3f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" diff --git a/test/adtypes.jl b/test/adtypes.jl new file mode 100644 index 00000000..f7cd0ad4 --- /dev/null +++ b/test/adtypes.jl @@ -0,0 +1,16 @@ +using ADTypes +using SparseConnectivityTracer +using SparseArrays +using Test + +sd = TracerSparsityDetector() + +x = rand(10) +y = zeros(9) +J1 = ADTypes.jacobian_sparsity(diff, x, sd) +J2 = ADTypes.jacobian_sparsity((y, x) -> y .= diff(x), y, x, sd) +@test J1 == J2 +@test J1 isa SparseMatrixCSC +@test J2 isa SparseMatrixCSC +@test nnz(J1) == nnz(J2) == 18 +@test_throws ErrorException ADTypes.hessian_sparsity(sum, x, sd) diff --git a/test/runtests.jl b/test/runtests.jl index 3d404279..7853d8e8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -30,7 +30,8 @@ DocMeta.setdocmeta!( Aqua.test_all( SparseConnectivityTracer; ambiguities=false, - deps_compat=(ignore=[:Random, :SparseArrays],), + deps_compat=(ignore=[:Random, :SparseArrays], check_extras=false), + persistent_tasks=false, ) end @testset "JET tests" begin @@ -91,8 +92,8 @@ DocMeta.setdocmeta!( @test C == C_ref end end - @testset "SparseDiffTools integration" begin - include("sparsedifftools.jl") + @testset "ADTypes integration" begin + include("adtypes.jl") end @testset "Doctests" begin Documenter.doctest(SparseConnectivityTracer) diff --git a/test/sparsedifftools.jl b/test/sparsedifftools.jl deleted file mode 100644 index fa4caf32..00000000 --- a/test/sparsedifftools.jl +++ /dev/null @@ -1,23 +0,0 @@ -using Base: get_extension -using ForwardDiff: ForwardDiff -using SparseArrays -using SparseConnectivityTracer -using SparseDiffTools -using Test - -ext = Base.get_extension( - SparseConnectivityTracer, :SparseConnectivityTracerSparseDiffToolsExt -) -@test !isnothing(ext) - -sd = ext.ConnectivityTracerSparsityDetection() -adtype = SparseDiffTools.AutoSparseForwardDiff() - -x = rand(10) -y = zeros(9) -J1 = sparse_jacobian(adtype, sd, diff, x) -J2 = sparse_jacobian(adtype, sd, (y, x) -> y .= diff(x), y, x) -@test J1 == J2 -@test J1 isa SparseMatrixCSC -@test J2 isa SparseMatrixCSC -@test nnz(J1) == nnz(J2) == 18