-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extend Weibull Regression Examples (#256)
Updated Weibull Regression Examples namely: - Added flexsurv and survreg examples for Weibull regression - Added second weibull regression examples file showing it applied to a common dataset (not simulated data) - Updated stan code to be more efficient (but other wise no changes in functionality nor interfaces)
- Loading branch information
Showing
3 changed files
with
213 additions
and
59 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,72 +1,52 @@ | ||
|
||
|
||
functions { | ||
// Stan's in-built weibull functions are for the AFT formulation | ||
// so instead we define here the PH formulation | ||
real weibull_ph_lpdf (real t, real lambda, real gamma) { | ||
return log(lambda) + log(gamma) + (gamma -1)* log(t) -lambda * t^gamma; | ||
} | ||
real weibull_ph_lccdf (real t, real lambda, real gamma) { | ||
return -lambda * t^gamma; | ||
} | ||
} | ||
|
||
|
||
data { | ||
int<lower=1> n; | ||
int<lower=0> p; | ||
vector[n] times; | ||
vector[n] event_fl; | ||
matrix[n, p] design; | ||
int<lower=1> n; // Number of subjects | ||
int<lower=0> p; // Number of covariates (including intercept) | ||
vector<lower=0>[n] times; // Event|Censor times | ||
array[n] int<lower=0, upper=1> event_fl; // 1=event 0=censor | ||
matrix[n, p] design; // Design matrix | ||
} | ||
|
||
transformed data { | ||
// To make life easier for the user we manually derive the indices | ||
// of the event and censoring times (this would be easier to implement | ||
// on the R side but would require more code from the user). | ||
|
||
// First we work out how many censoring and events there are each so that | ||
// we can construct index vectors of the correct size | ||
int n_event = 0; | ||
for (i in 1:n) { | ||
if (event_fl[i] == 1) { | ||
n_event += 1; | ||
} | ||
} | ||
int n_censor = n - n_event; | ||
|
||
// Now we construct the index vectors | ||
array[n_event] int event_idx; | ||
array[n_censor] int censor_idx; | ||
int j = 1; | ||
int k = 1; | ||
for (i in 1:n) { | ||
if (event_fl[i] == 1) { | ||
event_idx[j] = i; | ||
j+=1; | ||
} else { | ||
censor_idx[k] = i; | ||
k+=1; | ||
} | ||
} | ||
// Assuming that the first term is an intercept column which | ||
// will conflict with the lambda term so remove it | ||
matrix[n, p-1] design_reduced = design[, 2:p]; | ||
} | ||
|
||
parameters { | ||
vector[p] beta; | ||
// real<lower = 0> lambda; // Lambda is represented via exp(intercept) term of the design matrix | ||
real<lower = 0> gamma; | ||
vector[p-1] beta; | ||
real<lower=0> lambda_0; | ||
real<lower=0> gamma_0; | ||
} | ||
|
||
|
||
model { | ||
// Stans in built distributions use the AFT parameterisation of the Weibull distribution | ||
// As such we need to convert our PH parameters to the AFT parameters | ||
real d_alpha; | ||
vector[n] d_beta; | ||
d_beta = exp(design * beta)^(-1/gamma); | ||
d_alpha = gamma; | ||
|
||
|
||
vector[n] lambda = lambda_0 .* exp(design_reduced * beta); | ||
|
||
// Priors | ||
beta ~ normal(0, 2); | ||
gamma ~ normal(1, 0.5); | ||
beta ~ normal(0, 3); | ||
gamma_0 ~ lognormal(log(1), 1.5); | ||
lambda_0 ~ lognormal(log(0.05), 1.5); | ||
|
||
// Likelihood | ||
target += weibull_lpdf(times[event_idx] | d_alpha, d_beta[event_idx]); | ||
target += weibull_lccdf(times[censor_idx] | d_alpha, d_beta[censor_idx]); | ||
} | ||
|
||
generated quantities { | ||
// Reconstruct the baseline lambda value from the intercept | ||
// of the design matrix | ||
real lambda_0 = exp(beta[1]); | ||
for (i in 1:n) { | ||
if (event_fl[i] == 1) { | ||
target += weibull_ph_lpdf(times[i] | lambda[i], gamma_0); | ||
} else { | ||
target += weibull_ph_lccdf(times[i] | lambda[i], gamma_0); | ||
} | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
|
||
|
||
library(dplyr) | ||
library(flexsurv) | ||
library(survival) | ||
library(cmdstanr) | ||
library(posterior) | ||
library(bayesplot) | ||
library(here) | ||
|
||
|
||
dat <- flexsurv::bc |> | ||
as_tibble() |> | ||
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n())) | ||
|
||
|
||
|
||
################################ | ||
# | ||
# Cox Regression | ||
# | ||
|
||
coxph( | ||
Surv(recyrs, censrec) ~ group, | ||
data = dat | ||
) | ||
|
||
|
||
################################ | ||
# | ||
# Flexsurv parametric Regression | ||
# | ||
|
||
|
||
flexsurvreg( | ||
Surv(recyrs, censrec) ~ group, | ||
data = dat, | ||
dist = "weibullPH" | ||
) | ||
|
||
|
||
################################ | ||
# | ||
# Survreg parametric Regression | ||
# | ||
|
||
|
||
mod <- survreg( | ||
Surv(recyrs, censrec) ~ group, | ||
data = dat, | ||
dist = "weibull" | ||
) | ||
|
||
gamma <- 1 / mod$scale | ||
lambda <- exp(-mod$coefficients[1] * gamma) | ||
param_log_coefs <- -mod$coefficients[-1] * gamma | ||
|
||
c(gamma, lambda, param_log_coefs) | ||
## Need to use delta method to get standard errors | ||
|
||
|
||
|
||
################################ | ||
# | ||
# Bayesian Weibull Regression | ||
# | ||
|
||
mod <- cmdstan_model( | ||
stan_file = here("design/examples/weibull.stan"), | ||
exe_file = here("design/examples/models/weibull") | ||
) | ||
|
||
design_mat <- model.matrix(~ group, data = dat) | ||
|
||
|
||
stan_data <- list( | ||
n = nrow(dat), | ||
design = design_mat, | ||
p = ncol(design_mat), | ||
times = dat$recyrs, | ||
event_fl = dat$censrec | ||
) | ||
fit <- mod$sample( | ||
data = stan_data, | ||
chains = 2, | ||
parallel_chains = 2, | ||
refresh = 200, | ||
iter_warmup = 1000, | ||
iter_sampling = 1500 | ||
) | ||
|
||
fit$summary() | ||
|
||
|
||
bayesplot::mcmc_pairs(fit$draws()) | ||
|
||
################################ | ||
# | ||
# JMpost | ||
# | ||
|
||
devtools::load_all() | ||
# library(jmpost) | ||
|
||
jm <- JointModel( | ||
survival = SurvivalWeibullPH( | ||
lambda = prior_lognormal(log(1/200), 1.3), | ||
gamma = prior_lognormal(log(1), 1.3) | ||
) | ||
) | ||
|
||
jdat <- DataJoint( | ||
subject = DataSubject( | ||
data = dat, | ||
subject = "pt", | ||
arm = "arm", | ||
study = "study" | ||
), | ||
survival = DataSurvival( | ||
data = dat, | ||
formula = Surv(recyrs, censrec) ~ group | ||
) | ||
) | ||
|
||
mp <- sampleStanModel( | ||
jm, | ||
data = jdat, | ||
iter_warmup = 1000, | ||
iter_sampling = 1500, | ||
chains = 2, | ||
parallel_chains = 2 | ||
) | ||
|
||
vars <- c( | ||
"sm_weibull_ph_lambda", | ||
"sm_weibull_ph_gamma", | ||
"beta_os_cov" | ||
) | ||
|
||
mp@results$summary(vars) | ||
|