Skip to content

Commit

Permalink
Add NaNMath package extension
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill committed Sep 2, 2024
1 parent 2a67957 commit a392de3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"
SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions"
SparseConnectivityTracerNaNMathExt = "NaNMath"
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

Expand All @@ -31,6 +33,7 @@ DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
LogExpFunctions = "0.3"
NaNMath = "1"
NNlib = "0.8, 0.9"
Random = "<0.0.1, 1"
Requires = "1.3"
Expand All @@ -41,5 +44,6 @@ julia = "1.6"
[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
94 changes: 94 additions & 0 deletions ext/SparseConnectivityTracerNaNMathExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
module SparseConnectivityTracerNaNMathExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using NaNMath
else
import ..SparseConnectivityTracer as SCT
using ..NaNMath
end

## 1-to-1

# ops_1_to_1_s:
# x -> f != 0
# ∂f/∂x != 0
# ∂²f/∂x² != 0
ops_1_to_1_s = (
NaNMath.sqrt,
NaNMath.sin,
NaNMath.cos,
NaNMath.tan,
NaNMath.asin,
NaNMath.acos,
NaNMath.acosh,
NaNMath.atanh,
NaNMath.log,
NaNMath.log2,
NaNMath.log10,
NaNMath.log1p,
NaNMath.lgamma,
)

for op in ops_1_to_1_s
T = typeof(op)
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = false
end

ops_1_to_1 = ops_1_to_1_s

## 2-to-1

# ops_2_to_1_ssc:
# ∂f/∂x != 0
# ∂²f/∂x² != 0
# ∂f/∂y != 0
# ∂²f/∂y² != 0
# ∂²f/∂x∂y != 0
ops_2_to_1_ssc = (NaNMath.pow,)

for op in ops_2_to_1_ssc
T = typeof(op)
@eval SCT.is_der1_arg1_zero_global(::$T) = false
@eval SCT.is_der2_arg1_zero_global(::$T) = false
@eval SCT.is_der1_arg2_zero_global(::$T) = false
@eval SCT.is_der2_arg2_zero_global(::$T) = false
@eval SCT.is_der_cross_zero_global(::$T) = false
end

# ops_2_to_1_ffz:
# ∂f/∂x != 0
# ∂²f/∂x² == 0
# ∂f/∂y != 0
# ∂²f/∂y² == 0
# ∂²f/∂x∂y == 0
ops_2_to_1_ffz = (NaNMath.max, NaNMath.min)

for op in ops_2_to_1_ffz
T = typeof(op)
@eval SCT.is_der1_arg1_zero_global(::$T) = false
@eval SCT.is_der2_arg1_zero_global(::$T) = true
@eval SCT.is_der1_arg2_zero_global(::$T) = false
@eval SCT.is_der2_arg2_zero_global(::$T) = true
@eval SCT.is_der_cross_zero_global(::$T) = true
end

SCT.is_der1_arg1_zero_local(::typeof(NaNMath.max), x, y) = x < y
SCT.is_der1_arg2_zero_local(::typeof(NaNMath.max), x, y) = y < x

SCT.is_der1_arg1_zero_local(::typeof(NaNMath.min), x, y) = x > y
SCT.is_der1_arg2_zero_local(::typeof(NaNMath.min), x, y) = y > x

ops_2_to_1 = union(ops_2_to_1_ssc, ops_2_to_1_ffz)

## Overloads
eval(SCT.generate_code_1_to_1(:NaNMath, ops_1_to_1))
eval(SCT.generate_code_2_to_1(:NaNMath, ops_2_to_1))

## List operators for later testing
SCT.test_operators_1_to_1(::Val{:NaNMath}) = ops_1_to_1
SCT.test_operators_2_to_1(::Val{:NaNMath}) = ops_2_to_1
SCT.test_operators_1_to_2(::Val{:NaNMath}) = ()

end
3 changes: 3 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ function __init__()
@require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include(
"../ext/SparseConnectivityTracerLogExpFunctionsExt.jl"
)
@require NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" include(
"../ext/SparseConnectivityTracerNaNMathExt.jl"
)
# NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10
end
end
Expand Down

0 comments on commit a392de3

Please sign in to comment.