From ddf9ea63e8a6f9c59da4bb8d9edc72997bb81345 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 May 2025 02:22:01 -0500 Subject: [PATCH 1/4] Initial implementation of categorical feature re-weighting in the GFR algorithm --- include/stochtree/tree_sampler.h | 44 ++++-- tools/debug/gfr_mcmc_categorical_comparison.R | 131 ++++++++++++++++++ 2 files changed, 167 insertions(+), 8 deletions(-) create mode 100644 tools/debug/gfr_mcmc_categorical_comparison.R diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index a47660ea..c3e9e7b1 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -471,8 +471,8 @@ template & log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, std::vector& cutpoint_feature_types, - data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector& variable_weights, - std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args + data_size_t& valid_cutpoint_count, std::vector& feature_cutpoint_counts, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, + std::vector& variable_weights, std::vector& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args ) { // Initialize sufficient statistics LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...); @@ -496,6 +496,7 @@ static inline void EvaluateAllPossibleSplits( int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf(); // Compute sufficient statistics for each possible split + data_size_t feature_cutpoints; data_size_t num_cutpoints = 0; bool valid_split = false; data_size_t node_row_iter; @@ -509,6 +510,8 @@ static inline void EvaluateAllPossibleSplits( double log_split_eval = 0.0; double split_log_ml; for (int j = 0; j < covariates.cols(); j++) { + // Reset feature cutpoint counter + feature_cutpoints = 0; if (std::abs(variable_weights.at(j)) > kEpsilon) { // Enumerate cutpoint strides @@ -542,6 +545,7 @@ static inline void EvaluateAllPossibleSplits( valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) && right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf)); if (valid_split) { + feature_cutpoints++; num_cutpoints++; // Add to split rule vector cutpoint_feature_types.push_back(feature_type); @@ -553,7 +557,8 @@ static inline void EvaluateAllPossibleSplits( } } } - + // Add feature_cutpoints to feature_cutpoint_counts + feature_cutpoint_counts.push_back(feature_cutpoints); } // Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper) @@ -570,16 +575,38 @@ template & log_cutpoint_evaluations, std::vector& cutpoint_features, std::vector& cutpoint_values, - std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& variable_weights, - std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { + std::vector& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector& feature_cutpoint_counts, + std::vector& variable_weights, std::vector& feature_types, CutpointGridContainer& cutpoint_grid_container, + LeafSuffStatConstructorArgs&... leaf_suff_stat_args) { // Evaluate all possible cutpoints according to the leaf node model, // recording their log-likelihood and other split information in a series of vectors. // The last element of these vectors concerns the "no-split" option. EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations, - cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, + cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args... ); + + // Compute weighting adjustments for low-cardinality categorical features + // Check if the dataset has continuous features, ignore this adjustment if not + bool has_continuous_features = false; + int max_feature_cutpoint_count = 0; + for (int j = 0; j < feature_types.size(); j++) { + if (feature_types.at(j) == FeatureType::kNumeric) { + has_continuous_features = true; + if (feature_cutpoint_counts[j] > max_feature_cutpoint_count) max_feature_cutpoint_count = feature_cutpoint_counts[j]; + } + } + if (has_continuous_features) { + double feature_weight; + for (data_size_t i = 0; i < valid_cutpoint_count; i++) { + // Determine whether the feature is categorical (and thus needs to be re-weighted) + if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) || (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) { + feature_weight = ((double) max_feature_cutpoint_count) / ((double) feature_cutpoint_counts[cutpoint_features[i]]); + log_cutpoint_evaluations[i] += std::log(feature_weight); + } + } + } // Compute an adjustment to reflect the no split prior probability and the number of cutpoints double bart_prior_no_split_adj; @@ -614,12 +641,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count; + std::vector feature_cutpoint_counts; CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); EvaluateCutpoints( tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance, cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features, - cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types, - cutpoint_grid_container, leaf_suff_stat_args... + cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, variable_weights, + feature_types, cutpoint_grid_container, leaf_suff_stat_args... ); // TODO: maybe add some checks here? diff --git a/tools/debug/gfr_mcmc_categorical_comparison.R b/tools/debug/gfr_mcmc_categorical_comparison.R new file mode 100644 index 00000000..12e813d1 --- /dev/null +++ b/tools/debug/gfr_mcmc_categorical_comparison.R @@ -0,0 +1,131 @@ +################################################################################ +## Comparison of GFR / warm start with pure MCMC on datasets with a +## mix of numeric features and low-cardinality categorical features. +################################################################################ + +# Load libraries +library(stochtree) + +# Generate data +n <- 500 +p_continuous <- 5 +p_binary <- 2 +p_ordered_cat <- 2 +p <- p_continuous + p_binary + p_ordered_cat +stopifnot(p_continuous >= 3) +stopifnot(p_binary >= 2) +stopifnot(p_ordered_cat >= 1) +x_continuous <- matrix( + runif(n*p_continuous), + ncol = p_continuous +) +x_binary <- matrix( + rbinom(n*p_binary, size = 1, prob = 0.5), + ncol = p_binary +) +x_ordered_cat <- matrix( + sample(1:5, size = n*p_ordered_cat, replace = T), + ncol = p_ordered_cat +) +X_matrix <- cbind(x_continuous, x_binary, x_ordered_cat) +X_df <- as.data.frame(X_matrix) +colnames(X_df) <- paste0("x", 1:p) +for (i in (p_continuous+1):(p_continuous+p_binary+p_ordered_cat)) { + X_df[,i] <- factor(X_df[,i], ordered = T) +} +f_x_cont <- (2 + 4*x_continuous[,1] - 6*(x_continuous[,2] < 0) + + 6*(x_continuous[,2] >= 0) + 5*(abs(x_continuous[,3]) - sqrt(2/pi))) +f_x_binary <- -1.5 + 1*x_binary[,1] + 2*x_binary[,2] +f_x_ordered_cat <- 3 - 1*x_ordered_cat[,1] +pct_var_cont <- 1/3 +pct_var_binary <- 1/3 +pct_var_ordered_cat <- 1/3 +stopifnot(pct_var_cont + pct_var_binary + pct_var_ordered_cat == 1.0) +total_var <- var(f_x_cont+f_x_binary+f_x_ordered_cat) +f_x_cont_rescaled <- f_x_cont * sqrt( + pct_var_cont / (var(f_x_cont) / total_var) +) +f_x_binary_rescaled <- f_x_binary * sqrt( + pct_var_binary / (var(f_x_binary) / total_var) +) +f_x_ordered_cat_rescaled <- f_x_ordered_cat * sqrt( + pct_var_ordered_cat / (var(f_x_ordered_cat) / total_var) +) +E_y <- f_x_cont_rescaled + f_x_binary_rescaled + f_x_ordered_cat_rescaled +# var(f_x_cont_rescaled) / var(E_y) +# var(f_x_binary_rescaled) / var(E_y) +# var(f_x_ordered_cat_rescaled) / var(E_y) +snr <- 3 +epsilon <- rnorm(n, 0, 1) * sd(E_y) / snr +y <- E_y + epsilon +jitter_eps <- 0.1 +x_binary_jitter <- x_binary + matrix( + runif(n*p_binary, -jitter_eps, jitter_eps), ncol = p_binary +) +x_ordered_cat_jitter <- x_ordered_cat + matrix( + runif(n*p_ordered_cat, -jitter_eps, jitter_eps), ncol = p_ordered_cat +) +X_matrix_jitter <- cbind(x_continuous, x_binary_jitter, x_ordered_cat_jitter) +X_df_jitter <- as.data.frame(X_matrix_jitter) +colnames(X_df_jitter) <- paste0("x", 1:p) + +# Test-train split +test_set_pct <- 0.2 +n_test <- round(test_set_pct*n) +n_train <- n - n_test +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) +train_inds <- (1:n)[!((1:n) %in% test_inds)] +X_df_test <- X_df[test_inds,] +X_df_train <- X_df[train_inds,] +X_df_jitter_test <- X_df_jitter[test_inds,] +X_df_jitter_train <- X_df_jitter[train_inds,] +y_test <- y[test_inds] +y_train <- y[train_inds] + +# Fit BART with warmstart on the original data +ws_bart_fit <- bart(X_train = X_df_train, y_train = y_train, + X_test = X_df_test, num_gfr = 15, + num_burnin = 0, num_mcmc = 100) + +# Fit BART with MCMC only on the original data +bart_fit <- bart(X_train = X_df_train, y_train = y_train, + X_test = X_df_test, num_gfr = 0, + num_burnin = 2000, num_mcmc = 100) + +# Fit BART with warmstart on the jittered data +ws_bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train, + X_test = X_df_jitter_test, num_gfr = 15, + num_burnin = 0, num_mcmc = 100) + +# Fit BART with MCMC only on the jittered data +bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train, + X_test = X_df_jitter_test, num_gfr = 0, + num_burnin = 2000, num_mcmc = 100) + +# Compare the variable split counds +ws_bart_fit$mean_forests$get_aggregate_split_counts(p) +bart_fit$mean_forests$get_aggregate_split_counts(p) +ws_bart_jitter_fit$mean_forests$get_aggregate_split_counts(p) +bart_jitter_fit$mean_forests$get_aggregate_split_counts(p) + +# Compute out-of-sample RMSE +sqrt(mean((rowMeans(ws_bart_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(bart_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(ws_bart_jitter_fit$y_hat_test) - y_test)^2)) +sqrt(mean((rowMeans(bart_jitter_fit$y_hat_test) - y_test)^2)) + +# Compare sigma traceplots +sigma_min <- min(c(ws_bart_fit$sigma2_global_samples, + bart_fit$sigma2_global_samples, + ws_bart_jitter_fit$sigma2_global_samples, + bart_jitter_fit$sigma2_global_samples)) +sigma_max <- max(c(ws_bart_fit$sigma2_global_samples, + bart_fit$sigma2_global_samples, + ws_bart_jitter_fit$sigma2_global_samples, + bart_jitter_fit$sigma2_global_samples)) +plot(ws_bart_fit$sigma2_global_samples, + ylim = c(sigma_min - 0.1, sigma_max + 0.1), + type = "line", col = "black") +lines(bart_fit$sigma2_global_samples, col = "blue") +lines(ws_bart_jitter_fit$sigma2_global_samples, col = "green") +lines(bart_jitter_fit$sigma2_global_samples, col = "red") From 1ea27d7ce703985258a7182c065fa90bcabe212a Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 May 2025 02:27:23 -0500 Subject: [PATCH 2/4] Adjusted C++ unit tests to reflect updated API --- test/cpp/test_model.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/test/cpp/test_model.cpp b/test/cpp/test_model.cpp index 0e729bef..a320b8ac 100644 --- a/test/cpp/test_model.cpp +++ b/test/cpp/test_model.cpp @@ -44,6 +44,7 @@ TEST(LeafConstantModel, FullEnumeration) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -52,7 +53,7 @@ TEST(LeafConstantModel, FullEnumeration) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered @@ -109,7 +110,7 @@ TEST(LeafConstantModel, CutpointThinning) { StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau); // Evaluate all possible cutpoints - StochTree::EvaluateAllPossibleSplits( + StochTree::( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types ); @@ -162,6 +163,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -170,7 +172,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered @@ -222,6 +224,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model @@ -230,7 +233,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) { // Evaluate all possible cutpoints StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); From 755e6f0373eee9a5c5cb6881b6f03e4947745060 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Thu, 8 May 2025 02:39:17 -0500 Subject: [PATCH 3/4] Fixed C++ unit test typo --- test/cpp/test_model.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/cpp/test_model.cpp b/test/cpp/test_model.cpp index a320b8ac..9746f67c 100644 --- a/test/cpp/test_model.cpp +++ b/test/cpp/test_model.cpp @@ -104,15 +104,16 @@ TEST(LeafConstantModel, CutpointThinning) { std::vector cutpoint_values; std::vector cutpoint_feature_types; StochTree::data_size_t valid_cutpoint_count = 0; + std::vector feature_cutpoint_counts; StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size); // Initialize a leaf model StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau); // Evaluate all possible cutpoints - StochTree::( + StochTree::EvaluateAllPossibleSplits( dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values, - cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types + cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types ); // Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered From 6911cf5b93f94c2657cffd9a8d8aaf8eb37c6f4e Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Fri, 9 May 2025 01:15:12 -0500 Subject: [PATCH 4/4] Formatted C++ logic --- include/stochtree/tree_sampler.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/stochtree/tree_sampler.h b/include/stochtree/tree_sampler.h index c3e9e7b1..f41fcc28 100644 --- a/include/stochtree/tree_sampler.h +++ b/include/stochtree/tree_sampler.h @@ -601,7 +601,9 @@ static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafMod double feature_weight; for (data_size_t i = 0; i < valid_cutpoint_count; i++) { // Determine whether the feature is categorical (and thus needs to be re-weighted) - if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) || (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) { + if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) || + (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) { + // Weight according to max continuous feature cutpoint count / categorical feature cutpoint count feature_weight = ((double) max_feature_cutpoint_count) / ((double) feature_cutpoint_counts[cutpoint_features[i]]); log_cutpoint_evaluations[i] += std::log(feature_weight); }