From f9b3019268eab608cc68cd21437dd8c66836520f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 29 Aug 2024 16:25:32 +0100 Subject: [PATCH 1/2] changes to replication and reverse AR damp priors --- .../replications/mishra-2020/index.jl | 364 +++++++++--------- 1 file changed, 181 insertions(+), 183 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl index 553a3973a..2d8977f8a 100644 --- a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl +++ b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl @@ -1,5 +1,5 @@ ### A Pluto.jl notebook ### -# v0.19.43 +# v0.19.46 using Markdown using InteractiveUtils @@ -45,37 +45,19 @@ end # ╔═╡ 8a8d5682-2f89-443b-baf0-d4d3b134d311 md" -# Getting started with `EpiAware` +# Example: Early COVID-19 case data in South Korea -This tutorial introduces the basic functionality of `EpiAware`. `EpiAware` is a package for making inferences on epidemiological case/determined infection data using a model-based approach. +In this example we use `EpiAware` functionality to largely recreate an epidemiological model presented in [On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective, _Mishra et al_ (2020)](https://arxiv.org/abs/2006.16487). _Mishra et al_ consider test-confirmed cases of COVID-19 in South Korea between January to July 2020. The components of the epidemilogical model they consider are: -It is common to conceptualise the generative process of public health data, e.g a time series of reported cases of an infectious pathogen, in a modular way. For example, it is common to abstract the underlying latent infection process away from downstream issues of observation, or to treat quanitites such as the time-varying reproduction number as being itself generated as a random process. - -`EpiAware` is built using the [`DynamicPPL`](https://github.com/TuringLang/DynamicPPL.jl) probabilistic programming domain-specific language, which is part of the [`Turing`](https://turinglang.org/dev/docs/using-turing/guide/) PPL. The structural concept behind `EpiAware` is that each module of an epidemiological model is a self-contained `Turing` [`Model`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.Model-Tuple{}); that is each module is an object that can be conditioned on observable data and sampled from. A complete `EpiAware` model is the composition of these objects using the [`@submodel`](https://turinglang.org/DynamicPPL.jl/stable/api/#DynamicPPL.@submodel) macro. -" - -# ╔═╡ 27d73202-a93e-4471-ab50-d59345304a0b -md" -## Dependencies for this notebook -Now we want to import these dependencies into scope. If evaluating these code lines/blocks in REPL, then the REPL will offer to install any missing dependencies. Alternatively, you can add them to your active environment using `Pkg.add`. -" - -# ╔═╡ 9161ab72-5c39-4a67-9762-e19f1c54c7fd -md" -## Example: Early COVID-19 case data in South Korea - -To demonstrate `EpiAware` we largely recreate an epidemiological model presented in [On the derivation of the renewal equation from an age-dependent branching process: an epidemic modelling perspective, _Mishra et al_ (2020)](https://arxiv.org/abs/2006.16487). _Mishra et al_ consider test-confirmed cases of COVID-19 in South Korea between January to July 2020. The components of the epidemilogical model they consider are: - -- The log-time varying reproductive number $\log R_t$ is modelled as an AR(2) process. -- The latent infection ($I_t$) generating process is a renewal model: +- The time varying reproductive number modelled as an [AR(2) process](https://en.wikipedia.org/wiki/Autoregressive_model) on the log-scale $\log R_t \sim \text{AR(2)}$. +- The latent infection ($I_t$) generating process is a renewal model (note that we leave out external infections in this note): ```math -I_t = \mu_t + R_t \sum_{s\geq 1} I_{t-s} g_s. +I_t = R_t \sum_{s\geq 1} I_{t-s} g_s. ``` -Where $g_t$ is a daily discretisation of the probability mass function of an estimated serial interval distribution: +- The discrete generation interval $g_t$ is a daily discretisation of the probability mass function of an estimated serial interval distribution for SARS-CoV-2: ```math G \sim \text{Gamma}(6.5,0.62). ``` -And $\mu_t$ is an external importation of infection process. - Observed cases $C_t$ are distributed around latent infections with negative binomial errors: ```math C_t \sim \text{NegBin}(\text{mean} = I_t,~ \text{overdispersion} = \phi). @@ -84,10 +66,18 @@ C_t \sim \text{NegBin}(\text{mean} = I_t,~ \text{overdispersion} = \phi). In the examples below we are going to largely recreate the _Mishra et al_ model, whilst emphasing that each component of the overall epidemiological model is, itself, a stand alone model that can be sampled from. " +# ╔═╡ 27d73202-a93e-4471-ab50-d59345304a0b +md" +## Dependencies for this notebook +Now we want to import these dependencies into scope. If evaluating these code lines/blocks in REPL, then the REPL will offer to install any missing dependencies. Alternatively, you can add them to your active environment using `Pkg.add`. +" + # ╔═╡ 1d3b9541-80ad-41b5-a5ed-a947f5c0731b md" -## Load the data into scope -First, we make sure that we have the data we want to analysis in scope by downloading it. +## Load early SARS-2 case data for South Korea +First, we make sure that we have the data we want to analysis in scope by downloading it for where we have saved a copy in the `EpiAware` repository. + +NB: The case data is curated by the [`covidregionaldata`](https://github.com/epiforecasts/covidregionaldata) package. We accessed the South Korean case data using a short [R script](https://github.com/CDCgov/Rt-without-renewal/blob/main/EpiAware/docs/src/showcase/replications/mishra-2020/get_data.R). It is possible to interface directly from a Julia session using the `RCall.jl` package, but we do not do this in this notebook to reduce the number of underlying dependencies required to run this notebook. " # ╔═╡ 4e5e0e24-8c55-4cb4-be3a-d28198f81a69 @@ -98,138 +88,124 @@ data = CSV.read(download(url), DataFrame) # ╔═╡ 104f4d16-7433-4a2d-89e7-288a9b223563 md" -### Time-varying reproduction number as a `LatentModel` type +## Time-varying reproduction number as an `AbstractLatentModel` type -`EpiAware` exposes a `LatentModel` type system; the purpose of which is to define stochastic processes which can be interpreted as generating time-varying parameters/quantities of interest. +`EpiAware` exposes a `AbstractLatentModel` abstract type; the purpose of which is to group stochastic processes which can be interpreted as generating time-varying parameters/quantities of interest which we call latent process models. -In the _Mishra et al_ model the log-time varying reproductive number is _a priori_ assumed to evolve as an auto-regressive process, AR(2): +In the _Mishra et al_ model the log-time varying reproductive number $Z_t$ is assumed to evolve as an auto-regressive process, AR(2): ```math \begin{align} -Z_t &= \log R_t, \\ +R_t &= \exp Z_t, \\ Z_t &= \rho_1 Z_{t-1} + \rho_2 Z_{t-2} + \epsilon_t, \\ -\epsilon_t &\sim \text{Normal}(0, \sigma). +\epsilon_t &\sim \text{Normal}(0, \sigma^*). \end{align} ``` +Where $\rho_1,\rho_2$, which are the parameters of AR process, and $\epsilon_t$ is a white noise process with standard deviation $\sigma^*$. +" + +# ╔═╡ d753b21f-cf8e-4a25-bab3-46c811c80a78 +md" +In `EpiAware` we determine the behaviour of a latent process by choosing a concrete subtype (i.e. a struct) of `AbstractLatentModel` which has fields that set the priors of the various parameters required for the latent process. + +The AR process has the struct `AR <: AbstractLatentModel`. The user can supply the priors for $\rho_1,\rho_2$ in the field `damp_priors`, for $\sigma^*$ in the field `std_prior`, and the initial values $Z_1, Z_2$ in the field `init_priors`. " # ╔═╡ d201c82b-8efd-41e2-96d7-4f5e0c67088c md" -`EpiAware` gives a concrete subtype `AR <: AbstractLatentModel` which defines this behaviour of the latent model. The user can supply the priors for $\rho_1,\rho_2$, wich we call `damp_priors`, as well as for $\sigma$ (`std_prior`) and the initial values $Z_1, Z_2$ (`init_priors`). +We choose priors based on _Mishra et al_ using the `Distributions.jl` interface to probability distributions. Note that we condition the AR parameters onto $[0,1]$, as in _Mishra et al_, using the `truncated` function. + +In _Mishra et al_ the standard deviation of the _stationary distribution_ of $Z_t$ which has a standard normal distribution conditioned to be positive $\sigma \sim \mathcal{N}^+(0,1)$. The value $σ^*$ was determined from a nonlinear function of sampled $\sigma, ~\rho_1, ~\rho_2$ values. Since, _Mishra et al_ give sharply informative priors for $\rho_1,~\rho_2$ (see below) we simplify by calculating $\sigma^*$ at the prior mode of $\rho_1,~\rho_2$. This results in a $\sigma^* \sim \mathcal{N}^+(0, 0.5)$ prior. " # ╔═╡ c88bbbd6-0101-4c04-97c9-c5887ef23999 ar = AR( - damp_priors = [truncated(Normal(0.8, 0.05), 0, 1), - truncated(Normal(0.05, 0.05), 0, 1)], - std_prior = HalfNormal(1.0), - init_priors = [Normal(-1.0, 0.1), Normal(-1.0, 0.1)] + damp_priors = reverse([truncated(Normal(0.8, 0.05), 0, 1), + truncated(Normal(0.1, 0.05), 0, 1)]), + std_prior = HalfNormal(0.5), + init_priors = [Normal(-1.0, 0.1), Normal(-1.0, 0.5)] ) -# ╔═╡ 40352cd6-3592-438b-b5d8-f56dcb1a4d27 -md" -The priors here are based on _Mishra et al_, note that we have decreased the _a priori_ belief in the correlation parameter $\rho_1$. -" - # ╔═╡ 31ee2757-0409-45df-b193-60c552797a3d md" -##### `Turing` model interface +### `Turing` model interface to the AR process -As mentioned above, we can use this instance of the `AR` latent model to construct a `Turing` `Model` which implements the probabilistic behaviour determined by `ar`. +As mentioned above, we can use this instance of the `AR` latent model to construct a [`Turing`](https://turinglang.org/) model object which implements the probabilistic behaviour determined by `ar`. We do this with the constructor function exposed by `EpiAware`: `generate_latent` which combines an `AbstractLatentModel` substype struct with the number of time steps for which we want to generate the latent process. -We do this with the constructor function `generate_latent` which combines `ar` with a number of time steps to generate for (in this case we choose 30). +As a refresher, we remind that the `Turing.Model` object has the following properties: + +- The model object parameters are sampleable using `rand`; that is we can generate parameters from the specified priors e.g. `θ = rand(mdl)`. +- The model object is generative as a callable; that is we can sample instances of $Z_t$ e.g. `Z_t = mdl()`. +- The model object can construct new model objects by conditioning parameters using the [`DynamicPPL.jl`](https://turinglang.org/DynamicPPL.jl/stable/) syntax, e.g. `conditional_mdl = mdl | (σ_AR = 1.0, )`. + +As a concrete example we create a model object for the AR(2) process we specified above for 50 time steps: " # ╔═╡ 2bf22866-b785-4ee0-953d-ac990a197561 -ar_mdl = generate_latent(ar, 30) +ar_mdl = generate_latent(ar, 50) # ╔═╡ 25e25125-8587-4451-8600-9b55a04dbcd9 md" -We can sample from this model, which is useful for model diagnostic and prior predictive checking. +Ultimately, this will only be one component of the full epidemiological model. However, it is useful to visualise its probabilistic behaviour for model diagnostic and prior predictive checking. + +We can spaghetti plot generative samples from the AR(2) process with the priors specified above. " # ╔═╡ fbe117b7-a0b8-4604-a5dd-e71a0a1a4fc3 plt_ar_sample = let n_samples = 100 ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(ar_mdl, θ) + ar_mdl() #Sample Z_t trajectories for the model end - plot(ar_mdl_samples, + plot(ar_mdl_samples .|> exp, #R_t = exp(Z_t) lab = "", c = :grey, alpha = 0.25, - title = "$(n_samples) draws from the AR(2) model", - ylabel = "Log Rt") + title = "$(n_samples) draws from the prior Rₜ model", + ylabel = "Time varying Rₜ", + yticks = [10.0^n for n in -4:4], + yscale = :log10) end # ╔═╡ 9f84dec1-70f1-442e-8bef-a9494921549e md" -And we can sample from this model with some parameters conditioned, for example with $\sigma = 0$. In this case the AR process is an initial perturbation model with return to baseline. +This suggests that _a priori_ we believe that there is a few percent chance of achieving very high $R_t$ values, i.e. $R_t \sim 10-1000$ is not excluded by our priors. " -# ╔═╡ 51a82a62-2c59-43c9-8562-69d15a7edfdd -cond_ar_mdl = ar_mdl | (σ_AR = 0.0,) - -# ╔═╡ d3938381-01b7-40c6-b369-a456ff6dba72 -let - n_samples = 100 - ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(cond_ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(cond_ar_mdl, θ) - end - - plot(ar_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "AR(2) model conditioned on sigma = 0", - ylabel = "Log Rt") -end - -# ╔═╡ 12fd3bd5-657e-4b1a-aa88-6063419aaceb +# ╔═╡ 6a9e871f-a2fa-4e41-af89-8b0b3c3b5b4b md" -In this note, we are going to treat $R_t$ as varying every two days. The reason for this is to 1) reduce the effective number of parameters, and 2) showcase the `BroadcastLatentModel` wrapper. +## The Renewal model as an `AbstractEpiModel` type -In `EpiAware` we set this behaviour by wrapping a `LatentModel` in a `BroadcastLatentModel`. This allows us to set the broadcasting period and type. In this case we broadcast each latent process value over $2$ days in a `RepeatBlock`. -" - -# ╔═╡ 61eac666-9fe4-4918-bd3f-68e89275d07a -twod_ar = BroadcastLatentModel(ar, 2, RepeatBlock()) - -# ╔═╡ 5a96e7e9-0376-4365-8eb1-b2fad9be8fef -let - n_samples = 100 - twod_ar_mdl = generate_latent(twod_ar, 30) - twod_ar_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(twod_ar_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(twod_ar_mdl, θ) - end +The abstract type for models that generate infections exposed by `EpiAware` is called `AbstractEpiModel`. As with latent models different concrete subtypes of `AbstractEpiModel` define different classes of infection generating process. In this case we want to implement a renewal model. - plot(twod_ar_mdl_samples, - lab = "", - c = :grey, - alpha = 0.25, - title = "$(n_samples) draws from the weekly AR(2) model", - ylabel = "Log Rt") -end +The `Renewal <: AbstractEpiModel` type of struct needs two fields: -# ╔═╡ 6a9e871f-a2fa-4e41-af89-8b0b3c3b5b4b -md" -## The Renewal model as an `EpiModel` type +- Data about the generation interval of the infectious disease so it can construct $g_t$. +- A prior for the initial numbers of infected. -`EpiAware` has an `EpiModel` type system which we use to set the behaviour of the latent infection model. In this case we want to implement a renewal model. +In _Mishra et al_ they use an estimate of the serial interval of SARS-CoV-2 as an estimate of the generation interval. -To construct an `EpiModel` we need to supply some fixed data for the model contained in an `EpiData` object. The `EpiData` constructor performs double interval censoring to convert our _continuous_ estimate of the generation interval into a discretized version $g_t$. We also implement right truncation, the default is rounding the 99th percentile of the generation interval distribution, but this can be controlled using the keyword `D_gen`. " # ╔═╡ c1fc1929-0624-45c0-9a89-86c8479b2675 truth_GI = Gamma(6.5, 0.62) +# ╔═╡ ab0c6bec-1ab7-43d1-aa59-11225dea79eb +md" +This is a representation of the generation interval distribution as continuous whereas the infection process will be formulated in discrete daily time steps. By default, `EpiAware` performs [double interval censoring](https://www.medrxiv.org/content/10.1101/2024.01.12.24301247v1) to convert our continuous estimate of the generation interval into a discretized version $g_t$, whilst also applying left truncation such that $g_0 = 0$ and normalising $\sum_t g_t = 1.$ + +The constructor for converting a continuous estimate of the generation interval distribution into a usable discrete time estimate is `EpiData`. +" + # ╔═╡ 99c9ba2c-20a5-4c7f-94d2-272d6c9d5904 model_data = EpiData(gen_distribution = truth_GI) +# ╔═╡ 3c9849a8-1361-49e7-8b4e-cc4035b3fc70 +md" +We can compare the discretized generation interval with the continuous estimate, which in this example is the serial interval estimate. +" + # ╔═╡ 71d08f7e-c409-4fbe-b154-b21d09010683 let bar(model_data.gen_int, @@ -258,7 +234,7 @@ R_1 = 1 \Big{/} \sum_{t\geq 1} e^{-rt} g_t log_I0_prior = Normal(log(1.0), 1.0) # ╔═╡ 8487835e-d430-4300-bd7c-e33f5769ee32 -epi = RenewalWithPopulation(model_data, log_I0_prior, 1e8) +epi = Renewal(model_data, log_I0_prior) # ╔═╡ 2119319f-a2ef-4c96-82c4-3c7eaf40d2e0 md" @@ -267,17 +243,15 @@ _NB: We don't implement a background infection rate in this model._ # ╔═╡ 51b5d5b6-3ad3-4967-ad1d-b1caee201fcb md" -##### `Turing` model interface - -As mentioned above, we can use this instance of the `Renewal` latent infection model to construct a `Turing` `Model` which implements the probabilistic behaviour determined by `epi`. +### `Turing` model interface to `Renewal` process -We do this with the constructor function `generate_latent_infs` which combines `epi` with a provided $\log R_t$ time series. +As mentioned above, we can use this instance of the `Renewal` latent infection model to construct a `Turing` `Model` which implements the probabilistic behaviour determined by `epi` using the constructor function `generate_latent_infs` which combines `epi` with a provided $\log R_t$ time series. -Here we choose an example where $R_t$ decreases from $R_t = 3$ to $R_t = 0.5$ over the course of 30 days. +Here we choose an example where $R_t$ decreases from $R_t = 3$ to $R_t = 0.5$ over the course of 50 days. " # ╔═╡ 9e564a6e-f521-41e8-8604-6a9d73af9ba7 -R_t_fixed = [0.5 + 2.5 / (1 + exp(t - 15)) for t in 1:30] +R_t_fixed = [0.5 + 2.5 / (1 + exp(t - 15)) for t in 1:50] # ╔═╡ 72bdb47d-4967-4f20-9ae5-01f82e7b32c5 latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed)) @@ -286,8 +260,7 @@ latent_inf_mdl = generate_latent_infs(epi, log.(R_t_fixed)) plt_epi = let n_samples = 100 epi_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(latent_inf_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(latent_inf_mdl, θ) + latent_inf_mdl() #Sample unconditionally the underlying parameters of the model end p1 = plot(epi_mdl_samples, @@ -316,17 +289,20 @@ Observation models are set in `EpiAware` as concrete subtypes of an `Observation ```math \text{var} = \text{mean} + {\text{mean}^2 \over \phi}. ``` -In `EpiAware`, we default to a prior on $\sqrt{1/\phi}$ because this quantity has the dimensions of a standard deviation and, therefore, is easier to reason on _a priori_ beliefs. +In `EpiAware`, we default to a prior on $\sqrt{1/\phi}$ because this quantity is approximately the coefficient of variation of the observation noise and, therefore, is easier to reason on _a priori_ beliefs. We call this quantity the cluster factor. + +A prior for $\phi$ was not specified in _Mishra et al_, we select one below but we will condition a value in analysis below. " # ╔═╡ 714908a1-dc85-476f-a99f-ec5c95a78b60 -obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.15)) +obs = NegativeBinomialError(cluster_factor_prior = HalfNormal(0.1)) +# obs = PoissonError() # ╔═╡ dacb8094-89a4-404a-8243-525c0dbfa482 md" -##### `Turing` model interface +### `Turing` model interface to the `NegativeBinomialError` model -We can construct a `NegativeBinomialError` model implementation as a `Turing` `Model` using `generate_observations` +We can construct a `NegativeBinomialError` model implementation as a `Turing` `Model` using the `EpiAware` `generate_observations` functions. `Turing` uses `missing` arguments to indicate variables that are to be sampled. We use this to observe a forward model that samples observations, conditional on an underlying expected observation time series. " @@ -346,8 +322,7 @@ obs_mdl = generate_observations(obs, missing, expected_cases) plt_obs = let n_samples = 100 obs_mdl_samples = mapreduce(hcat, 1:n_samples) do _ - θ = rand(obs_mdl) #Sample unconditionally the underlying parameters of the model - gen = generated_quantities(obs_mdl, θ) + θ = obs_mdl() #Sample unconditionally the underlying parameters of the model end scatter(obs_mdl_samples, lab = "", @@ -362,14 +337,13 @@ plt_obs = let lab = "Expected cases") end -# ╔═╡ de5d96f0-4df6-4cc3-9f1d-156176b2b676 -md"A _reverse_ observation model, which samples the underlying latent infections conditional on observations would require a prior on the latent infections. This is the purpose of composing multiple models; as we'll see below the latent infection and latent $R_t$ models are informative priors on the latent infection time series underlying the observations." - # ╔═╡ a06065e1-0e20-4cf8-8d5a-2d588da20bee md" ## Composing models into an `EpiProblem` -As mentioned above, each module of the overall epidemiological model we are interested in is a `Turing` `Model` in its own right. In this section, we compose the individual models into the full epidemiological model using the `EpiProblem` struct. +_Mishra et al_ follows a common pattern of having an infection generation process driven by a latent process with an observation model that links the infection process to a discrete valued time series of incidence data. + +In `EpiAware` we provide an `EpiProblem` constructor for this common epidemiological model pattern. The constructor for an `EpiProblem` requires: - An `epi_model`. @@ -382,7 +356,7 @@ The `tspan` set the range of the time index for the models. # ╔═╡ eaad5f46-e928-47c2-90ec-2cca3871c75d epi_prob = EpiProblem(epi_model = epi, - latent_model = twod_ar, + latent_model = ar, observation_model = obs, tspan = (45, 80)) @@ -405,7 +379,8 @@ num_threads = min(10, Threads.nthreads()) # ╔═╡ 88b43e23-1e06-4716-b284-76e8afc6171b inference_method = EpiMethod( pre_sampler_steps = [ManyPathfinder(nruns = 4, maxiters = 100)], - sampler = NUTSampler(adtype = AutoReverseDiff(), + sampler = NUTSampler( + adtype = AutoReverseDiff(), ndraws = 2000, nchains = num_threads, mcmc_parallel = MCMCThreads()) @@ -414,9 +389,6 @@ inference_method = EpiMethod( # ╔═╡ 92333a96-5c9b-46e1-9a8f-f1890831066b md" ## Inference and analysis - -In the background of this note (see hidden top cell and short R script in this directory), we load daily reported cases from South Korea from Jan-July 2020 which were gathered using `covidregionaldata` from ECDC data archives. - We supply the data as a `NamedTuple` with the `y_t` field containing the observed data, shortened to fit the chosen `tspan` of `epi_prob`. " @@ -424,12 +396,33 @@ We supply the data as a `NamedTuple` with the `y_t` field containing the observe south_korea_data = (y_t = data.cases_new[epi_prob.tspan[1]:epi_prob.tspan[2]], dates = data.date[epi_prob.tspan[1]:epi_prob.tspan[2]]) +# ╔═╡ f6c168e5-6933-4bd7-bf71-35a37551d040 +md" +In the epidemiological model it is hard to identify between the AR parameters such as the standard deviation of the AR process and the cluster factor of the negative binomial observation model. The reason for this identifiability problem is that the model assumes no delay between infection and observation. Therefore, on any day the data could be explained by $R_t$ changing _or_ observation noise and its not easy to disentangle greater volatility in $R_t$ from higher noise in the observations. + +In models with latent delays, changes in $R_t$ impact the observed cases over several days which means that it easier to disentangle trend effects from observation-to-observation fluctuations. + +To counter act this problem we condition the model on a fixed cluster factor value. +" + +# ╔═╡ 9cbacc02-9c76-41eb-9c75-fec667b60829 +fixed_cluster_factor = 0.25 + +# ╔═╡ b2074ff2-562d-44e6-b4b4-7a77c0f85c16 +md" +`EpiAware` has the `generate_epiaware` function which joins an `EpiProblem` object with the data to produce as `Turing` model. This `Turing` model composes the three unit `Turing` models defined above: the Renewal infection generating process, the AR latent process for $\log R_t$, and the negative binomial observation model. Therefore, [we can condition on variables as with any other `Turing` model](https://turinglang.org/DynamicPPL.jl/stable/api/#Condition-and-decondition). +" + +# ╔═╡ fe47748e-151b-4819-987a-07cf35e6cc80 +mdl = generate_epiaware(epi_prob, south_korea_data) | + (var"obs.cluster_factor" = fixed_cluster_factor,) + # ╔═╡ 9970adfd-ee88-4598-87a3-ffde5297031c md" ### Sampling with `apply_method` The `apply_method` function combines the elements above: -- An `EpiProblem` object. +- An `EpiProblem` object or `Turing` model. - An `EpiMethod` object. - Data to condition the model upon. @@ -439,8 +432,8 @@ And returns a collection of results: - Generated quantities of the model. " -# ╔═╡ 660a8511-4dd1-4788-9c14-fdd604bf83ad -inference_results = apply_method(epi_prob, +# ╔═╡ 3d10379a-3bb4-474c-ad20-de767b82d52b +inference_results = apply_method(mdl, inference_method, south_korea_data ) @@ -449,11 +442,9 @@ inference_results = apply_method(epi_prob, md" ### Results and Predictive plotting -We can spaghetti plot generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. This is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). +To assess the quality of the inference visually we can plot predictive quantiles for generated case data from the version of the model _which hasn't conditioned on case data_ using posterior parameters inferred from the version conditioned on observed data. For this purpose, we add a `generated_quantiles` utility function. This kind of visualisation is known as _posterior predictive checking_, and is a useful diagnostic tool for Bayesian inference (see [here](http://www.stat.columbia.edu/~gelman/book/BDA3.pdf)). -Because we are using synthetic data we can also plot the model predictions for the _unobserved_ infections and check that (at least in this example) we were able to capture some unobserved/latent variables in the process accurate. - -We find that the `EpiAware` model recovers the main finding in _Mishra et al_; that the $R_t$ in South Korea peaked at a very high value ($R_t \sim 10$ at peak) before rapidly dropping below 1 in early March 2020. +We also plot the inferred $R_t$ estimates from the model. We find that the `EpiAware` model recovers the main finding in _Mishra et al_; that the $R_t$ in South Korea peaked at a very high value ($R_t \sim 10$ at peak) before rapidly dropping below 1 in early March 2020. Note that, in reality, the peak $R_t$ found here and in _Mishra et al_ is unrealistically high, this might be due to a combination of: - A mis-estimated generation interval/serial interval distribution. @@ -462,56 +453,63 @@ Note that, in reality, the peak $R_t$ found here and in _Mishra et al_ is unreal In a future note, we'll demonstrate having a time-varying ascertainment rate. " +# ╔═╡ aa1d8b72-a3d2-4844-bb43-406b98b2648f +function generated_quantiles(gens, quantity, qs; transformation = x -> x) + mapreduce(hcat, gens) do gen #loop over sampled generated quantities + getfield(gen, quantity) |> transformation + end |> mat -> mapreduce(hcat, qs) do q #Loop over matrix row to condense into qs + map(eachrow(mat)) do row + if any(ismissing, row) + return missing + else + quantile(row, q) + end + end + end +end + # ╔═╡ 8b557bf1-f3dd-4f42-a250-ce965412eb32 let C = south_korea_data.y_t D = south_korea_data.dates - gens = inference_results.generated #Unconditional model for posterior predictive sampling - mdl_unconditional = generate_epiaware(epi_prob, (y_t = missing,)) - predicted_y_t = mapreduce( - hcat, generated_quantities(mdl_unconditional, inference_results.samples)) do gen - gen.generated_y_t - end - predicted_I_t = mapreduce( - hcat, gens) do gen - gen.I_t - end - predicted_R_t = mapreduce( - hcat, gens) do gen - exp.(gen.Z_t) - end + mdl_unconditional = generate_epiaware(epi_prob, (y_t = fill(missing, length(C)),)) | + (var"obs.cluster_factor" = fixed_cluster_factor,) + posterior_gens = generated_quantities(mdl_unconditional, inference_results.samples) + + #plotting quantiles + qs = [0.025, 0.25, 0.5, 0.75, 0.975] + + #Prediction quantiles + predicted_y_t = generated_quantiles(posterior_gens, :generated_y_t, qs) + predicted_R_t = generated_quantiles( + posterior_gens, :Z_t, qs; transformation = x -> exp.(x)) + + #Plots + p1 = plot(D, predicted_y_t[:, 3], lw = 2, lab = "post. median", c = :purple) + plot!(p1, D, predicted_y_t[:, 2], fillrange = predicted_y_t[:, 4], + fillalpha = 0.5, lw = 0, c = :purple, lab = "50%") + plot!(p1, D, predicted_y_t[:, 1], fillrange = predicted_y_t[:, 5], + fillalpha = 0.2, lw = 0, c = :purple, lab = "95%") - p1 = plot(D, predicted_y_t, c = :grey, alpha = 0.05, lab = "") scatter!(p1, D, C, lab = "Actual cases", ylabel = "Daily Cases", - title = "Post. predictive: Cases", - ylims = (-0.5, maximum(C) * 2), - c = :red + title = "Posterior predictive: Cases", + ylims = (-50, maximum(C) * 2), + c = :black ) - p2 = plot(D, predicted_I_t, - c = :grey, - alpha = 0.05, - lab = "", - ylabel = "Daily latent infections", - ylims = (-0.5, maximum(C) * 1.5), - title = "Prediction: Latent infections" - ) - - p3 = plot(D, predicted_R_t, - c = :grey, - alpha = 0.025, - lab = "", - ylabel = "Rt", - title = "Prediction: Reproduction number", - yscale = :log10 - ) - hline!(p3, [1.0], lab = "Rt = 1", lw = 2, c = :blue) + p2 = plot(D, predicted_R_t[:, 3], lw = 2, lab = "post. median", c = :green, + yscale = :log10, title = "Prediction: Reproduction number") + plot!(p2, D, predicted_R_t[:, 2], fillrange = predicted_R_t[:, 4], + fillalpha = 0.5, lw = 0, c = :green, lab = "50%") + plot!(p2, D, predicted_R_t[:, 1], fillrange = predicted_R_t[:, 5], + fillalpha = 0.2, lw = 0, c = :green, lab = "95%") + hline!(p2, [1.0], lab = "Rt = 1", lw = 2, c = :blue) - plot(p1, p2, p3, layout = (3, 1), size = (500, 700), left_margin = 5mm) + plot(p1, p2, layout = (2, 1), size = (500, 700), left_margin = 5mm) end # ╔═╡ c05ed977-7a89-4ac8-97be-7078d69fce9f @@ -523,13 +521,13 @@ We can interrogate the sampled chains directly from the `samples` field of the ` # ╔═╡ ff21c9ec-1581-405f-8db1-0f522b5bc296 let - p1 = histogram(inference_results.samples["obs.cluster_factor"], + p1 = histogram(inference_results.samples["latent.σ_AR"], lab = "chain " .* string.([1 2 3 4]), fillalpha = 0.4, lw = 0, norm = :pdf, - title = "Posterior dist: Neg. bin. cluster factor") - plot!(p1, obs.cluster_factor_prior, + title = "Posterior dist: AR noise std") + plot!(p1, ar.std_prior, lw = 3, c = :black, lab = "prior") @@ -545,7 +543,7 @@ let c = :black, lab = "prior") - p3 = histogram(inference_results.samples["latent.damp_AR[1]"], + p3 = histogram(inference_results.samples["latent.rev_damp_AR[2]"], lab = "chain " .* string.([1 2 3 4]), fillalpha = 0.4, lw = 0, @@ -556,7 +554,7 @@ let c = :black, lab = "prior") - p4 = histogram(inference_results.samples["latent.damp_AR[2]"], + p4 = histogram(inference_results.samples["latent.rev_damp_AR[1]"], lab = "chain " .* string.([1 2 3 4]), fillalpha = 0.4, lw = 0, @@ -581,27 +579,23 @@ end # ╠═9eb03a0b-c6ca-4e23-8109-fb68f87d7fdf # ╠═97b5374e-7653-4b3b-98eb-d8f73aa30580 # ╠═1642dbda-4915-4e29-beff-bca592f3ec8d -# ╟─9161ab72-5c39-4a67-9762-e19f1c54c7fd # ╟─1d3b9541-80ad-41b5-a5ed-a947f5c0731b # ╠═4e5e0e24-8c55-4cb4-be3a-d28198f81a69 # ╠═a59d977c-0178-11ef-0063-83e30e0cf9f0 # ╟─104f4d16-7433-4a2d-89e7-288a9b223563 +# ╟─d753b21f-cf8e-4a25-bab3-46c811c80a78 # ╟─d201c82b-8efd-41e2-96d7-4f5e0c67088c # ╠═c88bbbd6-0101-4c04-97c9-c5887ef23999 -# ╟─40352cd6-3592-438b-b5d8-f56dcb1a4d27 # ╟─31ee2757-0409-45df-b193-60c552797a3d # ╠═2bf22866-b785-4ee0-953d-ac990a197561 # ╟─25e25125-8587-4451-8600-9b55a04dbcd9 # ╠═fbe117b7-a0b8-4604-a5dd-e71a0a1a4fc3 # ╟─9f84dec1-70f1-442e-8bef-a9494921549e -# ╠═51a82a62-2c59-43c9-8562-69d15a7edfdd -# ╠═d3938381-01b7-40c6-b369-a456ff6dba72 -# ╟─12fd3bd5-657e-4b1a-aa88-6063419aaceb -# ╠═61eac666-9fe4-4918-bd3f-68e89275d07a -# ╠═5a96e7e9-0376-4365-8eb1-b2fad9be8fef # ╟─6a9e871f-a2fa-4e41-af89-8b0b3c3b5b4b # ╠═c1fc1929-0624-45c0-9a89-86c8479b2675 +# ╟─ab0c6bec-1ab7-43d1-aa59-11225dea79eb # ╠═99c9ba2c-20a5-4c7f-94d2-272d6c9d5904 +# ╟─3c9849a8-1361-49e7-8b4e-cc4035b3fc70 # ╠═71d08f7e-c409-4fbe-b154-b21d09010683 # ╟─4a2b5cf1-623c-4fe7-8365-49fb7972af5a # ╠═9e49d451-946b-430b-bcdb-1ef4bba55a4b @@ -614,11 +608,10 @@ end # ╟─c8ef8a60-d087-4ae9-ae92-abeea5afc7ae # ╠═714908a1-dc85-476f-a99f-ec5c95a78b60 # ╟─dacb8094-89a4-404a-8243-525c0dbfa482 -# ╠═d45f34e2-64f0-4828-ae0d-7b4cb3a3287d +# ╟─d45f34e2-64f0-4828-ae0d-7b4cb3a3287d # ╠═2e0e8bf3-f34b-44bc-aa2d-046e1db6ee2d # ╠═55c639f6-b47b-47cf-a3d6-547e793c72bc # ╠═c3a62dda-e054-4c8c-b1b8-ba1b5c4447b3 -# ╟─de5d96f0-4df6-4cc3-9f1d-156176b2b676 # ╟─a06065e1-0e20-4cf8-8d5a-2d588da20bee # ╠═eaad5f46-e928-47c2-90ec-2cca3871c75d # ╟─2678f062-36ec-40a3-bd85-7b57a08fd809 @@ -626,9 +619,14 @@ end # ╠═88b43e23-1e06-4716-b284-76e8afc6171b # ╟─92333a96-5c9b-46e1-9a8f-f1890831066b # ╠═c7140b20-e030-4dc4-97bc-0efc0ff59631 +# ╟─f6c168e5-6933-4bd7-bf71-35a37551d040 +# ╠═9cbacc02-9c76-41eb-9c75-fec667b60829 +# ╟─b2074ff2-562d-44e6-b4b4-7a77c0f85c16 +# ╠═fe47748e-151b-4819-987a-07cf35e6cc80 # ╟─9970adfd-ee88-4598-87a3-ffde5297031c -# ╠═660a8511-4dd1-4788-9c14-fdd604bf83ad +# ╠═3d10379a-3bb4-474c-ad20-de767b82d52b # ╟─5e6f505b-49fe-4ff4-ac2e-f6adcd445569 +# ╠═aa1d8b72-a3d2-4844-bb43-406b98b2648f # ╠═8b557bf1-f3dd-4f42-a250-ce965412eb32 # ╟─c05ed977-7a89-4ac8-97be-7078d69fce9f # ╠═ff21c9ec-1581-405f-8db1-0f522b5bc296 From 015e71f2d0df65aedbd62c564e90bc87b798732f Mon Sep 17 00:00:00 2001 From: Samuel Brand Date: Thu, 29 Aug 2024 16:54:23 +0100 Subject: [PATCH 2/2] fix plot --- .../docs/src/showcase/replications/mishra-2020/index.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl index 2d8977f8a..87dc1ca05 100644 --- a/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl +++ b/EpiAware/docs/src/showcase/replications/mishra-2020/index.jl @@ -543,24 +543,24 @@ let c = :black, lab = "prior") - p3 = histogram(inference_results.samples["latent.rev_damp_AR[2]"], + p3 = histogram(inference_results.samples["latent.damp_AR[2]"], lab = "chain " .* string.([1 2 3 4]), fillalpha = 0.4, lw = 0, norm = :pdf, title = "Posterior dist: rho_1") - plot!(p3, ar.damp_prior.v[1], + plot!(p3, ar.damp_prior.v[2], lw = 3, c = :black, lab = "prior") - p4 = histogram(inference_results.samples["latent.rev_damp_AR[1]"], + p4 = histogram(inference_results.samples["latent.damp_AR[1]"], lab = "chain " .* string.([1 2 3 4]), fillalpha = 0.4, lw = 0, norm = :pdf, title = "Posterior dist: rho_2") - plot!(p4, ar.damp_prior.v[2], + plot!(p4, ar.damp_prior.v[1], lw = 3, c = :black, lab = "prior")