From 8afeaa5c1b40c980a292967e0b9d365def27e1f6 Mon Sep 17 00:00:00 2001 From: roboton Date: Mon, 2 Oct 2017 17:44:14 -0400 Subject: [PATCH 1/4] Comment out undocumented debug print statements. --- R/causalForest.R | 2 +- R/causalTree.R | 4 ++-- R/honest.causalTree.R | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/R/causalForest.R b/R/causalForest.R index a06dedb..ef9866a 100644 --- a/R/causalForest.R +++ b/R/causalForest.R @@ -18,7 +18,7 @@ predict.causalForest <- function(forest, newdata, predict.all = FALSE, type="vec }) #replace sapply with a loop if needed - print(dim(individual)) + #print(dim(individual)) aggregate <- rowMeans(individual) if (predict.all) { list(aggregate = aggregate, individual = individual) diff --git a/R/causalTree.R b/R/causalTree.R index 7896bf4..6a26b66 100755 --- a/R/causalTree.R +++ b/R/causalTree.R @@ -93,8 +93,8 @@ causalTree <- function(formula, data, weights, treatment, subset, split.Rule.int <- pmatch(split.Rule, c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")) - print(split.Rule.int) - print(split.Rule) + #print(split.Rule.int) + #print(split.Rule) if (is.na(split.Rule.int)) stop("Invalid splitting rule.") split.Rule <- c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")[split.Rule.int] diff --git a/R/honest.causalTree.R b/R/honest.causalTree.R index 6c70c0d..0961cea 100755 --- a/R/honest.causalTree.R +++ b/R/honest.causalTree.R @@ -134,8 +134,8 @@ honest.causalTree <- function(formula, data, weights, treatment, subset, split.Rule.int <- pmatch(split.Rule, c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")) if (is.na(split.Rule.int)) stop("Invalid splitting rule.") split.Rule <- c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")[split.Rule.int] - print(split.Rule.int) - print(split.Rule) + #print(split.Rule.int) + #print(split.Rule) ## check the Split.Honest, for convenience if (split.Rule.int %in% c(1, 5)) { if (!missing(split.Honest)) { From f8adbe7f8bbeeb2bdb8b2d4741115e11bea19546 Mon Sep 17 00:00:00 2001 From: roboton Date: Thu, 5 Oct 2017 23:19:22 -0400 Subject: [PATCH 2/4] Add plot.causalTree function to add p-values. --- R/plot.causalTree.R | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100755 R/plot.causalTree.R diff --git a/R/plot.causalTree.R b/R/plot.causalTree.R new file mode 100755 index 0000000..6733145 --- /dev/null +++ b/R/plot.causalTree.R @@ -0,0 +1,38 @@ +# +# Plot best fit causal tree and p-values at leaf nodes. +# +# + +# Helper functions for plot.causalTree + +# Concatenates p-values to leaf node labels +node.pvals <- function(x, labs, digits, varlen) { + ifelse(is.na(x$frame$p.value), labs, + paste(labs, "\np=", round(x$frame$p.value, 4))) +} + +# Uses the model and response (x and y) values from a rpart object to construct +# a data frame from which p-values between control and treatment within a +# leaf node is calculated. TODO: multiple testing corrections? +addPvalues <- function(tree) { + dat <- cbind(treatment=tree$x[,1], outcome=tree$y, node=tree$where) + dat <- aggregate(outcome ~ treatment + node, data=dat, FUN=c) + merged <- merge(dat[dat$treatment == 0,], dat[dat$treatment == 1,], by="node", + suffixes=c(".ctl", ".trt")) + p.values <- do.call( + rbind, apply(merged, 1, FUN=function(x) { + data.frame(node=x$node, + p.value=t.test(x$outcome.ctl, x$outcome.trt)$p.value) })) + tree$frame$p.value[p.values$node] <- p.values$p.value + return(tree) +} + +# Takes the optimally pruned causal tree and adds pvalues. Then plots. +plot.causalTree <- function(tree) { + opCp <- tree$cptable[,1][which.min(tree$cptable[,4])] + opFit <- prune(tree, opCp) + opFit <- addPvalues(opFit) + prp(opFit, type=2, extra=1, under=FALSE, fallen.leaves=TRUE, digits=4, + varlen=0, faclen=0, cex=NULL, tweak=1, snip=FALSE, shadow.col=0, + box.palette="auto", branch.type=0, node.fun=node.pvals) +} From 33e80f7e4d3be95eb2823fad78e4feedc17d43bb Mon Sep 17 00:00:00 2001 From: roboton Date: Thu, 5 Oct 2017 23:25:19 -0400 Subject: [PATCH 3/4] Add function to namespace. Change helper function name. --- NAMESPACE | 2 +- R/plot.causalTree.R | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index b899324..d0bf7f3 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,6 @@ useDynLib(causalTree, .registration = TRUE, .fixes = "C_") -export(causalTree, honest.causalTree, na.causalTree, estimate.causalTree, causalTree.matrix, causalForest, propensityForest,honest.rparttree) +export(causalTree, honest.causalTree, na.causalTree, estimate.causalTree, causalTree.matrix, causalForest, propensityForest,honest.rparttree, plot.causalTree) importFrom(grDevices, dev.cur, dev.off) importFrom(graphics, plot, text) diff --git a/R/plot.causalTree.R b/R/plot.causalTree.R index 6733145..6d12a4d 100755 --- a/R/plot.causalTree.R +++ b/R/plot.causalTree.R @@ -14,7 +14,7 @@ node.pvals <- function(x, labs, digits, varlen) { # Uses the model and response (x and y) values from a rpart object to construct # a data frame from which p-values between control and treatment within a # leaf node is calculated. TODO: multiple testing corrections? -addPvalues <- function(tree) { +add.pvals <- function(tree) { dat <- cbind(treatment=tree$x[,1], outcome=tree$y, node=tree$where) dat <- aggregate(outcome ~ treatment + node, data=dat, FUN=c) merged <- merge(dat[dat$treatment == 0,], dat[dat$treatment == 1,], by="node", @@ -29,9 +29,11 @@ addPvalues <- function(tree) { # Takes the optimally pruned causal tree and adds pvalues. Then plots. plot.causalTree <- function(tree) { + if (is.null(tree$x) || is.null(tree$y)) + stop("Must build causalTree with x=TRUE, y=TRUE") opCp <- tree$cptable[,1][which.min(tree$cptable[,4])] opFit <- prune(tree, opCp) - opFit <- addPvalues(opFit) + opFit <- add.pvals(opFit) prp(opFit, type=2, extra=1, under=FALSE, fallen.leaves=TRUE, digits=4, varlen=0, faclen=0, cex=NULL, tweak=1, snip=FALSE, shadow.col=0, box.palette="auto", branch.type=0, node.fun=node.pvals) From 45c8d15fa47394d8bc318aae9b1e28edf3f75084 Mon Sep 17 00:00:00 2001 From: roboton Date: Thu, 5 Oct 2017 23:48:20 -0400 Subject: [PATCH 4/4] Allow passing parameters to rpart.plot (prp) --- R/plot.causalTree.R | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/R/plot.causalTree.R b/R/plot.causalTree.R index 6d12a4d..e815d8b 100755 --- a/R/plot.causalTree.R +++ b/R/plot.causalTree.R @@ -28,13 +28,11 @@ add.pvals <- function(tree) { } # Takes the optimally pruned causal tree and adds pvalues. Then plots. -plot.causalTree <- function(tree) { +plot.causalTree <- function(tree, ...) { if (is.null(tree$x) || is.null(tree$y)) stop("Must build causalTree with x=TRUE, y=TRUE") opCp <- tree$cptable[,1][which.min(tree$cptable[,4])] opFit <- prune(tree, opCp) opFit <- add.pvals(opFit) - prp(opFit, type=2, extra=1, under=FALSE, fallen.leaves=TRUE, digits=4, - varlen=0, faclen=0, cex=NULL, tweak=1, snip=FALSE, shadow.col=0, - box.palette="auto", branch.type=0, node.fun=node.pvals) + rpart.plot(opFit, node.fun=node.pvals, ...) }