Skip to content

Commit c9db428

Browse files
Merge pull request #2095 from SciML/saveend
Fix save_end overriding behavior
2 parents f1b8d90 + d054a0e commit c9db428

File tree

3 files changed

+43
-5
lines changed

3 files changed

+43
-5
lines changed

src/integrators/integrator_utils.jl

+7-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
7979
integrator.cache.current)
8080
end
8181
else # ==t, just save
82+
if curt == integrator.sol.prob.tspan[2] && !integrator.opts.save_end
83+
integrator.saveiter -= 1
84+
continue
85+
end
8286
savedexactly = true
8387
copyat_or_push!(integrator.sol.t, integrator.saveiter, integrator.t)
8488
if integrator.opts.save_idxs === nothing
@@ -107,7 +111,9 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
107111
end
108112
end
109113
if force_save || (integrator.opts.save_everystep &&
110-
(isempty(integrator.sol.t) || (integrator.t !== integrator.sol.t[end])))
114+
(isempty(integrator.sol.t) || (integrator.t !== integrator.sol.t[end]) &&
115+
(integrator.opts.save_end || integrator.t !== integrator.sol.prob.tspan[2])
116+
))
111117
integrator.saveiter += 1
112118
saved, savedexactly = true, true
113119
if integrator.opts.save_idxs === nothing

src/solve.jl

+10-3
Original file line numberDiff line numberDiff line change
@@ -293,9 +293,16 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,
293293
sizehint!(ts, 50)
294294
sizehint!(ks, 50)
295295
elseif !isempty(saveat_internal)
296-
sizehint!(timeseries, length(saveat_internal) + 1)
297-
sizehint!(ts, length(saveat_internal) + 1)
298-
sizehint!(ks, length(saveat_internal) + 1)
296+
savelength = length(saveat_internal) + 1
297+
if save_start == false
298+
savelength -= 1
299+
end
300+
if save_end == false && prob.tspan[2] in saveat_internal.valtree
301+
savelength -= 1
302+
end
303+
sizehint!(timeseries, savelength)
304+
sizehint!(ts, savelength)
305+
sizehint!(ks, savelength)
299306
else
300307
sizehint!(timeseries, 2)
301308
sizehint!(ts, 2)

test/interface/ode_saveat_tests.jl

+26-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ integ = init(ODEProblem((u, p, t) -> u, 0.0, (0.0, 1.0)), Tsit5(), saveat = _sav
160160
save_end = false)
161161
add_tstop!(integ, 2.0)
162162
solve!(integ)
163-
@test integ.sol.t == _saveat
163+
@test integ.sol.t == _saveat[1:end-1]
164164

165165
# Catch save for maxiters
166166
ode = ODEProblem((u, p, t) -> u, 1.0, (0.0, 1.0))
@@ -187,3 +187,28 @@ prob = ODEProblem(SIR!, [0.99, 0.01, 0.0], (t_obs[1], t_obs[end]), [0.20, 0.15])
187187
sol = solve(prob, DP5(), reltol = 1e-6, abstol = 1e-6, saveat = t_obs)
188188
@test maximum(sol) <= 1
189189
@test minimum(sol) >= 0
190+
191+
@testset "Proper save_start and save_end behavior" begin
192+
function f2(du, u, p, t)
193+
du[1] = -cos(u[1]) * u[1]
194+
end
195+
prob = ODEProblem(f2, [10], (0.0, 0.4))
196+
197+
@test solve(prob, Tsit5(); saveat = 0:.1:.4).t == [0.0; 0.1; 0.2; 0.3; 0.4]
198+
@test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = true, save_end = true).t == [0.0; 0.1; 0.2; 0.3; 0.4]
199+
@test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = false, save_end = false).t == [0.1; 0.2; 0.3]
200+
201+
ts = solve(prob, Tsit5()).t
202+
@test 0.0 in ts
203+
@test 0.4 in ts
204+
ts = solve(prob, Tsit5(); save_start = true, save_end = true).t
205+
@test 0.0 in ts
206+
@test 0.4 in ts
207+
ts = solve(prob, Tsit5(); save_start = false, save_end = false).t
208+
@test 0.0 ts
209+
@test 0.4 ts
210+
211+
@test solve(prob, Tsit5(); saveat = [.2]).t == [0.2]
212+
@test solve(prob, Tsit5(); saveat = [.2], save_start = true, save_end = true).t == [0.0; 0.2; 0.4]
213+
@test solve(prob, Tsit5(); saveat = [.2], save_start = false, save_end = false).t == [0.2]
214+
end

0 commit comments

Comments
 (0)