Skip to content

Commit f1b8d90

Browse files
Merge pull request #2094 from SciML/interpolation_output_types
Fix interpolation output types for dynamical ODEs
2 parents 4406430 + dbb3e02 commit f1b8d90

File tree

3 files changed

+53
-9
lines changed

3 files changed

+53
-9
lines changed

src/dense/generic_dense.jl

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,13 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
687687
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
688688
T::Type{Val{0}}, differential_vars) # Default interpolant is Hermite
689689
#@.. broadcast=false (1-Θ)*y₀+Θ*y₁+Θ*(Θ-1)*((1-2Θ)*(y₁-y₀)+(Θ-1)*dt*k[1] + Θ*dt*k[2])
690-
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
691-
differential_vars .**- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
690+
if all(differential_vars)
691+
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
692+
*- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
693+
else
694+
@inbounds (1 - Θ) * y₀ + Θ * y₁ +
695+
differential_vars .**- 1) * ((1 - 2Θ) * (y₁ - y₀) +- 1) * dt * k[1] + Θ * dt * k[2]))
696+
end
692697
end
693698

694699
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
@@ -755,10 +760,17 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
755760
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
756761
T::Type{Val{1}}, differential_vars) # Default interpolant is Hermite
757762
#@.. broadcast=false k[1] + Θ*(-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(3*dt*k[1] + 3*dt*k[2] + 6*y₀ - 6*y₁) + 6*y₁)/dt
758-
@inbounds (.!differential_vars).*(y₁ - y₀)/dt + differential_vars .*(
759-
k[1] +
760-
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
761-
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
763+
if all(differential_vars)
764+
@inbounds (
765+
k[1] +
766+
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
767+
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
768+
else
769+
@inbounds (.!differential_vars).*(y₁ - y₀)/dt + differential_vars .*(
770+
k[1] +
771+
Θ * (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
772+
Θ * (3 * dt * k[1] + 3 * dt * k[2] + 6 * y₀ - 6 * y₁) + 6 * y₁) / dt)
773+
end
762774
end
763775

764776
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
@@ -826,8 +838,13 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
826838
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
827839
T::Type{Val{2}}, differential_vars) # Default interpolant is Hermite
828840
#@.. broadcast=false (-4*dt*k[1] - 2*dt*k[2] - 6*y₀ + Θ*(6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁) + 6*y₁)/(dt*dt)
829-
@inbounds differential_vars .* (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
830-
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
841+
if all(differential_vars)
842+
@inbounds (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
843+
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
844+
else
845+
@inbounds differential_vars .* (-4 * dt * k[1] - 2 * dt * k[2] - 6 * y₀ +
846+
Θ * (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) + 6 * y₁) / (dt * dt)
847+
end
831848
end
832849

833850
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
@@ -887,7 +904,11 @@ Herimte Interpolation, chosen if no other dispatch for ode_interpolant
887904
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{false}}, idxs::Nothing,
888905
T::Type{Val{3}}, differential_vars) # Default interpolant is Hermite
889906
#@.. broadcast=false (6*dt*k[1] + 6*dt*k[2] + 12*y₀ - 12*y₁)/(dt*dt*dt)
890-
@inbounds differential_vars .* (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
907+
if all(differential_vars)
908+
@inbounds (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
909+
else
910+
@inbounds differential_vars .* (6 * dt * k[1] + 6 * dt * k[2] + 12 * y₀ - 12 * y₁) / (dt * dt * dt)
911+
end
891912
end
892913

893914
@muladd function hermite_interpolant(Θ, dt, y₀, y₁, k, ::Type{Val{true}}, idxs::Nothing,
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using OrdinaryDiffEq, Test
2+
3+
# in terms of the voltage across all three elements
4+
rlc1!(v′,v,(R,L,C),t) = -(v′/R + v/L)/C
5+
identity_f(v,u,p,t) = v # needed to form second order dynamical ODE
6+
7+
setup_rlc(R,L,C;v_init=0.0,v′_init=0.0,tspan=(0.0,50.0)) =
8+
DynamicalODEProblem{false}(rlc1!,identity_f,v′_init,v_init,tspan,(R,L,C))
9+
10+
# simulate voltage impulse
11+
R,L,C = 10, 0.3, 2
12+
13+
prob = setup_rlc(R,L,C,v_init=2.0)
14+
15+
res1 = solve(prob,Vern8(),dt=1/10,saveat=1/10)
16+
res3 = solve(prob,CalvoSanz4(),dt=1/10,saveat=1/10)
17+
18+
sol = solve(prob,CalvoSanz4(),dt=1/10)
19+
@test sol(0.32) isa OrdinaryDiffEq.ArrayPartition
20+
@test sol(0.32, Val{1}) isa OrdinaryDiffEq.ArrayPartition
21+
@test sol(0.32, Val{2}) isa OrdinaryDiffEq.ArrayPartition
22+
@test sol(0.32, Val{3}) isa OrdinaryDiffEq.ArrayPartition

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end
4444
@time @safetestset "Complex Tests" include("interface/complex_tests.jl")
4545
@time @safetestset "Ndim Complex Tests" include("interface/ode_ndim_complex_tests.jl")
4646
@time @safetestset "Number Type Tests" include("interface/ode_numbertype_tests.jl")
47+
@time @safetestset "Interpolation Output Type Tests" include("interface/interpolation_output_types.jl")
4748
@time @safetestset "Stiffness Detection Tests" include("interface/stiffness_detection_test.jl")
4849
@time @safetestset "Composite Interpolation Tests" include("interface/composite_interpolation.jl")
4950
@time @safetestset "Export tests" include("interface/export_tests.jl")

0 commit comments

Comments
 (0)