Skip to content

Commit 18ab2bc

Browse files
Merge pull request #1858 from chriselrod/unbox
Unbox interpolant
2 parents 8efcd57 + e7f206b commit 18ab2bc

File tree

1 file changed

+50
-22
lines changed

1 file changed

+50
-22
lines changed

src/dense/generic_dense.jl

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,51 @@ end
267267
return expr
268268
end
269269

270+
function _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
271+
cache, idxs,
272+
deriv, ks, ts, p)
273+
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
274+
cache) # update the kcurrent
275+
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
276+
cache, idxs, deriv)
277+
end
278+
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
279+
caches::Tuple{C1, C2, Vararg}, idxs,
280+
deriv, ks, ts, p, cacheid) where {C1, C2}
281+
if (cacheid -= 1) != 0
282+
return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, Base.tail(caches),
283+
idxs,
284+
deriv, ks, ts, p, cacheid)
285+
end
286+
_evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
287+
first(caches), idxs,
288+
deriv, ks, ts, p)
289+
end
290+
function evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊,
291+
caches::Tuple{C}, idxs,
292+
deriv, ks, ts, p, _) where {C}
293+
_evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
294+
only(caches), idxs,
295+
deriv, ks, ts, p)
296+
end
297+
298+
function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
299+
deriv, ks, ts, id, p)
300+
if typeof(cache) <: (FunctionMapCache) || typeof(cache) <: FunctionMapConstantCache
301+
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
302+
deriv)
303+
elseif !id.dense
304+
return linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv)
305+
elseif typeof(cache) <: CompositeCache
306+
return evaluate_composite_cache(f, Θ, dt, timeseries, i₋, i₊, cache.caches, idxs,
307+
deriv, ks, ts, p, id.alg_choice[i₊])
308+
else
309+
return _evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊,
310+
cache, idxs,
311+
deriv, ks, ts, p)
312+
end
313+
end
314+
270315
"""
271316
ode_interpolation(tvals,ts,timeseries,ks)
272317
@@ -278,13 +323,11 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
278323
@unpack ts, timeseries, ks, f, cache = id
279324
@inbounds tdir = sign(ts[end] - ts[1])
280325
idx = sortperm(tvals, rev = tdir < 0)
281-
282326
# start the search thinking it's ts[1]-ts[2]
283-
i₋ = 1
284-
i₊ = 2
327+
i₋₊ref = Ref((1, 2))
285328
vals = map(idx) do j
286329
t = tvals[j]
287-
330+
(i₋, i₊) = i₋₊ref[]
288331
if continuity === :left
289332
# we have i₋ = i₊ = 1 if t = ts[1], i₊ = i₋ + 1 = lastindex(ts) if t > ts[end],
290333
# and otherwise i₋ and i₊ satisfy ts[i₋] < t ≤ ts[i₊]
@@ -296,26 +339,11 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
296339
i₋ = max(1, _searchsortedlast(ts, t, i₋, tdir > 0))
297340
i₊ = i₋ < lastindex(ts) ? i₋ + 1 : i₋
298341
end
299-
342+
i₋₊ref[] = (i₋, i₊)
300343
dt = ts[i₊] - ts[i₋]
301344
Θ = iszero(dt) ? oneunit(t) / oneunit(dt) : (t - ts[i₋]) / dt
302-
303-
if typeof(cache) <: (FunctionMapCache) || typeof(cache) <: FunctionMapConstantCache
304-
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], 0, cache, idxs,
305-
deriv)
306-
elseif !id.dense
307-
return linear_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], idxs, deriv)
308-
elseif typeof(cache) <: CompositeCache
309-
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
310-
cache.caches[id.alg_choice[i₊]]) # update the kcurrent
311-
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊],
312-
cache.caches[id.alg_choice[i₊]], idxs, deriv)
313-
else
314-
_ode_addsteps!(ks[i₊], ts[i₋], timeseries[i₋], timeseries[i₊], dt, f, p,
315-
cache) # update the kcurrent
316-
return ode_interpolant(Θ, dt, timeseries[i₋], timeseries[i₊], ks[i₊], cache,
317-
idxs, deriv)
318-
end
345+
evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
346+
deriv, ks, ts, id, p)
319347
end
320348
invpermute!(vals, idx)
321349
DiffEqArray(vals, tvals)

0 commit comments

Comments
 (0)