Skip to content

Commit

Permalink
adding first tutorial edited w/ arviz, still need to get proper plot_ppc
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Sep 24, 2024
1 parent 4183ccc commit 51c4c8b
Showing 1 changed file with 34 additions and 5 deletions.
39 changes: 34 additions & 5 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,42 @@ hosp_model.run(
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False),
)
```

# Plotting the posterior
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
```{python}
import arviz as az
import matplotlib.pyplot as plt
ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
constant_data={"daily_hosp_admits": daily_hosp_admits},
coords={"time": np.arange(daily_hosp_admits.size)},
dims={
"daily_hosp_admits": ["time"],
"latent_hospital_admissions": ["time"],
},
)
print(idata.observed_data)
fig, ax = plt.subplots(figsize=(8, 6))
az.plot_ppc(data=idata, kind="kde", ax=ax, num_pp_samples=100)
# ax.plot(np.arange(daily_hosp_admits.size), daily_hosp_admits.astype(float), color="black", label="Observed signal")
ax.legend()
plt.xlabel("Time")
plt.ylabel("Hospital Admissions")
plt.show()
# # Plotting the posterior
# out = hosp_model.plot_posterior(
# var="latent_hospital_admissions",
# ylab="Hospital Admissions",
# obs_signal=daily_hosp_admits.astype(float),
# )
```


Expand Down

0 comments on commit 51c4c8b

Please sign in to comment.