Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

debug tutorial #362

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8007fee
first checkin
damonbayer Jul 23, 2024
7695435
ignore .json
damonbayer Jul 23, 2024
a45cd58
initialization working
damonbayer Jul 24, 2024
e8a3fad
load variables from stan_data
damonbayer Jul 24, 2024
b4eb2e8
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
8940186
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
8519449
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 25, 2024
2315d0f
prep for update from main
damonbayer Jul 26, 2024
179f741
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 26, 2024
4a1cdc6
checkin (broken)
damonbayer Jul 26, 2024
3915fc4
working Rt
damonbayer Jul 29, 2024
428b37f
fix a prior
damonbayer Jul 29, 2024
e5252ab
infection with feedback (works but looks wrong)
damonbayer Jul 29, 2024
b25fc8e
cleanup
damonbayer Jul 29, 2024
3759770
fixing some things
damonbayer Jul 30, 2024
60013ff
check in
damonbayer Jul 30, 2024
a73a85f
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 30, 2024
a162235
fix AR process
damonbayer Jul 30, 2024
c8b71c2
Merge branch 'main' into dmb_hosp_only_ww_model
damonbayer Jul 30, 2024
9a66578
move n_timepoints to sample arg
damonbayer Jul 31, 2024
6ce532c
fix ihr broadcasting
damonbayer Aug 1, 2024
0ac1006
cleanup initialization
damonbayer Aug 1, 2024
023d91a
day of week effect
damonbayer Aug 1, 2024
b6a39f8
cleanup day of week effect
damonbayer Aug 1, 2024
68f1977
correct dotwe
damonbayer Aug 1, 2024
6e7e0b2
add latent hospitalizations
damonbayer Aug 2, 2024
5cfb946
rename hospitalizations
damonbayer Aug 2, 2024
61b5c2d
data observation model
damonbayer Aug 2, 2024
5be7e9a
rename n_timepoints
damonbayer Aug 2, 2024
0e00066
work with supplied data
damonbayer Aug 2, 2024
e52da09
replace ceil with integer division
damonbayer Aug 2, 2024
f5e5e13
prior predictive
damonbayer Aug 2, 2024
2f3e611
posterior
damonbayer Aug 2, 2024
49f879f
remove json from gitignore
damonbayer Aug 2, 2024
c2d48dd
add stan data
damonbayer Aug 2, 2024
ebfec69
convert to tutorial
damonbayer Aug 2, 2024
468743c
fix document title
damonbayer Aug 2, 2024
c5178a7
try avoiding ipython error
damonbayer Aug 2, 2024
3252e31
Print debug tutorial
dylanhmorris Aug 5, 2024
e48164b
Merge branch 'main' into dhm_hosp_only_example_ci
dylanhmorris Aug 5, 2024
ee2c4d1
more print debug
dylanhmorris Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ extend-exclude = [
".gitignore",
".pre-commit-config.yaml",
".github/ISSUE_TEMPLATE/general_issue.md",
"*.json"
]
250 changes: 250 additions & 0 deletions docs/source/tutorials/hosp_only_ww_model.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
---
title: "Replicating Hospital Only Model from `cdcgov/wastewater-informed-covid-forecasting`"
format: gfm
engine: jupyter
---

```{python}
# | label: setup
import json

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
import numpyro.distributions.transforms as transforms
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import DistributionalRV, TransformedRandomVariable
from pyrenew.model import hosp_only_ww_model

numpyro.set_host_device_count(1)
# model crashes if run in parallel
# see https://github.com/pyro-ppl/numpyro/issues/1836
```

## Background

This tutorial provides a demonstration of our reimplementation of "Model 2" from the `wastewater-informed-covid-forecasting` project.
The model is described [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/model_definition.md).
Stan code for the model is [here](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/inst/stan/renewal_ww_hosp_site_level_inf_dynamics.stan).

