Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make imports explicit, test with ExplicitImports.jl #188

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions ext/SparseConnectivityTracerDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ if isdefined(Base, :get_extension)
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using SparseConnectivityTracer: Fill # from FillArrays.jl
using FillArrays: Fill # from FillArrays.jl
import DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
Expand All @@ -19,15 +18,14 @@ if isdefined(Base, :get_extension)
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
PCHIPInterpolation,
# PCHIPInterpolation,
QuinticHermiteSpline
else
using ..SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using ..SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using ..SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using ..SparseConnectivityTracer: Fill # from FillArrays.jl
using ..FillArrays: Fill # from FillArrays.jl
import ..DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
Expand All @@ -38,7 +36,7 @@ else
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
PCHIPInterpolation,
# PCHIPInterpolation,
QuinticHermiteSpline
end

Expand Down
54 changes: 52 additions & 2 deletions ext/SparseConnectivityTracerLogExpFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,60 @@ module SparseConnectivityTracerLogExpFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using LogExpFunctions
using LogExpFunctions:
LogExpFunctions,
cexpexp,
cloglog,
log1mexp,
log1mlogistic,
log1pexp,
log1pmx,
log1psq,
log2mexp,
logabssinh,
logaddexp,
logcosh,
logexpm1,
logistic,
logit,
logit1mexp,
logitexp,
loglogistic,
logmxp1,
logsubexp,
xexpx,
xexpy,
xlog1py,
xlogx,
xlogy
else
import ..SparseConnectivityTracer as SCT
using ..LogExpFunctions
using ..LogExpFunctions:
LogExpFunctions,
cexpexp,
cloglog,
log1mexp,
log1mlogistic,
log1pexp,
log1pmx,
log1psq,
log2mexp,
logabssinh,
logaddexp,
logcosh,
logexpm1,
logistic,
logit,
logit1mexp,
logitexp,
loglogistic,
logmxp1,
logsubexp,
xexpx,
xexpy,
xlog1py,
xlogx,
xlogy
end

## 1-to-1 functions
Expand Down
52 changes: 50 additions & 2 deletions ext/SparseConnectivityTracerNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,58 @@ module SparseConnectivityTracerNNlibExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using NNlib
using NNlib:
NNlib,
celu,
elu,
gelu,
hardswish,
hardtanh,
hardσ,
leakyrelu,
lisht,
logcosh,
logσ,
mish,
relu,
relu6,
selu,
sigmoid_fast,
softplus,
softshrink,
softsign,
swish,
tanh_fast,
tanhshrink,
trelu,
σ
else
import ..SparseConnectivityTracer as SCT
using ..NNlib
using ..NNlib:
NNlib,
celu,
elu,
gelu,
hardswish,
hardtanh,
hardσ,
leakyrelu,
lisht,
logcosh,
logσ,
mish,
relu,
relu6,
selu,
sigmoid_fast,
softplus,
softshrink,
softsign,
swish,
tanh_fast,
tanhshrink,
trelu,
σ
end

## 1-to-1
Expand Down
4 changes: 2 additions & 2 deletions ext/SparseConnectivityTracerNaNMathExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module SparseConnectivityTracerNaNMathExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using NaNMath
using NaNMath: NaNMath
else
import ..SparseConnectivityTracer as SCT
using ..NaNMath
using ..NaNMath: NaNMath
end

## 1-to-1
Expand Down
92 changes: 90 additions & 2 deletions ext/SparseConnectivityTracerSpecialFunctionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,98 @@ module SparseConnectivityTracerSpecialFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using SpecialFunctions
using SpecialFunctions:
SpecialFunctions,
airyai,
airyaiprime,
airyaiprimex,
airyaix,
airybi,
airybiprime,
airybiprimex,
airybix,
besseli,
besselix,
besselj,
besselj0,
besselj1,
besseljx,
besselk,
besselkx,
bessely,
bessely0,
bessely1,
besselyx,
beta,
cosint,
digamma,
ellipe,
ellipk,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
expint,
expinti,
expintx,
gamma,
invdigamma,
jinc,
logbeta,
logerfc,
loggamma,
sinint,
sphericalbesselj,
sphericalbessely,
trigamma
else
import ..SparseConnectivityTracer as SCT
using ..SpecialFunctions
using ..SpecialFunctions:
SpecialFunctions,
airyai,
airyaiprime,
airyaiprimex,
airyaix,
airybi,
airybiprime,
airybiprimex,
airybix,
besseli,
besselix,
besselj,
besselj0,
besselj1,
besseljx,
besselk,
besselkx,
bessely,
bessely0,
bessely1,
besselyx,
beta,
cosint,
digamma,
ellipe,
ellipk,
erf,
erfc,
erfcinv,
erfcx,
erfinv,
expint,
expinti,
expintx,
gamma,
invdigamma,
jinc,
logbeta,
logerfc,
loggamma,
sinint,
sphericalbesselj,
sphericalbessely,
trigamma
end

#=
Expand Down
2 changes: 1 addition & 1 deletion src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using LinearAlgebra: LinearAlgebra, Symmetric
using LinearAlgebra: Diagonal, diag, diagind
using FillArrays: Fill

using DocStringExtensions
using DocStringExtensions: DocStringExtensions, TYPEDEF, TYPEDFIELDS

if !isdefined(Base, :get_extension)
using Requires
Expand Down
2 changes: 1 addition & 1 deletion src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ for op in ops_2_to_1_fsc
end

# gradient of x/y: [1/y -x/y²]
SparseConnectivityTracer.is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x)
is_der1_arg2_zero_local(::typeof(/), x, y) = iszero(x)

# ops_2_to_1_fsz:
# ∂f/∂x != 0
Expand Down
8 changes: 4 additions & 4 deletions src/overloads/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ function LinearAlgebra.eigen(
end

## Inverse
function LinearAlgebra.inv(A::StridedMatrix{T}) where {T<:AbstractTracer}
function Base.inv(A::StridedMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
t = second_order_or(A)
return Fill(t, size(A)...)
end
function LinearAlgebra.inv(D::Diagonal{T}) where {T<:AbstractTracer}
function Base.inv(D::Diagonal{T}) where {T<:AbstractTracer}
ts_in = D.diag
ts_out = similar(ts_in)
for i in 1:length(ts_out)
Expand All @@ -132,7 +132,7 @@ function LinearAlgebra.pinv(
t = second_order_or(A)
return Fill(t, m, n)
end
LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = LinearAlgebra.inv(D)
LinearAlgebra.pinv(D::Diagonal{T}) where {T<:AbstractTracer} = inv(D)

## Division
function LinearAlgebra.:\(
Expand All @@ -143,7 +143,7 @@ function LinearAlgebra.:\(
end

## Exponential
function LinearAlgebra.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer}
function Base.exp(A::AbstractMatrix{T}) where {T<:AbstractTracer}
LinearAlgebra.checksquare(A)
n = size(A, 1)
t = second_order_or(A)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand Down
Loading
Loading