diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 35a5e7eec..29f9d2e03 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -32,10 +32,10 @@ using Distributions, DataFramesMeta # Exported utilities -export create_discrete_pmf, spread_draws +export create_discrete_pmf, spread_draws, scan # Exported types -export EpiData, Renewal, ExpGrowthRate, DirectInfections +export EpiData, Renewal, ExpGrowthRate, DirectInfections, AbstractEpiModel # Exported Turing model constructors export make_epi_inference_model diff --git a/EpiAware/src/utilities.jl b/EpiAware/src/utilities.jl index 35c2c4c1f..4b1fd4095 100644 --- a/EpiAware/src/utilities.jl +++ b/EpiAware/src/utilities.jl @@ -1,23 +1,32 @@ + """ - scan(f, init, xs) + scan(f::F, init, xs) where {F <: AbstractEpiModel} + +Apply `f` to each element of `xs` and accumulate the results. + +`f` must be a [callable](https://docs.julialang.org/en/v1/manual/methods/#Function-like-objects) + on a sub-type of `AbstractEpiModel`. + +### Design note +`scan` is being restricted to `AbstractEpiModel` sub-types to ensure: + 1. That compiler specialization is [activated](https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing) + 2. Also avoids potential compiler [overhead](https://docs.julialang.org/en/v1/devdocs/functions/#compiler-efficiency-issues) + from specialisation on `f<: Function`. + -Apply a function `f` to each element of `xs` along with an accumulator hidden state with intial -value `init`. The function `f` takes the current accumulator value and the current element of `xs` as -arguments, and returns a new accumulator value and a result value. The function `scan` returns a tuple -`(ys, carry)`, where `ys` is an array containing the result values and `carry` is the final accumulator -value. This is similar to the JAX function `jax.lax.scan`. # Arguments -- `f`: A function that takes an accumulator value and an element of `xs` as arguments and returns a new - hidden state. -- `init`: The initial accumulator value. +- `f`: A callable/functor that takes two arguments, `carry` and `x`, and returns a new + `carry` and a result `y`. +- `init`: The initial value for the `carry` variable. - `xs`: An iterable collection of elements. # Returns -- `ys`: An array containing the result values of applying `f` to each element of `xs`. -- `carry`: The final accumulator value. +- `ys`: An array containing the results of applying `f` to each element of `xs`. +- `carry`: The final value of the `carry` variable after processing all elements of `xs`. + """ -function scan(f, init, xs::Vector{T}) where {T <: Union{Integer, AbstractFloat}} +function scan(f::F, init, xs) where {F <: AbstractEpiModel} carry = init ys = similar(xs) for (i, x) in enumerate(xs) diff --git a/EpiAware/test/test_utilities.jl b/EpiAware/test/test_utilities.jl index 8ae9b58d2..8e1295353 100644 --- a/EpiAware/test/test_utilities.jl +++ b/EpiAware/test/test_utilities.jl @@ -7,7 +7,19 @@ xs = [1, 2, 3, 4, 5] expected_ys = [1, 3, 6, 10, 15] expected_carry = 15 - ys, carry = EpiAware.scan(add, 0, xs) + + # Check that a generic function CAN'T be used + @test_throws MethodError EpiAware.scan(add, 0, xs) + + # Check that a callable subtype of `AbstractEpiModel` CAN be used + struct TestEpiModelAdd <: AbstractEpiModel + end + function (epimodel::TestEpiModelAdd)(a, b) + return a + b, a + b + end + + ys, carry = EpiAware.scan(TestEpiModelAdd(), 0, xs) + @test ys == expected_ys @test carry == expected_carry end @@ -22,7 +34,19 @@ end expected_ys = [1, 2, 6, 24, 120] expected_carry = 120 - ys, carry = EpiAware.scan(multiply, 1, xs) + # Check that a generic function CAN'T be used + @test_throws MethodError ys, carry=EpiAware.scan(multiply, 1, xs) + + # Check that a callable subtype of `AbstractEpiModel` CAN be used + struct TestEpiModelMult <: AbstractEpiModel + end + + function (epimodel::TestEpiModelMult)(a, b) + return a * b, a * b + end + + ys, carry = EpiAware.scan(TestEpiModelMult(), 1, xs) + @test ys == expected_ys @test carry == expected_carry end