Skip to content

Commit

Permalink
sim with RF GCM
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasKook committed Aug 26, 2023
1 parent 73fb725 commit 7fe47fa
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
2 changes: 1 addition & 1 deletion inst/code/run-simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ dags <- if (fixed) {
} else NULL

if (TEST) {
ns <- ns[1]
ns <- ns[2]
mods <- mods[1]
lmods <- lmods[1]
}
Expand Down
5 changes: 4 additions & 1 deletion inst/code/vis-simulation.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

settings <- c("main", "app", "hidden", "link", "wald-extended")

setting <- settings[as.numeric(commandArgs(TRUE))]
setting <- settings[as.numeric(commandArgs(TRUE))[1]]
if (is.na(setting))
setting <- settings[1]

Expand Down Expand Up @@ -162,6 +162,9 @@ prdat <- res %>%
fwer = unlist(map2(splpaY, splset, ~ as.numeric(length(setdiff(.y[.y != "Empty"], .x)) > 0))),
)

prdat %>% group_by(n, mod, type, test) %>%
summarize(jaccard = mean(jaccard), fwer = mean(fwer > 0))

if (setting == "hidden") {
anY <- paste0("X1+X2+X3", unlist(lapply(fobs$dags, \(x) {
ret <- names(which(x$dag["Y", c("X4", "X5")] == 0))
Expand Down
26 changes: 26 additions & 0 deletions inst/helpers/R/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,29 @@ vis <- function(mods = tmods, tests = ttests) {
scale_color_manual(values = c(cols, "Oracle" = "gray60")) +
scale_fill_manual(values = c(cols, "Oracle" = "gray60"))
}

RANGER <- function(formula, data, ...) {
response <- model.response(model.frame(formula, data))
binary <- is.factor(response)
tms <- .get_terms(formula)
if (identical(tms$me, character(0))) {
if (binary) return(glm(formula, data, family = "binomial")) else
return(lm(formula, data))
}
ret <- ranger(formula, data, probability = binary, ...)
ret$data <- data
ret$response <- if(!binary) response else as.numeric(response) - 1
ret$binary <- binary
ret
}

residuals.ranger <- function(object) {
if ("glm" %in% class(object))
return(residuals.binglm(object))
else if ("lm" %in% class(object))
return(residuals(object))
preds <- predict(object, data = object$data)$predictions
if (object$binary)
preds <- preds[, 2]
object$response - preds
}
14 changes: 10 additions & 4 deletions inst/helpers/R/simdesign.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,16 @@ ANA <- function(condition, dat, fixed_objects = NULL) {
kbw <- if (condition$mod == "polr") 0.01 else 0
oicp <- attr(dat, "oracle_icp")
pvals <- if (ttype == "kci") {
tmp <- cdkci(fixed_objects$resp, fixed_objects$env, fixed_objects$preds,
data = dat, coin = fixed_objects$coin)
inv <- attr(tmp, "intersection")
tmp
tmp <- dicp(as.formula(fixed_objects$fml), data = dat,
env = reformulate(fixed_objects$env),
modFUN = RANGER, type = "residual", test = "gcm.test",
controls = dicp_controls(residuals = residuals.ranger))
inv <- tmp$candidate
pvalues(tmp, which = "set")
# tmp <- cdkci(fixed_objects$resp, fixed_objects$env, fixed_objects$preds,
# data = dat, coin = fixed_objects$coin)
# inv <- attr(tmp, "intersection")
# tmp
} else {
if (condition$mod == "polr") {
ctrls <- dicp_controls(type = ttype, test = ttest,
Expand Down

0 comments on commit 7fe47fa

Please sign in to comment.