diff --git a/ext/SparseConnectivityTracerDataInterpolationsExt.jl b/ext/SparseConnectivityTracerDataInterpolationsExt.jl index 71d5a8be..aaa2cb9e 100644 --- a/ext/SparseConnectivityTracerDataInterpolationsExt.jl +++ b/ext/SparseConnectivityTracerDataInterpolationsExt.jl @@ -6,9 +6,8 @@ if isdefined(Base, :get_extension) using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1 using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1 - using SparseConnectivityTracer: Fill # from FillArrays.jl + using FillArrays: Fill # from FillArrays.jl import DataInterpolations: - AbstractInterpolation, LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, @@ -19,15 +18,14 @@ if isdefined(Base, :get_extension) BSplineInterpolation, BSplineApprox, CubicHermiteSpline, - PCHIPInterpolation, + # PCHIPInterpolation, QuinticHermiteSpline else using ..SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer using ..SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1 using ..SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1 - using ..SparseConnectivityTracer: Fill # from FillArrays.jl + using ..FillArrays: Fill # from FillArrays.jl import ..DataInterpolations: - AbstractInterpolation, LinearInterpolation, QuadraticInterpolation, LagrangeInterpolation, @@ -38,7 +36,7 @@ else BSplineInterpolation, BSplineApprox, CubicHermiteSpline, - PCHIPInterpolation, + # PCHIPInterpolation, QuinticHermiteSpline end diff --git a/ext/SparseConnectivityTracerLogExpFunctionsExt.jl b/ext/SparseConnectivityTracerLogExpFunctionsExt.jl index 63a75ed9..e0798034 100644 --- a/ext/SparseConnectivityTracerLogExpFunctionsExt.jl +++ b/ext/SparseConnectivityTracerLogExpFunctionsExt.jl @@ -2,10 +2,60 @@ module SparseConnectivityTracerLogExpFunctionsExt if isdefined(Base, :get_extension) import SparseConnectivityTracer as SCT - using LogExpFunctions + using LogExpFunctions: + LogExpFunctions, + cexpexp, + cloglog, + log1mexp, + log1mlogistic, + log1pexp, + log1pmx, + log1psq, + log2mexp, + logabssinh, + logaddexp, + logcosh, + logexpm1, + logistic, + logit, + logit1mexp, + logitexp, + loglogistic, + logmxp1, + logsubexp, + xexpx, + xexpy, + xlog1py, + xlogx, + xlogy else import ..SparseConnectivityTracer as SCT - using ..LogExpFunctions + using ..LogExpFunctions: + LogExpFunctions, + cexpexp, + cloglog, + log1mexp, + log1mlogistic, + log1pexp, + log1pmx, + log1psq, + log2mexp, + logabssinh, + logaddexp, + logcosh, + logexpm1, + logistic, + logit, + logit1mexp, + logitexp, + loglogistic, + logmxp1, + logsubexp, + xexpx, + xexpy, + xlog1py, + xlogx, + xlogy end ## 1-to-1 functions diff --git a/ext/SparseConnectivityTracerNNlibExt.jl b/ext/SparseConnectivityTracerNNlibExt.jl index ee74ecb4..b9a8a43e 100644 --- a/ext/SparseConnectivityTracerNNlibExt.jl +++ b/ext/SparseConnectivityTracerNNlibExt.jl @@ -4,10 +4,58 @@ module SparseConnectivityTracerNNlibExt if isdefined(Base, :get_extension) import SparseConnectivityTracer as SCT - using NNlib + using NNlib: + NNlib, + celu, + elu, + gelu, + hardswish, + hardtanh, + hardσ, + leakyrelu, + lisht, + logcosh, + logσ, + mish, + relu, + relu6, + selu, + sigmoid_fast, + softplus, + softshrink, + softsign, + swish, + tanh_fast, + tanhshrink, + trelu, + σ else import ..SparseConnectivityTracer as SCT - using ..NNlib + using ..NNlib: + NNlib, + celu, + elu, + gelu, + hardswish, + hardtanh, + hardσ, + leakyrelu, + lisht, + logcosh, + logσ, + mish, + relu, + relu6, + selu, + sigmoid_fast, + softplus, + softshrink, + softsign, + swish, + tanh_fast, + tanhshrink, + trelu, + σ end ## 1-to-1 diff --git a/ext/SparseConnectivityTracerNaNMathExt.jl b/ext/SparseConnectivityTracerNaNMathExt.jl index 01ba5151..9502fca9 100644 --- a/ext/SparseConnectivityTracerNaNMathExt.jl +++ b/ext/SparseConnectivityTracerNaNMathExt.jl @@ -2,10 +2,10 @@ module SparseConnectivityTracerNaNMathExt if isdefined(Base, :get_extension) import SparseConnectivityTracer as SCT - using NaNMath + using NaNMath: NaNMath else import ..SparseConnectivityTracer as SCT - using ..NaNMath + using ..NaNMath: NaNMath end ## 1-to-1 diff --git a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl index 210a2604..42a7d819 100644 --- a/ext/SparseConnectivityTracerSpecialFunctionsExt.jl +++ b/ext/SparseConnectivityTracerSpecialFunctionsExt.jl @@ -2,10 +2,98 @@ module SparseConnectivityTracerSpecialFunctionsExt if isdefined(Base, :get_extension) import SparseConnectivityTracer as SCT - using SpecialFunctions + using SpecialFunctions: + SpecialFunctions, + airyai, + airyaiprime, + airyaiprimex, + airyaix, + airybi, + airybiprime, + airybiprimex, + airybix, + besseli, + besselix, + besselj, + besselj0, + besselj1, + besseljx, + besselk, + besselkx, + bessely, + bessely0, + bessely1, + besselyx, + beta, + cosint, + digamma, + ellipe, + ellipk, + erf, + erfc, + erfcinv, + erfcx, + erfinv, + expint, + expinti, + expintx, + gamma, + invdigamma, + jinc, + logbeta, + logerfc, + loggamma, + sinint, + sphericalbesselj, + sphericalbessely, + trigamma else import ..SparseConnectivityTracer as SCT - using ..SpecialFunctions + using ..SpecialFunctions: + SpecialFunctions, + airyai, + airyaiprime, + airyaiprimex, + airyaix, + airybi, + airybiprime, + airybiprimex, + airybix, + besseli, + besselix, + besselj, + besselj0, + besselj1, + besseljx, + besselk, + besselkx, + bessely, + bessely0, + bessely1, + besselyx, + beta, + cosint, + digamma, + ellipe, + ellipk, + erf, + erfc, + erfcinv, + erfcx, + erfinv, + expint, + expinti, + expintx, + gamma, + invdigamma, + jinc, + logbeta, + logerfc, + loggamma, + sinint, + sphericalbesselj, + sphericalbessely, + trigamma end #= diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index fc920f16..221fa1ef 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -9,7 +9,7 @@ using LinearAlgebra: LinearAlgebra, Symmetric using LinearAlgebra: Diagonal, diag, diagind using FillArrays: Fill -using DocStringExtensions +using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDFIELDS if !isdefined(Base, :get_extension) using Requires diff --git a/src/operators.jl b/src/operators.jl index 8885043c..3f826d48 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -223,7 +223,7 @@ for op in ops_2_to_1_fsc end # gradient of x/y: [1/y -x/y²] -SparseConnectivityTracer.is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x) +is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x) # ops_2_to_1_fsz: # ∂f/∂x != 0 diff --git a/src/overloads/arrays.jl b/src/overloads/arrays.jl index ff8e6334..85ce3bc8 100644 --- a/src/overloads/arrays.jl +++ b/src/overloads/arrays.jl @@ -111,12 +111,12 @@ function LinearAlgebra.eigen( end ## Inverse -function LinearAlgebra.inv(A::StridedMatrix{T}) where {T<:AbstractTracer} +function Base.inv(A::StridedMatrix{T}) where {T<:AbstractTracer} LinearAlgebra.checksquare(A) t = second_order_or(A) return Fill(t, size(A)...) end -function LinearAlgebra.inv(D::Diagonal{T}) where {T<:AbstractTracer} +function Base.inv(D::Diagonal{T}) where {T<:AbstractTracer} ts_in = D.diag ts_out = similar(ts_in) for i in 1:length(ts_out) @@ -132,7 +132,7 @@ function LinearAlgebra.pinv( t = second_order_or(A) return Fill(t, m, n) end -LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = LinearAlgebra.inv(D) +LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = inv(D) ## Division function LinearAlgebra.:\( @@ -143,7 +143,7 @@ function LinearAlgebra.:\( end ## Exponential -function LinearAlgebra.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer} +function Base.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer} LinearAlgebra.checksquare(A) n = size(A, 1) t = second_order_or(A) diff --git a/test/Project.toml b/test/Project.toml index e5ec5cfc..26182e37 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,6 +8,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" diff --git a/test/linting.jl b/test/linting.jl new file mode 100644 index 00000000..4ae6a269 --- /dev/null +++ b/test/linting.jl @@ -0,0 +1,78 @@ +using SparseConnectivityTracer +using Test + +using JuliaFormatter: JuliaFormatter +using Aqua: Aqua +using JET: JET +using ExplicitImports: ExplicitImports + +# Load package extensions so they get tested by ExplicitImports.jl +using DataInterpolations: DataInterpolations +using NaNMath: NaNMath +using NNlib: NNlib +using SpecialFunctions: SpecialFunctions + +@testset "Code formatting" begin + @info "...with JuliaFormatter.jl" + @test JuliaFormatter.format(SparseConnectivityTracer; verbose=false, overwrite=false) +end + +@testset "Aqua tests" begin + @info "...with Aqua.jl" + Aqua.test_all( + SparseConnectivityTracer; + ambiguities=false, + deps_compat=(check_extras=false,), + stale_deps=(ignore=[:Requires],), + persistent_tasks=false, + ) +end + +@testset "JET tests" begin + @info "...with JET.jl" + JET.test_package(SparseConnectivityTracer; target_defined_modules=true) +end + +@testset "ExplicitImports tests" begin + @info "...with ExplicitImports.jl" + @testset "Improper implicit imports" begin + @test ExplicitImports.check_no_implicit_imports(SparseConnectivityTracer) === + nothing + end + @testset "Improper explicit imports" begin + @test ExplicitImports.check_no_stale_explicit_imports( + SparseConnectivityTracer; + ignore=( + # Used in code generation, which ExplicitImports doesn't pick up + :AbstractTracer, + :AkimaInterpolation, + :BSplineApprox, + :BSplineInterpolation, + :CubicHermiteSpline, + :CubicSpline, + :LagrangeInterpolation, + :QuadraticInterpolation, + :QuadraticSpline, + :QuinticHermiteSpline, + ), + ) === nothing + @test ExplicitImports.check_all_explicit_imports_via_owners( + SparseConnectivityTracer + ) === nothing + # TODO: test in the future when `public` is more common + # @test ExplicitImports.check_all_explicit_imports_are_public( + # SparseConnectivityTracer + # ) === nothing + end + @testset "Improper qualified accesses" begin + @test ExplicitImports.check_all_qualified_accesses_via_owners( + SparseConnectivityTracer + ) === nothing + @test ExplicitImports.check_no_self_qualified_accesses(SparseConnectivityTracer) === + nothing + # TODO: test in the future when `public` is more common + # @test ExplicitImports.check_all_qualified_accesses_are_public( + # SparseConnectivityTracer + # ) === nothing + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 04e253db..c2eb0739 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,13 +4,9 @@ Pkg.develop(; ) using SparseConnectivityTracer - using Compat: pkgversion +using Documenter: Documenter, DocMeta using Test -using JuliaFormatter -using Aqua -using JET -using Documenter DocMeta.setdocmeta!( SparseConnectivityTracer, @@ -23,36 +19,18 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core") @testset verbose = true "SparseConnectivityTracer.jl" begin if GROUP in ("Core", "All") - @testset verbose = true "Formalities" begin - @info "Testing formalities..." - if VERSION >= v"1.10" - @testset "Code formatting" begin - @info "...with JuliaFormatter.jl" - @test JuliaFormatter.format( - SparseConnectivityTracer; verbose=false, overwrite=false - ) - end - @testset "Aqua tests" begin - @info "...with Aqua.jl" - Aqua.test_all( - SparseConnectivityTracer; - ambiguities=false, - deps_compat=(check_extras=false,), - stale_deps=(ignore=[:Requires],), - persistent_tasks=false, - ) - end - @testset "JET tests" begin - @info "...with JET.jl" - JET.test_package(SparseConnectivityTracer; target_defined_modules=true) - end - end - @testset "Doctests" begin - Documenter.doctest(SparseConnectivityTracer) + if VERSION >= v"1.10" + @testset verbose = true "Linting" begin + @info "Testing linting..." + include("linting.jl") end end end - + if GROUP in ("Core", "All") + @testset "Doctests" begin + Documenter.doctest(SparseConnectivityTracer) + end + end if GROUP in ("Core", "All") @testset verbose = true "Set types" begin @testset "Correctness" begin