The model we provide is designed to be fully-compatible with the stan_data generated in the that project.
We provide the stan data used in the `toy_data_vignette` [vignette](https://github.com/CDCgov/wastewater-informed-covid-forecasting/blob/prod/cfaforecastrenewalww/vignettes/toy_data_vignette.Rmd) in the `wastewater-informed-covid-forecasting` project.
The data is available in `scratch/stan_data_hosp_only.json`.
This data was generated by running `scratch/save_from_vignette.R` after running all the cells in the vignette.
This script also saves the posterior samples from the model for comparison to our own model.

## Load Data and create Priors

We begin by loading the stan_data and converting it to priors used in our model.
```{python}
# | label: Load data and create priors
# | code-fold: true
def convert_to_logmean_log_sd(mean, sd):
logmean = np.log(
np.power(mean, 2) / np.sqrt(np.power(sd, 2) + np.power(mean, 2))
)
logsd = np.sqrt(np.log(1 + (np.power(sd, 2) / np.power(mean, 2))))
return logmean, logsd


# Load the JSON file
import os

print(os.listdir("."))

print(os.listdir("../../"))

with open(
"../../scratch/stan_data_hosp_only.json",
"r",
) as file:
stan_data = json.load(file)

i0_over_n_prior_a = stan_data["i0_over_n_prior_a"][0]
i0_over_n_prior_b = stan_data["i0_over_n_prior_b"][0]
i0_over_n_rv = DistributionalRV(
"i0_over_n_rv", dist.Beta(i0_over_n_prior_a, i0_over_n_prior_b)
)

initial_growth_prior_mean = stan_data["initial_growth_prior_mean"][0]
initial_growth_prior_sd = stan_data["initial_growth_prior_sd"][0]
initialization_rate_rv = DistributionalRV(
"rate",
dist.TruncatedNormal(
loc=initial_growth_prior_mean,
scale=initial_growth_prior_sd,
low=-1,
high=1,
),
)
# could reasonably switch to non-Truncated

r_prior_mean = stan_data["r_prior_mean"][0]
r_prior_sd = stan_data["r_prior_sd"][0]
r_logmean, r_logsd = convert_to_logmean_log_sd(r_prior_mean, r_prior_sd)
log_r_mu_intercept_rv = DistributionalRV(
"log_r_mu_intercept_rv", dist.Normal(r_logmean, r_logsd)
)


eta_sd_sd = stan_data["eta_sd_sd"][0]
eta_sd_rv = DistributionalRV(
"eta_sd", dist.TruncatedNormal(0, eta_sd_sd, low=0)
)

autoreg_rt_a = stan_data["autoreg_rt_a"][0]
autoreg_rt_b = stan_data["autoreg_rt_b"][0]
autoreg_rt_rv = DistributionalRV(
"autoreg_rt", dist.Beta(autoreg_rt_a, autoreg_rt_b)
)


generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf", jnp.array(stan_data["generation_interval"])
)

infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf", jnp.array(stan_data["infection_feedback_pmf"])
)

inf_feedback_prior_logmean = stan_data["inf_feedback_prior_logmean"][0]
inf_feedback_prior_logsd = stan_data["inf_feedback_prior_logsd"][0]
inf_feedback_strength_rv = TransformedRandomVariable(
"inf_feedback",
DistributionalRV(
"inf_feedback_raw",
dist.LogNormal(inf_feedback_prior_logmean, inf_feedback_prior_logsd),
),
transforms=transforms.AffineTransform(loc=0, scale=-1),
)
# Could be reparameterized?

p_hosp_prior_mean = stan_data["p_hosp_prior_mean"][0]
p_hosp_sd_logit = stan_data["p_hosp_sd_logit"][0]

p_hosp_mean_rv = DistributionalRV(
"p_hosp_mean",
dist.Normal(transforms.logit(p_hosp_prior_mean), p_hosp_sd_logit),
) # logit scale

p_hosp_w_sd_sd = stan_data["p_hosp_w_sd_sd"][0]
p_hosp_w_sd_rv = DistributionalRV(
"p_hosp_w_sd_sd", dist.TruncatedNormal(0, p_hosp_w_sd_sd, low=0)
)

autoreg_p_hosp_a = stan_data["autoreg_p_hosp_a"][0]
autoreg_p_hosp_b = stan_data["autoreg_p_hosp_b"][0]
autoreg_p_hosp_rv = DistributionalRV(
"autoreg_p_hosp", dist.Beta(autoreg_p_hosp_a, autoreg_p_hosp_b)
)

# hosp_wday_effect ~ normal(effect_mean, wday_effect_prior_sd);
# wday_effect_prior_mean = stan_data["wday_effect_prior_mean"][0]
# wday_effect_prior_sd = stan_data["wday_effect_prior_sd"][0]
# Instead of the above, use a Dirichlet prior (see https://github.com/CDCgov/ww-inference-model/issues/42)

hosp_wday_effect_rv = TransformedRandomVariable(
"hosp_wday_effect",
DistributionalRV(
"hosp_wday_effect_raw", dist.Dirichlet(concentration=jnp.ones(7))
),
transforms.AffineTransform(loc=0, scale=7),
)

inf_to_hosp_rv = DeterministicVariable(
"inf_to_hosp", jnp.array(stan_data["inf_to_hosp"])
)

inv_sqrt_phi_prior_mean = stan_data["inv_sqrt_phi_prior_mean"][0]
inv_sqrt_phi_prior_sd = stan_data["inv_sqrt_phi_prior_sd"][0]

phi_rv = TransformedRandomVariable(
"phi",
DistributionalRV(
"inv_sqrt_phi",
dist.TruncatedNormal(
loc=inv_sqrt_phi_prior_mean,
scale=inv_sqrt_phi_prior_sd,
low=1 / jnp.sqrt(5000),
),
),
transforms=transforms.PowerTransform(-2),
)

uot = stan_data["uot"][0]
state_pop = stan_data["state_pop"][0]

data_observed_hospital_admissions = jnp.array(stan_data["hosp"])
```

# Simulate from the model

Next, we define the model:

```{python}
# | label: define the model
my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_over_n_rv=i0_over_n_rv,
initialization_rate_rv=initialization_rate_rv,
log_r_mu_intercept_rv=log_r_mu_intercept_rv,
autoreg_rt_rv=autoreg_rt_rv, # ar process
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
phi_rv=phi_rv,
inf_to_hosp_rv=inf_to_hosp_rv,
n_initialization_points=uot,
)
```


We check that we can simulate from the prior predictive
```{python}
# | label: prior predictive
# | eval: false
# for some reason the posterior inference crashes if we do the prior predictive first
prior_predictive = my_model.prior_predictive(
n_datapoints=len(data_observed_hospital_admissions),
numpyro_predictive_args={"num_samples": 200},
)
```

# Fit the model

Now we can fit the model to the observed data:
```{python}
# | label: fit the model
my_model.run(
num_warmup=750,
num_samples=500,
rng_key=jax.random.key(200),
data_observed_hospital_admissions=data_observed_hospital_admissions,
mcmc_args=dict(num_chains=2),
)
```

Check the posterior predictive:

```{python}
# | label: posterior predictive
my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions)
)
```

Forecasting is broken (dependent on https://github.com/CDCgov/multisignal-epi-inference/issues/328)

```{python}
# | label: posterior forecast
# | eval: false
my_model.posterior_predictive(
n_datapoints=len(data_observed_hospital_admissions) + 2
)
4 changes: 4 additions & 0 deletions model/improving_hosp_only_notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
notes for improving hosp_only_ww_model

- Change I0 reference point to be immediately before the observation period. We may have some idea about the proportion of the population that is infectious at the start of the modeling period, but not 50 days before the modeling period + exponenetioal growth.
- Initial exponential growth rate prior should be positive (Not Truncated Normal(0, 0.01))
7 changes: 6 additions & 1 deletion model/src/pyrenew/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
# numpydoc ignore=GL08

from pyrenew.model.admissionsmodel import HospitalAdmissionsModel
from pyrenew.model.hosp_only_ww_model import hosp_only_ww_model
from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel

__all__ = ["RtInfectionsRenewalModel", "HospitalAdmissionsModel"]
__all__ = [
"HospitalAdmissionsModel",
"hosp_only_ww_model",
"RtInfectionsRenewalModel",
]
Loading