Skip to content

Commit

Permalink
Cleanup CPU predict function. (#11139)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jan 11, 2025
1 parent b4a7cd1 commit 712e39d
Show file tree
Hide file tree
Showing 19 changed files with 324 additions and 465 deletions.
27 changes: 5 additions & 22 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by XGBoost Contributors
* Copyright 2014-2025, XGBoost Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
Expand All @@ -15,10 +15,8 @@
#include <xgboost/model.h>

#include <vector>
#include <utility>
#include <string>
#include <functional>
#include <unordered_map>
#include <memory>

namespace xgboost {
Expand All @@ -42,13 +40,13 @@ class GradientBooster : public Model, public Configurable {
public:
/*! \brief virtual destructor */
~GradientBooster() override = default;
/*!
* \brief Set the configuration of gradient boosting.
/**
* @brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
*
* \param cfg configurations on both training and model parameters.
* @param cfg configurations on both training and model parameters.
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
virtual void Configure(Args const& cfg) = 0;
/*!
* \brief load model from stream
* \param fi input stream.
Expand Down Expand Up @@ -117,21 +115,6 @@ class GradientBooster : public Model, public Configurable {
bst_layer_t) const {
LOG(FATAL) << "Inplace predict is not supported by the current booster.";
}
/*!
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is usually
* more efficient than online prediction
* This function is NOT threadsafe, make sure you only call from one thread
*
* \param inst the instance you want to predict
* \param out_preds output vector to hold the predictions
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \sa Predict
*/
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned layer_begin, unsigned layer_end) = 0;
/*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor
Expand Down
58 changes: 19 additions & 39 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2024, XGBoost Contributors
* Copyright 2017-2025, XGBoost Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
Expand Down Expand Up @@ -28,7 +28,7 @@ namespace xgboost {
*/
struct PredictionCacheEntry {
// A storage for caching prediction values
HostDeviceVector<bst_float> predictions;
HostDeviceVector<float> predictions;
// The version of current cache, corresponding number of layers of trees
std::uint32_t version{0};

Expand Down Expand Up @@ -91,7 +91,7 @@ class Predictor {
* \param out_predt Prediction vector to be initialized.
* \param model Tree model used for prediction.
*/
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<float>* out_predt,
const gbm::GBTreeModel& model) const;

/**
Expand All @@ -105,8 +105,8 @@ class Predictor {
* \param tree_end The tree end index.
*/
virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t tree_end = 0) const = 0;
gbm::GBTreeModel const& model, bst_tree_t tree_begin,
bst_tree_t tree_end = 0) const = 0;

/**
* \brief Inplace prediction.
Expand All @@ -123,25 +123,7 @@ class Predictor {
*/
virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
float missing, PredictionCacheEntry* out_preds,
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
/**
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is
* usually more efficient than online prediction This function is NOT
* threadsafe, make sure you only call from one thread.
*
* \param inst The instance to predict.
* \param [in,out] out_preds The output preds.
* \param model The model to predict from
* \param tree_end (Optional) The tree end index.
* \param is_column_split (Optional) If the data is split column-wise.
*/

virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
bool is_column_split = false) const = 0;
bst_tree_t tree_begin = 0, bst_tree_t tree_end = 0) const = 0;

/**
* \brief predict the leaf index of each tree, the output will be nsample *
Expand All @@ -153,9 +135,8 @@ class Predictor {
* \param tree_end (Optional) The tree end index.
*/

virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model,
unsigned tree_end = 0) const = 0;
virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<float>* out_preds,
gbm::GBTreeModel const& model, bst_tree_t tree_end = 0) const = 0;

