Skip to content

Commit

Permalink
Merge pull request #26 from mnwright/tree_splitweights
Browse files Browse the repository at this point in the history
Tree split weights (issue #15)
  • Loading branch information
mnwright committed Dec 3, 2015
2 parents 9c93cf8 + 884e280 commit e811c27
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 41 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
##### Version 0.3.5
* Add tree-wise split.select.weights

##### Version 0.3.4
* Add predict.all option in predict() to get individual predictions for each tree for classification and regression
* Small changes in documentation
Expand Down
13 changes: 7 additions & 6 deletions ranger-r-package/ranger/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.3.4
Date: 2015-12-01
Version: 0.3.5
Date: 2015-12-03
Author: Marvin N. Wright
Maintainer: Marvin N. Wright <[email protected]>
Description: A fast implementation of Random Forests, particularly suited for high dimensional data. Ensembles of
classification, regression, survival and probability prediction trees are supported. Data from genome-wide
association studies can be analyzed efficiently. In addition to data frames, datasets of class 'gwaa.data'
(R package GenABEL) can be directly analyzed.
Description: A fast implementation of Random Forests, particularly suited for high dimensional data.
Ensembles of classification, regression, survival and probability prediction trees are
supported. Data from genome-wide association studies can be analyzed efficiently. In
addition to data frames, datasets of class 'gwaa.data' (R package GenABEL) can be directly
analyzed.
License: GPL-3
Imports: Rcpp (>= 0.11.2)
LinkingTo: Rcpp
Expand Down
3 changes: 3 additions & 0 deletions ranger-r-package/ranger/NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
##### Version 0.3.5
* Add tree-wise split.select.weights

##### Version 0.3.4
* Add predict.all option in predict() to get individual predictions for each tree for classification and regression
* Small changes in documentation
Expand Down
2 changes: 1 addition & 1 deletion ranger-r-package/ranger/R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
mtry <- 0
importance <- 0
min.node.size <- 0
split.select.weights <- c(0, 0)
split.select.weights <- list(c(0, 0))
use.split.select.weights <- FALSE
always.split.variables <- c("0", "0")
use.always.split.variables <- FALSE
Expand Down
19 changes: 15 additions & 4 deletions ranger-r-package/ranger/R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
##' @param replace Sample with replacement.
##' @param splitrule Splitting rule, survival only. The splitting rule can be chosen of "logrank" and "C" with default "logrank".
##' @param case.weights Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees.
##' @param split.select.weights Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting.
##' @param split.select.weights Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used.
##' @param always.split.variables Character vector with variable names to be always tried for splitting.
##' @param respect.unordered.factors Regard unordered factor covariates as unordered categorical variables. If \code{FALSE}, all factors are regarded ordered.
##' @param scale.permutation.importance Scale permutation importance by standard error as in (Breiman 2001). Only applicable if permutation variable importance mode selected.
Expand Down Expand Up @@ -330,10 +330,21 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,

## Split select weights: NULL for no weights
if (is.null(split.select.weights)) {
split.select.weights <- c(0,0)
split.select.weights <- list(c(0,0))
use.split.select.weights <- FALSE
} else {
} else if (is.numeric(split.select.weights)) {
if (length(split.select.weights) != length(all.independent.variable.names)) {
stop("Error: Number of split select weights not equal to number of independent variables.")
}
split.select.weights <- list(split.select.weights)
use.split.select.weights <- TRUE
} else if (is.list(split.select.weights)) {
if (length(split.select.weights) != num.trees) {
stop("Error: Size of split select weights list not equal to number of trees.")
}
use.split.select.weights <- TRUE
} else {
stop("Error: Invalid split select weights.")
}

## Always split variables: NULL for no variables
Expand All @@ -345,7 +356,7 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}

if (use.split.select.weights & use.always.split.variables) {
stop("Error: Please use only one option of use.split.select.weights and use.always.split.variables.")
stop("Error: Please use only one option of split.select.weights and always.split.variables.")
}

