From 7fe47fa6ad37486c4d70c94872abd0f94a833306 Mon Sep 17 00:00:00 2001 From: lkook Date: Sat, 26 Aug 2023 12:28:15 +0200 Subject: [PATCH] sim with RF GCM --- inst/code/run-simulation.R | 2 +- inst/code/vis-simulation.R | 5 ++++- inst/helpers/R/helpers.R | 26 ++++++++++++++++++++++++++ inst/helpers/R/simdesign.R | 14 ++++++++++---- 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/inst/code/run-simulation.R b/inst/code/run-simulation.R index a5eb598..d76b538 100644 --- a/inst/code/run-simulation.R +++ b/inst/code/run-simulation.R @@ -99,7 +99,7 @@ dags <- if (fixed) { } else NULL if (TEST) { - ns <- ns[1] + ns <- ns[2] mods <- mods[1] lmods <- lmods[1] } diff --git a/inst/code/vis-simulation.R b/inst/code/vis-simulation.R index 0fe1e67..202948c 100644 --- a/inst/code/vis-simulation.R +++ b/inst/code/vis-simulation.R @@ -4,7 +4,7 @@ settings <- c("main", "app", "hidden", "link", "wald-extended") -setting <- settings[as.numeric(commandArgs(TRUE))] +setting <- settings[as.numeric(commandArgs(TRUE))[1]] if (is.na(setting)) setting <- settings[1] @@ -162,6 +162,9 @@ prdat <- res %>% fwer = unlist(map2(splpaY, splset, ~ as.numeric(length(setdiff(.y[.y != "Empty"], .x)) > 0))), ) +prdat %>% group_by(n, mod, type, test) %>% + summarize(jaccard = mean(jaccard), fwer = mean(fwer > 0)) + if (setting == "hidden") { anY <- paste0("X1+X2+X3", unlist(lapply(fobs$dags, \(x) { ret <- names(which(x$dag["Y", c("X4", "X5")] == 0)) diff --git a/inst/helpers/R/helpers.R b/inst/helpers/R/helpers.R index 937a654..e8e0c4c 100644 --- a/inst/helpers/R/helpers.R +++ b/inst/helpers/R/helpers.R @@ -135,3 +135,29 @@ vis <- function(mods = tmods, tests = ttests) { scale_color_manual(values = c(cols, "Oracle" = "gray60")) + scale_fill_manual(values = c(cols, "Oracle" = "gray60")) } + +RANGER <- function(formula, data, ...) { + response <- model.response(model.frame(formula, data)) + binary <- is.factor(response) + tms <- .get_terms(formula) + if (identical(tms$me, character(0))) { + if (binary) return(glm(formula, data, family = "binomial")) else + return(lm(formula, data)) + } + ret <- ranger(formula, data, probability = binary, ...) + ret$data <- data + ret$response <- if(!binary) response else as.numeric(response) - 1 + ret$binary <- binary + ret +} + +residuals.ranger <- function(object) { + if ("glm" %in% class(object)) + return(residuals.binglm(object)) + else if ("lm" %in% class(object)) + return(residuals(object)) + preds <- predict(object, data = object$data)$predictions + if (object$binary) + preds <- preds[, 2] + object$response - preds +} diff --git a/inst/helpers/R/simdesign.R b/inst/helpers/R/simdesign.R index c3f6b58..e6d5d5f 100644 --- a/inst/helpers/R/simdesign.R +++ b/inst/helpers/R/simdesign.R @@ -41,10 +41,16 @@ ANA <- function(condition, dat, fixed_objects = NULL) { kbw <- if (condition$mod == "polr") 0.01 else 0 oicp <- attr(dat, "oracle_icp") pvals <- if (ttype == "kci") { - tmp <- cdkci(fixed_objects$resp, fixed_objects$env, fixed_objects$preds, - data = dat, coin = fixed_objects$coin) - inv <- attr(tmp, "intersection") - tmp + tmp <- dicp(as.formula(fixed_objects$fml), data = dat, + env = reformulate(fixed_objects$env), + modFUN = RANGER, type = "residual", test = "gcm.test", + controls = dicp_controls(residuals = residuals.ranger)) + inv <- tmp$candidate + pvalues(tmp, which = "set") + # tmp <- cdkci(fixed_objects$resp, fixed_objects$env, fixed_objects$preds, + # data = dat, coin = fixed_objects$coin) + # inv <- attr(tmp, "intersection") + # tmp } else { if (condition$mod == "polr") { ctrls <- dicp_controls(type = ttype, test = ttest,