Skip to content

Commit

Permalink
Add LogExpFunctions package extension (#184)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Sep 2, 2024
1 parent 61efc9a commit 840fc4e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 11 deletions.
16 changes: 10 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,33 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"

[extensions]
SparseConnectivityTracerDataInterpolationsExt = "DataInterpolations"
SparseConnectivityTracerLogExpFunctionsExt = "LogExpFunctions"
SparseConnectivityTracerNNlibExt = "NNlib"
SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
ADTypes = "1"
DataInterpolations = "6.2"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
LogExpFunctions = "0.3"
NNlib = "0.8, 0.9"
Random = "<0.0.1, 1"
Requires = "1.3"
SparseArrays = "<0.0.1, 1"
SpecialFunctions = "2.4"
julia = "1.6"

[extras]
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
65 changes: 65 additions & 0 deletions ext/SparseConnectivityTracerLogExpFunctionsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
module SparseConnectivityTracerLogExpFunctionsExt

if isdefined(Base, :get_extension)
import SparseConnectivityTracer as SCT
using LogExpFunctions
else
import ..SparseConnectivityTracer as SCT
using ..LogExpFunctions
end

# 1-to-1 functions

ops_1_to_1 = (
xlogx,
xexpx,
logistic,
logit,
logcosh,
logabssinh,
log1psq,
log1pexp,
log1mexp,
log2mexp,
logexpm1,
softplus,
invsoftplus,
log1pmx,
logmxp1,
cloglog,
cexpexp,
loglogistic,
logitexp,
log1mlogistic,
logit1mexp,
)

for op in ops_1_to_1
T = typeof(op)
@eval SCT.is_der1_zero_global(::$T) = false
@eval SCT.is_der2_zero_global(::$T) = false
end

# 2-to-1 functions

ops_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp)

for op in ops_2_to_1
T = typeof(op)
@eval SCT.is_der1_arg1_zero_global(::$T) = false
@eval SCT.is_der2_arg1_zero_global(::$T) = false
@eval SCT.is_der1_arg2_zero_global(::$T) = false
@eval SCT.is_der2_arg2_zero_global(::$T) = false
@eval SCT.is_der_cross_zero_global(::$T) = false
end

# Generate overloads
eval(SCT.generate_code_1_to_1(:LogExpFunctions, ops_1_to_1))
eval(SCT.generate_code_2_to_1(:LogExpFunctions, ops_2_to_1))

# List operators for later testing
SCT.test_operators_1_to_1(::Val{:LogExpFunctions}) = ops_1_to_1
SCT.test_operators_2_to_1(::Val{:LogExpFunctions}) = ops_2_to_1
SCT.test_operators_1_to_2(::Val{:LogExpFunctions}) = ()

end
3 changes: 3 additions & 0 deletions src/SparseConnectivityTracer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ function __init__()
@require NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" include(
"../ext/SparseConnectivityTracerNNlibExt.jl"
)
@require LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" include(
"../ext/SparseConnectivityTracerLogExpFunctionsExt.jl"
)
# NOTE: SparseConnectivityTracerDataInterpolationsExt is not loaded on Julia <1.10
end
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6"
NLPModelsJuMP = "792afdf1-32c1-5681-94e0-d7bf7a5df49e"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand Down
15 changes: 11 additions & 4 deletions test/classification.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SparseConnectivityTracer: # 1-to-1
using SparseConnectivityTracer: # 1-to-1
is_der1_zero_global,
is_der2_zero_global,
is_der1_zero_local,
Expand Down Expand Up @@ -29,6 +29,7 @@ using SparseConnectivityTracer: # testing
test_operators_1_to_2
using SpecialFunctions: SpecialFunctions
using NNlib: NNlib
using LogExpFunctions: LogExpFunctions
using Test
using ForwardDiff: derivative, gradient, hessian

Expand All @@ -43,6 +44,12 @@ random_input(op) = rand()
random_input(::Union{typeof(acosh),typeof(acoth),typeof(acsc),typeof(asec)}) = 1 + rand()
random_input(::typeof(sincosd)) = 180 * rand()

# LogExpFunctions.jl
random_input(::typeof(LogExpFunctions.log1mexp)) = -rand() # log1mexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.log2mexp)) = -rand() # log2mexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.logitexp)) = -rand() # logitexp(x) is defined for x < 0
random_input(::typeof(LogExpFunctions.logit1mexp)) = -rand() # logit1mexp(x) is defined for x < 0