## Splitting rule
Expand Down
2 changes: 1 addition & 1 deletion ranger-r-package/ranger/man/ranger.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ ranger(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,

\item{case.weights}{Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees.}

\item{split.select.weights}{Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting.}
\item{split.select.weights}{Numeric vector with weights between 0 and 1, representing the probability to select variables for splitting. Alternatively, a list of size num.trees, containing split select weight vectors for each tree can be used.}

\item{always.split.variables}{Character vector with variable names to be always tried for splitting.}

Expand Down
4 changes: 2 additions & 2 deletions ranger-r-package/ranger/src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace Rcpp;

// rangerCpp
Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name, Rcpp::NumericMatrix input_data, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, std::vector<double>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, bool predict_all);
Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name, Rcpp::NumericMatrix input_data, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose, uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size, std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights, std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names, std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data, bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names, bool use_unordered_variable_names, bool save_memory, uint splitrule_r, std::vector<double>& case_weights, bool use_case_weights, bool predict_all);
RcppExport SEXP ranger_rangerCpp(SEXP treetypeSEXP, SEXP dependent_variable_nameSEXP, SEXP input_dataSEXP, SEXP variable_namesSEXP, SEXP mtrySEXP, SEXP num_treesSEXP, SEXP verboseSEXP, SEXP seedSEXP, SEXP num_threadsSEXP, SEXP write_forestSEXP, SEXP importance_mode_rSEXP, SEXP min_node_sizeSEXP, SEXP split_select_weightsSEXP, SEXP use_split_select_weightsSEXP, SEXP always_split_variable_namesSEXP, SEXP use_always_split_variable_namesSEXP, SEXP status_variable_nameSEXP, SEXP prediction_modeSEXP, SEXP loaded_forestSEXP, SEXP sparse_dataSEXP, SEXP sample_with_replacementSEXP, SEXP probabilitySEXP, SEXP unordered_variable_namesSEXP, SEXP use_unordered_variable_namesSEXP, SEXP save_memorySEXP, SEXP splitrule_rSEXP, SEXP case_weightsSEXP, SEXP use_case_weightsSEXP, SEXP predict_allSEXP) {
BEGIN_RCPP
Rcpp::RObject __result;
Expand All @@ -24,7 +24,7 @@ BEGIN_RCPP
Rcpp::traits::input_parameter< bool >::type write_forest(write_forestSEXP);
Rcpp::traits::input_parameter< uint >::type importance_mode_r(importance_mode_rSEXP);
Rcpp::traits::input_parameter< uint >::type min_node_size(min_node_sizeSEXP);
Rcpp::traits::input_parameter< std::vector<double>& >::type split_select_weights(split_select_weightsSEXP);
Rcpp::traits::input_parameter< std::vector<std::vector<double>>& >::type split_select_weights(split_select_weightsSEXP);
Rcpp::traits::input_parameter< bool >::type use_split_select_weights(use_split_select_weightsSEXP);
Rcpp::traits::input_parameter< std::vector<std::string>& >::type always_split_variable_names(always_split_variable_namesSEXP);
Rcpp::traits::input_parameter< bool >::type use_always_split_variable_names(use_always_split_variable_namesSEXP);
Expand Down
2 changes: 1 addition & 1 deletion ranger-r-package/ranger/src/rangerCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name,
Rcpp::NumericMatrix input_data, std::vector<std::string> variable_names, uint mtry, uint num_trees, bool verbose,
uint seed, uint num_threads, bool write_forest, uint importance_mode_r, uint min_node_size,
std::vector<double>& split_select_weights, bool use_split_select_weights,
std::vector<std::vector<double>>& split_select_weights, bool use_split_select_weights,
std::vector<std::string>& always_split_variable_names, bool use_always_split_variable_names,
std::string status_variable_name, bool prediction_mode, Rcpp::List loaded_forest, Rcpp::RawMatrix sparse_data,
bool sample_with_replacement, bool probability, std::vector<std::string>& unordered_variable_names,
Expand Down
17 changes: 17 additions & 0 deletions ranger-r-package/ranger/tests/testthat/test_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,20 @@ test_that("Mean of predict.all for regression is equal to forest prediction", {
expect_that(rowMeans(pred_trees$predictions), equals(pred_forest$predictions))
})

test_that("split select weights work", {
expect_that(ranger(Species ~ ., iris, num.trees = 5, split.select.weights = c(0.1, 0.2, 0.3, 0.4)),
not(throws_error()))
expect_that(ranger(Species ~ ., iris, num.trees = 5, split.select.weights = c(0.1, 0.2, 0.3)),
throws_error())
})

test_that("Tree-wise split select weights work", {
num.trees <- 5
weights <- replicate(num.trees, runif(ncol(iris)-1), simplify = FALSE)
expect_that(ranger(Species ~ ., iris, num.trees = num.trees, split.select.weights = weights),
not(throws_error()))

weights <- replicate(num.trees+1, runif(ncol(iris)-1), simplify = FALSE)
expect_that(ranger(Species ~ ., iris, num.trees = num.trees, split.select.weights = weights),
throws_error())
})
89 changes: 67 additions & 22 deletions source/src/Forest/Forest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode
setAlwaysSplitVariables(always_split_variable_names);
}