/**
* \brief feature contributions to individual predictions; the output will be
Expand All @@ -172,18 +153,17 @@ class Predictor {
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/

virtual void
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;

virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false) const = 0;
virtual void PredictContribution(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
gbm::GBTreeModel const& model, bst_tree_t tree_end = 0,
std::vector<float> const* tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;

virtual void PredictInteractionContributions(DMatrix* dmat, HostDeviceVector<float>* out_contribs,
gbm::GBTreeModel const& model,
bst_tree_t tree_end = 0,
std::vector<float> const* tree_weights = nullptr,
bool approximate = false) const = 0;

/**
* \brief Creates a new Predictor*.
Expand Down
59 changes: 23 additions & 36 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2024, XGBoost Contributors
* Copyright 2014-2025, XGBoost Contributors
* \file tree_model.h
* \brief model structure for tree
* \author Tianqi Chen
Expand All @@ -23,7 +23,6 @@
#include <memory> // for make_unique
#include <stack>
#include <string>
#include <tuple>
#include <vector>

namespace xgboost {
Expand Down Expand Up @@ -562,7 +561,7 @@ class RegTree : public Model {
* \brief fill the vector with sparse vector
* \param inst The sparse instance to fill.
*/
void Fill(const SparsePage::Inst& inst);
void Fill(SparsePage::Inst const& inst);

/*!
* \brief drop the trace after fill, must be called after fill.
Expand All @@ -587,18 +586,17 @@ class RegTree : public Model {
*/
[[nodiscard]] bool IsMissing(size_t i) const;
[[nodiscard]] bool HasMissing() const;
void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }

[[nodiscard]] common::Span<float> Data() { return data_; }

private:
/*!
* \brief a union value of value and flag
* when flag == -1, this indicate the value is missing
/**
* @brief A dense vector for a single sample.
*
* It's nan if the value is missing.
*/
union Entry {
bst_float fvalue;
int flag;
};
std::vector<Entry> data_;
std::vector<float> data_;
bool has_missing_;
};

Expand Down Expand Up @@ -793,46 +791,35 @@ class RegTree : public Model {
};

inline void RegTree::FVec::Init(size_t size) {
Entry e; e.flag = -1;
data_.resize(size);
std::fill(data_.begin(), data_.end(), e);
std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
has_missing_ = true;
}

inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
size_t feature_count = 0;
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].fvalue = entry.fvalue;
++feature_count;
inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
auto p_data = inst.data();
auto p_out = data_.data();

for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
auto const& entry = p_data[i];
p_out[entry.index] = entry.fvalue;
}
has_missing_ = data_.size() != feature_count;
has_missing_ = data_.size() != inst.size();
}

inline void RegTree::FVec::Drop() {
Entry e{};
e.flag = -1;
std::fill_n(data_.data(), data_.size(), e);
has_missing_ = true;
}
inline void RegTree::FVec::Drop() { this->Init(this->Size()); }

inline size_t RegTree::FVec::Size() const {
return data_.size();
}

inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
inline float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i];
}

inline bool RegTree::FVec::IsMissing(size_t i) const {
return data_[i].flag == -1;
}
inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }

inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}
inline bool RegTree::FVec::HasMissing() const { return has_missing_; }

