Skip to content

Commit

Permalink
Extend Weibull Regression Examples (#256)
Browse files Browse the repository at this point in the history
Updated Weibull Regression Examples namely:
- Added flexsurv and survreg examples for Weibull regression
- Added second weibull regression examples file showing it applied to a common dataset (not simulated data)
- Updated stan code to be more efficient (but other wise no changes in functionality nor interfaces)
  • Loading branch information
gowerc authored Feb 12, 2024
1 parent 736cabd commit 51dc9ee
Show file tree
Hide file tree
Showing 3 changed files with 213 additions and 59 deletions.
39 changes: 36 additions & 3 deletions design/examples/weibull.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,43 @@ dat <- tibble(

coxph(
Surv(time, event) ~ age + sex + trt,
data =dat
data = dat
)


################################
#
# Flexsurv parametric Regression
#


flexsurvreg(
Surv(time, event) ~ sex + trt,
data = dat,
dist = "weibullPH"
)


################################
#
# Survreg parametric Regression
#


mod <- survreg(
Surv(time, event) ~ age + sex + trt,
data = dat,
dist = "weibull"
)

gamma <- 1 / mod$scale
lambda <- exp(-mod$coefficients[1] * gamma)
param_log_coefs <- -mod$coefficients[-1] * gamma
c(gamma, lambda, param_log_coefs)
## Need to use delta method to get standard errors



################################
#
# Bayesian Weibull Regression
Expand All @@ -67,12 +100,11 @@ mod <- cmdstan_model(
exe_file = here("design/examples/models/weibull")
)

design_mat <- model.matrix(~ age + sex + trt, data = dat)
design_mat <- model.matrix(~ sex + trt, data = dat)


stan_data <- list(
n = nrow(dat),
x = rnorm(50, 5, 2),
design = design_mat,
p = ncol(design_mat),
times = dat$time,
Expand Down Expand Up @@ -133,3 +165,4 @@ vars <- c(
)

mp@results$summary(vars)

92 changes: 36 additions & 56 deletions design/examples/weibull.stan
Original file line number Diff line number Diff line change
@@ -1,72 +1,52 @@


functions {
// Stan's in-built weibull functions are for the AFT formulation
// so instead we define here the PH formulation
real weibull_ph_lpdf (real t, real lambda, real gamma) {
return log(lambda) + log(gamma) + (gamma -1)* log(t) -lambda * t^gamma;
}
real weibull_ph_lccdf (real t, real lambda, real gamma) {
return -lambda * t^gamma;
}
}


data {
int<lower=1> n;
int<lower=0> p;
vector[n] times;
vector[n] event_fl;
matrix[n, p] design;
int<lower=1> n; // Number of subjects
int<lower=0> p; // Number of covariates (including intercept)
vector<lower=0>[n] times; // Event|Censor times
array[n] int<lower=0, upper=1> event_fl; // 1=event 0=censor
matrix[n, p] design; // Design matrix
}

transformed data {
// To make life easier for the user we manually derive the indices
// of the event and censoring times (this would be easier to implement
// on the R side but would require more code from the user).

// First we work out how many censoring and events there are each so that
// we can construct index vectors of the correct size
int n_event = 0;
for (i in 1:n) {
if (event_fl[i] == 1) {
n_event += 1;
}
}
int n_censor = n - n_event;

// Now we construct the index vectors
array[n_event] int event_idx;
array[n_censor] int censor_idx;
int j = 1;
int k = 1;
for (i in 1:n) {
if (event_fl[i] == 1) {
event_idx[j] = i;
j+=1;
} else {
censor_idx[k] = i;
k+=1;
}
}
// Assuming that the first term is an intercept column which
// will conflict with the lambda term so remove it
matrix[n, p-1] design_reduced = design[, 2:p];
}

parameters {
vector[p] beta;
// real<lower = 0> lambda; // Lambda is represented via exp(intercept) term of the design matrix
real<lower = 0> gamma;
vector[p-1] beta;
real<lower=0> lambda_0;
real<lower=0> gamma_0;
}


model {
// Stans in built distributions use the AFT parameterisation of the Weibull distribution
// As such we need to convert our PH parameters to the AFT parameters
real d_alpha;
vector[n] d_beta;
d_beta = exp(design * beta)^(-1/gamma);
d_alpha = gamma;


vector[n] lambda = lambda_0 .* exp(design_reduced * beta);

// Priors
beta ~ normal(0, 2);
gamma ~ normal(1, 0.5);
beta ~ normal(0, 3);
gamma_0 ~ lognormal(log(1), 1.5);
lambda_0 ~ lognormal(log(0.05), 1.5);

// Likelihood
target += weibull_lpdf(times[event_idx] | d_alpha, d_beta[event_idx]);
target += weibull_lccdf(times[censor_idx] | d_alpha, d_beta[censor_idx]);
}

generated quantities {
// Reconstruct the baseline lambda value from the intercept
// of the design matrix
real lambda_0 = exp(beta[1]);
for (i in 1:n) {
if (event_fl[i] == 1) {
target += weibull_ph_lpdf(times[i] | lambda[i], gamma_0);
} else {
target += weibull_ph_lccdf(times[i] | lambda[i], gamma_0);
}
}
}


141 changes: 141 additions & 0 deletions design/examples/weibull_std.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@


library(dplyr)
library(flexsurv)
library(survival)
library(cmdstanr)
library(posterior)
library(bayesplot)
library(here)


dat <- flexsurv::bc |>
as_tibble() |>
mutate(arm = "A", study = "S", pt = sprintf("pt-%05d", 1:n()))



################################
#
# Cox Regression
#

coxph(
Surv(recyrs, censrec) ~ group,
data = dat
)


################################
#
# Flexsurv parametric Regression
#


flexsurvreg(
Surv(recyrs, censrec) ~ group,
data = dat,
dist = "weibullPH"
)


################################
#
# Survreg parametric Regression
#


mod <- survreg(
Surv(recyrs, censrec) ~ group,
data = dat,
dist = "weibull"
)

gamma <- 1 / mod$scale
lambda <- exp(-mod$coefficients[1] * gamma)
param_log_coefs <- -mod$coefficients[-1] * gamma

c(gamma, lambda, param_log_coefs)
## Need to use delta method to get standard errors



################################
#
# Bayesian Weibull Regression
#

mod <- cmdstan_model(
stan_file = here("design/examples/weibull.stan"),
exe_file = here("design/examples/models/weibull")
)

design_mat <- model.matrix(~ group, data = dat)


stan_data <- list(
n = nrow(dat),
design = design_mat,
p = ncol(design_mat),
times = dat$recyrs,
event_fl = dat$censrec
)
fit <- mod$sample(
data = stan_data,
chains = 2,
parallel_chains = 2,
refresh = 200,
iter_warmup = 1000,
iter_sampling = 1500
)

fit$summary()


bayesplot::mcmc_pairs(fit$draws())

################################
#
# JMpost
#

devtools::load_all()
# library(jmpost)

jm <- JointModel(
survival = SurvivalWeibullPH(
lambda = prior_lognormal(log(1/200), 1.3),
gamma = prior_lognormal(log(1), 1.3)
)
)

jdat <- DataJoint(
subject = DataSubject(
data = dat,
subject = "pt",
arm = "arm",
study = "study"
),
survival = DataSurvival(
data = dat,
formula = Surv(recyrs, censrec) ~ group
)
)

mp <- sampleStanModel(
jm,
data = jdat,
iter_warmup = 1000,
iter_sampling = 1500,
chains = 2,
parallel_chains = 2
)

vars <- c(
"sm_weibull_ph_lambda",
"sm_weibull_ph_gamma",
"beta_os_cov"
)

mp@results$summary(vars)

0 comments on commit 51dc9ee

Please sign in to comment.