diff --git a/NAMESPACE b/NAMESPACE index b899324..d0bf7f3 100755 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,6 @@ useDynLib(causalTree, .registration = TRUE, .fixes = "C_") -export(causalTree, honest.causalTree, na.causalTree, estimate.causalTree, causalTree.matrix, causalForest, propensityForest,honest.rparttree) +export(causalTree, honest.causalTree, na.causalTree, estimate.causalTree, causalTree.matrix, causalForest, propensityForest,honest.rparttree, plot.causalTree) importFrom(grDevices, dev.cur, dev.off) importFrom(graphics, plot, text) diff --git a/R/causalForest.R b/R/causalForest.R index a06dedb..ef9866a 100644 --- a/R/causalForest.R +++ b/R/causalForest.R @@ -18,7 +18,7 @@ predict.causalForest <- function(forest, newdata, predict.all = FALSE, type="vec }) #replace sapply with a loop if needed - print(dim(individual)) + #print(dim(individual)) aggregate <- rowMeans(individual) if (predict.all) { list(aggregate = aggregate, individual = individual) diff --git a/R/causalTree.R b/R/causalTree.R index 7896bf4..6a26b66 100755 --- a/R/causalTree.R +++ b/R/causalTree.R @@ -93,8 +93,8 @@ causalTree <- function(formula, data, weights, treatment, subset, split.Rule.int <- pmatch(split.Rule, c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")) - print(split.Rule.int) - print(split.Rule) + #print(split.Rule.int) + #print(split.Rule) if (is.na(split.Rule.int)) stop("Invalid splitting rule.") split.Rule <- c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")[split.Rule.int] diff --git a/R/honest.causalTree.R b/R/honest.causalTree.R index 6c70c0d..0961cea 100755 --- a/R/honest.causalTree.R +++ b/R/honest.causalTree.R @@ -134,8 +134,8 @@ honest.causalTree <- function(formula, data, weights, treatment, subset, split.Rule.int <- pmatch(split.Rule, c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")) if (is.na(split.Rule.int)) stop("Invalid splitting rule.") split.Rule <- c("TOT", "CT", "fit", "tstats", "TOTD", "CTD", "fitD", "tstatsD", "user", "userD","policy","policyD")[split.Rule.int] - print(split.Rule.int) - print(split.Rule) + #print(split.Rule.int) + #print(split.Rule) ## check the Split.Honest, for convenience if (split.Rule.int %in% c(1, 5)) { if (!missing(split.Honest)) { diff --git a/R/plot.causalTree.R b/R/plot.causalTree.R new file mode 100755 index 0000000..e815d8b --- /dev/null +++ b/R/plot.causalTree.R @@ -0,0 +1,38 @@ +# +# Plot best fit causal tree and p-values at leaf nodes. +# +# + +# Helper functions for plot.causalTree + +# Concatenates p-values to leaf node labels +node.pvals <- function(x, labs, digits, varlen) { + ifelse(is.na(x$frame$p.value), labs, + paste(labs, "\np=", round(x$frame$p.value, 4))) +} + +# Uses the model and response (x and y) values from a rpart object to construct +# a data frame from which p-values between control and treatment within a +# leaf node is calculated. TODO: multiple testing corrections? +add.pvals <- function(tree) { + dat <- cbind(treatment=tree$x[,1], outcome=tree$y, node=tree$where) + dat <- aggregate(outcome ~ treatment + node, data=dat, FUN=c) + merged <- merge(dat[dat$treatment == 0,], dat[dat$treatment == 1,], by="node", + suffixes=c(".ctl", ".trt")) + p.values <- do.call( + rbind, apply(merged, 1, FUN=function(x) { + data.frame(node=x$node, + p.value=t.test(x$outcome.ctl, x$outcome.trt)$p.value) })) + tree$frame$p.value[p.values$node] <- p.values$p.value + return(tree) +} + +# Takes the optimally pruned causal tree and adds pvalues. Then plots. +plot.causalTree <- function(tree, ...) { + if (is.null(tree$x) || is.null(tree$y)) + stop("Must build causalTree with x=TRUE, y=TRUE") + opCp <- tree$cptable[,1][which.min(tree$cptable[,4])] + opFit <- prune(tree, opCp) + opFit <- add.pvals(opFit) + rpart.plot(opFit, node.fun=node.pvals, ...) +}