From 776d934677e447b5c8cab4e12c885f3404409bed Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Mon, 12 Feb 2024 18:38:17 +0000 Subject: [PATCH] Add a utility for generating the observation kernel, along with unit test and update the `./test` env --- EpiAware/src/EpiAware.jl | 1 + EpiAware/src/epimodel.jl | 16 ++-------------- EpiAware/src/utilities.jl | 24 ++++++++++++++++++++++++ EpiAware/test/Manifest.toml | 2 +- EpiAware/test/Project.toml | 1 + EpiAware/test/test_utilities.jl | 16 +++++++++++++++- 6 files changed, 44 insertions(+), 16 deletions(-) diff --git a/EpiAware/src/EpiAware.jl b/EpiAware/src/EpiAware.jl index 9758f6c57..2740ffdc3 100644 --- a/EpiAware/src/EpiAware.jl +++ b/EpiAware/src/EpiAware.jl @@ -33,6 +33,7 @@ using Distributions, export scan, create_discrete_pmf, growth_rate_to_reproductive_ratio, + generate_observation_kernel, EpiModel, log_daily_infections, random_walk diff --git a/EpiAware/src/epimodel.jl b/EpiAware/src/epimodel.jl index f2df9fbe8..d8ad31ab7 100644 --- a/EpiAware/src/epimodel.jl +++ b/EpiAware/src/epimodel.jl @@ -36,13 +36,7 @@ struct EpiModel{T<:Real} <: AbstractEpiModel @assert sum(gen_int) ≈ 1 "Generation interval must sum to 1" @assert sum(delay_int) ≈ 1 "Delay interval must sum to 1" #construct observation delay kernel - K = zeros(time_horizon, time_horizon) |> SparseMatrixCSC - for i = 1:time_horizon, j = 1:time_horizon - m = i - j - if m >= 0 && m <= (length(delay_int) - 1) - K[i, j] = delay_int[m+1] - end - end + K = generate_observation_kernel(delay_int, time_horizon) new{eltype(gen_int)}( gen_int, @@ -70,13 +64,7 @@ struct EpiModel{T<:Real} <: AbstractEpiModel #construct observation delay kernel #Recall first element is zero delay - K = zeros(time_horizon, time_horizon) |> SparseMatrixCSC - for i = 1:time_horizon, j = 1:time_horizon - m = i - j - if m >= 0 && m <= (length(delay_int) - 1) - K[i, j] = delay_int[m+1] - end - end + K = generate_observation_kernel(delay_int, time_horizon) new{eltype(gen_int)}( gen_int, diff --git a/EpiAware/src/utilities.jl b/EpiAware/src/utilities.jl index bbd9e33fc..ddef40845 100644 --- a/EpiAware/src/utilities.jl +++ b/EpiAware/src/utilities.jl @@ -145,3 +145,27 @@ function mean_cc_neg_bin(μ, α) r = μ^2 / ex_σ² return NegativeBinomial(r, p) end + + +""" + generate_observation_kernel(delay_int, time_horizon) + +Generate an observation kernel matrix based on the given delay interval and time horizon. + +# Arguments +- `delay_int::Vector{Float64}`: The delay PMF vector. +- `time_horizon::Int`: The number of time steps of the observation period. + +# Returns +- `K::SparseMatrixCSC{Float64, Int}`: The observation kernel matrix. +""" +function generate_observation_kernel(delay_int, time_horizon) + K = zeros(eltype(delay_int), time_horizon, time_horizon) |> SparseMatrixCSC + for i = 1:time_horizon, j = 1:time_horizon + m = i - j + if m >= 0 && m <= (length(delay_int) - 1) + K[i, j] = delay_int[m+1] + end + end + return K +end diff --git a/EpiAware/test/Manifest.toml b/EpiAware/test/Manifest.toml index 7ed8a6240..acba947c6 100644 --- a/EpiAware/test/Manifest.toml +++ b/EpiAware/test/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.10.0" manifest_format = "2.0" -project_hash = "0b01aa91e53bb772f02a49192dfa1019eaa23f4b" +project_hash = "0dea5a2fa6648a3a05ed8cb24ee73213ffe76d33" [[deps.ADTypes]] git-tree-sha1 = "41c37aa88889c171f1300ceac1313c06e891d245" diff --git a/EpiAware/test/Project.toml b/EpiAware/test/Project.toml index 5211de1c5..96861b1d0 100644 --- a/EpiAware/test/Project.toml +++ b/EpiAware/test/Project.toml @@ -2,6 +2,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" diff --git a/EpiAware/test/test_utilities.jl b/EpiAware/test/test_utilities.jl index e766516c5..7b503ead1 100644 --- a/EpiAware/test/test_utilities.jl +++ b/EpiAware/test/test_utilities.jl @@ -81,7 +81,7 @@ end end -@testitem"Testing growth_rate_to_reproductive_ratio function" begin +@testitem "Testing growth_rate_to_reproductive_ratio function" begin #Test that zero exp growth rate imples R0 = 1 @testset "Test case 1" begin r = 0 @@ -99,3 +99,17 @@ end end end + +@testitem "Testing generate_observation_kernel function" begin + using SparseArrays + @testset "Test case 1" begin + delay_int = [0.2, 0.5, 0.3] + time_horizon = 5 + expected_K = SparseMatrixCSC( + [0.2 0 0 0 0 0.5 0.2 0 0 0 0.3 0.5 0.2 0 0 0 0.3 0.5 0.2 0 0 0 0.3 0.5 0.2], + ) + K = generate_observation_kernel(delay_int, time_horizon) + @test K == expected_K + end + +end