diff --git a/src/aggregators/coevolve.jl b/src/aggregators/coevolve.jl index 9414cd78..e469a45d 100644 --- a/src/aggregators/coevolve.jl +++ b/src/aggregators/coevolve.jl @@ -1,7 +1,7 @@ """ Queue method. This method handles variable intensity rates. """ -mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: +mutable struct CoevolveJumpAggregation{T, S, F1, F2, F3, RNG, GR, PQ} <: AbstractSSAJumpAggregator next_jump::Int # the next jump to execute prev_jump::Int # the previous jump that was executed @@ -18,7 +18,7 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <: pq::PQ # priority queue of next time lrates::F1 # vector of rate lower bound functions urates::F1 # vector of rate upper bound functions - rateintervals::F1 # vector of interval length functions + rateintervals::F3 # vector of interval length functions haslratevec::Vector{Bool} # vector of whether an lrate was provided for this vrj end @@ -46,7 +46,7 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not end pq = MutableBinaryMinHeap{T}() - CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof(dg), + CoevolveJumpAggregation{T, S, F1, F2, typeof(rateintervals), RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng, dg, pq, lrates, urates, rateintervals, haslratevec) end @@ -58,6 +58,9 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{Any}} RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t), Tuple{typeof(u), typeof(p), typeof(t)}} + RateIntervalWrapper = FunctionWrappers.FunctionWrapper{typeof(t), + Tuple{typeof(u), typeof(p), + typeof(t), typeof(t)}} ncrjs = (constant_jumps === nothing) ? 0 : length(constant_jumps) nvrjs = (variable_jumps === nothing) ? 0 : length(variable_jumps) @@ -65,7 +68,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, affects! = Vector{AffectWrapper}(undef, nrjs) rates = Vector{RateWrapper}(undef, nvrjs) lrates = similar(rates) - rateintervals = similar(rates) + rateintervals = Vector{RateIntervalWrapper}(undef, nvrjs) urates = Vector{RateWrapper}(undef, nrjs) haslratevec = zeros(Bool, nvrjs) @@ -84,7 +87,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps, urates[idx] = RateWrapper(vrj.urate) idx += 1 rates[i] = RateWrapper(vrj.rate) - rateintervals[i] = RateWrapper(vrj.rateinterval) + rateintervals[i] = RateIntervalWrapper(vrj.rateinterval) haslratevec[i] = haslrate(vrj) lrates[i] = haslratevec[i] ? RateWrapper(vrj.lrate) : RateWrapper(nullrate) end @@ -143,8 +146,8 @@ end @inbounds return p.urates[uidx](u, params, t) end -@inline function get_rateinterval(p::CoevolveJumpAggregation, lidx, u, params, t) - @inbounds return p.rateintervals[lidx](u, params, t) +@inline function get_rateinterval(p::CoevolveJumpAggregation, lidx, u, params, t, urate) + @inbounds return p.rateintervals[lidx](u, params, t, urate) end @inline function get_lrate(p::CoevolveJumpAggregation, lidx, u, params, t) @@ -171,7 +174,7 @@ function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) whe _t = t + s if lidx > 0 while t < tstop - rateinterval = get_rateinterval(p, lidx, u, params, t) + rateinterval = get_rateinterval(p, lidx, u, params, t, urate) if s > rateinterval t = t + rateinterval urate = get_urate(p, uidx, u, params, t) diff --git a/test/hawkes_test.jl b/test/hawkes_test.jl index 0de428e3..2c12e77e 100644 --- a/test/hawkes_test.jl +++ b/test/hawkes_test.jl @@ -37,17 +37,11 @@ function hawkes_jump(i::Int, g, h; uselrate = true) urate = rate if uselrate lrate(u, p, t) = p[1] - rateinterval = (u, p, t) -> begin - _lrate = lrate(u, p, t) - _urate = urate(u, p, t) - return _urate == _lrate ? typemax(t) : 1 / (2 * _urate) - end + rateinterval = (u, p, t, urate) -> begin return urate == p[1] ? typemax(t) : + 1 / (2 * urate) end else lrate = nothing - rateinterval = (u, p, t) -> begin - _urate = urate(u, p, t) - return 1 / (2 * _urate) - end + rateinterval = (u, p, t, urate) -> begin return 1 / (2 * urate) end end function affect!(integrator) push!(h[i], integrator.t)