Seeking advice on MCMC sampling using Numpyro for an ODE model solved with Diffrax #26022
Replies: 1 comment 1 reply
-
Replacing the list comprehension in # times = jnp.array([d[0] for d in dose_schedule])
times = dose_schedule[:, 0]
# doses = jnp.array([d[1] for d in dose_schedule])
doses = dose_schedule[:, 1]
active_doses = jnp.where((t >= times) & (t < times + D1), doses, 0.0)
return jnp.sum(active_doses) Often, improving the performance of a sampler is easier by making a change of variables than making the individual likelihood evaluations run faster. Stan has a good discussion here. If the condition number of your posterior correlation matrix is large, making a change of variables can have a significant impact on sampling performance. all_samples = jnp.concatenate([... loop over posterior samples from each numpyro site ...])
all_samples.shape # Should be (n_posterior_samples, n_parameters_of_your_model).
corr = = jnp.corrcoef(all_samples.T)
jnp.linalg.cond(corr) If the condition number is much larger than the number of parameters, reparameterization may be a good candidate to improve the performance of your model. The Stan forums are a great resource. |
Beta Was this translation helpful? Give feedback.
-
Seeking advice on optimizing Diffrax/Numpyro implementation of Pharmacokinetic-Pharmacodynamic (PK-PD) model
I've reimplemented a published PK-PD model using diffrax and numpyro, and I'm looking for advice on optimizing performance, particularly around JAX usage and MCMC sampling efficiency.
I have been working on setting up a Bayesian Model for this problem but have struggled with a long sampling time which hinders my ability to evaluate alternative priors and model structures efficiently. For now I have a very low amount of samples used for illustrative purposes and am seeking advice on if I am following JAX best practices to achieve high performance. If anyone has some advice I would greatly appreciate it.
Model Overview
[Aldea R, Grimm HP, Gieschke R, et al. In silico exploration of amyloid-related imaging abnormalities in the gantenerumab open-label extension trials using a semi-mechanistic model. Alzheimer's Dement. 2022; 8:e12306. https://doi.org/10.1002/trc2.12306]
The model tracks drug concentration through absorption, central, and peripheral compartments, then models its effect on amyloid beta and vascular wall damage (VWD). The BGTS score (biomarker) is calculated from VWD using a sigmoidal response function. Most PK parameters are fixed from literature, while I'm trying to estimate:
Note: This is not the true data from the paper
I'm using diffrax's Dopri5 solver with a fixed step size. (PID controllers tend to make the ODE system fail to solve in the max steps)
Current Performance
System Details
I'm particularly interested in:
I would like to be able to run this code efficiently for much longer chains, this initial performance for a relatively small model has me questioning the reasonability to fit larger models in a similar manner.
Full implementation code below. Any suggestions for improving performance would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions