diff --git a/Project.toml b/Project.toml index 292287c..7787e65 100644 --- a/Project.toml +++ b/Project.toml @@ -15,12 +15,14 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [weakdeps] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations" SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions" +SparseConnectivityTracerNaNMathExt = "NaNMath" SparseConnectivityTracerNNlibExt = "NNlib" SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions" @@ -31,6 +33,7 @@ DocStringExtensions = "0.9" FillArrays = "1" LinearAlgebra = "<0.0.1, 1" LogExpFunctions = "0.3" +NaNMath = "1" NNlib = "0.8, 0.9" Random = "<0.0.1, 1" Requires = "1.3" @@ -41,5 +44,6 @@ julia = "1.6" [extras] DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/ext/SparseConnectivityTracerNaNMathExt.jl b/ext/SparseConnectivityTracerNaNMathExt.jl new file mode 100644 index 0000000..31b72ac --- /dev/null +++ b/ext/SparseConnectivityTracerNaNMathExt.jl @@ -0,0 +1,94 @@ +module SparseConnectivityTracerNaNMathExt + +if isdefined(Base, :get_extension) + import SparseConnectivityTracer as SCT + using NaNMath +else + import ..SparseConnectivityTracer as SCT + using ..NaNMath +end + +## 1-to-1 + +# ops_1_to_1_s: +# x -> f != 0 +# ∂f/∂x != 0 +# ∂²f/∂x² != 0 +ops_1_to_1_s = ( + NaNMath.sqrt, + NaNMath.sin, + NaNMath.cos, + NaNMath.tan, + NaNMath.asin, + NaNMath.acos, + NaNMath.acosh, + NaNMath.atanh, + NaNMath.log, + NaNMath.log2, + NaNMath.log10, + NaNMath.log1p, + NaNMath.lgamma, +) + +for op in ops_1_to_1_s + T = typeof(op) + @eval SCT.is_der1_zero_global(::$T) = false + @eval SCT.is_der2_zero_global(::$T) = false +end + +ops_1_to_1 = ops_1_to_1_s + +## 2-to-1 + +# ops_2_to_1_ssc: +# ∂f/∂x != 0 +# ∂²f/∂x² != 0 +# ∂f/∂y != 0 +# ∂²f/∂y² != 0 +# ∂²f/∂x∂y != 0 +ops_2_to_1_ssc = (NaNMath.pow,) + +for op in ops_2_to_1_ssc + T = typeof(op) + @eval SCT.is_der1_arg1_zero_global(::$T) = false + @eval SCT.is_der2_arg1_zero_global(::$T) = false + @eval SCT.is_der1_arg2_zero_global(::$T) = false + @eval SCT.is_der2_arg2_zero_global(::$T) = false + @eval SCT.is_der_cross_zero_global(::$T) = false +end + +# ops_2_to_1_ffz: +# ∂f/∂x != 0 +# ∂²f/∂x² == 0 +# ∂f/∂y != 0 +# ∂²f/∂y² == 0 +# ∂²f/∂x∂y == 0 +ops_2_to_1_ffz = (NaNMath.max, NaNMath.min) + +for op in ops_2_to_1_ffz + T = typeof(op) + @eval SCT.is_der1_arg1_zero_global(::$T) = false + @eval SCT.is_der2_arg1_zero_global(::$T) = true + @eval SCT.is_der1_arg2_zero_global(::$T) = false + @eval SCT.is_der2_arg2_zero_global(::$T) = true + @eval SCT.is_der_cross_zero_global(::$T) = true +end + +SCT.is_der1_arg1_zero_local(::typeof(NaNMath.max), x, y) = x < y +SCT.is_der1_arg2_zero_local(::typeof(NaNMath.max), x, y) = y < x + +SCT.is_der1_arg1_zero_local(::typeof(NaNMath.min), x, y) = x > y +SCT.is_der1_arg2_zero_local(::typeof(NaNMath.min), x, y) = y > x + +ops_2_to_1 = union(ops_2_to_1_ssc, ops_2_to_1_ffz) + +## Overloads +eval(SCT.generate_code_1_to_1(:NaNMath, ops_1_to_1)) +eval(SCT.generate_code_2_to_1(:NaNMath, ops_2_to_1)) + +## List operators for later testing +SCT.test_operators_1_to_1(::Val{:NaNMath}) = ops_1_to_1 +SCT.test_operators_2_to_1(::Val{:NaNMath}) = ops_2_to_1 +SCT.test_operators_1_to_2(::Val{:NaNMath}) = () + +end diff --git a/src/SparseConnectivityTracer.jl b/src/SparseConnectivityTracer.jl index a17511b..fc920f1 100644 --- a/src/SparseConnectivityTracer.jl +++ b/src/SparseConnectivityTracer.jl @@ -53,6 +53,9 @@ function __init__() @require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include( "../ext/SparseConnectivityTracerLogExpFunctionsExt.jl" ) + @require NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" include( + "../ext/SparseConnectivityTracerNaNMathExt.jl" + ) # NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10 end end