// TODO: Read 2d weights for tree-wise split select weights
// Load split select weights from file
if (!split_select_weights_file.empty()) {
std::vector<double> split_select_weights;
loadDoubleVectorFromFile(split_select_weights, split_select_weights_file);
if (split_select_weights.size() != num_variables - 1) {
std::vector<std::vector<double>> split_select_weights;
split_select_weights.resize(1);
loadDoubleVectorFromFile(split_select_weights[0], split_select_weights_file);
if (split_select_weights[0].size() != num_variables - 1) {
throw std::runtime_error("Number of split select weights is not equal to number of independent variables.");
}
setSplitWeightVector(split_select_weights);
Expand All @@ -135,7 +137,7 @@ void Forest::initCpp(std::string dependent_variable_name, MemoryMode memory_mode

void Forest::initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees,
std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size,
std::vector<double>& split_select_weights, std::vector<std::string>& always_split_variable_names,
std::vector<std::vector<double>>& split_select_weights, std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
std::vector<double>& case_weights, bool predict_all) {
Expand Down Expand Up @@ -235,6 +237,9 @@ void Forest::init(std::string dependent_variable_name, MemoryMode memory_mode, D
// Sort no split variables in ascending order
std::sort(no_split_variables.begin(), no_split_variables.end());

// Init split select weights
split_select_weights.push_back(std::vector<double>());

// Check if mtry is in valid range
if (this->mtry > num_variables - 1) {
throw std::runtime_error("mtry can not be larger than number of variables in data.");
Expand Down Expand Up @@ -292,7 +297,7 @@ void Forest::writeOutput() {
*verbose_out << "Overall OOB prediction error: " << overall_prediction_error << std::endl;
*verbose_out << std::endl;

if (!split_select_weights.empty()) {
if (!split_select_weights.empty() & !split_select_weights[0].empty()) {
*verbose_out
<< "Warning: Split select weights used. Variable importance measures are only comparable for variables with equal weights."
<< std::endl;
Expand Down Expand Up @@ -380,8 +385,17 @@ void Forest::grow() {
} else {
tree_seed = (i + 1) * seed;
}

// Get split select weights for tree
std::vector<double>* tree_split_select_weights;
if (split_select_weights.size() > 1) {
tree_split_select_weights = &split_select_weights[i];
} else {
tree_split_select_weights = &split_select_weights[0];
}

trees[i]->init(data, mtry, dependent_varID, num_samples, tree_seed, &deterministic_varIDs, &split_select_varIDs,
&split_select_weights, importance_mode, min_node_size, &no_split_variables, sample_with_replacement,
tree_split_select_weights, importance_mode, min_node_size, &no_split_variables, sample_with_replacement,
&is_ordered_variable, memory_saving_splitting, splitrule, &case_weights);
}

Expand Down Expand Up @@ -711,29 +725,58 @@ void Forest::loadFromFile(std::string filename) {
equalSplit(thread_ranges, 0, num_trees - 1, num_threads);
}

void Forest::setSplitWeightVector(std::vector<double>& split_select_weights) {
void Forest::setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights) {

// Size should be 1 x num_independent_variables or num_trees x num_independent_variables
if (split_select_weights.size() != 1 && split_select_weights.size() != num_trees) {
throw std::runtime_error("Size of split select weights not equal to 1 or number of trees.");
}

if (split_select_weights.size() != num_independent_variables) {
throw std::runtime_error("Number of split select weights not equal to number of independent variables.");
// Reserve space
if (split_select_weights.size() == 1) {
this->split_select_weights[0].resize(num_independent_variables);
} else {
this->split_select_weights.clear();
this->split_select_weights.resize(num_trees, std::vector<double>(num_independent_variables));
}
this->split_select_varIDs.resize(num_independent_variables);
deterministic_varIDs.reserve(num_independent_variables);

// Split up in deterministic and weighted variables, ignore zero weights
for (size_t i = 0; i < split_select_weights.size(); ++i) {
double weight = split_select_weights[i];
size_t varID = i;
for (auto& skip : no_split_variables) {
if (varID >= skip) {
++varID;
}

// Size should be 1 x num_independent_variables or num_trees x num_independent_variables
if (split_select_weights[i].size() != num_independent_variables) {
throw std::runtime_error("Number of split select weights not equal to number of independent variables.");
}

if (weight == 1) {
deterministic_varIDs.push_back(varID);
} else if (weight < 1 && weight > 0) {
this->split_select_varIDs.push_back(varID);
this->split_select_weights.push_back(weight);
} else if (weight < 0 || weight > 1) {
throw std::runtime_error("One or more split select weights not in range [0,1].");
for (size_t j = 0; j < split_select_weights[i].size(); ++j) {
double weight = split_select_weights[i][j];

if (i == 0) {
size_t varID = j;
for (auto& skip : no_split_variables) {
if (varID >= skip) {
++varID;
}
}

if (weight == 1) {
deterministic_varIDs.push_back(varID);
} else if (weight < 1 && weight > 0) {
this->split_select_varIDs[j] = varID;
this->split_select_weights[i][j] = weight;
} else if (weight < 0 || weight > 1) {
throw std::runtime_error("One or more split select weights not in range [0,1].");
}

} else {
if (weight < 1 && weight > 0) {
this->split_select_weights[i][j] = weight;
} else if (weight < 0 || weight > 1) {
throw std::runtime_error("One or more split select weights not in range [0,1].");
}
}
}
}

Expand All @@ -747,6 +790,8 @@ void Forest::setSplitWeightVector(std::vector<double>& split_select_weights) {

void Forest::setAlwaysSplitVariables(std::vector<std::string>& always_split_variable_names) {

deterministic_varIDs.reserve(num_independent_variables);

for (auto& variable_name : always_split_variable_names) {
size_t varID = data->getVariableID(variable_name);
deterministic_varIDs.push_back(varID);
Expand Down
6 changes: 3 additions & 3 deletions source/src/Forest/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class Forest {
std::string case_weights_file, bool predict_all);
void initR(std::string dependent_variable_name, Data* input_data, uint mtry, uint num_trees,
std::ostream* verbose_out, uint seed, uint num_threads, ImportanceMode importance_mode, uint min_node_size,
std::vector<double>& split_select_weights, std::vector<std::string>& always_split_variable_names,
std::vector<std::vector<double>>& split_select_weights, std::vector<std::string>& always_split_variable_names,
std::string status_variable_name, bool prediction_mode, bool sample_with_replacement,
std::vector<std::string>& unordered_variable_names, bool memory_saving_splitting, SplitRule splitrule,
std::vector<double>& case_weights, bool predict_all);
Expand Down Expand Up @@ -160,7 +160,7 @@ class Forest {
virtual void loadFromFileInternal(std::ifstream& infile) = 0;

// Set split select weights and variables to be always considered for splitting
void setSplitWeightVector(std::vector<double>& split_select_weights);
void setSplitWeightVector(std::vector<std::vector<double>>& split_select_weights);
void setAlwaysSplitVariables(std::vector<std::string>& always_split_variable_names);

// Show progress every few seconds
Expand Down Expand Up @@ -212,7 +212,7 @@ class Forest {
// Deterministic variables are always selected
std::vector<size_t> deterministic_varIDs;
std::vector<size_t> split_select_varIDs;
std::vector<double> split_select_weights;
std::vector<std::vector<double>> split_select_weights;

// Bootstrap weights
std::vector<double> case_weights;
Expand Down
2 changes: 1 addition & 1 deletion source/src/version.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#ifndef RANGER_VERSION
#define RANGER_VERSION "0.3.4"
#define RANGER_VERSION "0.3.5"
#endif

0 comments on commit e811c27

Please sign in to comment.