From 22a13fe88471d87f501e155d6b346b43b0d9fac8 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 3 Dec 2015 13:32:41 +0100 Subject: [PATCH 1/3] add tree-wise split select weights --- NEWS.md | 3 + ranger-r-package/ranger/DESCRIPTION | 13 +-- ranger-r-package/ranger/NEWS | 3 + ranger-r-package/ranger/R/predict.R | 2 +- ranger-r-package/ranger/R/ranger.R | 17 +++- ranger-r-package/ranger/src/RcppExports.cpp | 4 +- ranger-r-package/ranger/src/rangerCpp.cpp | 2 +- source/src/Forest/Forest.cpp | 89 ++++++++++++++++----- source/src/Forest/Forest.h | 6 +- source/src/version.h | 2 +- 10 files changed, 102 insertions(+), 39 deletions(-) diff --git a/NEWS.md b/NEWS.md index bf5a36e98..d946741e5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,6 @@ +##### Version 0.3.5 +* Add tree-wise split.select.weights + ##### Version 0.3.4 * Add predict.all option in predict() to get individual predictions for each tree for classification and regression * Small changes in documentation diff --git a/ranger-r-package/ranger/DESCRIPTION b/ranger-r-package/ranger/DESCRIPTION index 3e024ff7b..1e13c12da 100644 --- a/ranger-r-package/ranger/DESCRIPTION +++ b/ranger-r-package/ranger/DESCRIPTION @@ -1,14 +1,15 @@ Package: ranger Type: Package Title: A Fast Implementation of Random Forests -Version: 0.3.4 -Date: 2015-12-01 +Version: 0.3.5 +Date: 2015-12-03 Author: Marvin N. Wright Maintainer: Marvin N. Wright -Description: A fast implementation of Random Forests, particularly suited for high dimensional data. Ensembles of - classification, regression, survival and probability prediction trees are supported. Data from genome-wide - association studies can be analyzed efficiently. In addition to data frames, datasets of class 'gwaa.data' - (R package GenABEL) can be directly analyzed. +Description: A fast implementation of Random Forests, particularly suited for high dimensional data. + Ensembles of classification, regression, survival and probability prediction trees are + supported. Data from genome-wide association studies can be analyzed efficiently. In + addition to data frames, datasets of class 'gwaa.data' (R package GenABEL) can be directly + analyzed. License: GPL-3 Imports: Rcpp (>= 0.11.2) LinkingTo: Rcpp diff --git a/ranger-r-package/ranger/NEWS b/ranger-r-package/ranger/NEWS index bf5a36e98..d946741e5 100644 --- a/ranger-r-package/ranger/NEWS +++ b/ranger-r-package/ranger/NEWS @@ -1,3 +1,6 @@ +##### Version 0.3.5 +* Add tree-wise split.select.weights + ##### Version 0.3.4 * Add predict.all option in predict() to get individual predictions for each tree for classification and regression * Small changes in documentation diff --git a/ranger-r-package/ranger/R/predict.R b/ranger-r-package/ranger/R/predict.R index 89bd6d5fa..8a8ef3824 100644 --- a/ranger-r-package/ranger/R/predict.R +++ b/ranger-r-package/ranger/R/predict.R @@ -167,7 +167,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE, mtry <- 0 importance <- 0 min.node.size <- 0 - split.select.weights <- c(0, 0) + split.select.weights <- list(c(0, 0)) use.split.select.weights <- FALSE always.split.variables <- c("0", "0") use.always.split.variables <- FALSE diff --git a/ranger-r-package/ranger/R/ranger.R b/ranger-r-package/ranger/R/ranger.R index aa0dc623a..8d5e0c658 100644 --- a/ranger-r-package/ranger/R/ranger.R +++ b/ranger-r-package/ranger/R/ranger.R @@ -330,10 +330,21 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, ## Split select weights: NULL for no weights if (is.null(split.select.weights)) { - split.select.weights <- c(0,0) + split.select.weights <- list(c(0,0)) use.split.select.weights <- FALSE - } else { + } else if (is.numeric(split.select.weights)) { + if (length(split.select.weights) != length(all.independent.variable.names)) { + stop("Error: Number of split select weights not equal to number of independent variables.") + } + split.select.weights <- list(split.select.weights) use.split.select.weights <- TRUE + } else if (is.list(split.select.weights)) { + if (length(split.select.weights) != num.trees) { + stop("Error: Size of split select weights list not equal to number of trees.") + } + use.split.select.weights <- TRUE + } else { + stop("Error: Invalid split select weights.") } ## Always split variables: NULL for no variables @@ -345,7 +356,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, } if (use.split.select.weights & use.always.split.variables) { - stop("Error: Please use only one option of use.split.select.weights and use.always.split.variables.") + stop("Error: Please use only one option of split.select.weights and always.split.variables.") } ## Splitting rule diff --git a/ranger-r-package/ranger/src/RcppExports.cpp b/ranger-r-package/ranger/src/RcppExports.cpp index 7bbb8a3a8..24181f437 100644 --- a/ranger-r-package/ranger/src/RcppExports.cpp +++ b/ranger-r-package/ranger/src/RcppExports.cpp @@ -7,7 +7,7 @@ using namespace Rcpp; // rangerCpp -Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name, Rcpp::NumericMatrix input_data, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, std::vector& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, bool predict_all); +Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name, Rcpp::NumericMatrix input_data, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector& case_weights, bool use_case_weights, bool predict_all); RcppExport SEXP ranger_rangerCpp(SEXP treetypeSEXP, SEXP dependent_variable_nameSEXP, SEXP input_dataSEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP status_variable_nameSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP sparse_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP predict_allSEXP) { BEGIN_RCPP Rcpp::RObject __result; @@ -24,7 +24,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP); Rcpp::traits::input_parameter< uint >::type importance_mode_r(importance_mode_rSEXP); Rcpp::traits::input_parameter< uint >::type min_node_size(min_node_sizeSEXP); - Rcpp::traits::input_parameter< std::vector& >::type split_select_weights(split_select_weightsSEXP); + Rcpp::traits::input_parameter< std::vector>& >::type split_select_weights(split_select_weightsSEXP); Rcpp::traits::input_parameter< bool >::type use_split_select_weights(use_split_select_weightsSEXP); Rcpp::traits::input_parameter< std::vector& >::type always_split_variable_names(always_split_variable_namesSEXP); Rcpp::traits::input_parameter< bool >::type use_always_split_variable_names(use_always_split_variable_namesSEXP); diff --git a/ranger-r-package/ranger/src/rangerCpp.cpp b/ranger-r-package/ranger/src/rangerCpp.cpp index 4967d3c36..17a010532 100644 --- a/ranger-r-package/ranger/src/rangerCpp.cpp +++ b/ranger-r-package/ranger/src/rangerCpp.cpp @@ -45,7 +45,7 @@ Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name, Rcpp::NumericMatrix input_data, std::vector variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, - std::vector& split_select_weights, bool use_split_select_weights, + std::vector>& split_select_weights, bool use_split_select_weights, std::vector& always_split_variable_names, bool use_always_split_variable_names, std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data, bool sample_with_replacement, bool probability, std::vector& unordered_variable_names, diff --git a/source/src/Forest/Forest.cpp b/source/src/Forest/Forest.cpp index 7c2bdcb04..e061d5155 100644 --- a/source/src/Forest/Forest.cpp +++ b/source/src/Forest/Forest.cpp @@ -106,11 +106,13 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode setAlwaysSplitVariables(always_split_variable_names); } + // TODO: Read 2d weights for tree-wise split select weights // Load split select weights from file if (!split_select_weights_file.empty()) { - std::vector split_select_weights; - loadDoubleVectorFromFile(split_select_weights, split_select_weights_file); - if (split_select_weights.size() != num_variables - 1) { + std::vector> split_select_weights; + split_select_weights.resize(1); + loadDoubleVectorFromFile(split_select_weights[0], split_select_weights_file); + if (split_select_weights[0].size() != num_variables - 1) { throw std::runtime_error("Number of split select weights is not equal to number of independent variables."); } setSplitWeightVector(split_select_weights); @@ -135,7 +137,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode void Forest::initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, - std::vector& split_select_weights, std::vector& always_split_variable_names, + std::vector>& split_select_weights, std::vector& always_split_variable_names, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, std::vector& case_weights, bool predict_all) { @@ -235,6 +237,9 @@ void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, D // Sort no split variables in ascending order std::sort(no_split_variables.begin(), no_split_variables.end()); + // Init split select weights + split_select_weights.push_back(std::vector()); + // Check if mtry is in valid range if (this->mtry > num_variables - 1) { throw std::runtime_error("mtry can not be larger than number of variables in data."); @@ -292,7 +297,7 @@ void Forest::writeOutput() { *verbose_out << "Overall OOB prediction error: " << overall_prediction_error << std::endl; *verbose_out << std::endl; - if (!split_select_weights.empty()) { + if (!split_select_weights.empty() & !split_select_weights[0].empty()) { *verbose_out << "Warning: Split select weights used. Variable importance measures are only comparable for variables with equal weights." << std::endl; @@ -380,8 +385,17 @@ void Forest::grow() { } else { tree_seed = (i + 1) * seed; } + + // Get split select weights for tree + std::vector* tree_split_select_weights; + if (split_select_weights.size() > 1) { + tree_split_select_weights = &split_select_weights[i]; + } else { + tree_split_select_weights = &split_select_weights[0]; + } + trees[i]->init(data, mtry, dependent_varID, num_samples, tree_seed, &deterministic_varIDs, &split_select_varIDs, - &split_select_weights, importance_mode, min_node_size, &no_split_variables, sample_with_replacement, + tree_split_select_weights, importance_mode, min_node_size, &no_split_variables, sample_with_replacement, &is_ordered_variable, memory_saving_splitting, splitrule, &case_weights); } @@ -711,29 +725,58 @@ void Forest::loadFromFile(std::string filename) { equalSplit(thread_ranges, 0, num_trees - 1, num_threads); } -void Forest::setSplitWeightVector(std::vector& split_select_weights) { +void Forest::setSplitWeightVector(std::vector>& split_select_weights) { + + // Size should be 1 x num_independent_variables or num_trees x num_independent_variables + if (split_select_weights.size() != 1 && split_select_weights.size() != num_trees) { + throw std::runtime_error("Size of split select weights not equal to 1 or number of trees."); + } - if (split_select_weights.size() != num_independent_variables) { - throw std::runtime_error("Number of split select weights not equal to number of independent variables."); + // Reserve space + if (split_select_weights.size() == 1) { + this->split_select_weights[0].resize(num_independent_variables); + } else { + this->split_select_weights.clear(); + this->split_select_weights.resize(num_trees, std::vector(num_independent_variables)); } + this->split_select_varIDs.resize(num_independent_variables); + deterministic_varIDs.reserve(num_independent_variables); // Split up in deterministic and weighted variables, ignore zero weights for (size_t i = 0; i < split_select_weights.size(); ++i) { - double weight = split_select_weights[i]; - size_t varID = i; - for (auto& skip : no_split_variables) { - if (varID >= skip) { - ++varID; - } + + // Size should be 1 x num_independent_variables or num_trees x num_independent_variables + if (split_select_weights[i].size() != num_independent_variables) { + throw std::runtime_error("Number of split select weights not equal to number of independent variables."); } - if (weight == 1) { - deterministic_varIDs.push_back(varID); - } else if (weight < 1 && weight > 0) { - this->split_select_varIDs.push_back(varID); - this->split_select_weights.push_back(weight); - } else if (weight < 0 || weight > 1) { - throw std::runtime_error("One or more split select weights not in range [0,1]."); + for (size_t j = 0; j < split_select_weights[i].size(); ++j) { + double weight = split_select_weights[i][j]; + + if (i == 0) { + size_t varID = j; + for (auto& skip : no_split_variables) { + if (varID >= skip) { + ++varID; + } + } + + if (weight == 1) { + deterministic_varIDs.push_back(varID); + } else if (weight < 1 && weight > 0) { + this->split_select_varIDs[j] = varID; + this->split_select_weights[i][j] = weight; + } else if (weight < 0 || weight > 1) { + throw std::runtime_error("One or more split select weights not in range [0,1]."); + } + + } else { + if (weight < 1 && weight > 0) { + this->split_select_weights[i][j] = weight; + } else if (weight < 0 || weight > 1) { + throw std::runtime_error("One or more split select weights not in range [0,1]."); + } + } } } @@ -747,6 +790,8 @@ void Forest::setSplitWeightVector(std::vector& split_select_weights) { void Forest::setAlwaysSplitVariables(std::vector& always_split_variable_names) { + deterministic_varIDs.reserve(num_independent_variables); + for (auto& variable_name : always_split_variable_names) { size_t varID = data->getVariableID(variable_name); deterministic_varIDs.push_back(varID); diff --git a/source/src/Forest/Forest.h b/source/src/Forest/Forest.h index a93d51af0..51396a140 100644 --- a/source/src/Forest/Forest.h +++ b/source/src/Forest/Forest.h @@ -59,7 +59,7 @@ class Forest { std::string case_weights_file, bool predict_all); void initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees, std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size, - std::vector& split_select_weights, std::vector& always_split_variable_names, + std::vector>& split_select_weights, std::vector& always_split_variable_names, std::string status_variable_name, bool prediction_mode, bool sample_with_replacement, std::vector& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule, std::vector& case_weights, bool predict_all); @@ -160,7 +160,7 @@ class Forest { virtual void loadFromFileInternal(std::ifstream& infile) = 0; // Set split select weights and variables to be always considered for splitting - void setSplitWeightVector(std::vector& split_select_weights); + void setSplitWeightVector(std::vector>& split_select_weights); void setAlwaysSplitVariables(std::vector& always_split_variable_names); // Show progress every few seconds @@ -212,7 +212,7 @@ class Forest { // Deterministic variables are always selected std::vector deterministic_varIDs; std::vector split_select_varIDs; - std::vector split_select_weights; + std::vector> split_select_weights; // Bootstrap weights std::vector case_weights; diff --git a/source/src/version.h b/source/src/version.h index 4e3c9292b..6f277a1b3 100644 --- a/source/src/version.h +++ b/source/src/version.h @@ -1,3 +1,3 @@ #ifndef RANGER_VERSION -#define RANGER_VERSION "0.3.4" +#define RANGER_VERSION "0.3.5" #endif From e6f98280f056afcce9b39189e50755b45a5b9138 Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 3 Dec 2015 13:33:52 +0100 Subject: [PATCH 2/3] add tests for tree-wise split select weights --- .../ranger/tests/testthat/test_ranger.R | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/ranger-r-package/ranger/tests/testthat/test_ranger.R b/ranger-r-package/ranger/tests/testthat/test_ranger.R index 0dd4ca47c..0422b7ec0 100644 --- a/ranger-r-package/ranger/tests/testthat/test_ranger.R +++ b/ranger-r-package/ranger/tests/testthat/test_ranger.R @@ -322,3 +322,20 @@ test_that("Mean of predict.all for regression is equal to forest prediction", { expect_that(rowMeans(pred_trees$predictions), equals(pred_forest$predictions)) }) +test_that("split select weights work", { + expect_that(ranger(Species ~ ., iris, num.trees = 5, split.select.weights = c(0.1, 0.2, 0.3, 0.4)), + not(throws_error())) + expect_that(ranger(Species ~ ., iris, num.trees = 5, split.select.weights = c(0.1, 0.2, 0.3)), + throws_error()) +}) + +test_that("Tree-wise split select weights work", { + num.trees <- 5 + weights <- replicate(num.trees, runif(ncol(iris)-1), simplify = FALSE) + expect_that(ranger(Species ~ ., iris, num.trees = num.trees, split.select.weights = weights), + not(throws_error())) + + weights <- replicate(num.trees+1, runif(ncol(iris)-1), simplify = FALSE) + expect_that(ranger(Species ~ ., iris, num.trees = num.trees, split.select.weights = weights), + throws_error()) +}) From 884e2805567a4e3f144f5de5e4d63784dc7e5bed Mon Sep 17 00:00:00 2001 From: Marvin Wright Date: Thu, 3 Dec 2015 13:36:49 +0100 Subject: [PATCH 3/3] add tree-wise split select weights to doc --- ranger-r-package/ranger/R/ranger.R | 2 +- ranger-r-package/ranger/man/ranger.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ranger-r-package/ranger/R/ranger.R b/ranger-r-package/ranger/R/ranger.R index 8d5e0c658..98dd0676c 100644 --- a/ranger-r-package/ranger/R/ranger.R +++ b/ranger-r-package/ranger/R/ranger.R @@ -77,7 +77,7 @@ ##' @param replace Sample with replacement. ##' @param splitrule Splitting rule, survival only. The splitting rule can be chosen of "logrank" and "C" with default "logrank". ##' @param case.weights Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees. -##' @param split.select.weights Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting. +##' @param split.select.weights Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used. ##' @param always.split.variables Character vector with variable names to be always tried for splitting. ##' @param respect.unordered.factors Regard unordered factor covariates as unordered categorical variables. If \code{FALSE}, all factors are regarded ordered. ##' @param scale.permutation.importance Scale permutation importance by standard error as in (Breiman 2001). Only applicable if permutation variable importance mode selected. diff --git a/ranger-r-package/ranger/man/ranger.Rd b/ranger-r-package/ranger/man/ranger.Rd index 2c3e774a7..29a21bc85 100644 --- a/ranger-r-package/ranger/man/ranger.Rd +++ b/ranger-r-package/ranger/man/ranger.Rd @@ -37,7 +37,7 @@ ranger(formula = NULL, data = NULL, num.trees = 500, mtry = NULL, \item{case.weights}{Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees.} -\item{split.select.weights}{Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting.} +\item{split.select.weights}{Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used.} \item{always.split.variables}{Character vector with variable names to be always tried for splitting.}