random_first_input(op) = random_input(op)
random_second_input(op) = random_input(op)

Expand Down Expand Up @@ -90,7 +97,7 @@ function correct_classification_1_to_1(op, x; atol)
end

@testset verbose = true "1-to-1" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_1_to_1(Val(Symbol(m)))
@test all(
correct_classification_1_to_1(op, random_input(op); atol=DEFAULT_ATOL) for
Expand Down Expand Up @@ -133,7 +140,7 @@ function correct_classification_2_to_1(op, x, y; atol)
end

@testset verbose = true "2-to-1" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_2_to_1(Val(Symbol(m)))
@test all(
correct_classification_2_to_1(
Expand Down Expand Up @@ -173,7 +180,7 @@ function correct_classification_1_to_2(op, x; atol)
end

@testset verbose = true "1-to-2" begin
@testset "$m" for m in (Base, SpecialFunctions, NNlib)
@testset "$m" for m in (Base, SpecialFunctions, NNlib, LogExpFunctions)
@testset "$op" for op in test_operators_1_to_2(Val(Symbol(m)))
@test all(
correct_classification_1_to_2(op, random_input(op); atol=DEFAULT_ATOL) for
Expand Down
98 changes: 98 additions & 0 deletions test/ext/test_LogExpFunctions.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using SparseConnectivityTracer
using LogExpFunctions
using Test

# Load definitions of GRADIENT_TRACERS, GRADIENT_PATTERNS, HESSIAN_TRACERS and HESSIAN_PATTERNS
include("../tracers_definitions.jl")

lef_1_to_1_pos_input = (
xlogx,
logistic,
logit,
log1psq,
log1pexp,
logexpm1,
softplus,
invsoftplus,
log1pmx,
logmxp1,
logcosh,
logabssinh,
cloglog,
cexpexp,
loglogistic,
log1mlogistic,
)
lef_1_to_1_neg_input = (log1mexp, log2mexp, logitexp, logit1mexp)
lef_1_to_1 = union(lef_1_to_1_pos_input, lef_1_to_1_neg_input)
lef_2_to_1 = (xlogy, xlog1py, xexpy, logaddexp, logsubexp)

@testset "Jacobian Global" begin
method = TracerSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1
@test J(x -> f(x[1]), rand(2)) == [1 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test J(x -> f(x[1], x[2]), rand(3)) == [1 1 0]
end
end
end

@testset "Jacobian Local" begin
method = TracerLocalSparsityDetector()
J(f, x) = jacobian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1_pos_input
@test J(x -> f(x[1]), [0.5, 1.0]) == [1 0]
end
@testset "$f" for f in lef_1_to_1_neg_input
@test J(x -> f(x[1]), [-0.5, 1.0]) == [1 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test J(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0]
end
end
end

@testset "Hessian Global" begin
method = TracerSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1
@test H(x -> f(x[1]), rand(2)) == [1 0; 0 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test H(x -> f(x[1], x[2]), rand(3)) == [1 1 0; 1 1 0; 0 0 0]
end
end
end

@testset "Hessian Local" begin
method = TracerLocalSparsityDetector()
H(f, x) = hessian_sparsity(f, x, method)

@testset "1-to-1 functions" begin
@testset "$f" for f in lef_1_to_1_pos_input
@test H(x -> f(x[1]), [0.5, 1.0]) == [1 0; 0 0]
end
@testset "$f" for f in lef_1_to_1_neg_input
@test H(x -> f(x[1]), [-0.5, 1.0]) == [1 0; 0 0]
end
end
@testset "2-to-1 functions" begin
@testset "$f" for f in lef_2_to_1
@test H(x -> f(x[1], x[2]), [0.5, 1.0, 2.0]) == [1 1 0; 1 1 0; 0 0 0]
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ GROUP = get(ENV, "JULIA_SCT_TEST_GROUP", "Core")
if GROUP in ("Core", "All")
@info "Testing package extensions..."
@testset verbose = true "Package extensions" begin
for ext in (:NNlib, :SpecialFunctions)
for ext in (:NNlib, :SpecialFunctions, :LogExpFunctions)
@testset "$ext" begin
@info "...$ext"
include("ext/test_$ext.jl")
Expand Down

0 comments on commit 840fc4e

Please sign in to comment.