Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SpecialFunctions extension #82

Merged
merged 18 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,21 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1"
Compat = "3,4"
DocStringExtensions = "0.9"
Requires = "1.3"
SparseArrays = "1"
SpecialFunctions = "2.4"
julia = "1.6"
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
module SparseConnectivityTracerSpecialFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using SpecialFunctions
else
import ..SparseConnectivityTracer as SCT
using ..SpecialFunctions
end

#=
Complex functions are ignored.
Functions with more than 2 arguments are ignored.
Functions with integer arguments are ignored.
adrhill marked this conversation as resolved.
Show resolved Hide resolved
=#

## 1-to-1

ops_1_to_1_s = (
# Gamma Function
gamma,
loggamma,
digamma,
invdigamma,
trigamma,
# Exponential and Trigonometric Integrals
expinti,
sinint,
cosint,
# Error functions, Dawson's and Fresnel Integrals
erf,
erfc,
erfcinv,
erfcx,
logerfc,
erfinv,
# Airy and Related Functions
airyai,
airyaiprime,
airybi,
airybiprime,
airyaix,
airyaiprimex,
airybix,
airybiprimex,
# Bessel Functions
besselj0,
besselj1,
bessely0,
bessely1,
jinc,
# Elliptic Integrals
ellipk,
ellipe,
)

for op in ops_1_to_1_s
T = typeof(op)
@eval SCT.is_influence_zero_global(::$T) = false
@eval SCT.is_firstder_zero_global(::$T) = false
@eval SCT.is_seconder_zero_global(::$T) = false
end

ops_1_to_1 = ops_1_to_1_s

## 2-to-1

ops_2_to_1_ssc = (
# Gamma Function
gamma,
loggamma,
beta,
logbeta,
# Exponential and Trigonometric Integrals
expint,
expintx,
# Error functions, Dawson's and Fresnel Integrals
erf,
# Bessel Functions
besselj,
besseljx,
sphericalbesselj,
bessely,
besselyx,
sphericalbessely,
besseli,
besselix,
besselk,
besselkx,
)

for op in ops_2_to_1_ssc
T = typeof(op)
@eval SCT.is_influence_arg1_zero_global(::$T) = false
@eval SCT.is_influence_arg2_zero_global(::$T) = false
@eval SCT.is_firstder_arg1_zero_global(::$T) = false
@eval SCT.is_seconder_arg1_zero_global(::$T) = false
@eval SCT.is_firstder_arg2_zero_global(::$T) = false
@eval SCT.is_seconder_arg2_zero_global(::$T) = false
@eval SCT.is_crossder_zero_global(::$T) = false
end

ops_2_to_1 = ops_2_to_1_ssc

## Lists

SCT.list_operators_1_to_1(::Val{:SpecialFunctions}) = ops_1_to_1
SCT.list_operators_2_to_1(::Val{:SpecialFunctions}) = ops_2_to_1
SCT.list_operators_1_to_2(::Val{:SpecialFunctions}) = ()

## Overloads

eval(SCT.overload_all(:SpecialFunctions))

end
17 changes: 17 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
module SparseConnectivityTracer

const SCT = SparseConnectivityTracer

using ADTypes: ADTypes
using Compat: Returns
import SparseArrays: sparse
import Random: rand, AbstractRNG, SamplerType

using DocStringExtensions

if !isdefined(Base, :get_extension)
using Requires
end

include("settypes/duplicatevector.jl")
include("settypes/recursiveset.jl")
include("settypes/sortedvector.jl")
Expand All @@ -15,10 +21,13 @@ include("tracers.jl")
include("exceptions.jl")
include("conversion.jl")
include("operators.jl")

include("overload_connectivity.jl")
include("overload_gradient.jl")
include("overload_hessian.jl")
include("overload_dual.jl")
include("overload_all.jl")

include("pattern.jl")
include("adtypes.jl")

Expand All @@ -30,4 +39,12 @@ export hessian_pattern, local_hessian_pattern
export TracerSparsityDetector
export TracerLocalSparsityDetector

function __init__()
@static if !isdefined(Base, :get_extension)
@require SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" include(
"../ext/SparseConnectivityTracerSpecialFunctionsExt/SparseConnectivityTracerSpecialFunctionsExt.jl",
)
end
end

end # module
4 changes: 4 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,3 +641,7 @@ ops_1_to_2 = union(
ops_1_to_2_zz,
)
#! format: on

list_operators_1_to_1(::Val{:Base}) = ops_1_to_1
list_operators_2_to_1(::Val{:Base}) = ops_2_to_1
list_operators_1_to_2(::Val{:Base}) = ops_1_to_2
30 changes: 30 additions & 0 deletions src/overload_all.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function overload_all(M)
exprs_1_to_1 = [
quote
$(overload_connectivity_1_to_1(M, op))
$(overload_gradient_1_to_1(M, op))
$(overload_hessian_1_to_1(M, op))
end for op in nameof.(list_operators_1_to_1(Val(M)))
]
exprs_2_to_1 = [
quote
$(overload_connectivity_2_to_1(M, op))
$(overload_gradient_2_to_1(M, op))
$(overload_hessian_2_to_1(M, op))
end for op in nameof.(list_operators_2_to_1(Val(M)))
]
exprs_1_to_2 = [
quote
$(overload_connectivity_1_to_2(M, op))
$(overload_gradient_1_to_2(M, op))
$(overload_hessian_1_to_2(M, op))
end for op in nameof.(list_operators_1_to_2(Val(M)))
]
return quote
$(exprs_1_to_1...)
$(exprs_2_to_1...)
$(exprs_1_to_2...)
end
end

eval(overload_all(:Base))
Comment on lines +1 to +30
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this combined with list_operators_* and the big quotes in overload_*.jl feels very hacky. 🫤
There are too many layers of metaprogramming with a lot of implicit dependencies in their design that interact across several files.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I spent quite a lot of time trying out various things to make it work with minimal changes, and this seems to work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's in the quotes is exactly the same as what was directly eval-ed earlied (if you ignore the additional module prefixes, which I got from ForwardDiff).
The difference is that we split code generation (creating the expression) from code evaluation (running it through eval)

Copy link
Collaborator Author

@gdalle gdalle May 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contents of overload_all are only designed so that it can be a one-liner in all the package extensions we will no doubt need to add. It's nothing but a big concatenation of all the expressions

Loading
Loading