From 9e1c55f0d9b8da5f83831d75f1e8b4171ce49996 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Wed, 26 Jun 2024 11:45:03 +0200 Subject: [PATCH 1/4] Add array overloads (#131) --- Project.toml | 5 + src/SparseConnectivityTracer.jl | 5 + src/overloads/arrays.jl | 232 ++++++++++++++++++++++++++++++++ test/runtests.jl | 5 +- test/test_arrays.jl | 179 ++++++++++++++++++++++++ 5 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 src/overloads/arrays.jl create mode 100644 test/test_arrays.jl diff --git a/Project.toml b/Project.toml index 7b20760b..5b83a9c5 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,8 @@ version = "0.6.0-DEV" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -25,7 +27,10 @@ SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" ADTypes = "1" Compat = "3,4" DocStringExtensions = "0.9" +FillArrays = "1" +LinearAlgebra = "<0.0.1, 1" NNlib = "0.8, 0.9" +Random = "<0.0.1, 1" Requires = "1.3" SparseArrays = "<0.0.1, 1" SpecialFunctions = "2.4" diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 07b8514e..846f8d81 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -2,9 +2,13 @@ module SparseConnectivityTracer using ADTypes: ADTypes using Compat: Returns +using SparseArrays: SparseArrays using SparseArrays: sparse using Random: AbstractRNG, SamplerType +using LinearAlgebra: LinearAlgebra +using FillArrays: Fill + using DocStringExtensions if !isdefined(Base, :get_extension) @@ -26,6 +30,7 @@ include("overloads/hessian_tracer.jl") include("overloads/ifelse_global.jl") include("overloads/dual.jl") include("overloads/overload_all.jl") +include("overloads/arrays.jl") include("interface.jl") include("adtypes.jl") diff --git a/src/overloads/arrays.jl b/src/overloads/arrays.jl new file mode 100644 index 00000000..3fc5f4b0 --- /dev/null +++ b/src/overloads/arrays.jl @@ -0,0 +1,232 @@ +""" + second_order_or(tracers) + +Compute the most conservative elementwise OR of tracer sparsity patterns, +including second-order interactions to update the `hessian` field of `HessianTracer`. + +This is functionally equivalent to: +```julia +reduce(^, tracers) +``` +""" +function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} + # TODO: improve performance + return reduce(second_order_or, ts; init=myempty(T)) +end + +function second_order_or(a::T, b::T) where {T<:ConnectivityTracer} + return connectivity_tracer_2_to_1(a, b, false, false) +end +function second_order_or(a::T, b::T) where {T<:GradientTracer} + return gradient_tracer_2_to_1(a, b, false, false) +end +function second_order_or(a::T, b::T) where {T<:HessianTracer} + return hessian_tracer_2_to_1(a, b, false, false, false, false, false) +end + +""" + first_order_or(tracers) + +Compute the most conservative elementwise OR of tracer sparsity patterns, +excluding second-order interactions of `HessianTracer`. + +This is functionally equivalent to: +```julia +reduce(+, tracers) +``` +""" +function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} + # TODO: improve performance + return reduce(first_order_or, ts; init=myempty(T)) +end +function first_order_or(a::T, b::T) where {T<:ConnectivityTracer} + return connectivity_tracer_2_to_1(a, b, false, false) +end +function first_order_or(a::T, b::T) where {T<:GradientTracer} + return gradient_tracer_2_to_1(a, b, false, false) +end +function first_order_or(a::T, b::T) where {T<:HessianTracer} + return hessian_tracer_2_to_1(a, b, false, true, false, true, true) +end + +#===========# +# Utilities # +#===========# + +function split_dual_array(A::AbstractArray{D}) where {D<:Dual} + primals = getproperty.(A, :primal) + tracers = getproperty.(A, :tracer) + return primals, tracers +end +function split_dual_array(A::SparseArrays.SparseMatrixCSC{D}) where {D<:Dual} + primals = getproperty.(A, :primal) + tracers = getproperty.(A, :tracer) + return primals, tracers +end + +#==================# +# LinearAlgebra.jl # +#==================# + +# TODO: replace `second_order_or` by less conservative sparsity patterns when possible + +## Determinant +LinearAlgebra.det(A::AbstractMatrix{T}) where {T<:AbstractTracer} = second_order_or(A) +LinearAlgebra.logdet(A::AbstractMatrix{T}) where {T<:AbstractTracer} = second_order_or(A) +function LinearAlgebra.logabsdet(A::AbstractMatrix{T}) where {T<:AbstractTracer} + t1 = second_order_or(A) + t2 = sign(t1) # corresponds to sign of det(A): set first- and second-order derivatives to zero + return (t1, t2) +end + +## Norm +function LinearAlgebra.norm(A::AbstractArray{T}, p::Real=2) where {T<:AbstractTracer} + if isone(p) || isinf(p) + return first_order_or(A) + else + return second_order_or(A) + end +end +function LinearAlgebra.opnorm(A::AbstractArray{T}, p::Real=2) where {T<:AbstractTracer} + if isone(p) || isinf(p) + return first_order_or(A) + else + return second_order_or(A) + end +end +function LinearAlgebra.opnorm(A::AbstractMatrix{T}, p::Real=2) where {T<:AbstractTracer} + if isone(p) || isinf(p) + return first_order_or(A) + else + return second_order_or(A) + end +end + +## Eigenvalues + +function LinearAlgebra.eigmax( + A::Union{T,AbstractMatrix{T}}; permute::Bool=true, scale::Bool=true +) where {T<:AbstractTracer} + return second_order_or(A) +end +function LinearAlgebra.eigmin( + A::Union{T,AbstractMatrix{T}}; permute::Bool=true, scale::Bool=true +) where {T<:AbstractTracer} + return second_order_or(A) +end +function LinearAlgebra.eigen( + A::AbstractMatrix{T}; + permute::Bool=true, + scale::Bool=true, + sortby::Union{Function,Nothing}=nothing, +) where {T<:AbstractTracer} + LinearAlgebra.checksquare(A) + n = size(A, 1) + t = second_order_or(A) + values = Fill(t, n) + vectors = Fill(t, n, n) + return LinearAlgebra.Eigen(values, vectors) +end + +## Inverse +function LinearAlgebra.inv(A::StridedMatrix{T}) where {T<:AbstractTracer} + LinearAlgebra.checksquare(A) + t = second_order_or(A) + return Fill(t, size(A)...) +end +function LinearAlgebra.pinv( + A::AbstractMatrix{T}; atol::Real=0.0, rtol::Real=0.0 +) where {T<:AbstractTracer} + n, m = size(A) + t = second_order_or(A) + return Fill(t, m, n) +end + +## Division +function LinearAlgebra.:\( + A::AbstractMatrix{T}, B::AbstractVecOrMat +) where {T<:AbstractTracer} + Ainv = LinearAlgebra.pinv(A) + return Ainv * B +end + +## Exponential +function LinearAlgebra.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer} + LinearAlgebra.checksquare(A) + n = size(A, 1) + t = second_order_or(A) + return Fill(t, n, n) +end + +## Matrix power +function LinearAlgebra.:^(A::AbstractMatrix{T}, p::Integer) where {T<:AbstractTracer} + LinearAlgebra.checksquare(A) + n = size(A, 1) + if iszero(p) + return Fill(myempty(T), n, n) + else + t = second_order_or(A) + return Fill(t, n, n) + end +end + +#==========================# +# LinearAlgebra.jl on Dual # +#==========================# + +# `Duals` should use LinearAlgebra's generic fallback implementations +# to compute the "least conservative" sparsity patterns possible on a scalar level. + +# The following three methods are a temporary fix for issue #108. +# TODO: instead overload `lu` on AbstractMatrix of Duals. +function LinearAlgebra.det(A::AbstractMatrix{D}) where {D<:Dual} + primals, tracers = split_dual_array(A) + p = LinearAlgebra.logdet(primals) + t = LinearAlgebra.logdet(tracers) + return D(p, t) +end +function LinearAlgebra.logdet(A::AbstractMatrix{D}) where {D<:Dual} + primals, tracers = split_dual_array(A) + p = LinearAlgebra.logdet(primals) + t = LinearAlgebra.logdet(tracers) + return D(p, t) +end +function LinearAlgebra.logabsdet(A::AbstractMatrix{D}) where {D<:Dual} + primals, tracers = split_dual_array(A) + p1, p2 = LinearAlgebra.logabsdet(primals) + t1, t2 = LinearAlgebra.logabsdet(tracers) + return (D(p1, t1), D(p2, t2)) +end + +#==============# +# SparseArrays # +#==============# + +# Conversion of matrices of tracers to SparseMatrixCSC has to be rewritten +# due to use of `count(_isnotzero, M)` in SparseArrays.jl +# +# Code modified from MIT licensed SparseArrays.jl source: +# https://github.com/JuliaSparse/SparseArrays.jl/blob/45dfe459ede2fa1419e7068d4bda92d9d22bd44d/src/sparsematrix.jl#L901-L920 +# Copyright (c) 2009-2024: Jeff Bezanson, Stefan Karpinski, Viral B. Shah, and other contributors: https://github.com/JuliaLang/julia/contributors +function SparseArrays.SparseMatrixCSC{Tv,Ti}( + M::StridedMatrix{Tv} +) where {Tv<:AbstractTracer,Ti} + nz = count(!isemptytracer, M) + colptr = zeros(Ti, size(M, 2) + 1) + nzval = Vector{Tv}(undef, nz) + rowval = Vector{Ti}(undef, nz) + colptr[1] = 1 + cnt = 1 + @inbounds for j in 1:size(M, 2) + for i in 1:size(M, 1) + v = M[i, j] + if !isemptytracer(v) + rowval[cnt] = i + nzval[cnt] = v + cnt += 1 + end + end + colptr[j + 1] = cnt + end + return SparseArrays.SparseMatrixCSC(size(M, 1), size(M, 2), colptr, rowval, nzval) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1fa254b4..6cdf3f39 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") Aqua.test_all( SparseConnectivityTracer; ambiguities=false, - deps_compat=(ignore=[:Random, :SparseArrays], check_extras=false), + deps_compat=(check_extras=false,), stale_deps=(ignore=[:Requires],), persistent_tasks=false, ) @@ -82,6 +82,9 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @testset "HessianTracer" begin include("test_hessian.jl") end + @testset "Array overloads" begin + include("test_arrays.jl") + end end end diff --git a/test/test_arrays.jl b/test/test_arrays.jl new file mode 100644 index 00000000..90b224e6 --- /dev/null +++ b/test/test_arrays.jl @@ -0,0 +1,179 @@ +import SparseConnectivityTracer as SCT +using SparseConnectivityTracer +using SparseConnectivityTracer: GradientTracer +using LinearAlgebra: Symmetric, Diagonal +using LinearAlgebra: det, logdet, logabsdet, norm, opnorm +using LinearAlgebra: eigen, eigmax, eigmin +using LinearAlgebra: inv, pinv +using SparseArrays: sparse, spdiagm +using Test + +PATTERN_FUNCTIONS = (connectivity_pattern, jacobian_pattern, hessian_pattern) + +TEST_SQUARE_MATRICES = Dict( + "`Matrix` (3×3)" => rand(3, 3), + "`Symmetric` (3×3)" => Symmetric(rand(3, 3)), + "`Diagonal` (3×3)" => Diagonal(rand(3)), +) +TEST_MATRICES = merge(TEST_SQUARE_MATRICES, Dict("`Matrix` (3×4)" => rand(3, 4))) + +S = BitSet +TG = GradientTracer{S} + +# NOTE: we currently test for conservative patterns on array overloads +# Changes making array overloads less convervative will break these tests, but are welcome! +function test_patterns(f, x; outsum=false, con=isone, jac=isone, hes=isone) + @testset "$f" begin + if outsum + _f(x) = sum(f(x)) + else + _f = f + end + @testset "Connecivity pattern" begin + pattern = connectivity_pattern(_f, x) + @test all(con, pattern) + end + @testset "Jacobian pattern" begin + pattern = jacobian_pattern(_f, x) + @test all(jac, pattern) + end + @testset "Hessian pattern" begin + pattern = hessian_pattern(_f, x) + @test all(hes, pattern) + end + end +end + +@testset "Scalar functions" begin + norm1(A) = norm(A, 1) + norm2(A) = norm(A, 2) + norminf(A) = norm(A, Inf) + opnorm1(A) = opnorm(A, 1) + opnorm2(A) = opnorm(A, 2) + opnorminf(A) = opnorm(A, Inf) + logabsdet_first(A) = first(logabsdet(A)) + logabsdet_last(A) = last(logabsdet(A)) + + @testset "$name" for (name, A) in TEST_MATRICES + test_patterns(det, A) + test_patterns(logdet, A) + test_patterns(norm1, A; hes=iszero) + test_patterns(norm2, A) + test_patterns(norminf, A; hes=iszero) + test_patterns(eigmax, A) + test_patterns(eigmin, A) + test_patterns(opnorm1, A; hes=iszero) + test_patterns(opnorm2, A) + test_patterns(opnorminf, A; hes=iszero) + test_patterns(logabsdet_first, A) + test_patterns(logabsdet_last, A; jac=iszero, hes=iszero) + end + @testset "`SparseMatrixCSC` (3×3)" begin + # TODO: this is a temporary solution until sparse matrix inputs are supported (#28) + test_patterns(A -> det(sparse(A)), rand(3, 3)) + test_patterns(A -> logdet(sparse(A)), rand(3, 3)) + test_patterns(A -> norm(sparse(A)), rand(3, 3)) + test_patterns(A -> eigmax(sparse(A)), rand(3, 3)) + test_patterns(A -> eigmin(sparse(A)), rand(3, 3)) + test_patterns(A -> opnorm1(sparse(A)), rand(3, 3); hes=iszero) + test_patterns(A -> logabsdet_first(sparse(A)), rand(3, 3)) + test_patterns(A -> logabsdet_last(sparse(A)), rand(3, 3); jac=iszero, hes=iszero) + + test_patterns(v -> det(spdiagm(v)), rand(3)) + test_patterns(v -> logdet(spdiagm(v)), rand(3)) + test_patterns(v -> norm(spdiagm(v)), rand(3)) + test_patterns(v -> eigmax(spdiagm(v)), rand(3)) + test_patterns(v -> eigmin(spdiagm(v)), rand(3)) + test_patterns(v -> opnorm1(spdiagm(v)), rand(3); hes=iszero) + test_patterns(v -> logabsdet_first(spdiagm(v)), rand(3)) + test_patterns(v -> logabsdet_last(spdiagm(v)), rand(3); jac=iszero, hes=iszero) + end +end + +@testset "Matrix-valued functions" begin + pow0(A) = A^0 + pow3(A) = A^3 + + # Functions that only work on square matrices + @testset "$name" for (name, A) in TEST_SQUARE_MATRICES + test_patterns(inv, A; outsum=true) + test_patterns(exp, A; outsum=true) + test_patterns(pow0, A; outsum=true, con=iszero, jac=iszero, hes=iszero) + test_patterns(pow3, A; outsum=true) + end + @testset "`SparseMatrixCSC` (3×3)" begin + # TODO: this is a temporary solution until sparse matrix inputs are supported (#28) + + test_patterns(A -> exp(sparse(A)), rand(3, 3); outsum=true) + test_patterns( + A -> pow0(sparse(A)), + rand(3, 3); + outsum=true, + con=iszero, + jac=iszero, + hes=iszero, + ) + test_patterns(A -> pow3(sparse(A)), rand(3, 3); outsum=true) + + test_patterns(v -> exp(spdiagm(v)), rand(3); outsum=true) + + if VERSION >= v"1.10" + # issue with custom _mapreducezeros in SparseArrays on Julia 1.6 + test_patterns( + v -> pow0(spdiagm(v)), + rand(3); + outsum=true, + con=iszero, + jac=iszero, + hes=iszero, + ) + test_patterns(v -> pow3(spdiagm(v)), rand(3); outsum=true) + end + end + + # Functions that work on all matrices + @testset "$name" for (name, A) in TEST_MATRICES + test_patterns(pinv, A; outsum=true) + end + @testset "`SparseMatrixCSC` (3×4)" begin + test_patterns(A -> pinv(sparse(A)), rand(3, 4); outsum=true) + end +end + +@testset "Matrix division" begin + t1 = TG(S([1, 3, 4])) + t2 = TG(S([2, 4])) + t3 = TG(S([8, 9])) + t4 = TG(S([8, 9])) + A = [t1 t2; t3 t4] + s_out = S([1, 2, 3, 4, 8, 9]) + + x = rand(2) + b = A \ x + @test all(t -> SCT.gradient(t) == s_out, b) +end + +@testset "Eigenvalues" begin + t1 = TG(S([1, 3, 4])) + t2 = TG(S([2, 4])) + t3 = TG(S([8, 9])) + t4 = TG(S([8, 9])) + A = [t1 t2; t3 t4] + s_out = S([1, 2, 3, 4, 8, 9]) + values, vectors = eigen(A) + @test size(values) == (2,) + @test size(vectors) == (2, 2) + @test all(t -> SCT.gradient(t) == s_out, values) + @test all(t -> SCT.gradient(t) == s_out, vectors) +end + +@testset "SparseMatrixCSC construction" begin + t1 = TG(S(1)) + t2 = TG(S(2)) + t3 = TG(S(3)) + SA = sparse([t1 t2; t3 0]) + @test length(SA.nzval) == 3 + + res = opnorm(SA, 1) + @test SCT.gradient(res) == S([1, 2, 3]) +end From b56207bfed83c3e11757488e1ceacdcc4f6e2c3a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 26 Jun 2024 11:46:54 +0200 Subject: [PATCH 2/4] Put benchmark boilerplate into subpackage (#136) * Put benchmark boilerplate into subpackage * Use PkgJogger to test accidental breakage of the benchmark suite --- .github/workflows/CI.yml | 5 ++- .gitignore | 1 + benchmark/Project.toml | 1 + .../Project.toml | 15 +++++++++ .../src/SparseConnectivityTracerBenchmarks.jl | 27 +++++++++++++++ .../src/brusselator.jl | 0 .../src/nlpmodels.jl | 17 ++-------- benchmark/bench_jogger.jl | 33 +++++++++++++++++++ benchmark/benchmarks.jl | 33 +++---------------- benchmark/jacobian.jl | 3 +- benchmark/nlpmodels.jl | 4 +-- test/Project.toml | 3 ++ test/benchmarks_correctness.jl | 4 +++ test/brusselator.jl | 3 +- test/nlpmodels.jl | 11 +++++-- test/runtests.jl | 12 +++++++ 16 files changed, 118 insertions(+), 54 deletions(-) create mode 100644 benchmark/SparseConnectivityTracerBenchmarks/Project.toml create mode 100644 benchmark/SparseConnectivityTracerBenchmarks/src/SparseConnectivityTracerBenchmarks.jl rename test/definitions/brusselator_definition.jl => benchmark/SparseConnectivityTracerBenchmarks/src/brusselator.jl (100%) rename test/definitions/nlpmodels_definitions.jl => benchmark/SparseConnectivityTracerBenchmarks/src/nlpmodels.jl (90%) create mode 100644 benchmark/bench_jogger.jl create mode 100644 test/benchmarks_correctness.jl diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a791561e..95f3f47c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -27,10 +27,13 @@ jobs: - '1' group: - Core + - Benchmarks - NLPModels exclude: - version: '1.6' group: NLPModels + - version: '1.6' + group: Benchmarks steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -44,7 +47,7 @@ jobs: JULIA_SCT_TEST_GROUP: ${{ matrix.group }} - uses: julia-actions/julia-processcoverage@v1 with: - directories: src,ext,test + directories: src,ext,test,benchmark - uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.gitignore b/.gitignore index a75e4104..e8aa8d64 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ /docs/build/ /docs/src/index.md /benchmark/Manifest.toml +/benchmark/SparseConnectivityTracerBenchmarks/Manifest.toml /references/* \ No newline at end of file diff --git a/benchmark/Project.toml b/benchmark/Project.toml index 19512fbe..a40f9589 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -8,6 +8,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" OptimizationProblems = "5049e819-d29b-5fba-b941-0eee7e64c1c6" +PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01" SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" diff --git a/benchmark/SparseConnectivityTracerBenchmarks/Project.toml b/benchmark/SparseConnectivityTracerBenchmarks/Project.toml new file mode 100644 index 00000000..a6e9ae19 --- /dev/null +++ b/benchmark/SparseConnectivityTracerBenchmarks/Project.toml @@ -0,0 +1,15 @@ +name = "SparseConnectivityTracerBenchmarks" +uuid = "fb1f6577-eb25-4e27-a243-7f62c22307d7" +authors = ["Guillaume Dalle", "Adrian Hill"] +version = "0.1.0" + +[deps] +ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" +NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" +OptimizationProblems = "5049e819-d29b-5fba-b941-0eee7e64c1c6" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" diff --git a/benchmark/SparseConnectivityTracerBenchmarks/src/SparseConnectivityTracerBenchmarks.jl b/benchmark/SparseConnectivityTracerBenchmarks/src/SparseConnectivityTracerBenchmarks.jl new file mode 100644 index 00000000..b1d5d076 --- /dev/null +++ b/benchmark/SparseConnectivityTracerBenchmarks/src/SparseConnectivityTracerBenchmarks.jl @@ -0,0 +1,27 @@ +module SparseConnectivityTracerBenchmarks + +module ODE + include("brusselator.jl") + export Brusselator!, brusselator_2d_loop! +end + +module Optimization + using ADTypes: ADTypes + using SparseConnectivityTracer + import SparseConnectivityTracer as SCT + + using ADNLPModels: ADNLPModels + using NLPModels: NLPModels, AbstractNLPModel + using NLPModelsJuMP: NLPModelsJuMP + using OptimizationProblems: OptimizationProblems + + using LinearAlgebra + using SparseArrays + + include("nlpmodels.jl") + export optimization_problem_names + export compute_jac_sparsity_sct, compute_hess_sparsity_sct + export compute_jac_and_hess_sparsity_sct, compute_jac_and_hess_sparsity_and_value_jump +end + +end # module SparseConnectivityTracerBenchmarks diff --git a/test/definitions/brusselator_definition.jl b/benchmark/SparseConnectivityTracerBenchmarks/src/brusselator.jl similarity index 100% rename from test/definitions/brusselator_definition.jl rename to benchmark/SparseConnectivityTracerBenchmarks/src/brusselator.jl diff --git a/test/definitions/nlpmodels_definitions.jl b/benchmark/SparseConnectivityTracerBenchmarks/src/nlpmodels.jl similarity index 90% rename from test/definitions/nlpmodels_definitions.jl rename to benchmark/SparseConnectivityTracerBenchmarks/src/nlpmodels.jl index 2993b6b1..74f1ae99 100644 --- a/test/definitions/nlpmodels_definitions.jl +++ b/benchmark/SparseConnectivityTracerBenchmarks/src/nlpmodels.jl @@ -1,18 +1,3 @@ -using ADTypes: ADTypes -using SparseConnectivityTracer -import SparseConnectivityTracer as SCT - -using ADNLPModels: ADNLPModels -using NLPModels: NLPModels, AbstractNLPModel -using NLPModelsJuMP: NLPModelsJuMP -using OptimizationProblems: OptimizationProblems - -using Dates: now -using LinearAlgebra -using SparseArrays - -problem_names() = Symbol.(OptimizationProblems.meta[!, :name]) - #= Given an optimization problem `min f(x) s.t. c(x) <= 0`, we study @@ -31,6 +16,8 @@ Package ecosystem overview: https://jso.dev/ecosystems/models/ - OptimizationProblems.PureJuMP: spits out `JuMP.Model` =# +optimization_problem_names() = Symbol.(OptimizationProblems.meta[!, :name]) + ## SCT #= diff --git a/benchmark/bench_jogger.jl b/benchmark/bench_jogger.jl new file mode 100644 index 00000000..525a384e --- /dev/null +++ b/benchmark/bench_jogger.jl @@ -0,0 +1,33 @@ +using Pkg +Pkg.develop(; path=joinpath(@__DIR__, "SparseConnectivityTracerBenchmarks")) + +using BenchmarkTools +using SparseConnectivityTracer +using SparseConnectivityTracer: GradientTracer, HessianTracer +using SparseConnectivityTracer: DuplicateVector, SortedVector, RecursiveSet + +SET_TYPES = (BitSet, Set{Int}, DuplicateVector{Int}, RecursiveSet{Int}, SortedVector{Int}) + +include("jacobian.jl") +include("hessian.jl") +include("nlpmodels.jl") + +suite = BenchmarkGroup() + +suite["OptimizationProblems"] = optbench([:britgas]) + +for S1 in SET_TYPES + S2 = Set{Tuple{Int,Int}} + + G = GradientTracer{S1} + H = HessianTracer{S1,S2} + + suite["Jacobian"]["Global"][nameof(S1)] = jacbench(TracerSparsityDetector(G, H)) + suite["Jacobian"]["Local"][nameof(S1)] = jacbench(TracerLocalSparsityDetector(G, H)) + suite["Hessian"]["Global"][(nameof(S1), nameof(S2))] = hessbench( + TracerSparsityDetector(G, H) + ) + suite["Hessian"]["Local"][(nameof(S1), nameof(S2))] = hessbench( + TracerLocalSparsityDetector(G, H) + ) +end diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index e6b36ad6..25b09c40 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,30 +1,5 @@ -using BenchmarkTools +using PkgJogger using SparseConnectivityTracer -using SparseConnectivityTracer: GradientTracer, HessianTracer -using SparseConnectivityTracer: DuplicateVector, SortedVector, RecursiveSet - -SET_TYPES = (BitSet, Set{Int}, DuplicateVector{Int}, RecursiveSet{Int}, SortedVector{Int}) - -include("jacobian.jl") -include("hessian.jl") -include("nlpmodels.jl") - -SUITE = BenchmarkGroup() - -SUITE["OptimizationProblems"] = optbench([:britgas]) - -for S1 in SET_TYPES - S2 = Set{Tuple{Int,Int}} - - G = GradientTracer{S1} - H = HessianTracer{S1,S2} - - SUITE["Jacobian"]["Global"][nameof(S1)] = jacbench(TracerSparsityDetector(G, H)) - SUITE["Jacobian"]["Local"][nameof(S1)] = jacbench(TracerLocalSparsityDetector(G, H)) - SUITE["Hessian"]["Global"][(nameof(S1), nameof(S2))] = hessbench( - TracerSparsityDetector(G, H) - ) - SUITE["Hessian"]["Local"][(nameof(S1), nameof(S2))] = hessbench( - TracerLocalSparsityDetector(G, H) - ) -end +# Use PkgJogger.@jog to create the JogSparseConnectivityTracer module +@jog SparseConnectivityTracer +SUITE = JogSparseConnectivityTracer.suite() diff --git a/benchmark/jacobian.jl b/benchmark/jacobian.jl index 3669c623..ced93945 100644 --- a/benchmark/jacobian.jl +++ b/benchmark/jacobian.jl @@ -2,6 +2,7 @@ using BenchmarkTools using ADTypes: AbstractSparsityDetector, jacobian_sparsity using SparseConnectivityTracer +using SparseConnectivityTracerBenchmarks.ODE: Brusselator!, brusselator_2d_loop! using SparseArrays: sprand using SimpleDiffEq: ODEProblem, solve, SimpleEuler @@ -49,8 +50,6 @@ end ## Brusselator -include("../test/definitions/brusselator_definition.jl") - function jacbench_brusselator(method) suite = BenchmarkGroup() for N in (6, 24) diff --git a/benchmark/nlpmodels.jl b/benchmark/nlpmodels.jl index 0aa802c2..a2711ae3 100644 --- a/benchmark/nlpmodels.jl +++ b/benchmark/nlpmodels.jl @@ -1,7 +1,7 @@ using BenchmarkTools using OptimizationProblems: ADNLPProblems - -include("../test/definitions/nlpmodels_definitions.jl") +using SparseConnectivityTracerBenchmarks.Optimization: + compute_jac_sparsity_sct, compute_hess_sparsity_sct function optbench(names::Vector{Symbol}) suite = BenchmarkGroup() diff --git a/test/Project.toml b/test/Project.toml index 1ae6ce53..81e54b10 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ ADNLPModels = "54578032-b7ea-4c30-94aa-7cbd1cce6c9a" ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -15,8 +16,10 @@ NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" OptimizationProblems = "5049e819-d29b-5fba-b941-0eee7e64c1c6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PkgJogger = "10150987-6cc1-4b76-abee-b1c1cbd91c01" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf" +SimpleDiffEq = "05bca326-078c-5bf0-a5bf-ce7c7982d7fd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/benchmarks_correctness.jl b/test/benchmarks_correctness.jl new file mode 100644 index 00000000..6b25db18 --- /dev/null +++ b/test/benchmarks_correctness.jl @@ -0,0 +1,4 @@ +using PkgJogger +using SparseConnectivityTracer + +PkgJogger.@test_benchmarks SparseConnectivityTracer diff --git a/test/brusselator.jl b/test/brusselator.jl index a8683eea..5cadfd5e 100644 --- a/test/brusselator.jl +++ b/test/brusselator.jl @@ -3,6 +3,7 @@ using ADTypes: AbstractSparsityDetector using ReferenceTests using SparseConnectivityTracer using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector +using SparseConnectivityTracerBenchmarks.ODE: Brusselator! using Test GRADIENT_TRACERS = ( @@ -12,8 +13,6 @@ GRADIENT_TRACERS = ( GradientTracer{SortedVector{Int}}, ) -include("definitions/brusselator_definition.jl") - function test_brusselator(method::AbstractSparsityDetector) N = 6 f! = Brusselator!(N) diff --git a/test/nlpmodels.jl b/test/nlpmodels.jl index d5a4bf9a..940d2e78 100644 --- a/test/nlpmodels.jl +++ b/test/nlpmodels.jl @@ -1,7 +1,12 @@ +using Dates: now +using LinearAlgebra using OptimizationProblems +using SparseArrays using Test - -include("definitions/nlpmodels_definitions.jl"); +using SparseConnectivityTracerBenchmarks.Optimization: + compute_jac_and_hess_sparsity_sct, + compute_jac_and_hess_sparsity_and_value_jump, + optimization_problem_names function compare_patterns( truth::AbstractMatrix{<:Real}; sct::AbstractMatrix{Bool}, jump::AbstractMatrix{Bool} @@ -33,7 +38,7 @@ Please look at the warnings displayed at the end. jac_inconsistencies = [] hess_inconsistencies = [] -@testset "$name" for name in problem_names() +@testset "$name" for name in optimization_problem_names() @info "$(now()) - $name" (jac_sparsity_sct, hess_sparsity_sct) = compute_jac_and_hess_sparsity_sct(name) diff --git a/test/runtests.jl b/test/runtests.jl index 6cdf3f39..89afe3b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,8 @@ +using Pkg +Pkg.develop(; + path=joinpath(@__DIR__, "..", "benchmark", "SparseConnectivityTracerBenchmarks") +) + using SparseConnectivityTracer using Compat @@ -109,6 +114,13 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") end end + if GROUP in ("Benchmarks", "All") + @info "Testing benchmarks correctness..." + @testset "Benchmarks correctness" begin + include("benchmarks_correctness.jl") + end + end + if GROUP in ("NLPModels", "All") @info "Testing NLPModels..." @testset "NLPModels" begin From f662af1da1ff26398977c58f209cda1497146a29 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 27 Jun 2024 17:12:16 +0200 Subject: [PATCH 3/4] Introduce sparsity patterns (#139) --- benchmark/bench_jogger.jl | 8 +- src/SparseConnectivityTracer.jl | 1 + src/interface.jl | 8 +- src/overloads/connectivity_tracer.jl | 14 +-- src/overloads/gradient_tracer.jl | 21 +++- src/overloads/hessian_tracer.jl | 70 ++++++------- src/overloads/ifelse_global.jl | 16 +-- src/patterns.jl | 149 +++++++++++++++++++++++++++ src/tracers.jl | 97 ++++++++--------- test/brusselator.jl | 9 +- test/classification.jl | 3 + test/flux.jl | 8 +- test/nlpmodels.jl | 1 + test/runtests.jl | 3 + test/test_arrays.jl | 25 ++--- test/test_connectivity.jl | 11 +- test/test_constructors.jl | 24 +---- test/test_gradient.jl | 11 +- test/test_hessian.jl | 12 +-- test/tracers_definitions.jl | 23 +++++ 20 files changed, 334 insertions(+), 180 deletions(-) create mode 100644 src/patterns.jl create mode 100644 test/tracers_definitions.jl diff --git a/benchmark/bench_jogger.jl b/benchmark/bench_jogger.jl index 525a384e..25e153ca 100644 --- a/benchmark/bench_jogger.jl +++ b/benchmark/bench_jogger.jl @@ -4,6 +4,7 @@ Pkg.develop(; path=joinpath(@__DIR__, "SparseConnectivityTracerBenchmarks")) using BenchmarkTools using SparseConnectivityTracer using SparseConnectivityTracer: GradientTracer, HessianTracer +using SparseConnectivityTracer: IndexSetGradientPattern, IndexSetHessianPattern using SparseConnectivityTracer: DuplicateVector, SortedVector, RecursiveSet SET_TYPES = (BitSet, Set{Int}, DuplicateVector{Int}, RecursiveSet{Int}, SortedVector{Int}) @@ -19,8 +20,11 @@ suite["OptimizationProblems"] = optbench([:britgas]) for S1 in SET_TYPES S2 = Set{Tuple{Int,Int}} - G = GradientTracer{S1} - H = HessianTracer{S1,S2} + PG = IndexSetGradientPattern{Int,S1} + PH = IndexSetHessianPattern{Int,S1,S2} + + G = GradientTracer{PG} + H = HessianTracer{PH} suite["Jacobian"]["Global"][nameof(S1)] = jacbench(TracerSparsityDetector(G, H)) suite["Jacobian"]["Local"][nameof(S1)] = jacbench(TracerLocalSparsityDetector(G, H)) diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index 846f8d81..a2791a79 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -19,6 +19,7 @@ include("settypes/duplicatevector.jl") include("settypes/recursiveset.jl") include("settypes/sortedvector.jl") +include("patterns.jl") include("tracers.jl") include("exceptions.jl") include("operators.jl") diff --git a/src/interface.jl b/src/interface.jl index 9e3f85de..d7e80138 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,6 +1,8 @@ -const DEFAULT_CONNECTIVITY_TRACER = ConnectivityTracer{BitSet} -const DEFAULT_GRADIENT_TRACER = GradientTracer{BitSet} -const DEFAULT_HESSIAN_TRACER = HessianTracer{BitSet,Set{Tuple{Int,Int}}} +const DEFAULT_CONNECTIVITY_TRACER = ConnectivityTracer{IndexSetGradientPattern{Int,BitSet}} +const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}} +const DEFAULT_HESSIAN_TRACER = HessianTracer{ + IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}} +} #==================# # Enumerate inputs # diff --git a/src/overloads/connectivity_tracer.jl b/src/overloads/connectivity_tracer.jl index eb666229..6b58073a 100644 --- a/src/overloads/connectivity_tracer.jl +++ b/src/overloads/connectivity_tracer.jl @@ -49,23 +49,23 @@ end return connectivity_tracer_1_to_1(ty, is_infl_arg2_zero) else i_out = connectivity_tracer_2_to_1_inner( - inputs(tx), inputs(ty), is_infl_arg1_zero, is_infl_arg2_zero + pattern(tx), pattern(ty), is_infl_arg1_zero, is_infl_arg2_zero ) return T(i_out) # return tracer end end function connectivity_tracer_2_to_1_inner( - sx::S, sy::S, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool -) where {S<:AbstractSet{<:Integer}} + px::P, py::P, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool +) where {P<:IndexSetGradientPattern} if is_infl_arg1_zero && is_infl_arg2_zero - return myempty(S) + return myempty(P) elseif !is_infl_arg1_zero && is_infl_arg2_zero - return sx + return px elseif is_infl_arg1_zero && !is_infl_arg2_zero - return sy + return py else - return union(sx, sy) # return set + return P(union(set(px), set(py))) # return pattern end end diff --git a/src/overloads/gradient_tracer.jl b/src/overloads/gradient_tracer.jl index b339078c..3a5ad831 100644 --- a/src/overloads/gradient_tracer.jl +++ b/src/overloads/gradient_tracer.jl @@ -10,7 +10,14 @@ end end -# Called by HessianTracer with AbstractSet +function gradient_tracer_1_to_1_inner( + p::P, is_der1_zero::Bool +) where {P<:IndexSetGradientPattern} + return P(gradient_tracer_1_to_1_inner(set(p), is_der1_zero)) # return pattern +end + +# This is only required because it is called by HessianTracer with IndexSetHessianPattern +# Otherwise, we would just have the method on IndexSetGradientPattern above. function gradient_tracer_1_to_1_inner( s::S, is_der1_zero::Bool ) where {S<:AbstractSet{<:Integer}} @@ -60,12 +67,22 @@ end return gradient_tracer_1_to_1(ty, is_der1_arg2_zero) else g_out = gradient_tracer_2_to_1_inner( - gradient(tx), gradient(ty), is_der1_arg1_zero, is_der1_arg2_zero + pattern(tx), pattern(ty), is_der1_arg1_zero, is_der1_arg2_zero ) return T(g_out) # return tracer end end +function gradient_tracer_2_to_1_inner( + px::P, py::P, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool +) where {P<:IndexSetGradientPattern} + return P( + gradient_tracer_2_to_1_inner(set(px), set(py), is_der1_arg1_zero, is_der1_arg2_zero) + ) # return pattern +end + +# This is only required because it is called by HessianTracer with IndexSetHessianPattern +# Otherwise, we would just have the method on IndexSetGradientPattern above. function gradient_tracer_2_to_1_inner( sx::S, sy::S, is_der1_arg1_zero::Bool, is_der1_arg2_zero::Bool ) where {S<:AbstractSet{<:Integer}} diff --git a/src/overloads/hessian_tracer.jl b/src/overloads/hessian_tracer.jl index c15d7ee2..7b845f9d 100644 --- a/src/overloads/hessian_tracer.jl +++ b/src/overloads/hessian_tracer.jl @@ -1,32 +1,32 @@ ## 1-to-1 @noinline function hessian_tracer_1_to_1( - t::T, is_der1_zero::Bool, is_secondder_zero::Bool + t::T, is_der1_zero::Bool, is_der2_zero::Bool ) where {T<:HessianTracer} if isemptytracer(t) # TODO: add test return t else - g_out, h_out = hessian_tracer_1_to_1_inner( - gradient(t), hessian(t), is_der1_zero, is_secondder_zero - ) - return T(g_out, h_out) # return tracer + p_out = hessian_tracer_1_to_1_inner(pattern(t), is_der1_zero, is_der2_zero) + return T(p_out) # return tracer end end function hessian_tracer_1_to_1_inner( - sg::G, sh::H, is_der1_zero::Bool, is_secondder_zero::Bool -) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}} + p::P, is_der1_zero::Bool, is_der2_zero::Bool +) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}} + sg = gradient(p) + sh = hessian(p) sg_out = gradient_tracer_1_to_1_inner(sg, is_der1_zero) - sh_out = if is_der1_zero && is_secondder_zero - myempty(H) - elseif !is_der1_zero && is_secondder_zero + sh_out = if is_der1_zero && is_der2_zero + myempty(SH) + elseif !is_der1_zero && is_der2_zero sh - elseif is_der1_zero && !is_secondder_zero - union_product!(myempty(H), sg, sg) + elseif is_der1_zero && !is_der2_zero + union_product!(myempty(SH), sg, sg) else union_product!(copy(sh), sg, sg) end - return sg_out, sh_out # return sets + return P(sg_out, sh_out) # return pattern end function overload_hessian_1_to_1(M, op) @@ -62,54 +62,52 @@ end tx::T, ty::T, is_der1_arg1_zero::Bool, - is_secondder_arg1_zero::Bool, + is_der2_arg1_zero::Bool, is_der1_arg2_zero::Bool, - is_secondder_arg2_zero::Bool, + is_der2_arg2_zero::Bool, is_der_cross_zero::Bool, ) where {T<:HessianTracer} # TODO: add tests for isempty if tx.isempty && ty.isempty return tx # empty tracer elseif ty.isempty - return hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_secondder_arg1_zero) + return hessian_tracer_1_to_1(tx, is_der1_arg1_zero, is_der2_arg1_zero) elseif tx.isempty - return hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_secondder_arg2_zero) + return hessian_tracer_1_to_1(ty, is_der1_arg2_zero, is_der2_arg2_zero) else - g_out, h_out = hessian_tracer_2_to_1_inner( - gradient(tx), - hessian(tx), - gradient(ty), - hessian(ty), + p_out = hessian_tracer_2_to_1_inner( + pattern(tx), + pattern(ty), is_der1_arg1_zero, - is_secondder_arg1_zero, + is_der2_arg1_zero, is_der1_arg2_zero, - is_secondder_arg2_zero, + is_der2_arg2_zero, is_der_cross_zero, ) - return T(g_out, h_out) # return tracer + return T(p_out) # return tracer end end function hessian_tracer_2_to_1_inner( - sgx::G, - shx::H, - sgy::G, - shy::H, + px::P, + py::P, is_der1_arg1_zero::Bool, - is_secondder_arg1_zero::Bool, + is_der2_arg1_zero::Bool, is_der1_arg2_zero::Bool, - is_secondder_arg2_zero::Bool, + is_der2_arg2_zero::Bool, is_der_cross_zero::Bool, -) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}} +) where {I,SG,SH,P<:IndexSetHessianPattern{I,SG,SH}} + sgx, shx = gradient(px), hessian(px) + sgy, shy = gradient(py), hessian(py) sg_out = gradient_tracer_2_to_1_inner(sgx, sgy, is_der1_arg1_zero, is_der1_arg2_zero) - sh_out = myempty(H) + sh_out = myempty(SH) !is_der1_arg1_zero && union!(sh_out, shx) # hessian alpha !is_der1_arg2_zero && union!(sh_out, shy) # hessian beta - !is_secondder_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha - !is_secondder_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta + !is_der2_arg1_zero && union_product!(sh_out, sgx, sgx) # product alpha + !is_der2_arg2_zero && union_product!(sh_out, sgy, sgy) # product beta !is_der_cross_zero && union_product!(sh_out, sgx, sgy) # cross product 1 !is_der_cross_zero && union_product!(sh_out, sgy, sgx) # cross product 2 - return sg_out, sh_out # return sets + return P(sg_out, sh_out) # return pattern end function overload_hessian_2_to_1(M, op) diff --git a/src/overloads/ifelse_global.jl b/src/overloads/ifelse_global.jl index f9ad5836..2609acc1 100644 --- a/src/overloads/ifelse_global.jl +++ b/src/overloads/ifelse_global.jl @@ -9,16 +9,16 @@ end ## output union on scalar outputs - function output_union(tx::T, ty::T) where {T<:ConnectivityTracer} - return T(union(inputs(tx), inputs(ty))) + function output_union(tx::T, ty::T) where {T<:AbstractTracer} + return T(output_union(pattern(tx), pattern(ty))) # return tracer end - function output_union(tx::T, ty::T) where {T<:GradientTracer} - return T(union(gradient(tx), gradient(ty))) + function output_union(px::P, py::P) where {P<:IndexSetGradientPattern} + return P(union(set(px), set(py))) # return pattern end - function output_union(tx::T, ty::T) where {T<:HessianTracer} - g_out = union(gradient(tx), gradient(ty)) - h_out = union(hessian(tx), hessian(ty)) - return T(g_out, h_out) + function output_union(px::P, py::P) where {P<:IndexSetHessianPattern} + g_out = union(gradient(px), gradient(py)) + h_out = union(hessian(px), hessian(py)) + return P(g_out, h_out) # return pattern end output_union(tx::AbstractTracer, y) = tx diff --git a/src/patterns.jl b/src/patterns.jl new file mode 100644 index 00000000..afea743c --- /dev/null +++ b/src/patterns.jl @@ -0,0 +1,149 @@ +""" + AbstractPattern + +Abstract supertype of all sparsity pattern representations. + +## Type hierarchy +``` +AbstractPattern +├── AbstractGradientPattern: used in GradientTracer, ConnectivityTracer +│ └── IndexSetGradientPattern +└── AbstractHessianPattern: used in HessianTracer + └── IndexSetHessianPattern +``` +""" +abstract type AbstractPattern end + +""" + myempty(T) + myempty(tracer) + myempty(pattern) + + +Constructor for an empty tracer or pattern of type `T` representing a new number (usually an empty pattern). +""" +myempty + +""" + seed(T, i) + seed(tracer, i) + seed(pattern, i) + +Constructor for a tracer or pattern of type `T` that only contains the given index `i`. +""" +seed + +#==========================# +# Utilities on AbstractSet # +#==========================# + +myempty(::Type{S}) where {S<:AbstractSet} = S() +seed(::Type{S}, i::Integer) where {S<:AbstractSet} = S(i) + +"""" + product(a::S{T}, b::S{T})::S{Tuple{T,T}} + +Inner product of set-like inputs `a` and `b`. +""" +product(a::AbstractSet{I}, b::AbstractSet{I}) where {I<:Integer} = + Set((i, j) for i in a, j in b) + +function union_product!( + hessian::SH, gradient_x::SG, gradient_y::SG +) where {I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}} + hxy = product(gradient_x, gradient_y) + return union!(hessian, hxy) +end + +#=======================# +# AbstractGradientPattern # +#=======================# + +# For use with GradientTracer. + +""" + AbstractGradientPattern <: AbstractPattern + +Abstract supertype of sparsity patterns representing a vector. +For use with [`GradientTracer`](@ref). + +## Expected interface + +* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern) +* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i` +* `gradient(p::MyPattern)`: return non-zero indices `i` for use with `GradientTracer` + +Note that besides their names, the last two functions are usually identical. +""" +abstract type AbstractGradientPattern <: AbstractPattern end + +""" +$(TYPEDEF) + +Vector sparsity pattern represented by an `AbstractSet` of indices ``{i}`` of non-zero values. + +## Fields +$(TYPEDFIELDS) +""" +struct IndexSetGradientPattern{I<:Integer,S<:AbstractSet{I}} <: AbstractGradientPattern + "Set of indices represting non-zero entries ``i`` in a vector." + gradient::S +end + +set(v::IndexSetGradientPattern) = v.gradient + +Base.show(io::IO, p::IndexSetGradientPattern) = Base.show(io, set(p)) + +function myempty(::Type{IndexSetGradientPattern{I,S}}) where {I,S} + return IndexSetGradientPattern{I,S}(myempty(S)) +end +function seed(::Type{IndexSetGradientPattern{I,S}}, i) where {I,S} + return IndexSetGradientPattern{I,S}(seed(S, i)) +end + +# Tracer compatibility +inputs(s::IndexSetGradientPattern) = s.gradient +gradient(s::IndexSetGradientPattern) = s.gradient + +#========================# +# AbstractHessianPattern # +#========================# + +# For use with HessianTracer. + +""" + AbstractHessianPattern <: AbstractPattern + +Abstract supertype of sparsity patterns representing both gradient and Hessian sparsity. +For use with [`HessianTracer`](@ref). + +## Expected interface + +* `myempty(::Type{MyPattern})`: return a pattern representing a new number (usually an empty pattern) +* `seed(::Type{MyPattern}, i::Integer)`: return an pattern that only contains the given index `i` in the first-order representation +* `gradient(p::MyPattern)`: return non-zero indices `i` in the first-order representation +* `hessian(p::MyPattern)`: return non-zero indices `(i, j)` in the second-order representation +""" +abstract type AbstractHessianPattern <: AbstractPattern end + +""" + IndexSetHessianPattern(vector::AbstractGradientPattern, mat::AbstractMatrixPattern) + +Gradient and Hessian sparsity patterns constructed by combining two AbstractSets. +""" +struct IndexSetHessianPattern{I<:Integer,SG<:AbstractSet{I},SH<:AbstractSet{Tuple{I,I}}} <: + AbstractHessianPattern + gradient::SG + hessian::SH +end + +function myempty(::Type{IndexSetHessianPattern{I,SG,SH}}) where {I,SG,SH} + return IndexSetHessianPattern{I,SG,SH}(myempty(SG), myempty(SH)) +end +function seed(::Type{IndexSetHessianPattern{I,SG,SH}}, index) where {I,SG,SH} + return IndexSetHessianPattern{I,SG,SH}(seed(SG, index), myempty(SH)) +end + +# Tracer compatibility +gradient(s::IndexSetHessianPattern) = s.gradient +hessian(s::IndexSetHessianPattern) = s.hessian diff --git a/src/tracers.jl b/src/tracers.jl index 672d9a42..94345b1e 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -1,20 +1,4 @@ -abstract type AbstractTracer <: Real end - -#===================# -# Set operations # -#===================# - -myempty(::Type{S}) where {S<:AbstractSet} = S() -seed(::Type{S}, i::Integer) where {S<:AbstractSet} = S(i) - -product(a::AbstractSet{I}, b::AbstractSet{I}) where {I} = Set((i, j) for i in a, j in b) - -function union_product!( - h::H, gx::G, gy::G -) where {I<:Integer,G<:AbstractSet{I},H<:AbstractSet{Tuple{I,I}}} - hxy = product(gx, gy) - return union!(h, hxy) -end +abstract type AbstractTracer{P<:AbstractPattern} <: Real end #====================# # ConnectivityTracer # @@ -30,14 +14,14 @@ For a higher-level interface, refer to [`connectivity_pattern`](@ref). ## Fields $(TYPEDFIELDS) """ -struct ConnectivityTracer{I} <: AbstractTracer +struct ConnectivityTracer{P<:AbstractGradientPattern} <: AbstractTracer{P} "Sparse representation of connected inputs." - inputs::I + pattern::P "Indicator whether pattern in tracer contains only zeros." isempty::Bool - function ConnectivityTracer{I}(inputs::I, isempty::Bool=false) where {I} - return new{I}(inputs, isempty) + function ConnectivityTracer{P}(inputs::P, isempty::Bool=false) where {P} + return new{P}(inputs, isempty) end end @@ -45,12 +29,13 @@ end # Generic code expecting "regular" numbers `x` will sometimes convert them # by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`. # When this happens, we create a new empty tracer with no input pattern. -ConnectivityTracer{I}(::Real) where {I} = myempty(ConnectivityTracer{I}) -ConnectivityTracer{I}(t::ConnectivityTracer{I}) where {I} = t +ConnectivityTracer{P}(::Real) where {P} = myempty(ConnectivityTracer{P}) +ConnectivityTracer{P}(t::ConnectivityTracer{P}) where {P} = t ConnectivityTracer(t::ConnectivityTracer) = t -inputs(t::ConnectivityTracer) = t.inputs isemptytracer(t::ConnectivityTracer) = t.isempty +pattern(t::ConnectivityTracer) = t.pattern +inputs(t::ConnectivityTracer) = inputs(pattern(t)) function Base.show(io::IO, t::ConnectivityTracer) print(io, typeof(t)) @@ -77,23 +62,24 @@ For a higher-level interface, refer to [`jacobian_pattern`](@ref). ## Fields $(TYPEDFIELDS) """ -struct GradientTracer{G} <: AbstractTracer +struct GradientTracer{P<:AbstractGradientPattern} <: AbstractTracer{P} "Sparse representation of non-zero entries in the gradient." - gradient::G + pattern::P "Indicator whether gradient in tracer contains only zeros." isempty::Bool - function GradientTracer{G}(gradient::G, isempty::Bool=false) where {G} - return new{G}(gradient, isempty) + function GradientTracer{P}(gradient::P, isempty::Bool=false) where {P} + return new{P}(gradient, isempty) end end -GradientTracer{G}(::Real) where {G} = myempty(GradientTracer{G}) -GradientTracer{G}(t::GradientTracer{G}) where {G} = t +GradientTracer{P}(::Real) where {P} = myempty(GradientTracer{P}) +GradientTracer{P}(t::GradientTracer{P}) where {P} = t GradientTracer(t::GradientTracer) = t -gradient(t::GradientTracer) = t.gradient isemptytracer(t::GradientTracer) = t.isempty +pattern(t::GradientTracer) = t.pattern +gradient(t::GradientTracer) = gradient(pattern(t)) function Base.show(io::IO, t::GradientTracer) print(io, typeof(t)) @@ -120,26 +106,25 @@ For a higher-level interface, refer to [`hessian_pattern`](@ref). ## Fields $(TYPEDFIELDS) """ -struct HessianTracer{G,H} <: AbstractTracer +struct HessianTracer{P<:AbstractHessianPattern} <: AbstractTracer{P} "Sparse representation of non-zero entries in the gradient and the Hessian." - gradient::G - "Sparse representation of non-zero entries in the Hessian." - hessian::H + pattern::P "Indicator whether gradient and Hessian in tracer both contain only zeros." isempty::Bool - function HessianTracer{G,H}(gradient::G, hessian::H, isempty::Bool=false) where {G,H} - return new{G,H}(gradient, hessian, isempty) + function HessianTracer{P}(pattern::P, isempty::Bool=false) where {P} + return new{P}(pattern, isempty) end end -HessianTracer{G,H}(::Real) where {G,H} = myempty(HessianTracer{G,H}) -HessianTracer{G,H}(t::HessianTracer{G,H}) where {G,H} = t +HessianTracer{P}(::Real) where {P} = myempty(HessianTracer{P}) +HessianTracer{P}(t::HessianTracer{P}) where {P} = t HessianTracer(t::HessianTracer) = t -gradient(t::HessianTracer) = t.gradient -hessian(t::HessianTracer) = t.hessian isemptytracer(t::HessianTracer) = t.isempty +pattern(t::HessianTracer) = t.pattern +gradient(t::HessianTracer) = gradient(pattern(t)) +hessian(t::HessianTracer) = hessian(pattern(t)) function Base.show(io::IO, t::HessianTracer) print(io, typeof(t)) @@ -199,27 +184,31 @@ end # Utilities # #===========# -myempty(::Type{ConnectivityTracer{I}}) where {I} = ConnectivityTracer{I}(myempty(I), true) -myempty(::Type{GradientTracer{G}}) where {G} = GradientTracer{G}(myempty(G), true) -myempty(::Type{HessianTracer{G,H}}) where {G,H} = HessianTracer{G,H}(myempty(G), myempty(H), true) +myempty(::T) where {T<:AbstractTracer} = myempty(T) + +# myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this +myempty(::Type{T}) where {P,T<:ConnectivityTracer{P}} = T(myempty(P), true) +myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true) +myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true) + +seed(::T, i) where {T<:AbstractTracer} = seed(T, i) + +# seed(::Type{T}, i) where {P,T<:AbstractTracer{P}} = T(seed(P, i)) # JET complains about this +seed(::Type{T}, i) where {P,T<:ConnectivityTracer{P}} = T(seed(P, i)) +seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i)) +seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i)) """ create_tracer(T, index) where {T<:AbstractTracer} Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices. """ -function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T} - return Dual(primal, create_tracer(T, primal, index)) +function create_tracer(::Type{T}, ::Real, index::Integer) where {P,T<:AbstractTracer{P}} + return T(seed(P, index)) end -function create_tracer(::Type{ConnectivityTracer{I}}, ::Real, index::Integer) where {I} - return ConnectivityTracer{I}(seed(I, index)) -end -function create_tracer(::Type{GradientTracer{G}}, ::Real, index::Integer) where {G} - return GradientTracer{G}(seed(G, index)) -end -function create_tracer(::Type{HessianTracer{G,H}}, ::Real, index::Integer) where {G,H} - return HessianTracer{G,H}(seed(G, index), myempty(H)) +function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P,T} + return Dual(primal, create_tracer(T, primal, index)) end # Pretty-printing of Dual tracers diff --git a/test/brusselator.jl b/test/brusselator.jl index 5cadfd5e..6229ff20 100644 --- a/test/brusselator.jl +++ b/test/brusselator.jl @@ -6,12 +6,8 @@ using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using SparseConnectivityTracerBenchmarks.ODE: Brusselator! using Test -GRADIENT_TRACERS = ( - GradientTracer{BitSet}, - GradientTracer{Set{Int}}, - GradientTracer{DuplicateVector{Int}}, - GradientTracer{SortedVector{Int}}, -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") function test_brusselator(method::AbstractSparsityDetector) N = 6 @@ -26,4 +22,5 @@ end @testset "$T" for T in GRADIENT_TRACERS method = TracerSparsityDetector(; gradient_tracer_type=T) test_brusselator(method) + yield() end diff --git a/test/classification.jl b/test/classification.jl index f4aee69c..f25909f9 100644 --- a/test/classification.jl +++ b/test/classification.jl @@ -89,6 +89,7 @@ end correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for _ in 1:DEFAULT_TRIALS ) + yield() end end end; @@ -132,6 +133,7 @@ end op, random_first_input(op), random_second_input(op); atol=DEFAULT_ATOL ) for _ in 1:DEFAULT_TRIALS ) + yield() end end end; @@ -170,6 +172,7 @@ end correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for _ in 1:DEFAULT_TRIALS ) + yield() end end end; diff --git a/test/flux.jl b/test/flux.jl index e428f552..2cbf33ed 100644 --- a/test/flux.jl +++ b/test/flux.jl @@ -6,12 +6,8 @@ using SparseConnectivityTracer using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test -GRADIENT_TRACERS = ( - GradientTracer{BitSet}, - GradientTracer{Set{Int}}, - GradientTracer{DuplicateVector{Int}}, - GradientTracer{SortedVector{Int}}, -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") const INPUT_FLUX = reshape( [ diff --git a/test/nlpmodels.jl b/test/nlpmodels.jl index 940d2e78..87ec57c8 100644 --- a/test/nlpmodels.jl +++ b/test/nlpmodels.jl @@ -67,6 +67,7 @@ hess_inconsistencies = [] push!(hess_inconsistencies, (name, message)) end end + yield() end; if !isempty(jac_inconsistencies) || !isempty(hess_inconsistencies) diff --git a/test/runtests.jl b/test/runtests.jl index 89afe3b4..f9e70e64 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,11 +31,13 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @info "Testing formalities..." if VERSION >= v"1.10" @testset "Code formatting" begin + @info "...with JuliaFormatter.jl" @test JuliaFormatter.format( SparseConnectivityTracer; verbose=false, overwrite=false ) end @testset "Aqua tests" begin + @info "...with Aqua.jl" Aqua.test_all( SparseConnectivityTracer; ambiguities=false, @@ -45,6 +47,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") ) end @testset "JET tests" begin + @info "...with JET.jl" JET.test_package(SparseConnectivityTracer; target_defined_modules=true) end end diff --git a/test/test_arrays.jl b/test/test_arrays.jl index 90b224e6..336a6326 100644 --- a/test/test_arrays.jl +++ b/test/test_arrays.jl @@ -18,7 +18,8 @@ TEST_SQUARE_MATRICES = Dict( TEST_MATRICES = merge(TEST_SQUARE_MATRICES, Dict("`Matrix` (3×4)" => rand(3, 4))) S = BitSet -TG = GradientTracer{S} +P = IndexSetGradientPattern{Int,S} +TG = GradientTracer{P} # NOTE: we currently test for conservative patterns on array overloads # Changes making array overloads less convervative will break these tests, but are welcome! @@ -141,10 +142,10 @@ end end @testset "Matrix division" begin - t1 = TG(S([1, 3, 4])) - t2 = TG(S([2, 4])) - t3 = TG(S([8, 9])) - t4 = TG(S([8, 9])) + t1 = TG(P(S([1, 3, 4]))) + t2 = TG(P(S([2, 4]))) + t3 = TG(P(S([8, 9]))) + t4 = TG(P(S([8, 9]))) A = [t1 t2; t3 t4] s_out = S([1, 2, 3, 4, 8, 9]) @@ -154,10 +155,10 @@ end end @testset "Eigenvalues" begin - t1 = TG(S([1, 3, 4])) - t2 = TG(S([2, 4])) - t3 = TG(S([8, 9])) - t4 = TG(S([8, 9])) + t1 = TG(P(S([1, 3, 4]))) + t2 = TG(P(S([2, 4]))) + t3 = TG(P(S([8, 9]))) + t4 = TG(P(S([8, 9]))) A = [t1 t2; t3 t4] s_out = S([1, 2, 3, 4, 8, 9]) values, vectors = eigen(A) @@ -168,9 +169,9 @@ end end @testset "SparseMatrixCSC construction" begin - t1 = TG(S(1)) - t2 = TG(S(2)) - t3 = TG(S(3)) + t1 = TG(P(S(1))) + t2 = TG(P(S(2))) + t3 = TG(P(S(3))) SA = sparse([t1 t2; t3 0]) @test length(SA.nzval) == 3 diff --git a/test/test_connectivity.jl b/test/test_connectivity.jl index 46f263b8..5d6e6812 100644 --- a/test/test_connectivity.jl +++ b/test/test_connectivity.jl @@ -1,17 +1,12 @@ using SparseConnectivityTracer using SparseConnectivityTracer: ConnectivityTracer, Dual, MissingPrimalError, trace_input -using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using LinearAlgebra: det, dot, logdet using SpecialFunctions: erf, beta using NNlib: NNlib using Test -CONNECTIVITY_TRACERS = ( - ConnectivityTracer{BitSet}, - ConnectivityTracer{Set{Int}}, - ConnectivityTracer{DuplicateVector{Int}}, - ConnectivityTracer{SortedVector{Int}}, -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") NNLIB_ACTIVATIONS_S = ( NNlib.σ, @@ -121,6 +116,7 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) @test_throws TypeError connectivity_pattern( x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], T ) == [0 0 1 1;] + yield() end end @@ -133,5 +129,6 @@ end x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 3 2 4], T ) == [0 0 1 1] @test local_connectivity_pattern(x -> 0, 1, T) ≈ [0;;] + yield() end end diff --git a/test/test_constructors.jl b/test/test_constructors.jl index cb6f479c..470502c1 100644 --- a/test/test_constructors.jl +++ b/test/test_constructors.jl @@ -3,30 +3,10 @@ using SparseConnectivityTracer: AbstractTracer, ConnectivityTracer, GradientTracer, HessianTracer, Dual using SparseConnectivityTracer: inputs, primal, tracer, isemptytracer using SparseConnectivityTracer: myempty, name -using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test -CONNECTIVITY_TRACERS = ( - ConnectivityTracer{BitSet}, - ConnectivityTracer{Set{Int}}, - ConnectivityTracer{DuplicateVector{Int}}, - ConnectivityTracer{SortedVector{Int}}, -) - -GRADIENT_TRACERS = ( - GradientTracer{BitSet}, - GradientTracer{Set{Int}}, - GradientTracer{DuplicateVector{Int}}, - GradientTracer{SortedVector{Int}}, -) - -HESSIAN_TRACERS = ( - HessianTracer{BitSet,Set{Tuple{Int,Int}}}, - HessianTracer{Set{Int},Set{Tuple{Int,Int}}}, - HessianTracer{DuplicateVector{Int},DuplicateVector{Tuple{Int,Int}}}, - HessianTracer{SortedVector{Int},SortedVector{Tuple{Int,Int}}}, - # TODO: test on RecursiveSet -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") function test_nested_duals(::Type{T}) where {T<:AbstractTracer} # Putting Duals into Duals is prohibited diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 8f5aea5e..1f783fa5 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -1,18 +1,13 @@ using SparseConnectivityTracer using SparseConnectivityTracer: GradientTracer, Dual, MissingPrimalError, trace_input -using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using ADTypes: jacobian_sparsity using LinearAlgebra: det, dot, logdet using SpecialFunctions: erf, beta using NNlib: NNlib using Test -GRADIENT_TRACERS = ( - GradientTracer{BitSet}, - GradientTracer{Set{Int}}, - GradientTracer{DuplicateVector{Int}}, - GradientTracer{SortedVector{Int}}, -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") NNLIB_ACTIVATIONS_S = ( NNlib.σ, @@ -130,6 +125,7 @@ NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) @test_throws TypeError jacobian_sparsity( x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], method ) == [0 0 1 1;] + yield() end end @@ -254,5 +250,6 @@ end @test jacobian_sparsity(NNlib.softshrink, -1, method) ≈ [1;;] @test jacobian_sparsity(NNlib.softshrink, 0, method) ≈ [0;;] @test jacobian_sparsity(NNlib.softshrink, 1, method) ≈ [1;;] + yield() end end diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 6075e38d..68cf5f08 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -1,17 +1,11 @@ using SparseConnectivityTracer using SparseConnectivityTracer: Dual, HessianTracer, MissingPrimalError, trace_input, empty -using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using ADTypes: hessian_sparsity using SpecialFunctions: erf, beta using Test -HESSIAN_TRACERS = ( - HessianTracer{BitSet,Set{Tuple{Int,Int}}}, - HessianTracer{Set{Int},Set{Tuple{Int,Int}}}, - HessianTracer{DuplicateVector{Int},DuplicateVector{Tuple{Int,Int}}}, - HessianTracer{SortedVector{Int},SortedVector{Tuple{Int,Int}}}, - # TODO: test on RecursiveSet -) +# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +include("tracers_definitions.jl") @testset "Global Hessian" begin @testset "Default hessian_pattern" begin @@ -214,6 +208,7 @@ HESSIAN_TRACERS = ( 1 1 0 0 0 0 ] + yield() end end @@ -268,5 +263,6 @@ end @test hessian_sparsity(x -> (2//3)^zero(x), 1, method) ≈ [0;;] @test hessian_sparsity(x -> zero(x)^ℯ, 1, method) ≈ [0;;] @test hessian_sparsity(x -> ℯ^zero(x), 1, method) ≈ [0;;] + yield() end end diff --git a/test/tracers_definitions.jl b/test/tracers_definitions.jl new file mode 100644 index 00000000..c7a3514f --- /dev/null +++ b/test/tracers_definitions.jl @@ -0,0 +1,23 @@ +using SparseConnectivityTracer: + AbstractTracer, ConnectivityTracer, GradientTracer, HessianTracer, Dual +using SparseConnectivityTracer: IndexSetGradientPattern, IndexSetHessianPattern +using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector + +VECTOR_PATTERNS = ( + IndexSetGradientPattern{Int,BitSet}, + IndexSetGradientPattern{Int,Set{Int}}, + IndexSetGradientPattern{Int,DuplicateVector{Int}}, + IndexSetGradientPattern{Int,SortedVector{Int}}, +) + +HESSIAN_PATTERNS = ( + IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}}, + IndexSetHessianPattern{Int,Set{Int},Set{Tuple{Int,Int}}}, + IndexSetHessianPattern{Int,DuplicateVector{Int},DuplicateVector{Tuple{Int,Int}}}, + IndexSetHessianPattern{Int,SortedVector{Int},SortedVector{Tuple{Int,Int}}}, + # TODO: test on RecursiveSet +) + +CONNECTIVITY_TRACERS = (ConnectivityTracer{P} for P in VECTOR_PATTERNS) +GRADIENT_TRACERS = (GradientTracer{P} for P in VECTOR_PATTERNS) +HESSIAN_TRACERS = (HessianTracer{P} for P in HESSIAN_PATTERNS) From ac94586c854726370a0e464ee62074774295a929 Mon Sep 17 00:00:00 2001 From: Adrian Hill Date: Thu, 27 Jun 2024 18:37:18 +0200 Subject: [PATCH 4/4] Remove `ConnectivityTracer` and legacy API (#140) * Remove `ConnectivityTracer` * Remove legacy interface * Add CHANGELOG: this is our first actually breaking release in a while --- CHANGELOG.md | 44 +++ README.md | 43 ++- docs/src/api.md | 45 +-- ext/SparseConnectivityTracerNNlibExt.jl | 2 - ...seConnectivityTracerSpecialFunctionsExt.jl | 1 - src/SparseConnectivityTracer.jl | 10 +- src/adtypes.jl | 20 +- src/exceptions.jl | 8 +- src/interface.jl | 283 +----------------- src/operators.jl | 58 ---- src/overloads/arrays.jl | 6 - src/overloads/connectivity_tracer.jl | 202 ------------- src/overloads/ifelse_global.jl | 1 - src/overloads/overload_all.jl | 6 - src/patterns.jl | 3 +- src/tracers.jl | 84 +----- test/brusselator.jl | 2 +- test/flux.jl | 2 +- test/runtests.jl | 3 - test/test_arrays.jl | 11 +- test/test_connectivity.jl | 134 --------- test/test_constructors.jl | 29 +- test/test_gradient.jl | 2 +- test/test_hessian.jl | 12 +- test/tracers_definitions.jl | 4 +- 25 files changed, 132 insertions(+), 883 deletions(-) create mode 100644 CHANGELOG.md delete mode 100644 src/overloads/connectivity_tracer.jl delete mode 100644 test/test_connectivity.jl diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..336a9078 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,44 @@ +# SparseConnectivityTracer.jl + +## Version `v0.6.0` +* ![BREAKING][badge-breaking] Remove `ConnectivityTracer` ([#140][pr-140]) +* ![BREAKING][badge-breaking] Remove legacy interface ([#140][pr-140]) + * instead of `jacobian_pattern(f, x)`, use `jacobian_sparsity(f, x, TracerSparsityDetector())` + * instead of `hessian_pattern(f, x)`, use `hessian_sparsity(f, x, TracerSparsityDetector())` + * instead of `local_jacobian_pattern(f, x)`, use `jacobian_sparsity(f, x, TracerLocalSparsityDetector())` + * instead of `local_hessian_pattern(f, x)`, use `hessian_sparsity(f, x, TracerLocalSparsityDetector())` +* ![Bugfix][badge-bugfix] Remove overloads on `similar` to reduce amount of invalidations ([#132][pr-132]) +* ![Enhancement][badge-enhancement] Add array overloads ([#131][pr-131]) +* ![Enhancement][badge-enhancement] Generalize sparsity pattern representations ([#139][pr-139], [#119][pr-119]) +* ![Enhancement][badge-enhancement] Reduce allocations of new tracers ([#128][pr-128]) +* ![Enhancement][badge-enhancement] Reduce compile times ([#119][pr-119]) + +[pr-140]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/140 +[pr-139]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/139 +[pr-132]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/132 +[pr-131]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/131 +[pr-128]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/128 +[pr-126]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/126 +[pr-119]: https://github.com/adrhill/SparseConnectivityTracer.jl/pull/119 + + + +[badge-breaking]: https://img.shields.io/badge/BREAKING-red.svg +[badge-deprecation]: https://img.shields.io/badge/deprecation-orange.svg +[badge-feature]: https://img.shields.io/badge/feature-green.svg +[badge-enhancement]: https://img.shields.io/badge/enhancement-blue.svg +[badge-bugfix]: https://img.shields.io/badge/bugfix-purple.svg +[badge-security]: https://img.shields.io/badge/security-black.svg +[badge-experimental]: https://img.shields.io/badge/experimental-lightgrey.svg +[badge-maintenance]: https://img.shields.io/badge/maintenance-gray.svg +[badge-docs]: https://img.shields.io/badge/docs-orange.svg \ No newline at end of file diff --git a/README.md b/README.md index f1f9a0cf..eb3d4901 100644 --- a/README.md +++ b/README.md @@ -8,9 +8,6 @@ Fast Jacobian and Hessian sparsity detection via operator-overloading. -> [!WARNING] -> This package is in early development. Expect frequent breaking changes and refer to the stable documentation. - ## Installation To install this package, open the Julia REPL and run @@ -21,17 +18,19 @@ julia> ]add SparseConnectivityTracer ## Examples ### Jacobian -For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian of $f$ can be obtained -by computing a single forward-pass through `f`: +For functions `y = f(x)` and `f!(y, x)`, the sparsity pattern of the Jacobian can be obtained +by computing a single forward-pass through the function: ```julia-repl julia> using SparseConnectivityTracer +julia> detector = TracerSparsityDetector(); + julia> x = rand(3); julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])]; -julia> jacobian_pattern(f, x) +julia> jacobian_sparsity(f, x, detector) 3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries: 1 ⋅ ⋅ 1 1 ⋅ @@ -43,11 +42,13 @@ As a larger example, let's compute the sparsity pattern from a convolutional lay ```julia-repl julia> using SparseConnectivityTracer, Flux +julia> detector = TracerSparsityDetector(); + julia> x = rand(28, 28, 3, 1); julia> layer = Conv((3, 3), 3 => 2); -julia> jacobian_pattern(layer, x) +julia> jacobian_sparsity(layer, x, detector) 1352×2352 SparseArrays.SparseMatrixCSC{Bool, Int64} with 36504 stored entries: ⎡⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣄⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎤ ⎢⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣤⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⎥ @@ -64,7 +65,7 @@ julia> jacobian_pattern(layer, x) ⎣⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠙⢿⣦⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠛⢿⣦⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠻⣷⣄⎦ ``` -The type of index set `S` that is internally used to keep track of connectivity can be specified via `jacobian_pattern(f, x, S)`, defaulting to `BitSet`. +The type of index set `S` that is internally used to keep track of connectivity can be specified via `jacobian_sparsity(f, x, S)`, defaulting to `BitSet`. For high-dimensional functions, `Set{Int64}` can be more efficient . ### Hessian @@ -77,7 +78,7 @@ julia> x = rand(5); julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; -julia> hessian_pattern(f, x) +julia> hessian_sparsity(f, x, detector) 5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ @@ -87,7 +88,7 @@ julia> hessian_pattern(f, x) julia> g(x) = f(x) + x[2]^x[5]; -julia> hessian_pattern(g, x) +julia> hessian_sparsity(g, x, detector) 5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 7 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 1 ⋅ 1 @@ -100,23 +101,27 @@ For more detailled examples, take a look at the [documentation](https://adrianhi ### Local tracing -The functions `jacobian_pattern`, `hessian_pattern` and `connectivity_pattern` return conservative sparsity patterns over the entire input domain of `x`. -They are not compatible with functions that require information about the primal values of a computation (e.g. `iszero`, `>`, `==`). +`TracerSparsityDetector` returns conservative sparsity patterns over the entire input domain of `x`. +It is not compatible with functions that require information about the primal values of a computation (e.g. `iszero`, `>`, `==`). -To compute a less conservative sparsity pattern at an input point `x`, use `local_jacobian_pattern`, `local_hessian_pattern` and `local_connectivity_pattern` instead. -Note that these patterns depend on the input `x`: +To compute a less conservative sparsity pattern at an input point `x`, use `TracerLocalSparsityDetector` instead. +Note that patterns computed with `TracerLocalSparsityDetector` depend on the input `x`: ```julia-repl +julia> using SparseConnectivityTracer + +julia> detector = TracerLocalSparsityDetector(); + julia> f(x) = ifelse(x[2] < x[3], x[1] ^ x[2], x[3] * x[4]); -julia> local_hessian_pattern(f, [1 2 3 4]) +julia> hessian_sparsity(f, [1 2 3 4], detector) 4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries: 1 1 ⋅ ⋅ 1 1 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ -julia> local_hessian_pattern(f, [1 3 2 4]) +julia> hessian_sparsity(f, [1 3 2 4], detector) 4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ @@ -124,6 +129,12 @@ julia> local_hessian_pattern(f, [1 3 2 4]) ⋅ ⋅ 1 ⋅ ``` +## ADTypes.jl compatibility +SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface for [sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector), +making it compatible with [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl)'s [sparse automatic differentiation](https://gdalle.github.io/DifferentiationInterface.jl/DifferentiationInterface/stable/tutorial2/) functionality. + +In fact, the functions `jacobian_sparsity` and `hessian_sparsity` are re-exported from ADTypes. + ## Related packages * [SparseDiffTools.jl](https://github.com/JuliaDiff/SparseDiffTools.jl): automatic sparsity detection via Symbolics.jl and Cassette.jl * [SparsityTracing.jl](https://github.com/PALEOtoolkit/SparsityTracing.jl): automatic Jacobian sparsity detection using an algorithm based on SparsLinC by Bischof et al. (1996) diff --git a/docs/src/api.md b/docs/src/api.md index 731dbf38..8d8564c2 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -10,55 +10,28 @@ CollapsedDocStrings = true ## ADTypes Interface -For package developers, we recommend using the [ADTypes.jl](https://github.com/SciML/ADTypes.jl) interface. +SparseConnectivityTracer uses [ADTypes.jl](https://github.com/SciML/ADTypes.jl)'s interface for [sparsity detection](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector). +In fact, the functions `jacobian_sparsity` and `hessian_sparsity` are re-exported from ADTypes. -To compute global sparsity patterns of `f(x)` over the entire input domain `x`, use +To compute **global** sparsity patterns of `f(x)` over the entire input domain `x`, use ```@docs TracerSparsityDetector ``` -To compute local sparsity patterns of `f(x)` at a specific input `x`, use +To compute **local** sparsity patterns of `f(x)` at a specific input `x`, use ```@docs TracerLocalSparsityDetector ``` -## Legacy Interface - -### Global sparsity - -The following functions can be used to compute global sparsity patterns of `f(x)` over the entire input domain `x`. - -```@docs -connectivity_pattern -jacobian_pattern -hessian_pattern -``` - -[`TracerSparsityDetector`](@ref) is the ADTypes equivalent of these functions. - -### Local sparsity - -The following functions can be used to compute local sparsity patterns of `f(x)` at a specific input `x`. -Note that these patterns are sparser than global patterns but need to be recomputed when `x` changes. - -```@docs -local_connectivity_pattern -local_jacobian_pattern -local_hessian_pattern -``` - -[`TracerLocalSparsityDetector`](@ref) is the ADTypes equivalent of these functions. - ## Internals !!! warning Internals may change without warning in a future release of SparseConnectivityTracer. SparseConnectivityTracer works by pushing `Real` number types called tracers through generic functions. -Currently, three tracer types are provided: +Currently, two tracer types are provided: ```@docs -SparseConnectivityTracer.ConnectivityTracer SparseConnectivityTracer.GradientTracer SparseConnectivityTracer.HessianTracer ``` @@ -69,11 +42,3 @@ which keeps track of the primal computation and allows tracing through compariso ```@docs SparseConnectivityTracer.Dual ``` - -We also define alternative pseudo-set types that can deliver faster `union`: - -```@docs -SparseConnectivityTracer.DuplicateVector -SparseConnectivityTracer.RecursiveSet -SparseConnectivityTracer.SortedVector -``` diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index d75de566..cff4f152 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -39,7 +39,6 @@ ops_1_to_1_s = ( for op in ops_1_to_1_s T = typeof(op) - @eval SCT.is_infl_zero_global(::$T) = false @eval SCT.is_der1_zero_global(::$T) = false @eval SCT.is_der2_zero_global(::$T) = false end @@ -69,7 +68,6 @@ ops_1_to_1_f = ( for op in ops_1_to_1_f T = typeof(op) - @eval SCT.is_infl_zero_global(::$T) = false @eval SCT.is_der1_zero_global(::$T) = false @eval SCT.is_der2_zero_global(::$T) = true end diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index bd794cbd..a95a0f5e 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -61,7 +61,6 @@ ops_1_to_1_s = ( for op in ops_1_to_1_s T = typeof(op) - @eval SCT.is_infl_zero_global(::$T) = false @eval SCT.is_der1_zero_global(::$T) = false @eval SCT.is_der2_zero_global(::$T) = false end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index a2791a79..47d421e4 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -1,6 +1,6 @@ module SparseConnectivityTracer -using ADTypes: ADTypes +using ADTypes: ADTypes, jacobian_sparsity, hessian_sparsity using Compat: Returns using SparseArrays: SparseArrays using SparseArrays: sparse @@ -25,7 +25,6 @@ include("exceptions.jl") include("operators.jl") include("overloads/conversion.jl") -include("overloads/connectivity_tracer.jl") include("overloads/gradient_tracer.jl") include("overloads/hessian_tracer.jl") include("overloads/ifelse_global.jl") @@ -36,13 +35,10 @@ include("overloads/arrays.jl") include("interface.jl") include("adtypes.jl") -export connectivity_pattern, local_connectivity_pattern -export jacobian_pattern, local_jacobian_pattern -export hessian_pattern, local_hessian_pattern - -# ADTypes interface export TracerSparsityDetector export TracerLocalSparsityDetector +# Reexport ADTypes interface +export jacobian_sparsity, hessian_sparsity function __init__() @static if !isdefined(Base, :get_extension) diff --git a/src/adtypes.jl b/src/adtypes.jl index 1cc3ed08..a47b4c56 100644 --- a/src/adtypes.jl +++ b/src/adtypes.jl @@ -23,7 +23,7 @@ julia> using ADTypes, SparseConnectivityTracer julia> f(x) = x[1] + x[2]*x[3] + 1/x[4]; -julia> ADTypes.hessian_sparsity(f, rand(4), TracerSparsityDetector()) +julia> hessian_sparsity(f, rand(4), TracerSparsityDetector()) 4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ 1 ⋅ @@ -46,15 +46,15 @@ function TracerSparsityDetector(; end function ADTypes.jacobian_sparsity(f, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH} - return jacobian_pattern(f, x, TG) + return _jacobian_sparsity(f, x, TG) end function ADTypes.jacobian_sparsity(f!, y, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH} - return jacobian_pattern(f!, y, x, TG) + return _jacobian_sparsity(f!, y, x, TG) end function ADTypes.hessian_sparsity(f, x, ::TracerSparsityDetector{TG,TH}) where {TG,TH} - return hessian_pattern(f, x, TH) + return _hessian_sparsity(f, x, TH) end """ @@ -72,13 +72,13 @@ julia> using ADTypes, SparseConnectivityTracer julia> f(x) = x[1] > x[2] ? x[1:3] : x[2:4]; -julia> ADTypes.jacobian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector()) +julia> jacobian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector()) 3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ ⋅ ⋅ 1 -julia> ADTypes.jacobian_sparsity(f, [2.0, 1.0, 3.0, 4.0], TracerLocalSparsityDetector()) +julia> jacobian_sparsity(f, [2.0, 1.0, 3.0, 4.0], TracerLocalSparsityDetector()) 3×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: 1 ⋅ ⋅ ⋅ ⋅ 1 ⋅ ⋅ @@ -90,7 +90,7 @@ julia> using ADTypes, SparseConnectivityTracer julia> f(x) = x[1] + max(x[2], x[3]) * x[3] + 1/x[4]; -julia> ADTypes.hessian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector()) +julia> hessian_sparsity(f, [1.0, 2.0, 3.0, 4.0], TracerLocalSparsityDetector()) 4×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ @@ -113,15 +113,15 @@ function TracerLocalSparsityDetector(; end function ADTypes.jacobian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH} - return local_jacobian_pattern(f, x, TG) + return _local_jacobian_sparsity(f, x, TG) end function ADTypes.jacobian_sparsity( f!, y, x, ::TracerLocalSparsityDetector{TG,TH} ) where {TG,TH} - return local_jacobian_pattern(f!, y, x, TG) + return _local_jacobian_sparsity(f!, y, x, TG) end function ADTypes.hessian_sparsity(f, x, ::TracerLocalSparsityDetector{TG,TH}) where {TG,TH} - return local_hessian_pattern(f, x, TH) + return _local_hessian_sparsity(f, x, TH) end diff --git a/src/exceptions.jl b/src/exceptions.jl index 714f46ca..8b7deede 100644 --- a/src/exceptions.jl +++ b/src/exceptions.jl @@ -7,13 +7,7 @@ function Base.showerror(io::IO, e::MissingPrimalError) println(io, "Function ", e.fn, " requires primal value(s).") print( io, - "A dual-number tracer for local sparsity detection can be used via `", - str_local_pattern_fn(e.tracer), - "`.", + "A dual-number tracer for local sparsity detection can be used via `TracerLocalSparsityDetector`.", ) return nothing end - -str_local_pattern_fn(::ConnectivityTracer) = "local_connectivity_pattern" -str_local_pattern_fn(::GradientTracer) = "local_jacobian_pattern" -str_local_pattern_fn(::HessianTracer) = "local_hessian_pattern" diff --git a/src/interface.jl b/src/interface.jl index d7e80138..0b5a98a5 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,4 +1,3 @@ -const DEFAULT_CONNECTIVITY_TRACER = ConnectivityTracer{IndexSetGradientPattern{Int,BitSet}} const DEFAULT_GRADIENT_TRACER = GradientTracer{IndexSetGradientPattern{Int,BitSet}} const DEFAULT_HESSIAN_TRACER = HessianTracer{ IndexSetHessianPattern{Int,BitSet,Set{Tuple{Int,Int}}} @@ -14,7 +13,7 @@ const DEFAULT_HESSIAN_TRACER = HessianTracer{ Enumerates input indices and constructs the specified type `T` of tracer. -Supports [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref). +Supports [`GradientTracer`](@ref), [`HessianTracer`](@ref) and [`Dual`](@ref). """ trace_input(::Type{T}, x) where {T<:Union{AbstractTracer,Dual}} = trace_input(T, x, 1) @@ -50,207 +49,28 @@ to_array(x::AbstractArray) = x _tracer_or_number(x::Real) = x _tracer_or_number(d::Dual) = tracer(d) -#====================# -# ConnectivityTracer # -#====================# - -""" - connectivity_pattern(f, x) - connectivity_pattern(f, x, T) - -Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - -The type of `ConnectivityTracer` can be specified as an optional argument and defaults to `$DEFAULT_CONNECTIVITY_TRACER`. - -## Example - -```jldoctest -julia> x = rand(3); - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])]; - -julia> connectivity_pattern(f, x) -3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries: - 1 ⋅ ⋅ - 1 1 ⋅ - ⋅ ⋅ 1 -``` -""" -function connectivity_pattern( - f, x, ::Type{T}=DEFAULT_CONNECTIVITY_TRACER -) where {T<:ConnectivityTracer} - xt, yt = trace_function(T, f, x) - return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) -end - -""" - connectivity_pattern(f!, y, x) - connectivity_pattern(f!, y, x, T) - -Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - -The type of `ConnectivityTracer` can be specified as an optional argument and defaults to `$DEFAULT_CONNECTIVITY_TRACER`. -""" -function connectivity_pattern( - f!, y, x, ::Type{T}=DEFAULT_CONNECTIVITY_TRACER -) where {T<:ConnectivityTracer} - xt, yt = trace_function(T, f!, y, x) - return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) -end - -""" - local_connectivity_pattern(f, x) - local_connectivity_pattern(f, x, P) - -Enumerates inputs `x` and primal outputs `y = f(x)` and returns sparse matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - -Unlike [`connectivity_pattern`](@ref), this function supports control flow and comparisons. - -The type of `ConnectivityTracer` can be specified as an optional argument and defaults to `$DEFAULT_CONNECTIVITY_TRACER`. - -## Example - -```jldoctest -julia> f(x) = ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]); - -julia> x = [1 2 3 4]; - -julia> local_connectivity_pattern(f, x) -1×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: - 1 1 ⋅ ⋅ - -julia> x = [1 3 2 4]; - -julia> local_connectivity_pattern(f, x) -1×4 SparseArrays.SparseMatrixCSC{Bool, Int64} with 2 stored entries: - ⋅ ⋅ 1 1 -``` -""" -function local_connectivity_pattern( - f, x, ::Type{T}=DEFAULT_CONNECTIVITY_TRACER -) where {T<:ConnectivityTracer} - D = Dual{eltype(x),T} - xt, yt = trace_function(D, f, x) - return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) -end - -""" - local_connectivity_pattern(f!, y, x) - local_connectivity_pattern(f!, y, x, T) - -Enumerates inputs `x` and primal outputs `y` after `f!(y, x)` and returns sparse matrix `C` of size `(m, n)` -where `C[i, j]` is true if the compute graph connects the `i`-th entry in `y` to the `j`-th entry in `x`. - -Unlike [`connectivity_pattern`](@ref), this function supports control flow and comparisons. - - -The type of `ConnectivityTracer` can be specified as an optional argument and defaults to `$DEFAULT_CONNECTIVITY_TRACER`. -""" -function local_connectivity_pattern( - f!, y, x, ::Type{T}=DEFAULT_CONNECTIVITY_TRACER -) where {T<:ConnectivityTracer} - D = Dual{eltype(x),T} - xt, yt = trace_function(D, f!, y, x) - return connectivity_pattern_to_mat(to_array(xt), to_array(yt)) -end - -function connectivity_pattern_to_mat( - xt::AbstractArray{T}, yt::AbstractArray{<:Real} -) where {T<:ConnectivityTracer} - n, m = length(xt), length(yt) - I = Int[] # row indices - J = Int[] # column indices - V = Bool[] # values - for (i, y) in enumerate(yt) - if y isa T && !isemptytracer(y) - for j in inputs(y) - push!(I, i) - push!(J, j) - push!(V, true) - end - end - end - return sparse(I, J, V, m, n) -end - -function connectivity_pattern_to_mat( - xt::AbstractArray{D}, yt::AbstractArray{<:Real} -) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - return connectivity_pattern_to_mat(tracer.(xt), _tracer_or_number.(yt)) -end - #================# # GradientTracer # #================# -""" - jacobian_pattern(f, x) - jacobian_pattern(f, x, T) - -Compute the sparsity pattern of the Jacobian of `y = f(x)`. - -The type of `GradientTracer` can be specified as an optional argument and defaults to `$DEFAULT_CONNECTIVITY_TRACER`. - -## Example - -```jldoctest -julia> x = rand(3); - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sign(x[3])]; - -julia> jacobian_pattern(f, x) -3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: - 1 ⋅ ⋅ - 1 1 ⋅ - ⋅ ⋅ ⋅ -``` -""" -function jacobian_pattern(f, x, ::Type{T}=DEFAULT_GRADIENT_TRACER) where {T<:GradientTracer} +# Compute the sparsity pattern of the Jacobian of `y = f(x)`. +function _jacobian_sparsity( + f, x, ::Type{T}=DEFAULT_GRADIENT_TRACER +) where {T<:GradientTracer} xt, yt = trace_function(T, f, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end -""" - jacobian_pattern(f!, y, x) - jacobian_pattern(f!, y, x, T) - -Compute the sparsity pattern of the Jacobian of `f!(y, x)`. - -The type of `GradientTracer` can be specified as an optional argument and defaults to `$DEFAULT_GRADIENT_TRACER`. -""" -function jacobian_pattern( +# Compute the sparsity pattern of the Jacobian of `f!(y, x)`. +function _jacobian_sparsity( f!, y, x, ::Type{T}=DEFAULT_GRADIENT_TRACER ) where {T<:GradientTracer} xt, yt = trace_function(T, f!, y, x) return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end -""" - local_jacobian_pattern(f, x) - local_jacobian_pattern(f, x, T) - -Compute the local sparsity pattern of the Jacobian of `y = f(x)` at `x`. - -The type of `GradientTracer` can be specified as an optional argument and defaults to `$DEFAULT_GRADIENT_TRACER`. - -## Example - -```jldoctest -julia> x = [1.0, 2.0, 3.0]; - -julia> f(x) = [x[1]^2, 2 * x[1] * x[2]^2, max(x[2],x[3])]; - -julia> local_jacobian_pattern(f, x) -3×3 SparseArrays.SparseMatrixCSC{Bool, Int64} with 4 stored entries: - 1 ⋅ ⋅ - 1 1 ⋅ - ⋅ ⋅ 1 -``` -""" -function local_jacobian_pattern( +# Compute the local sparsity pattern of the Jacobian of `y = f(x)` at `x`. +function _local_jacobian_sparsity( f, x, ::Type{T}=DEFAULT_GRADIENT_TRACER ) where {T<:GradientTracer} D = Dual{eltype(x),T} @@ -258,15 +78,8 @@ function local_jacobian_pattern( return jacobian_pattern_to_mat(to_array(xt), to_array(yt)) end -""" - local_jacobian_pattern(f!, y, x) - local_jacobian_pattern(f!, y, x, T) - -Compute the local sparsity pattern of the Jacobian of `f!(y, x)` at `x`. - -The type of `GradientTracer` can be specified as an optional argument and defaults to `$DEFAULT_GRADIENT_TRACER`. -""" -function local_jacobian_pattern( +# Compute the local sparsity pattern of the Jacobian of `f!(y, x)` at `x`. +function _local_jacobian_sparsity( f!, y, x, ::Type{T}=DEFAULT_GRADIENT_TRACER ) where {T<:GradientTracer} D = Dual{eltype(x),T} @@ -303,80 +116,14 @@ end # HessianTracer # #===============# -""" - hessian_pattern(f, x) - hessian_pattern(f, x, T) - -Computes the sparsity pattern of the Hessian of a scalar function `y = f(x)`. - -The type of `HessianTracer` can be specified as an optional argument and defaults to `$DEFAULT_HESSIAN_TRACER`. - -## Example - -```jldoctest -julia> x = rand(5); - -julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + 1*x[5]; - -julia> hessian_pattern(f, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 3 stored entries: - ⋅ ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ 1 ⋅ ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - ⋅ ⋅ ⋅ ⋅ ⋅ - -julia> g(x) = f(x) + x[2]^x[5]; - -julia> hessian_pattern(g, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 7 stored entries: - ⋅ ⋅ ⋅ ⋅ ⋅ - ⋅ 1 1 ⋅ 1 - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - ⋅ 1 ⋅ ⋅ 1 -``` -""" -function hessian_pattern(f, x, ::Type{T}=DEFAULT_HESSIAN_TRACER) where {T<:HessianTracer} +# Compute the sparsity pattern of the Hessian of a scalar function `y = f(x)`. +function _hessian_sparsity(f, x, ::Type{T}=DEFAULT_HESSIAN_TRACER) where {T<:HessianTracer} xt, yt = trace_function(T, f, x) return hessian_pattern_to_mat(to_array(xt), yt) end -""" - local_hessian_pattern(f, x) - local_hessian_pattern(f, x, T) - -Computes the local sparsity pattern of the Hessian of a scalar function `y = f(x)` at `x`. - -The type of `HessianTracer` can be specified as an optional argument and defaults to `$DEFAULT_HESSIAN_TRACER`. - -## Example - -```jldoctest -julia> x = [1.0 3.0 5.0 1.0 2.0]; - -julia> f(x) = x[1] + x[2]*x[3] + 1/x[4] + x[2] * max(x[1], x[5]); - -julia> local_hessian_pattern(f, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries: - ⋅ ⋅ ⋅ ⋅ ⋅ - ⋅ ⋅ 1 ⋅ 1 - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - -julia> x = [4.0 3.0 5.0 1.0 2.0]; - -julia> local_hessian_pattern(f, x) -5×5 SparseArrays.SparseMatrixCSC{Bool, Int64} with 5 stored entries: - ⋅ 1 ⋅ ⋅ ⋅ - 1 ⋅ 1 ⋅ ⋅ - ⋅ 1 ⋅ ⋅ ⋅ - ⋅ ⋅ ⋅ 1 ⋅ - ⋅ ⋅ ⋅ ⋅ ⋅ -``` -""" -function local_hessian_pattern( +# Compute the local sparsity pattern of the Hessian of a scalar function `y = f(x)` at `x`. +function _local_hessian_sparsity( f, x, ::Type{T}=DEFAULT_HESSIAN_TRACER ) where {T<:HessianTracer} D = Dual{eltype(x),T} diff --git a/src/operators.jl b/src/operators.jl index e0a82751..22055a8e 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -12,12 +12,10 @@ ##=================================# # Operators for functions f: ℝ → ℝ # #==================================# -function is_infl_zero_global end function is_der1_zero_global end function is_der2_zero_global end # Fallbacks for local derivatives: -is_infl_zero_local(f::F, x) where {F} = is_infl_zero_global(f) is_der1_zero_local(f::F, x) where {F} = is_der1_zero_global(f) is_der2_zero_local(f::F, x) where {F} = is_der2_zero_global(f) @@ -53,7 +51,6 @@ ops_1_to_1_s = ( ) for op in ops_1_to_1_s T = typeof(op) - @eval is_infl_zero_global(::$T) = false @eval is_der1_zero_global(::$T) = false @eval is_der2_zero_global(::$T) = false end @@ -78,7 +75,6 @@ ops_1_to_1_f = ( ) for op in ops_1_to_1_f T = typeof(op) - @eval is_infl_zero_global(::$T) = false @eval is_der1_zero_global(::$T) = false @eval is_der2_zero_global(::$T) = true end @@ -93,7 +89,6 @@ ops_1_to_1_z = ( ) for op in ops_1_to_1_z T = typeof(op) - @eval is_infl_zero_global(::$T) = false @eval is_der1_zero_global(::$T) = true @eval is_der2_zero_global(::$T) = true end @@ -109,7 +104,6 @@ ops_1_to_1_i = ( ) for op in ops_1_to_1_i T = typeof(op) - @eval is_infl_zero_global(::$T) = true @eval is_der1_zero_global(::$T) = true @eval is_der2_zero_global(::$T) = true end @@ -125,8 +119,6 @@ ops_1_to_1 = union( # Operators for functions f: ℝ² → ℝ # #===================================# -function is_infl_arg1_zero_global end -function is_infl_arg2_zero_global end function is_der1_arg1_zero_global end function is_der2_arg1_zero_global end function is_der1_arg2_zero_global end @@ -134,8 +126,6 @@ function is_der2_arg2_zero_global end function is_der_cross_zero_global end # Fallbacks for local derivatives: -is_infl_arg1_zero_local(f::F, x, y) where {F} = is_infl_arg1_zero_global(f) -is_infl_arg2_zero_local(f::F, x, y) where {F} = is_infl_arg2_zero_global(f) is_der1_arg1_zero_local(f::F, x, y) where {F} = is_der1_arg1_zero_global(f) is_der2_arg1_zero_local(f::F, x, y) where {F} = is_der2_arg1_zero_global(f) is_der1_arg2_zero_local(f::F, x, y) where {F} = is_der1_arg2_zero_global(f) @@ -153,8 +143,6 @@ ops_2_to_1_ssc = ( ) for op in ops_2_to_1_ssc T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = false @eval is_der1_arg2_zero_global(::$T) = false @@ -172,8 +160,6 @@ ops_2_to_1_ssz = () #= for op in ops_2_to_1_ssz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der1_arg2_zero_global(::$T) = false @@ -192,8 +178,6 @@ ops_2_to_1_sfc = () #= for op in ops_2_to_1_sfc T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = false @eval is_der1_arg2_zero_global(::$T) = false @@ -212,8 +196,6 @@ ops_2_to_1_sfz = () #= for op in ops_2_to_1_sfz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = false @eval is_der1_arg2_zero_global(::$T) = false @@ -234,8 +216,6 @@ ops_2_to_1_fsc = ( ) for op in ops_2_to_1_fsc T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -256,8 +236,6 @@ ops_2_to_1_fsz = () #= for op in ops_2_to_1_fsz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -277,8 +255,6 @@ ops_2_to_1_ffc = ( ) for op in ops_2_to_1_ffc T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -303,8 +279,6 @@ ops_2_to_1_ffz = ( ) for op in ops_2_to_1_ffz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -330,8 +304,6 @@ ops_2_to_1_szz = () #= for op in ops_2_to_1_szz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = false @eval is_der1_arg2_zero_global(::$T) = true @@ -350,8 +322,6 @@ ops_2_to_1_zsz = () #= for op in ops_2_to_1_zsz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = true @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -371,8 +341,6 @@ ops_2_to_1_fzz = ( ) for op in ops_2_to_1_fzz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = false @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = true @@ -390,8 +358,6 @@ ops_2_to_1_zfz = () #= for op in ops_2_to_1_zfz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = true @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = false @@ -412,8 +378,6 @@ ops_2_to_1_zzz = ( ) for op in ops_2_to_1_zzz T = typeof(op) - @eval is_infl_arg1_zero_global(::$T) = false - @eval is_infl_arg2_zero_global(::$T) = false @eval is_der1_arg1_zero_global(::$T) = true @eval is_der2_arg1_zero_global(::$T) = true @eval is_der1_arg2_zero_global(::$T) = true @@ -448,16 +412,12 @@ ops_2_to_1 = union( # Operators for functions f: ℝ → ℝ² # #===================================# -function is_infl_out1_zero_global end -function is_infl_out2_zero_global end function is_der1_out1_zero_global end function is_der2_out1_zero_global end function is_der1_out2_zero_global end function is_der2_out2_zero_global end # Fallbacks for local derivatives: -is_infl_out1_zero_local(f::F, x) where {F} = is_infl_out1_zero_global(f) -is_infl_out2_zero_local(f::F, x) where {F} = is_infl_out2_zero_global(f) is_der1_out1_zero_local(f::F, x) where {F} = is_der1_out1_zero_global(f) is_der2_out1_zero_local(f::F, x) where {F} = is_der2_out1_zero_global(f) is_der1_out2_zero_local(f::F, x) where {F} = is_der1_out2_zero_global(f) @@ -475,8 +435,6 @@ ops_1_to_2_ss = ( ) for op in ops_1_to_2_ss T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = false @eval is_der1_out2_zero_global(::$T) = false @@ -492,8 +450,6 @@ ops_1_to_2_sf = () #= for op in ops_1_to_2_sf T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = false @eval is_der1_out2_zero_global(::$T) = false @@ -510,8 +466,6 @@ ops_1_to_2_sz = () #= for op in ops_1_to_2_sz T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = false @eval is_der1_out2_zero_global(::$T) = true @@ -528,8 +482,6 @@ ops_1_to_2_fs = () #= for op in ops_1_to_2_fs T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = false @@ -546,8 +498,6 @@ ops_1_to_2_ff = () #= for op in ops_1_to_2_ff T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = false @@ -566,8 +516,6 @@ ops_1_to_2_fz = ( #= for op in ops_1_to_2_fz T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = false @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = true @@ -584,8 +532,6 @@ ops_1_to_2_zs = () #= for op in ops_1_to_2_zs T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = true @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = false @@ -602,8 +548,6 @@ ops_1_to_2_zf = () #= for op in ops_1_to_2_zf T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = true @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = false @@ -620,8 +564,6 @@ ops_1_to_2_zz = () #= for op in ops_1_to_2_zz T = typeof(op) - @eval is_infl_out1_zero_global(::$T) = false - @eval is_infl_out2_zero_global(::$T) = false @eval is_der1_out1_zero_global(::$T) = true @eval is_der2_out1_zero_global(::$T) = true @eval is_der1_out2_zero_global(::$T) = true diff --git a/src/overloads/arrays.jl b/src/overloads/arrays.jl index 3fc5f4b0..1b3a6b19 100644 --- a/src/overloads/arrays.jl +++ b/src/overloads/arrays.jl @@ -14,9 +14,6 @@ function second_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} return reduce(second_order_or, ts; init=myempty(T)) end -function second_order_or(a::T, b::T) where {T<:ConnectivityTracer} - return connectivity_tracer_2_to_1(a, b, false, false) -end function second_order_or(a::T, b::T) where {T<:GradientTracer} return gradient_tracer_2_to_1(a, b, false, false) end @@ -39,9 +36,6 @@ function first_order_or(ts::AbstractArray{T}) where {T<:AbstractTracer} # TODO: improve performance return reduce(first_order_or, ts; init=myempty(T)) end -function first_order_or(a::T, b::T) where {T<:ConnectivityTracer} - return connectivity_tracer_2_to_1(a, b, false, false) -end function first_order_or(a::T, b::T) where {T<:GradientTracer} return gradient_tracer_2_to_1(a, b, false, false) end diff --git a/src/overloads/connectivity_tracer.jl b/src/overloads/connectivity_tracer.jl deleted file mode 100644 index 6b58073a..00000000 --- a/src/overloads/connectivity_tracer.jl +++ /dev/null @@ -1,202 +0,0 @@ -## 1-to-1 - -@noinline function connectivity_tracer_1_to_1( - t::T, is_infl_zero::Bool -) where {T<:ConnectivityTracer} - if is_infl_zero && !isemptytracer(t) - return myempty(T) - else - return t - end -end - -function overload_connectivity_1_to_1(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(t::T) where {T<:$SCT.ConnectivityTracer} - is_infl_zero = $SCT.is_infl_zero_global($M.$op) - return $SCT.connectivity_tracer_1_to_1(t, is_infl_zero) - end - end -end - -function overload_connectivity_1_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(d::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out = $M.$op(x) - - t = $SCT.tracer(d) - is_infl_zero = $SCT.is_infl_zero_local($M.$op, x) - t_out = $SCT.connectivity_tracer_1_to_1(t, is_infl_zero) - return $SCT.Dual(p_out, t_out) - end - end -end - -## 2-to-1 - -@noinline function connectivity_tracer_2_to_1( - tx::T, ty::T, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool -) where {T<:ConnectivityTracer} - # TODO: add tests for isempty - if tx.isempty && ty.isempty - return tx # empty tracer - elseif ty.isempty - return connectivity_tracer_1_to_1(tx, is_infl_arg1_zero) - elseif tx.isempty - return connectivity_tracer_1_to_1(ty, is_infl_arg2_zero) - else - i_out = connectivity_tracer_2_to_1_inner( - pattern(tx), pattern(ty), is_infl_arg1_zero, is_infl_arg2_zero - ) - return T(i_out) # return tracer - end -end - -function connectivity_tracer_2_to_1_inner( - px::P, py::P, is_infl_arg1_zero::Bool, is_infl_arg2_zero::Bool -) where {P<:IndexSetGradientPattern} - if is_infl_arg1_zero && is_infl_arg2_zero - return myempty(P) - elseif !is_infl_arg1_zero && is_infl_arg2_zero - return px - elseif is_infl_arg1_zero && !is_infl_arg2_zero - return py - else - return P(union(set(px), set(py))) # return pattern - end -end - -function overload_connectivity_2_to_1(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(tx::T, ty::T) where {T<:$SCT.ConnectivityTracer} - is_infl_arg1_zero = $SCT.is_infl_arg1_zero_global($M.$op) - is_infl_arg2_zero = $SCT.is_infl_arg2_zero_global($M.$op) - return $SCT.connectivity_tracer_2_to_1( - tx, ty, is_infl_arg1_zero, is_infl_arg2_zero - ) - end - - function $M.$op(tx::$SCT.ConnectivityTracer, ::Real) - is_infl_arg1_zero = $SCT.is_infl_arg1_zero_global($M.$op) - return $SCT.connectivity_tracer_1_to_1(tx, is_infl_arg1_zero) - end - - function $M.$op(::Real, ty::$SCT.ConnectivityTracer) - is_infl_arg2_zero = $SCT.is_infl_arg2_zero_global($M.$op) - return $SCT.connectivity_tracer_1_to_1(ty, is_infl_arg2_zero) - end - end -end - -function overload_connectivity_2_to_1_dual(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(dx::D, dy::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - y = $SCT.primal(dy) - p_out = $M.$op(x, y) - - tx = $SCT.tracer(dx) - ty = $SCT.tracer(dy) - is_infl_arg1_zero = $SCT.is_infl_arg1_zero_local($M.$op, x, y) - is_infl_arg2_zero = $SCT.is_infl_arg2_zero_local($M.$op, x, y) - t_out = $SCT.connectivity_tracer_2_to_1( - tx, ty, is_infl_arg1_zero, is_infl_arg2_zero - ) - return $SCT.Dual(p_out, t_out) - end - - function $M.$op( - dx::D, y::Real - ) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(dx) - p_out = $M.$op(x, y) - - tx = $SCT.tracer(dx) - is_infl_arg1_zero = $SCT.is_infl_arg1_zero_local($M.$op, x, y) - t_out = $SCT.connectivity_tracer_1_to_1(tx, is_infl_arg1_zero) - return $SCT.Dual(p_out, t_out) - end - - function $M.$op( - x::Real, dy::D - ) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} - y = $SCT.primal(dy) - p_out = $M.$op(x, y) - - ty = $SCT.tracer(dy) - is_infl_arg2_zero = $SCT.is_infl_arg2_zero_local($M.$op, x, y) - t_out = $SCT.connectivity_tracer_1_to_1(ty, is_infl_arg2_zero) - return $SCT.Dual(p_out, t_out) - end - end -end - -## 1-to-2 - -@noinline function connectivity_tracer_1_to_2( - t::T, is_infl_out1_zero::Bool, is_infl_out2_zero::Bool -) where {T<:ConnectivityTracer} - if isemptytracer(t) # TODO: add test - return (t, t) - else - t_out1 = connectivity_tracer_1_to_1(t, is_infl_out1_zero) - t_out2 = connectivity_tracer_1_to_1(t, is_infl_out2_zero) - return (t_out1, t_out2) # return tracers - end -end - -function overload_connectivity_1_to_2(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(t::$SCT.ConnectivityTracer) - is_infl_out1_zero = $SCT.is_infl_out1_zero_global($M.$op) - is_infl_out2_zero = $SCT.is_infl_out2_zero_global($M.$op) - return $SCT.connectivity_tracer_1_to_2(t, is_infl_out1_zero, is_infl_out2_zero) - end - end -end - -function overload_connectivity_1_to_2_dual(M, op) - SCT = SparseConnectivityTracer - return quote - function $M.$op(d::D) where {P,T<:$SCT.ConnectivityTracer,D<:$SCT.Dual{P,T}} - x = $SCT.primal(d) - p_out1, p_out2 = $M.$op(x) - - t = tracer(d) - is_infl_out1_zero = $SCT.is_infl_out1_zero_local($M.$op, x) - is_infl_out2_zero = $SCT.is_infl_out2_zero_local($M.$op, x) - t_out1, t_out2 = $SCT.connectivity_tracer_1_to_2( - t, is_infl_out1_zero, is_infl_out2_zero - )# TODO: add test, this was buggy - return ($SCT.Dual(p_out1, t_out1), $SCT.Dual(p_out2, t_out2)) - end - end -end - -## Special cases - -## Exponent (requires extra types) -for S in (Integer, Rational, Irrational{:ℯ}) - Base.:^(t::ConnectivityTracer, ::S) = t - Base.:^(::S, t::ConnectivityTracer) = t - function Base.:^(dx::D, y::S) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - x = primal(dx) - return Dual(x^y, tracer(dx)) - end - function Base.:^(x::S, dy::D) where {P,T<:ConnectivityTracer,D<:Dual{P,T}} - y = primal(dy) - return Dual(x^y, tracer(dy)) - end -end - -## Rounding -Base.round(t::ConnectivityTracer, ::RoundingMode; kwargs...) = t - -## Random numbers -Base.rand(::AbstractRNG, ::SamplerType{T}) where {T<:ConnectivityTracer} = myempty(T) # TODO: was missing Base, add tests diff --git a/src/overloads/ifelse_global.jl b/src/overloads/ifelse_global.jl index 2609acc1..79a16e6d 100644 --- a/src/overloads/ifelse_global.jl +++ b/src/overloads/ifelse_global.jl @@ -50,7 +50,6 @@ for op in (isequal, isapprox, isless, ==, <, >, <=, >=) @eval is_der_cross_zero_global(::$T) = true op_symb = nameof(op) - SparseConnectivityTracer.eval(overload_connectivity_2_to_1(:Base, op_symb)) SparseConnectivityTracer.eval(overload_gradient_2_to_1(:Base, op_symb)) SparseConnectivityTracer.eval(overload_hessian_2_to_1(:Base, op_symb)) end diff --git a/src/overloads/overload_all.jl b/src/overloads/overload_all.jl index 5c47fff1..737fed2c 100644 --- a/src/overloads/overload_all.jl +++ b/src/overloads/overload_all.jl @@ -1,42 +1,36 @@ function overload_all(M) exprs_1_to_1 = [ quote - $(overload_connectivity_1_to_1(M, op)) $(overload_gradient_1_to_1(M, op)) $(overload_hessian_1_to_1(M, op)) end for op in nameof.(list_operators_1_to_1(Val(M))) ] exprs_1_to_1_dual = [ quote - $(overload_connectivity_1_to_1_dual(M, op)) $(overload_gradient_1_to_1_dual(M, op)) $(overload_hessian_1_to_1_dual(M, op)) end for op in nameof.(list_operators_1_to_1(Val(M))) ] exprs_2_to_1 = [ quote - $(overload_connectivity_2_to_1(M, op)) $(overload_gradient_2_to_1(M, op)) $(overload_hessian_2_to_1(M, op)) end for op in nameof.(list_operators_2_to_1(Val(M))) ] exprs_2_to_1_dual = [ quote - $(overload_connectivity_2_to_1_dual(M, op)) $(overload_gradient_2_to_1_dual(M, op)) $(overload_hessian_2_to_1_dual(M, op)) end for op in nameof.(list_operators_2_to_1(Val(M))) ] exprs_1_to_2 = [ quote - $(overload_connectivity_1_to_2(M, op)) $(overload_gradient_1_to_2(M, op)) $(overload_hessian_1_to_2(M, op)) end for op in nameof.(list_operators_1_to_2(Val(M))) ] exprs_1_to_2_dual = [ quote - $(overload_connectivity_1_to_2_dual(M, op)) $(overload_gradient_1_to_2_dual(M, op)) $(overload_hessian_1_to_2_dual(M, op)) end for op in nameof.(list_operators_1_to_2(Val(M))) diff --git a/src/patterns.jl b/src/patterns.jl index afea743c..a3d9b479 100644 --- a/src/patterns.jl +++ b/src/patterns.jl @@ -6,7 +6,7 @@ Abstract supertype of all sparsity pattern representations. ## Type hierarchy ``` AbstractPattern -├── AbstractGradientPattern: used in GradientTracer, ConnectivityTracer +├── AbstractGradientPattern: used in GradientTracer │ └── IndexSetGradientPattern └── AbstractHessianPattern: used in HessianTracer └── IndexSetHessianPattern @@ -102,7 +102,6 @@ function seed(::Type{IndexSetGradientPattern{I,S}}, i) where {I,S} end # Tracer compatibility -inputs(s::IndexSetGradientPattern) = s.gradient gradient(s::IndexSetGradientPattern) = s.gradient #========================# diff --git a/src/tracers.jl b/src/tracers.jl index 94345b1e..6fd50064 100644 --- a/src/tracers.jl +++ b/src/tracers.jl @@ -1,53 +1,5 @@ abstract type AbstractTracer{P<:AbstractPattern} <: Real end -#====================# -# ConnectivityTracer # -#====================# - -""" -$(TYPEDEF) - -`Real` number type keeping track of input indices of previous computations. - -For a higher-level interface, refer to [`connectivity_pattern`](@ref). - -## Fields -$(TYPEDFIELDS) -""" -struct ConnectivityTracer{P<:AbstractGradientPattern} <: AbstractTracer{P} - "Sparse representation of connected inputs." - pattern::P - "Indicator whether pattern in tracer contains only zeros." - isempty::Bool - - function ConnectivityTracer{P}(inputs::P, isempty::Bool=false) where {P} - return new{P}(inputs, isempty) - end -end - -# We have to be careful when defining constructors: -# Generic code expecting "regular" numbers `x` will sometimes convert them -# by calling `T(x)` (instead of `convert(T, x)`), where `T` can be `ConnectivityTracer`. -# When this happens, we create a new empty tracer with no input pattern. -ConnectivityTracer{P}(::Real) where {P} = myempty(ConnectivityTracer{P}) -ConnectivityTracer{P}(t::ConnectivityTracer{P}) where {P} = t -ConnectivityTracer(t::ConnectivityTracer) = t - -isemptytracer(t::ConnectivityTracer) = t.isempty -pattern(t::ConnectivityTracer) = t.pattern -inputs(t::ConnectivityTracer) = inputs(pattern(t)) - -function Base.show(io::IO, t::ConnectivityTracer) - print(io, typeof(t)) - if isemptytracer(t) - print(io, "()") - else - printsorted(io, inputs(t)) - end - println(io) - return nothing -end - #================# # GradientTracer # #================# @@ -57,8 +9,6 @@ $(TYPEDEF) `Real` number type keeping track of non-zero gradient entries. -For a higher-level interface, refer to [`jacobian_pattern`](@ref). - ## Fields $(TYPEDFIELDS) """ @@ -101,8 +51,6 @@ $(TYPEDEF) `Real` number type keeping track of non-zero gradient and Hessian entries. -For a higher-level interface, refer to [`hessian_pattern`](@ref). - ## Fields $(TYPEDFIELDS) """ @@ -167,11 +115,10 @@ end primal(d::Dual) = d.primal tracer(d::Dual) = d.tracer -inputs(d::Dual{P,T}) where {P,T<:ConnectivityTracer} = inputs(tracer(d)) -gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(tracer(d)) -gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(tracer(d)) -hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(tracer(d)) -isemptytracer(d::Dual) = isemptytracer(tracer(d)) +gradient(d::Dual{P,T}) where {P,T<:GradientTracer} = gradient(tracer(d)) +gradient(d::Dual{P,T}) where {P,T<:HessianTracer} = gradient(tracer(d)) +hessian(d::Dual{P,T}) where {P,T<:HessianTracer} = hessian(tracer(d)) +isemptytracer(d::Dual) = isemptytracer(tracer(d)) Dual{P,T}(d::Dual{P,T}) where {P<:Real,T<:AbstractTracer} = d Dual(primal::P, tracer::T) where {P,T} = Dual{P,T}(primal, tracer) @@ -187,21 +134,19 @@ end myempty(::T) where {T<:AbstractTracer} = myempty(T) # myempty(::Type{T}) where {P,T<:AbstractTracer{P}} = T(myempty(P), true) # JET complains about this -myempty(::Type{T}) where {P,T<:ConnectivityTracer{P}} = T(myempty(P), true) -myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true) -myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true) +myempty(::Type{T}) where {P,T<:GradientTracer{P}} = T(myempty(P), true) +myempty(::Type{T}) where {P,T<:HessianTracer{P}} = T(myempty(P), true) seed(::T, i) where {T<:AbstractTracer} = seed(T, i) # seed(::Type{T}, i) where {P,T<:AbstractTracer{P}} = T(seed(P, i)) # JET complains about this -seed(::Type{T}, i) where {P,T<:ConnectivityTracer{P}} = T(seed(P, i)) -seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i)) -seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i)) +seed(::Type{T}, i) where {P,T<:GradientTracer{P}} = T(seed(P, i)) +seed(::Type{T}, i) where {P,T<:HessianTracer{P}} = T(seed(P, i)) """ create_tracer(T, index) where {T<:AbstractTracer} -Convenience constructor for [`ConnectivityTracer`](@ref), [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices. +Convenience constructor for [`GradientTracer`](@ref) and [`HessianTracer`](@ref) from input indices. """ function create_tracer(::Type{T}, ::Real, index::Integer) where {P,T<:AbstractTracer{P}} return T(seed(P, index)) @@ -212,12 +157,11 @@ function create_tracer(::Type{Dual{P,T}}, primal::Real, index::Integer) where {P end # Pretty-printing of Dual tracers -name(::Type{T}) where {T<:ConnectivityTracer} = "ConnectivityTracer" -name(::Type{T}) where {T<:GradientTracer} = "GradientTracer" -name(::Type{T}) where {T<:HessianTracer} = "HessianTracer" -name(::Type{D}) where {P,T,D<:Dual{P,T}} = "Dual-$(name(T))" -name(::T) where {T<:AbstractTracer} = name(T) -name(::D) where {D<:Dual} = name(D) +name(::Type{T}) where {T<:GradientTracer} = "GradientTracer" +name(::Type{T}) where {T<:HessianTracer} = "HessianTracer" +name(::Type{D}) where {P,T,D<:Dual{P,T}} = "Dual-$(name(T))" +name(::T) where {T<:AbstractTracer} = name(T) +name(::D) where {D<:Dual} = name(D) # Utilities for printing sets printsorted(io::IO, x) = Base.show_delim_array(io, sort(x), "(", ',', ')', true) diff --git a/test/brusselator.jl b/test/brusselator.jl index 6229ff20..c2cae04d 100644 --- a/test/brusselator.jl +++ b/test/brusselator.jl @@ -6,7 +6,7 @@ using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using SparseConnectivityTracerBenchmarks.ODE: Brusselator! using Test -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS include("tracers_definitions.jl") function test_brusselator(method::AbstractSparsityDetector) diff --git a/test/flux.jl b/test/flux.jl index 2cbf33ed..d28f7bde 100644 --- a/test/flux.jl +++ b/test/flux.jl @@ -6,7 +6,7 @@ using SparseConnectivityTracer using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector using Test -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS include("tracers_definitions.jl") const INPUT_FLUX = reshape( diff --git a/test/runtests.jl b/test/runtests.jl index f9e70e64..0186e6a1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,9 +81,6 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @testset "Tracer Construction" begin include("test_constructors.jl") end - @testset "ConnectivityTracer" begin - include("test_connectivity.jl") - end @testset "GradientTracer" begin include("test_gradient.jl") end diff --git a/test/test_arrays.jl b/test/test_arrays.jl index 336a6326..a1871c2e 100644 --- a/test/test_arrays.jl +++ b/test/test_arrays.jl @@ -8,7 +8,8 @@ using LinearAlgebra: inv, pinv using SparseArrays: sparse, spdiagm using Test -PATTERN_FUNCTIONS = (connectivity_pattern, jacobian_pattern, hessian_pattern) +PATTERN_FUNCTIONS = (jacobian_sparsity, hessian_sparsity) +detector_global = TracerSparsityDetector() TEST_SQUARE_MATRICES = Dict( "`Matrix` (3×3)" => rand(3, 3), @@ -30,16 +31,12 @@ function test_patterns(f, x; outsum=false, con=isone, jac=isone, hes=isone) else _f = f end - @testset "Connecivity pattern" begin - pattern = connectivity_pattern(_f, x) - @test all(con, pattern) - end @testset "Jacobian pattern" begin - pattern = jacobian_pattern(_f, x) + pattern = jacobian_sparsity(_f, x, detector_global) @test all(jac, pattern) end @testset "Hessian pattern" begin - pattern = hessian_pattern(_f, x) + pattern = hessian_sparsity(_f, x, detector_global) @test all(hes, pattern) end end diff --git a/test/test_connectivity.jl b/test/test_connectivity.jl deleted file mode 100644 index 5d6e6812..00000000 --- a/test/test_connectivity.jl +++ /dev/null @@ -1,134 +0,0 @@ -using SparseConnectivityTracer -using SparseConnectivityTracer: ConnectivityTracer, Dual, MissingPrimalError, trace_input -using LinearAlgebra: det, dot, logdet -using SpecialFunctions: erf, beta -using NNlib: NNlib -using Test - -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS -include("tracers_definitions.jl") - -NNLIB_ACTIVATIONS_S = ( - NNlib.σ, - NNlib.celu, - NNlib.elu, - NNlib.gelu, - NNlib.hardswish, - NNlib.lisht, - NNlib.logσ, - NNlib.logcosh, - NNlib.mish, - NNlib.selu, - NNlib.softplus, - NNlib.softsign, - NNlib.swish, - NNlib.sigmoid_fast, - NNlib.tanhshrink, - NNlib.tanh_fast, -) -NNLIB_ACTIVATIONS_F = ( - NNlib.hardσ, - NNlib.hardtanh, - NNlib.leakyrelu, - NNlib.relu, - NNlib.relu6, - NNlib.softshrink, - NNlib.trelu, -) -NNLIB_ACTIVATIONS = union(NNLIB_ACTIVATIONS_S, NNLIB_ACTIVATIONS_F) - -@testset "Connectivity Global" begin - @testset "$T" for T in CONNECTIVITY_TRACERS - A = rand(1, 3) - @test connectivity_pattern(x -> only(A * x), rand(3), T) == [1 1 1] - - f(x) = [x[1]^2, 2 * x[1] * x[2]^2, sin(x[3])] - @test connectivity_pattern(f, rand(3), T) == [1 0 0; 1 1 0; 0 0 1] - @test connectivity_pattern(identity, rand(), T) ≈ [1;;] - @test connectivity_pattern(Returns(1), 1, T) ≈ [0;;] - - # Test ConnectivityTracer on functions with zero derivatives - x = rand(2) - g(x) = [x[1] * x[2], ceil(x[1] * x[2]), x[1] * round(x[2])] - @test connectivity_pattern(g, x, T) == [1 1; 1 1; 1 1] - - # Code coverage - @test connectivity_pattern(x -> [sincos(x)...], 1, T) ≈ [1; 1] - @test connectivity_pattern(typemax, 1, T) ≈ [0;;] - @test connectivity_pattern(x -> x^(2//3), 1, T) ≈ [1;;] - @test connectivity_pattern(x -> (2//3)^x, 1, T) ≈ [1;;] - @test connectivity_pattern(x -> x^ℯ, 1, T) ≈ [1;;] - @test connectivity_pattern(x -> ℯ^x, 1, T) ≈ [1;;] - @test connectivity_pattern(x -> round(x, RoundNearestTiesUp), 1, T) ≈ [1;;] - @test connectivity_pattern(x -> 0, 1, T) ≈ [0;;] - - # SpecialFunctions extension - @test connectivity_pattern(x -> erf(x[1]), rand(2), T) == [1 0] - @test connectivity_pattern(x -> beta(x[1], x[2]), rand(3), T) == [1 1 0] - - # NNlib extension - for f in NNLIB_ACTIVATIONS - @test connectivity_pattern(f, 1, T) ≈ [1;;] - end - - # Missing primal errors - @testset "MissingPrimalError on $f" for f in ( - iseven, - isfinite, - isinf, - isinteger, - ismissing, - isnan, - isnothing, - isodd, - isone, - isreal, - iszero, - ) - @test_throws MissingPrimalError connectivity_pattern(f, rand(), T) - end - - # ifelse and comparisons - if VERSION >= v"1.8" - @test connectivity_pattern( - x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], T - ) == [1 1 1 1] - - @test connectivity_pattern( - x -> ifelse(x[2] < x[3], x[1] + x[2], 1.0), [1 2 3 4], T - ) == [1 1 0 0] - - @test connectivity_pattern( - x -> ifelse(x[2] < x[3], 1.0, x[3] * x[4]), [1 2 3 4], T - ) == [0 0 1 1] - end - - function f_ampgo07(x) - return (x[1] <= 0) * convert(eltype(x), Inf) + - sin(x[1]) + - sin(10//3 * x[1]) + - log(abs(x[1])) - 84//100 * x[1] + 3 - end - @test connectivity_pattern(f_ampgo07, [1.0], T) ≈ [1;;] - - # Error handling when applying non-dual tracers to "local" functions with control flow - # TypeError: non-boolean (SparseConnectivityTracer.GradientTracer{BitSet}) used in boolean context - @test_throws TypeError connectivity_pattern( - x -> x[1] > x[2] ? x[3] : x[4], [1.0, 2.0, 3.0, 4.0], T - ) == [0 0 1 1;] - yield() - end -end - -@testset "Connectivity Local" begin - @testset "$T" for T in CONNECTIVITY_TRACERS - @test local_connectivity_pattern( - x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 2 3 4], T - ) == [1 1 0 0] - @test local_connectivity_pattern( - x -> ifelse(x[2] < x[3], x[1] + x[2], x[3] * x[4]), [1 3 2 4], T - ) == [0 0 1 1] - @test local_connectivity_pattern(x -> 0, 1, T) ≈ [0;;] - yield() - end -end diff --git a/test/test_constructors.jl b/test/test_constructors.jl index 470502c1..1bf9ed31 100644 --- a/test/test_constructors.jl +++ b/test/test_constructors.jl @@ -1,11 +1,10 @@ # Test construction and conversions of internal tracer types -using SparseConnectivityTracer: - AbstractTracer, ConnectivityTracer, GradientTracer, HessianTracer, Dual -using SparseConnectivityTracer: inputs, primal, tracer, isemptytracer +using SparseConnectivityTracer: AbstractTracer, GradientTracer, HessianTracer, Dual +using SparseConnectivityTracer: primal, tracer, isemptytracer using SparseConnectivityTracer: myempty, name using Test -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS include("tracers_definitions.jl") function test_nested_duals(::Type{T}) where {T<:AbstractTracer} @@ -158,28 +157,6 @@ function test_similar(::Type{D}) where {P,T,D<:Dual{P,T}} @test size(B5) == (5, 6) end -@testset "ConnectivityTracer" begin - P = Float32 - DUAL_CONNECTIVITY_TRACERS = [Dual{P,T} for T in CONNECTIVITY_TRACERS] - ALL_CONNECTIVITY_TRACERS = (CONNECTIVITY_TRACERS..., DUAL_CONNECTIVITY_TRACERS...) - - @testset "Nested Duals on HessianTracer" for T in CONNECTIVITY_TRACERS - test_nested_duals(T) - end - @testset "Constant functions on $T" for T in ALL_CONNECTIVITY_TRACERS - test_constant_functions(T) - end - @testset "Type conversions on $T" for T in ALL_CONNECTIVITY_TRACERS - test_type_conversion_functions(T) - end - @testset "Type casting on $T" for T in ALL_CONNECTIVITY_TRACERS - test_type_casting(T) - end - @testset "similar on $T" for T in ALL_CONNECTIVITY_TRACERS - test_similar(T) - end -end - @testset "GradientTracer" begin P = Float32 DUAL_GRADIENT_TRACERS = [Dual{P,T} for T in GRADIENT_TRACERS] diff --git a/test/test_gradient.jl b/test/test_gradient.jl index 1f783fa5..795e8518 100644 --- a/test/test_gradient.jl +++ b/test/test_gradient.jl @@ -6,7 +6,7 @@ using SpecialFunctions: erf, beta using NNlib: NNlib using Test -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS include("tracers_definitions.jl") NNLIB_ACTIVATIONS_S = ( diff --git a/test/test_hessian.jl b/test/test_hessian.jl index 68cf5f08..1792fa9f 100644 --- a/test/test_hessian.jl +++ b/test/test_hessian.jl @@ -4,20 +4,10 @@ using ADTypes: hessian_sparsity using SpecialFunctions: erf, beta using Test -# Load definitions of CONNECTIVITY_TRACERS, GRADIENT_TRACERS, HESSIAN_TRACERS +# Load definitions of GRADIENT_TRACERS and HESSIAN_TRACERS include("tracers_definitions.jl") @testset "Global Hessian" begin - @testset "Default hessian_pattern" begin - h = hessian_pattern(x -> x[1] / x[2] + x[3] / 1 + 1 / x[4], rand(4)) - @test h == [ - 0 1 0 0 - 1 1 0 0 - 0 0 0 0 - 0 0 0 1 - ] - end - @testset "$T" for T in HESSIAN_TRACERS method = TracerSparsityDetector(; hessian_tracer_type=T) diff --git a/test/tracers_definitions.jl b/test/tracers_definitions.jl index c7a3514f..671f3e2b 100644 --- a/test/tracers_definitions.jl +++ b/test/tracers_definitions.jl @@ -1,5 +1,4 @@ -using SparseConnectivityTracer: - AbstractTracer, ConnectivityTracer, GradientTracer, HessianTracer, Dual +using SparseConnectivityTracer: AbstractTracer, GradientTracer, HessianTracer, Dual using SparseConnectivityTracer: IndexSetGradientPattern, IndexSetHessianPattern using SparseConnectivityTracer: DuplicateVector, RecursiveSet, SortedVector @@ -18,6 +17,5 @@ HESSIAN_PATTERNS = ( # TODO: test on RecursiveSet ) -CONNECTIVITY_TRACERS = (ConnectivityTracer{P} for P in VECTOR_PATTERNS) GRADIENT_TRACERS = (GradientTracer{P} for P in VECTOR_PATTERNS) HESSIAN_TRACERS = (HessianTracer{P} for P in HESSIAN_PATTERNS)