Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further Prune PostProcessing Code, Specifically plot_posterior And spread_draws #431

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,38 @@ Now, let's investigate the output, particularly the posterior distribution of th
```{python}
# | label: fig-output-rt
# | fig-cap: Rt posterior distribution
out = model1.plot_posterior(var="Rt")
import arviz as az

# Create arviz inference data object
idata = az.from_numpyro(
posterior=model1.mcmc,
)

# Extract Rt signal samples across chains
rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values


# Plot Rt signal
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(
np.arange(rt.shape[0]),
rt,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples")
ax.plot(
np.arange(rt.shape[0]),
rt.mean(axis=1),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel(r"$\mathscr{R}_t$ Signal", fontsize=20)
ax.set_xlabel("Days", fontsize=20)
plt.show()
```
We can use the `get_samples` method to extract samples from the model
```{python}
Expand Down
149 changes: 131 additions & 18 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ obs = observation.NegativeBinomialObservation(

```{python}
# | label: init-model
# | code-fold: true
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
Expand All @@ -186,6 +187,7 @@ Here is what the model looks like without the day-of-the-week effect:
```{python}
# | label: fig-output-admissions-padding-and-weekday
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
# | code-fold: true
import jax
import numpy as np

Expand All @@ -197,16 +199,56 @@ hosp_model.run(
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False),
)
```

```{python}
# | code-fold: true
import arviz as az
import matplotlib.pyplot as plt


# Retrieve the posterior samples from the model
ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)

# Create an InferenceData object from model
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
)

# Plotting the posterior
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# Use a time series plot (plot_ts) from arviz for plotting
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best"
)
plt.show()
```





## Round 2: Incorporating day-of-the-week effects

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.
Expand Down Expand Up @@ -280,34 +322,105 @@ As a result, we can see the posterior distribution of our novel day-of-the-week
```{python}
# | label: fig-output-day-of-week
# | fig-cap: Day of the week effect
out = hosp_model_dow.plot_posterior(
var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500
# Create an InferenceData object from hosp_model_dow
dow_idata = az.from_numpyro(
posterior=hosp_model_dow.mcmc,
)

sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"])
# dayofweek_effect is not recorded
# Extract day of week effect (DOW)
dow_effect_raw = dow_idata.posterior["dayofweek_effect_raw"].squeeze().T
indices = np.random.choice(dow_effect_raw.shape[1], size=200, replace=False)
dow_plot_samples = dow_effect_raw[:, indices]
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
ax.plot(
np.arange(dow_effect_raw.shape[0]),
dow_plot_samples,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.10, label="DOW Posterior Samples")
ax.plot(
np.arange(dow_effect_raw.shape[0]),
dow_plot_samples.mean(dim="draw"),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel("Effect", fontsize=20)
ax.set_xlabel("Day Of Week", fontsize=20)
plt.show()
```

The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect:

```{python}
# | label: fig-output-admissions-original
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
# Figure without weekday effect
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# Without weekday effect (from earlier)
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles,
["Observed", "Posterior Predictive", "Samples wo/ WDE"],
loc="best",
)
plt.show()
```

```{python}
# | label: fig-output-admissions-wof
# | fig-cap: Hospital Admissions posterior distribution with weekday effect
# Figure with weekday effect
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
ppc_samples = hosp_model_dow.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model_dow.mcmc,
posterior_predictive=ppc_samples,
)

axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Posterior Predictive", "Samples w/ WDE"], loc="best"
)
plt.show()
```
43 changes: 35 additions & 8 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,47 @@ hosp_model.run(
)
```

We can use the `Model` object's `plot_posterior` method to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]:
We can use `arviz` to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]:

[^capture]: The output is captured to avoid `quarto` from displaying the output twice.

```{python}
# | label: fig-output-hospital-admissions
# | fig-cap: Latent hospital admissions posterior samples (blue) and observed admissions timeseries (black).
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# | fig-cap: Latent hospital admissions posterior samples (gray) and observed admissions timeseries (red).
import arviz as az

ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
)

axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best"
)
plt.show()
```

## Results exploration and MCMC diagnostics
Expand All @@ -317,7 +346,6 @@ To explore further, We can use [ArviZ](https://www.arviz.org/) to visualize the

```{python}
# | label: convert-inferenceData
import arviz as az

idata = az.from_numpyro(hosp_model.mcmc)
```
Expand Down Expand Up @@ -419,7 +447,6 @@ We can use the `Model`'s `posterior_predictive` and `prior_predictive` methods t

```{python}
# | label: demonstrate-use-of-predictive-methods
import arviz as az

idata = az.from_numpyro(
hosp_model.mcmc,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ python = "^3.12"
jax = ">=0.4.30"
numpy = "^2.0.0"
polars = "^1.2.1"
matplotlib = "^3.8.3"
numpyro = ">=0.15.3"

[tool.poetry.group.dev]
Expand All @@ -30,6 +29,7 @@ deptry = "^0.17.0"
optional = true

[tool.poetry.group.docs.dependencies]
matplotlib = "^3.8.3"
ipykernel = "^6.29.3"
pyyaml = "^6.0.0"
nbclient = "^0.10.0"
Expand Down
Loading