From 2d1c26b38a0a48f0d719783dc216f942ca9c41ee Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Thu, 19 Dec 2024 07:31:38 +0100 Subject: [PATCH] Optimize memory reallocations (#11112) --- src/tree/common_row_partitioner.h | 12 ++++++++++-- src/tree/updater_quantile_hist.cc | 12 +++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/tree/common_row_partitioner.h b/src/tree/common_row_partitioner.h index 281861a367a1..45926b4ea22d 100644 --- a/src/tree/common_row_partitioner.h +++ b/src/tree/common_row_partitioner.h @@ -131,12 +131,20 @@ class CommonRowPartitioner { CommonRowPartitioner(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid, bool is_col_split) : base_rowid{_base_rowid}, is_col_split_{is_col_split} { - row_set_collection_.Clear(); + Reset(ctx, num_row, _base_rowid, is_col_split); + } + + void Reset(Context const* ctx, bst_idx_t num_row, bst_idx_t _base_rowid, bool is_col_split) { + base_rowid = _base_rowid; + is_col_split_ = is_col_split; + std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(num_row); bst_idx_t* p_row_indices = row_indices.data(); - common::Iota(ctx, p_row_indices, p_row_indices + row_indices.size(), base_rowid); + common::Iota(ctx, p_row_indices, p_row_indices + num_row, base_rowid); + + row_set_collection_.Clear(); row_set_collection_.Init(); if (is_col_split_) { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 277c844162dd..51b26b781148 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -346,15 +346,21 @@ class HistUpdater { void InitData(DMatrix *fmat, RegTree const *p_tree) { monitor_->Start(__func__); bst_bin_t n_total_bins{0}; - partitioner_.clear(); + size_t page_idx = 0; for (auto const &page : fmat->GetBatches(ctx_, HistBatch(param_))) { if (n_total_bins == 0) { n_total_bins = page.cut.TotalBins(); } else { CHECK_EQ(n_total_bins, page.cut.TotalBins()); } - partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, - fmat->Info().IsColumnSplit()); + if (page_idx < partitioner_.size()) { + partitioner_[page_idx].Reset(this->ctx_, page.Size(), page.base_rowid, + fmat->Info().IsColumnSplit()); + } else { + partitioner_.emplace_back(this->ctx_, page.Size(), page.base_rowid, + fmat->Info().IsColumnSplit()); + } + page_idx++; } histogram_builder_->Reset(ctx_, n_total_bins, 1, HistBatch(param_), collective::IsDistributed(), fmat->Info().IsColumnSplit(), hist_param_);