Skip to content

Commit

Permalink
Merge pull request #10 from LucasKook/dev
Browse files Browse the repository at this point in the history
Add survival forests
  • Loading branch information
LucasKook authored Jan 24, 2024
2 parents 5b77ca6 + 0ab101b commit fda95a4
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ inst/results*
*.pdf
*.sh
*.sav

model-classes/
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ export(lmICP)
export(polrICP)
export(pvalues)
export(rangerICP)
export(survforestICP)
export(survregICP)
import(tram)
28 changes: 28 additions & 0 deletions R/alias.R
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,31 @@ rangerICP <- function(formula, data, env, verbose = TRUE, type = "residual",
ret$call <- call
ret
}

#' nonparametric ICP for right-censored observations with ranger GCM
#' @rdname tramicp-alias
#'
#' @inheritParams dicp
#'
#' @export
#'
#' @examples
#' \donttest{
#' set.seed(12)
#' d <- dgp_dicp(mod = "coxph", n = 3e2)
#' d$Y <- survival::Surv(d$Y, sample(0:1, 3e2, TRUE, prob = c(0.1, 0.9)))
#' survforestICP(Y ~ X1 + X2 + X3, data = d, env = ~ E)
#' }
#'
survforestICP <- function(formula, data, env, verbose = TRUE, type = "residual",
test = "gcm.test", controls = NULL, alpha = 0.05,
baseline_fixed = TRUE, greedy = FALSE, max_size = NULL,
mandatory = NULL, ...) {
call <- match.call()
ret <- dicp(formula = formula, data = data, env = env, modFUN = survforest,
verbose = verbose, type = type, test = test, controls = controls,
alpha = alpha, baseline_fixed = baseline_fixed, greedy = greedy,
max_size = max_size, mandatory = mandatory, ... = ...)
ret$call <- call
ret
}
2 changes: 1 addition & 1 deletion R/controls.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ dicp_controls <- function(

.test_fun <- function(type, test, ctest) {
if (is.function(test))
return(list(test = "custom", test_fun = test_fun, test_name = ctest))
return(list(test = "custom", test_fun = test, test_name = ctest))

if (type %in% c("wald", "partial")) {
ctest <- "wald"
Expand Down
24 changes: 24 additions & 0 deletions R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,27 @@ residuals.ranger <- function(object, newdata = NULL, newy = NULL, ...) {
return(character(0))
ret
}

survforest <- function(formula, data, ...) {
tms <- .get_terms(formula)
if (identical(tms$me, character(0))) {
return(survival::coxph(formula, data))
}
rf <- ranger::ranger(formula, data, ...)
class(rf) <- c("survforest", class(rf))
rf$y <- stats::model.response(stats::model.frame(formula, data))
rf$data <- data
rf
}

residuals.survforest <- function(object, ...) {
times <- object$y[, 1]
status <- object$y[, 2]
pred <- stats::predict(object, data = object$data)
idx <- match(times, pred$unique.death.times)
preds <- pred$survival
ipreds <- sapply(seq_len(nrow(preds)), \(smpl) {
-log(preds[smpl, idx[smpl]])
})
status - ipreds
}
4 changes: 2 additions & 2 deletions R/invariance-types.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

### Return
if (set == 1) tset <- "Empty"
structure(list(set = tset, test = tst, coef = stats::coef(m), tram = m$tram),
class = "dICPtest")
structure(list(set = tset, test = tst, coef = stats::coef(m), tram = m$tram,
rYX = r, rEX = e), class = "dICPtest")

}

Expand Down
4 changes: 2 additions & 2 deletions R/tramicp.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ invariant_sets <- function(object, with_pvalues = FALSE) {
modFUN = modFUN, data = data, controls = controls,
mandatory = mandatory, ... = ...
)

if (.get_pvalue(ret$test) > controls$alpha) {
tpv <- .get_pvalue(ret$test)
if (!is.nan(tpv) && tpv > controls$alpha) {
MI <- c(MI, lps[[set]])
}

Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ the `dicp()` function, for instance, after loading `tramME`, `dicp(..., modFUN =
"BoxCoxME")` can be used.

Nonparametric ICP via the GCM test [4] and random forests for the two
regressions is implemented in the alias `rangerICP()`.
regressions is implemented in the alias `rangerICP()`. Survival forests
are supported for right-censored observations and implemented in
`survforestICP()`.

# Replication materials

Expand Down
24 changes: 24 additions & 0 deletions man/tramicp-alias.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit fda95a4

Please sign in to comment.