Skip to content

Commit

Permalink
Fix Claret Unit test (#441)
Browse files Browse the repository at this point in the history
  • Loading branch information
gowerc authored Feb 26, 2025
1 parent 6b20f37 commit 3d7f616
Show file tree
Hide file tree
Showing 11 changed files with 617 additions and 44 deletions.
2 changes: 2 additions & 0 deletions design/debug-cb/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

**/claret_bruno
37 changes: 37 additions & 0 deletions design/debug-cb/A/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@





## Model Specification

Fixed effects only model, that is we have no patient level random effects. We have no study or arm level hierarchical effects. There is no censoring and we only account for strictly positive time.


$$
\begin{align*}
y_i &\sim \mathcal{N} \left( \mu_i, \mu_i^2 \sigma^2 \right) \\
\\
\mu_i =
b \cdot &\exp \left( g t_{i} - \frac{p}{c} \left( 1 - e^{-c t_{i}} \right) \right)
\end{align*}
$$

- $i$ is the observation index

### Priors

$$
\begin{align*}
b &\sim \text{LogNormal} \left( \right) \\
g &\sim \text{LogNormal} \left( \right) \\
p &\sim \text{LogNormal} \left( \right) \\
c &\sim \text{LogNormal} \left( \right) \\
\sigma &\sim \text{LogNormal} \left( \right)
\end{align*}
$$





58 changes: 58 additions & 0 deletions design/debug-cb/A/claret_bruno.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@


library(cmdstanr)
library(dplyr)
library(ggplot2)
library(bayesplot)


get_sld <- function(t, b, g, c, p) {
b * exp((g * t) - (p/c) * (1 - exp(-c * t)))
}

pars <- list(
b = 60,
g = 0.5,
c = 0.4,
p = 0.7,
sigma = 0.002
)

dat <- tibble(
t = seq(1, 900, by = 5) / 365,
sld_mu = get_sld(t, pars$b, pars$g, pars$c, pars$p),
sld = rnorm(length(t), sld_mu, sld_mu * pars$sigma)
)

ggplot(data = dat, aes(x = t, y = sld)) +
geom_point() +
geom_line(aes(y = sld_mu), color = "red")


mod <- cmdstan_model(
stan_file = here::here("design/debug-cb/A/claret_bruno.stan")
)

stan_data <- list(
N = nrow(dat),
values = dat$sld,
times = dat$t
)

fit <- mod$sample(
data = stan_data,
chains = 2,
parallel_chains = 2,
refresh = 200,
iter_warmup = 1000,
iter_sampling = 2000
)

fit$summary()







52 changes: 52 additions & 0 deletions design/debug-cb/A/claret_bruno.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

functions {
vector claret_bruno_mu(vector times, real b, real c, real p, real g) {
vector[rows(times)] values;
values = b * exp(
(g .* times) - (
(p/c) .* (1 - exp(-c .* times))
)
);
return values;
}
}


data {
int <lower=0> N;
vector[N] values;
vector[N] times;
}

parameters {
real <lower=0> b;
real <lower=0> c;
real <lower=0> p;
real <lower=0> g;
real <lower=0> sigma;
}

transformed parameters {
vector[N] mu = claret_bruno_mu(times, b, c, p, g);
}

// pars <- list(
// b = 60,
// g = 0.5,
// c = 0.4,
// p = 0.7,
// sigma = 0.004
// )

model {
b ~ lognormal(log(60), 0.5);
c ~ lognormal(log(0.5), 0.5);
p ~ normal(log(0.4), 0.5);
g ~ lognormal(log(0.7), 0.5);
sigma ~ lognormal(log(0.004), 0.5);
values ~ normal(mu, mu * sigma);
}




50 changes: 50 additions & 0 deletions design/debug-cb/B/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@





## Model Specification

Simple by-patient Random effects model. We have no study or arm level hierarchical effects. There is no censoring and we only account for strictly positive time.


$$
\begin{align*}
y_{ij} &\sim \mathcal{N} \left( \mu_{ij}, \mu_{ij}^2 \sigma^2 \right) \\
\\
\mu_{ij} &=
b_i \cdot \exp \left( g_i t_{ij} - \frac{p_i}{c_i} \left( 1 - e^{-c_i t_{ij}} \right) \right)
\end{align*}
$$


- $i$ is the patient index
- $j$ is the time index

### Priors

$$
\begin{align*}
b_i &\sim \text{LogNormal} \left( log(\mu_b) , \sigma_b \right) \\
g_i &\sim \text{LogNormal} \left( log(\mu_g) , \sigma_g \right) \\
p_i &\sim \text{LogNormal} \left( log(\mu_p) , \sigma_p \right) \\
c_i &\sim \text{LogNormal} \left( log(\mu_c) , \sigma_c \right) \\
\\
\mu_b &\sim \text{LogNormal} \left( \right) \\
\mu_g &\sim \text{LogNormal} \left( \right) \\
\mu_p &\sim \text{LogNormal} \left( \right) \\
\mu_c &\sim \text{LogNormal} \left( \right) \\
\\
\sigma_b &\sim \text{LogNormal} \left( \right) \\
\sigma_g &\sim \text{LogNormal} \left( \right) \\
\sigma_p &\sim \text{LogNormal} \left( \right) \\
\sigma_c &\sim \text{LogNormal} \left( \right) \\
\\
\sigma &\sim \text{LogNormal} \left( \right) \\
\end{align*}
$$





164 changes: 164 additions & 0 deletions design/debug-cb/B/claret_bruno.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@


library(cmdstanr)
library(dplyr)
library(ggplot2)
library(bayesplot)
library(tidyr)
library(brms)


get_sld <- function(t, b, g, c, p) {
b * exp((g * t) - (p/c) * (1 - exp(-c * t)))
}

pars_mu <- list(
b = 60,
g = 0.5,
c = 0.4,
p = 0.7,
sigma = 0.02
)

pars_sigma <- list(
b = 0.1,
g = 0.1,
c = 0.1,
p = 0.1
)

N <- 120

dat_baseline <- tibble(
pt = sprintf("pt%05d", 1:N),
b = exp(rnorm(N, log(pars_mu$b), pars_sigma$b)),
g = exp(rnorm(N, log(pars_mu$g), pars_sigma$g)),
c = exp(rnorm(N, log(pars_mu$c), pars_sigma$c)),
p = exp(rnorm(N, log(pars_mu$p), pars_sigma$p))
)

dat <- tidyr::crossing(
pt = dat_baseline$pt,
t = seq(1, 900, length.out = 8) / 365
) |>
left_join(dat_baseline, by = "pt") |>
mutate(
sld_mu = get_sld(t, b, g, c, p),
sld = rnorm(n(), sld_mu, sld_mu * pars_mu$sigma)
) |>
arrange(pt, t) |>
mutate(pt = factor(pt))


mod <- cmdstan_model(
stan_file = here::here("design/debug-cb/B/claret_bruno.stan")
)

stan_data <- list(
N_obs = nrow(dat),
N_pt = N,
pt_index = as.numeric(dat$pt),
values = dat$sld,
times = dat$t
)


fit <- mod$sample(
data = stan_data,
chains = 3,
parallel_chains = 3,
refresh = 200,
iter_warmup = 1500,
iter_sampling = 2000
)

fit
fit$summary()





######################
#
# brms implementation
#
#


# pars_mu <- list(
# b = 60,
# g = 0.5,
# c = 0.4,
# p = 0.7,
# sigma = 0.02
# )

bfit <- brm(
bf(
value ~ exp(b) * exp( exp(g) * t - exp(p-c) * (1 - exp(- exp(c) * t))),
b ~ 1 + (1 | pt),
g ~ 1 + (1 | pt),
c ~ 1 + (1 | pt),
p ~ 1 + (1 | pt),
nl = TRUE
),
data = dat |> select(pt, value = sld, t),
prior = c(
prior("normal(log(60), 0.3)", nlpar = "b"), # b intercept
prior("normal(log(0.5), 0.3)", nlpar = "g"), # g intercept
prior("normal(log(0.4), 0.3)", nlpar = "c"), # c intercept
prior("normal(log(0.7), 0.3)", nlpar = "p"), # p intercept
prior("lognormal(log(0.1), 0.3)", nlpar = "b", class = "sd"), # b random effect sigma
prior("lognormal(log(0.1), 0.3)", nlpar = "g", class = "sd"), # g random effect sigma
prior("lognormal(log(0.1), 0.3)", nlpar = "c", class = "sd"), # c random effect sigma
prior("lognormal(log(0.1), 0.3)", nlpar = "p", class = "sd"), # p random effect sigma
prior("lognormal(log(0.02), 0.3)", class = "sigma") # overall sigma
),
warmup = 1500,
iter = 2500,
chains = 3,
cores = 3,
backend = "cmdstanr",
control = list(adapt_delta = 0.95)
)




#####################
#
# Debugging
#
#



# Plot patient profiles
pdat <- dat |> filter(pt %in% sample(dat$pt, 5))

ggplot(data = pdat, aes(x = t, y = sld, group = pt, col = pt)) +
geom_point() +
geom_line(aes(y = sld_mu))


# Plottig priors
plot(density(exp(rnorm(5000, log(0.6), 0.2))))


# Plotting Joint Priors
N <- 100000
mu <- rnorm(N, log(0.6), 0.3)
sigma <- exp(rnorm(N, log(0.1), 0.3))
value <- exp(rnorm(N, mu, sigma))

pdat <- tibble(
mu = mu,
sigma = sigma,
value = value
)

ggplot(data = pdat, aes(x =value, y = mu)) +
geom_bin2d(bins = 300)


Loading

0 comments on commit 3d7f616

Please sign in to comment.