Skip to content

Commit

Permalink
Simplify DataInterpolations.jl extension (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrhill authored Oct 21, 2024
1 parent 94b5bc3 commit e6d357b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 121 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ SparseConnectivityTracerSpecialFunctionsExt = "SpecialFunctions"

[compat]
ADTypes = "1"
DataInterpolations = "6.4.2"
DataInterpolations = "6.5"
DocStringExtensions = "0.9"
FillArrays = "1"
LinearAlgebra = "<0.0.1, 1"
Expand Down
156 changes: 61 additions & 95 deletions ext/SparseConnectivityTracerDataInterpolationsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using FillArrays: Fill # from FillArrays.jl
import DataInterpolations: AbstractInterpolation
import DataInterpolations:
LinearInterpolation,
QuadraticInterpolation,
Expand All @@ -20,42 +21,71 @@ import DataInterpolations:
# PCHIPInterpolation,
QuinticHermiteSpline

#========================#
# General interpolations #
#========================#
#===========#
# Utilities #
#===========#

function _sct_interpolate(
::AbstractInterpolation{T,N},
uType::Type{V},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
) where {T,N,V<:AbstractVector}
return gradient_tracer_1_to_1(t, is_der_1_zero)
end
function _sct_interpolate(
::AbstractInterpolation{T,N},
uType::Type{V},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
) where {T,N,V<:AbstractVector}
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
end
function _sct_interpolate(
::AbstractInterpolation{T,N},
uType::Type{M},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
) where {T,N,M<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, is_der_1_zero)
return Fill(t, N)
end
function _sct_interpolate(
::AbstractInterpolation{T,N},
uType::Type{M},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
) where {T,N,M<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
return Fill(t, N)
end

#===========#
# Overloads #
#===========#

# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
# all interpolations have a non-zero second derivative at some point in the input domain.

for I in (
:QuadraticInterpolation,
:LagrangeInterpolation,
:AkimaInterpolation,
:QuadraticSpline,
:CubicSpline,
:BSplineInterpolation,
:BSplineApprox,
:CubicHermiteSpline,
:QuinticHermiteSpline,
for (I, is_der1_zero, is_der2_zero) in (
(:ConstantInterpolation, true, true),
(:LinearInterpolation, false, true),
(:QuadraticInterpolation, false, false),
(:LagrangeInterpolation, false, false),
(:AkimaInterpolation, false, false),
(:QuadraticSpline, false, false),
(:CubicSpline, false, false),
(:BSplineInterpolation, false, false),
(:BSplineApprox, false, false),
(:CubicHermiteSpline, false, false),
(:QuinticHermiteSpline, false, false),
)
# 1D Interpolations (uType<:AbstractVector)
@eval function (interp::$(I){uType})(t::GradientTracer) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, false)
end
@eval function (interp::$(I){uType})(t::HessianTracer) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, false, false)
end

# ND Interpolations (uType<:AbstractMatrix)
@eval function (interp::$(I){uType})(t::GradientTracer) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
@eval function (interp::$(I){uType})(t::HessianTracer) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, false, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
@eval function (interp::$(I){uType})(t::AbstractTracer) where {uType}
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
end
end

Expand All @@ -80,68 +110,4 @@ for I in (
end
end

#=======================#
# ConstantInterpolation #
#=======================#

# 1D Interpolations (uType<:AbstractVector)
function (interp::ConstantInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, true)
end
function (interp::ConstantInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, true, true)
end

# ND Interpolations (uType<:AbstractMatrix)
function (interp::ConstantInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
function (interp::ConstantInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, true, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end

#=====================#
# LinearInterpolation #
#=====================#

# 1D Interpolations (uType<:AbstractVector)
function (interp::LinearInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractVector}
return gradient_tracer_1_to_1(t, false)
end
function (interp::LinearInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractVector}
return hessian_tracer_1_to_1(t, false, true)
end

# ND Interpolations (uType<:AbstractMatrix)
function (interp::LinearInterpolation{uType})(
t::GradientTracer
) where {uType<:AbstractMatrix}
t = gradient_tracer_1_to_1(t, false)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end
function (interp::LinearInterpolation{uType})(
t::HessianTracer
) where {uType<:AbstractMatrix}
t = hessian_tracer_1_to_1(t, false, true)
nstates = size(interp.u, 1)
return Fill(t, nstates)
end

end # module SparseConnectivityTracerDataInterpolationsExt
50 changes: 25 additions & 25 deletions test/ext/test_DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ struct InterpolationTest{N,I<:AbstractInterpolation} # N = output dim. of interp
is_der2_zero::Bool
end
function InterpolationTest(
N, interp::I; is_der1_zero=false, is_der2_zero=false
) where {I<:AbstractInterpolation}
return InterpolationTest{N,I}(interp, is_der1_zero, is_der2_zero)
interp::I; is_der1_zero=false, is_der2_zero=false
) where {T,N,I<:AbstractInterpolation{T,N}}
return InterpolationTest{only(N),I}(interp, is_der1_zero, is_der2_zero)
end
testname(t::InterpolationTest{N}) where {N} = "$N-dim $(typeof(t.interp))"

Expand Down Expand Up @@ -231,19 +231,19 @@ end
@testset "1D Interpolations" begin
@testset "$(testname(t))" for t in (
InterpolationTest(
1, ConstantInterpolation(u, t); is_der1_zero=true, is_der2_zero=true
ConstantInterpolation(u, t); is_der1_zero=true, is_der2_zero=true
),
InterpolationTest(1, LinearInterpolation(u, t); is_der2_zero=true),
InterpolationTest(1, QuadraticInterpolation(u, t)),
InterpolationTest(1, LagrangeInterpolation(u, t)),
InterpolationTest(1, AkimaInterpolation(u, t)),
InterpolationTest(1, QuadraticSpline(u, t)),
InterpolationTest(1, CubicSpline(u, t)),
InterpolationTest(1, BSplineInterpolation(u, t, 3, :ArcLen, :Average)),
InterpolationTest(1, BSplineApprox(u, t, 3, 4, :ArcLen, :Average)),
InterpolationTest(1, PCHIPInterpolation(u, t)),
InterpolationTest(1, CubicHermiteSpline(du, u, t)),
InterpolationTest(1, QuinticHermiteSpline(ddu, du, u, t)),
InterpolationTest(LinearInterpolation(u, t); is_der2_zero=true),
InterpolationTest(QuadraticInterpolation(u, t)),
InterpolationTest(LagrangeInterpolation(u, t)),
InterpolationTest(AkimaInterpolation(u, t)),
InterpolationTest(QuadraticSpline(u, t)),
InterpolationTest(CubicSpline(u, t)),
InterpolationTest(BSplineInterpolation(u, t, 3, :ArcLen, :Average)),
InterpolationTest(BSplineApprox(u, t, 3, 4, :ArcLen, :Average)),
InterpolationTest(PCHIPInterpolation(u, t)),
InterpolationTest(CubicHermiteSpline(du, u, t)),
InterpolationTest(QuinticHermiteSpline(ddu, du, u, t)),
)
test_jacobian(t)
test_hessian(t)
Expand All @@ -258,18 +258,18 @@ for N in (2, 5)
@testset "$(N)D Interpolations" begin
@testset "$(testname(t))" for t in (
InterpolationTest(
N, ConstantInterpolation(um, t); is_der1_zero=true, is_der2_zero=true
ConstantInterpolation(um, t); is_der1_zero=true, is_der2_zero=true
),
InterpolationTest(N, LinearInterpolation(um, t); is_der2_zero=true),
InterpolationTest(N, QuadraticInterpolation(um, t)),
InterpolationTest(N, LagrangeInterpolation(um, t)),
InterpolationTest(LinearInterpolation(um, t); is_der2_zero=true),
InterpolationTest(QuadraticInterpolation(um, t)),
InterpolationTest(LagrangeInterpolation(um, t)),
## The following interpolations appear to not be supported on N dimensions as of DataInterpolations v6.2.0:
# InterpolationTest(N, AkimaInterpolation(um, t)),
# InterpolationTest(N, BSplineApprox(um, t, 3, 4, :ArcLen, :Average)),
# InterpolationTest(N, QuadraticSpline(um, t)),
# InterpolationTest(N, CubicSpline(um, t)),
# InterpolationTest(N, BSplineInterpolation(um, t, 3, :ArcLen, :Average)),
# InterpolationTest(N, PCHIPInterpolation(um, t)),
# InterpolationTest(AkimaInterpolation(um, t)),
# InterpolationTest(BSplineApprox(um, t, 3, 4, :ArcLen, :Average)),
# InterpolationTest(QuadraticSpline(um, t)),
# InterpolationTest(CubicSpline(um, t)),
# InterpolationTest(BSplineInterpolation(um, t, 3, :ArcLen, :Average)),
# InterpolationTest(PCHIPInterpolation(um, t)),
)
test_jacobian(t)
test_hessian(t)
Expand Down
2 changes: 2 additions & 0 deletions test/linting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ end
:AkimaInterpolation,
:BSplineApprox,
:BSplineInterpolation,
:ConstantInterpolation,
:CubicHermiteSpline,
:CubicSpline,
:LagrangeInterpolation,
:LinearInterpolation,
:QuadraticInterpolation,
:QuadraticSpline,
:QuinticHermiteSpline,
Expand Down

0 comments on commit e6d357b

Please sign in to comment.