Skip to content

Commit

Permalink
1) Fixed Predictor lifecycle
Browse files Browse the repository at this point in the history
2) Fixed Boosting trees initialization

#5482
  • Loading branch information
AndreyOrb committed Jan 8, 2025
1 parent e0c34e7 commit 2dda961
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
13 changes: 12 additions & 1 deletion src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,11 +433,18 @@ class GBDT : public GBDTBase {
num_iteration_for_pred_ = num_iteration_for_pred_ - start_iteration;
}
start_iteration_for_pred_ = start_iteration;
if (is_pred_contrib) {

if (is_pred_contrib && !models_initialized_) {
std::lock_guard<std::mutex> lock(instance_mutex_);
if (models_initialized_)
return;

#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
models_[i]->RecomputeMaxDepth();
}

models_initialized_ = true;
}
}

Expand Down Expand Up @@ -548,6 +555,10 @@ class GBDT : public GBDTBase {
int max_feature_idx_;
/*! \brief Parser config file content */
std::string parser_config_str_ = "";
/*! \brief Are the models initialized (passed RecomputeMaxDepth phase) */
bool models_initialized_ = false;
/*! \brief Mutex for exclusive models initialization */
std::mutex instance_mutex_;

#ifdef USE_CUDA
/*! \brief First order derivative of training data */
Expand Down
10 changes: 5 additions & 5 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class Booster {
*out_len = single_row_predictor->num_pred_in_one_row;
}

Predictor CreatePredictor(int start_iteration, int num_iteration, int predict_type, int ncol, const Config& config) const {
std::shared_ptr<Predictor> CreatePredictor(int start_iteration, int num_iteration, int predict_type, int ncol, const Config& config) const {
if (!config.predict_disable_shape_check && ncol != boosting_->MaxFeatureIdx() + 1) {
Log::Fatal("The number of features in data (%d) is not the same as it was in training data (%d).\n" \
"You can set ``predict_disable_shape_check=true`` to discard this error, but please be aware what you are doing.", ncol, boosting_->MaxFeatureIdx() + 1);
Expand All @@ -478,7 +478,7 @@ class Booster {
is_raw_score = false;
}

return Predictor(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
return std::make_shared<Predictor>(boosting_.get(), start_iteration, num_iteration, is_raw_score, is_predict_leaf, predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
}

Expand All @@ -496,7 +496,7 @@ class Booster {
predict_contrib = true;
}
int64_t num_pred_in_one_row = boosting_->NumPredictOneRow(start_iteration, num_iteration, is_predict_leaf, predict_contrib);
auto pred_fun = predictor.GetPredictFunction();
auto pred_fun = predictor->GetPredictFunction();
OMP_INIT_EX();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
for (int i = 0; i < nrow; ++i) {
Expand All @@ -517,7 +517,7 @@ class Booster {
int32_t** out_indices, void** out_data, int data_type,
bool* is_data_float32_ptr, int num_matrices) const {
auto predictor = CreatePredictor(start_iteration, num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction();
auto pred_sparse_fun = predictor->GetPredictSparseFunction();
std::vector<std::vector<std::unordered_map<int, double>>>& agg = *agg_ptr;
OMP_INIT_EX();
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
Expand Down Expand Up @@ -652,7 +652,7 @@ class Booster {
// Get the number of trees per iteration (for multiclass scenario we output multiple sparse matrices)
int num_matrices = boosting_->NumModelPerIteration();
auto predictor = CreatePredictor(start_iteration, num_iteration, predict_type, ncol, config);
auto pred_sparse_fun = predictor.GetPredictSparseFunction();
auto pred_sparse_fun = predictor->GetPredictSparseFunction();
bool is_col_ptr_int32 = false;
bool is_data_float32 = false;
int num_output_cols = ncol + 1;
Expand Down

0 comments on commit 2dda961

Please sign in to comment.