// Multi-target tree not yet implemented error
inline StringView MTNotImplemented() {
Expand Down
22 changes: 7 additions & 15 deletions plugin/sycl/predictor/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class Predictor : public xgboost::Predictor {
}

void PredictBatch(DMatrix *dmat, PredictionCacheEntry *predts,
const gbm::GBTreeModel &model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
const gbm::GBTreeModel &model, bst_tree_t tree_begin,
bst_tree_t tree_end = 0) const override {
auto* out_preds = &predts->predictions;
out_preds->SetDevice(ctx_->Device());
if (tree_end == 0) {
Expand All @@ -221,28 +221,20 @@ class Predictor : public xgboost::Predictor {

bool InplacePredict(std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds, uint32_t tree_begin,
unsigned tree_end) const override {
PredictionCacheEntry *out_preds, bst_tree_t tree_begin,
bst_tree_t tree_end) const override {
LOG(WARNING) << "InplacePredict is not yet implemented for SYCL. CPU Predictor is used.";
return cpu_predictor->InplacePredict(p_m, model, missing, out_preds, tree_begin, tree_end);
}

void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit,
bool is_column_split) const override {
LOG(WARNING) << "PredictInstance is not yet implemented for SYCL. CPU Predictor is used.";
cpu_predictor->PredictInstance(inst, out_preds, model, ntree_limit, is_column_split);
}

void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
const gbm::GBTreeModel& model, bst_tree_t ntree_limit) const override {
LOG(WARNING) << "PredictLeaf is not yet implemented for SYCL. CPU Predictor is used.";
cpu_predictor->PredictLeaf(p_fmat, out_preds, model, ntree_limit);
}

void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
const gbm::GBTreeModel& model, uint32_t ntree_limit,
const gbm::GBTreeModel& model, bst_tree_t ntree_limit,
const std::vector<bst_float>* tree_weights,
bool approximate, int condition,
unsigned condition_feature) const override {
Expand All @@ -252,7 +244,7 @@ class Predictor : public xgboost::Predictor {
}

void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
const gbm::GBTreeModel& model, bst_tree_t ntree_limit,
const std::vector<bst_float>* tree_weights,
bool approximate) const override {
LOG(WARNING) << "PredictInteractionContributions is not yet implemented for SYCL. "
Expand Down
6 changes: 3 additions & 3 deletions src/common/column_matrix.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2024, XGBoost Contributors
* Copyright 2017-2025, XGBoost Contributors
* \file column_matrix.h
* \brief Utility for fast column-wise access
* \author Philip Cho
Expand Down Expand Up @@ -45,15 +45,15 @@ class Column {
virtual ~Column() = default;

[[nodiscard]] bst_bin_t GetGlobalBinIdx(size_t idx) const {
return index_base_ + static_cast<bst_bin_t>(index_[idx]);
return index_base_ + static_cast<bst_bin_t>(index_.data()[idx]);
}

/* returns number of elements in column */
[[nodiscard]] size_t Size() const { return index_.size(); }

private:
/* bin indexes in range [0, max_bins - 1] */
common::Span<const BinIdxType> index_;
common::Span<BinIdxType const> index_;
/* bin index offset for specific feature */
bst_bin_t const index_base_;
};
Expand Down
5 changes: 1 addition & 4 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ class HistogramCuts {
[[nodiscard]] bst_bin_t FeatureBins(bst_feature_t feature) const {
return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature];
}
[[nodiscard]] bst_feature_t NumFeatures() const {
CHECK_EQ(this->min_vals_.Size(), this->cut_ptrs_.Size() - 1);
return this->min_vals_.Size();
}
[[nodiscard]] bst_feature_t NumFeatures() const { return this->cut_ptrs_.Size() - 1; }

std::vector<uint32_t> const& Ptrs() const { return cut_ptrs_.ConstHostVector(); }
std::vector<float> const& Values() const { return cut_values_.ConstHostVector(); }
Expand Down
10 changes: 9 additions & 1 deletion src/common/ref_resource_view.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023-2024, XGBoost Contributors
* Copyright 2023-2025, XGBoost Contributors
*/
#ifndef XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
#define XGBOOST_COMMON_REF_RESOURCE_VIEW_H_
Expand Down Expand Up @@ -88,6 +88,14 @@ class RefResourceView {

[[nodiscard]] value_type& operator[](size_type i) { return ptr_[i]; }
[[nodiscard]] value_type const& operator[](size_type i) const { return ptr_[i]; }
[[nodiscard]] value_type& at(size_type i) { // NOLINT
SPAN_LT(i, this->size_);
return ptr_[i];
}
[[nodiscard]] value_type const& at(size_type i) const { // NOLINT
SPAN_LT(i, this->size_);
return ptr_[i];
}

/**
* @brief Get the underlying resource.
Expand Down
Loading

0 comments on commit 712e39d

Please sign in to comment.