Skip to content

Commit

Permalink
Support stat with categorical features in graphviz dump. (#11053)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Dec 6, 2024
1 parent c7c158d commit 18b013a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
39 changes: 20 additions & 19 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,8 @@ class TextGenerator : public TreeGenerator {
return result;
}

std::string SplitNodeImpl(
RegTree const& tree, int32_t nid, std::string const& template_str,
std::string cond, uint32_t depth) const {
std::string SplitNodeImpl(RegTree const& tree, bst_node_t nid, std::string const& template_str,
std::string cond, uint32_t depth) const {
auto split_index = tree[nid].SplitIndex();
std::string const result = SuperT::Match(
template_str,
Expand Down Expand Up @@ -345,18 +344,16 @@ class TextGenerator : public TreeGenerator {
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}

std::string Categorical(RegTree const &tree, int32_t nid,
uint32_t depth) const override {
std::string Categorical(RegTree const& tree, bst_node_t nid, uint32_t depth) const override {
auto cats = GetSplitCategories(tree, nid);
std::string cats_str = PrintCatsAsSet(cats);
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}:{cond}] yes={right},no={left},missing={missing}";
std::string const result =
SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
std::string const result = SplitNodeImpl(tree, nid, kNodeTemplate, cats_str, depth);
return result;
}

std::string NodeStat(RegTree const& tree, int32_t nid) const override {
std::string NodeStat(RegTree const& tree, bst_node_t nid) const override {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
Expand Down Expand Up @@ -679,15 +676,12 @@ class GraphvizGenerator : public TreeGenerator {
std::string result;
if (this->with_stats_) {
CHECK(!tree.IsMultiTarget()) << MTNotImplemented();
result = SuperT::Match(
kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{stat}", Match("\ncover={cover}\ngain={gain}",
{{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)},
{"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}})},
{"{params}", param_.condition_node_params}});
result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{stat}", this->NodeStat(tree, nidx)},
{"{params}", param_.condition_node_params}});
} else {
result = SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
Expand All @@ -703,9 +697,15 @@ class GraphvizGenerator : public TreeGenerator {
return result;
};

std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
std::string NodeStat(RegTree const& tree, bst_node_t nidx) const override {
return Match("\ngain={gain}\ncover={cover}",
{{"{cover}", std::to_string(tree.Stat(nidx).sum_hess)},
{"{gain}", std::to_string(tree.Stat(nidx).loss_chg)}});
}

std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t /*depth*/) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
" {nid} [ label=\"{fname}:{cond}{stat}\" {params}]\n";
auto cats = GetSplitCategories(tree, nidx);
auto cats_str = PrintCatsAsSet(cats);
auto split_index = tree.SplitIndex(nidx);
Expand All @@ -714,6 +714,7 @@ class GraphvizGenerator : public TreeGenerator {
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{cond}", cats_str},
{"{stat}", this->NodeStat(tree, nidx)},
{"{params}", param_.condition_node_params}});

result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
Expand Down
1 change: 1 addition & 0 deletions tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ void TestCategoricalTreeDump(std::string format, std::string sep) {
ASSERT_NE(pos, std::string::npos);
pos = str.find(cond_str, pos + 1);
ASSERT_NE(pos, std::string::npos);
ASSERT_NE(str.find("gain"), std::string::npos);

if (format == "json") {
// Make sure it's valid JSON
Expand Down

0 comments on commit 18b013a

Please sign in to comment.