Skip to content

Commit

Permalink
Add an example to convert a pipeline with CatBoostClassifier (#968)
Browse files Browse the repository at this point in the history
* Add an example to convert a pipeline with CatBoostClassifier

Signed-off-by: xadupre <[email protected]>

* fix doc requirements.txt

Signed-off-by: xadupre <[email protected]>

---------

Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre authored Feb 27, 2023
1 parent 847f2dc commit e6c104f
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
autopep8
catboost
category_encoders
codecov
coverage
Expand Down
161 changes: 161 additions & 0 deletions docs/tutorial/plot_gexternal_catboost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-License-Identifier: Apache-2.0

"""
.. _example-catboost:
Convert a pipeline with a CatBoost classifier
=============================================
.. index:: CatBoost
:epkg:`sklearn-onnx` only converts :epkg:`scikit-learn` models into *ONNX*
but many libraries implement :epkg:`scikit-learn` API so that their models
can be included in a :epkg:`scikit-learn` pipeline. This example considers
a pipeline including a :epkg:`CatBoost` model. :epkg:`sklearn-onnx` can convert
the whole pipeline as long as it knows the converter associated to
a *CatBoostClassifier*. Let's see how to do it.
.. contents::
:local:
Train a CatBoostClassifier
++++++++++++++++++++++++++
"""
from pyquickhelper.helpgen.graphviz_helper import plot_graphviz
import numpy
from onnx.helper import get_attribute_value
from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from mlprodict.onnxrt import OnnxInference
import onnxruntime as rt
from skl2onnx import convert_sklearn, update_registered_converter
from skl2onnx.common.shape_calculator import calculate_linear_classifier_output_shapes # noqa
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, guess_tensor_type
from skl2onnx._parse import _apply_zipmap, _get_sklearn_operator_name
from catboost import CatBoostClassifier
from catboost.utils import convert_to_onnx_object

data = load_iris()
X = data.data[:, :2]
y = data.target

ind = numpy.arange(X.shape[0])
numpy.random.shuffle(ind)
X = X[ind, :].copy()
y = y[ind].copy()

pipe = Pipeline([('scaler', StandardScaler()),
('lgbm', CatBoostClassifier(n_estimators=3))])
pipe.fit(X, y)

######################################
# Register the converter for CatBoostClassifier
# +++++++++++++++++++++++++++++++++++++++++++++
#
# The model has no converter implemented in sklearn-onnx.
# We need to register the one coming from *CatBoost* itself.
# However, the converter does not follow sklearn-onnx design and
# needs to be wrapped.


def skl2onnx_parser_castboost_classifier(scope, model, inputs,
custom_parsers=None):
options = scope.get_options(model, dict(zipmap=True))
no_zipmap = isinstance(options['zipmap'], bool) and not options['zipmap']

alias = _get_sklearn_operator_name(type(model))
this_operator = scope.declare_local_operator(alias, model)
this_operator.inputs = inputs

label_variable = scope.declare_local_variable('label', Int64TensorType())
prob_dtype = guess_tensor_type(inputs[0].type)
probability_tensor_variable = scope.declare_local_variable('probabilities', prob_dtype)
this_operator.outputs.append(label_variable)
this_operator.outputs.append(probability_tensor_variable)
probability_tensor = this_operator.outputs

if no_zipmap:
return probability_tensor

return _apply_zipmap(options['zipmap'], scope, model,
inputs[0].type, probability_tensor)


def skl2onnx_convert_catboost(scope, operator, container):
"""
CatBoost returns an ONNX graph with a single node.
This function adds it to the main graph.
"""
onx = convert_to_onnx_object(operator.raw_operator)
opsets = {d.domain: d.version for d in onx.opset_import}
if '' in opsets and opsets[''] >= container.target_opset:
raise RuntimeError(
"CatBoost uses an opset more recent than the target one.")
if len(onx.graph.initializer) > 0 or len(onx.graph.sparse_initializer) > 0:
raise NotImplementedError(
"CatBoost returns a model initializers. This option is not implemented yet.")
if (len(onx.graph.node) not in (1, 2) or not onx.graph.node[0].op_type.startswith("TreeEnsemble") or
(len(onx.graph.node) == 2 and onx.graph.node[1].op_type != "ZipMap")):
types = ", ".join(map(lambda n: n.op_type, onx.graph.node))
raise NotImplementedError(
f"CatBoost returns {len(onx.graph.node)} != 1 (types={types}). "
f"This option is not implemented yet.")
node = onx.graph.node[0]
atts = {}
for att in node.attribute:
atts[att.name] = get_attribute_value(att)
container.add_node(
node.op_type, [operator.inputs[0].full_name],
[operator.outputs[0].full_name, operator.outputs[1].full_name],
op_domain=node.domain, op_version=opsets.get(node.domain, None),
**atts)


update_registered_converter(
CatBoostClassifier,
'CatBoostCatBoostClassifier',
calculate_linear_classifier_output_shapes,
skl2onnx_convert_catboost,
parser=skl2onnx_parser_castboost_classifier,
options={'nocl': [True, False], 'zipmap': [True, False, 'columns']})

##################################
# Convert
# +++++++

model_onnx = convert_sklearn(
pipe, 'pipeline_catboost',
[('input', FloatTensorType([None, 2]))],
target_opset={'': 12, 'ai.onnx.ml': 2})

# And save.
with open("pipeline_catboost.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())

###########################
# Compare the predictions
# +++++++++++++++++++++++
#
# Predictions with CatBoost.

print("predict", pipe.predict(X[:5]))
print("predict_proba", pipe.predict_proba(X[:1]))

##########################
# Predictions with onnxruntime.

sess = rt.InferenceSession("pipeline_catboost.onnx")

pred_onx = sess.run(None, {"input": X[:5].astype(numpy.float32)})
print("predict", pred_onx[0])
print("predict_proba", pred_onx[1][:1])

#############################
# Final graph
# +++++++++++

oinf = OnnxInference(model_onnx)
ax = plot_graphviz(oinf.to_dot())
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
1 change: 1 addition & 0 deletions docs/tutorial_1-5_external.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ model are part of a pipeline.
auto_tutorial/plot_gexternal_lightgbm
auto_tutorial/plot_gexternal_lightgbm_reg
auto_tutorial/plot_gexternal_xgboost
auto_tutorial/plot_gexternal_catboost

0 comments on commit e6c104f

Please sign in to comment.