From 48ddc1a6006c3f46887b4d90b04b156db2fd60eb Mon Sep 17 00:00:00 2001 From: gowerc Date: Thu, 20 Feb 2025 16:19:44 +0000 Subject: [PATCH] support logical events --- R/SurvivalQuantities.R | 2 +- tests/testthat/test-brierScore.R | 67 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/R/SurvivalQuantities.R b/R/SurvivalQuantities.R index 9f10be723..4a4009252 100644 --- a/R/SurvivalQuantities.R +++ b/R/SurvivalQuantities.R @@ -371,7 +371,7 @@ brierScore.SurvivalQuantities <- function( event_col <- extractVariableNames(object@data@survival)$event groups <- as.character(object@data@survival@data[[subject_col]]) orig_times <- object@data@survival@data[[time_col]] - events <- object@data@survival@data[[event_col]] + events <- as.numeric(object@data@survival@data[[event_col]]) pred_mat <- matrix( ncol = length(times), diff --git a/tests/testthat/test-brierScore.R b/tests/testthat/test-brierScore.R index 2b28a39f6..2f28763f9 100644 --- a/tests/testthat/test-brierScore.R +++ b/tests/testthat/test-brierScore.R @@ -238,3 +238,70 @@ test_that("reverse_km_event_first() and reverse_km_cen_first() work as expected" extract_prodlim(mod, new_times) ) }) + + + +test_that("brierScore() works on logical events #438", { + set.seed(739) + simjdat <- SimJointData( + design = list( + SimGroup(75, "Arm-A", "Study-X"), + SimGroup(75, "Arm-B", "Study-X") + ), + survival = SimSurvivalExponential( + lambda = 1 / 100, + time_max = 2000 + ), + longitudinal = SimLongitudinalRandomSlope( + times = c(0, 1, 100, 200, 250, 300, 350), + intercept = 30, + sigma = 3, + slope_mu = c(1, 3), + slope_sigma = 0.2, + link_dsld = 0 + ), + .silent = TRUE + ) + dat_os <- simjdat@survival + dat_lm <- simjdat@longitudinal + + jm <- JointModel( + survival = SurvivalExponential( + lambda = prior_lognormal(log(1 / 100), 1 / 100) + ) + ) + + jdat <- DataJoint( + subject = DataSubject( + data = dat_os, + subject = "subject", + arm = "arm", + study = "study" + ), + survival = DataSurvival( + data = dat_os, + formula = Surv(time, event) ~ cov_cat + cov_cont + ) + ) + + mp <- sampleStanModel( + jm, + data = jdat, + iter_sampling = 100, + iter_warmup = 150, + chains = 2, + refresh = 0, + parallel_chains = 1 + ) + + t_grid <- c(1, 30, 45, 60, 425, 750) + sq <- SurvivalQuantities( + mp, + grid = GridFixed(times = t_grid), + type = "surv" + ) + expected <- brierScore(sq) + sq@data@survival@data$event <- as.logical(sq@data@survival@data$event) + actual <- brierScore(sq) + expect_equal(actual, expected) +})