Skip to content

Commit

Permalink
feat: add support to sklearn TargetEncoder
Browse files Browse the repository at this point in the history
Signed-off-by: boccaff <[email protected]>
  • Loading branch information
boccaff committed Nov 15, 2024
1 parent 2bd2d8f commit 95111ff
Show file tree
Hide file tree
Showing 8 changed files with 448 additions and 0 deletions.
2 changes: 2 additions & 0 deletions skl2onnx/_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
LabelEncoder,
Normalizer,
OneHotEncoder,
TargetEncoder,
)

try:
Expand Down Expand Up @@ -511,6 +512,7 @@ def build_sklearn_operator_name_map():
RidgeClassifierCV: "SklearnLinearClassifier",
SGDRegressor: "SklearnLinearRegressor",
StandardScaler: "SklearnScaler",
TargetEncoder: "SklearnTargetEncoder",
TheilSenRegressor: "SklearnLinearRegressor",
}
)
Expand Down
2 changes: 2 additions & 0 deletions skl2onnx/operator_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from . import sgd_oneclass_svm
from . import stacking
from . import support_vector_machines
from . import target_encoder
from . import text_vectoriser
from . import tfidf_transformer
from . import tfidf_vectoriser
Expand Down Expand Up @@ -128,6 +129,7 @@
sgd_oneclass_svm,
stacking,
support_vector_machines,
target_encoder,
text_vectoriser,
tfidf_transformer,
tfidf_vectoriser,
Expand Down
100 changes: 100 additions & 0 deletions skl2onnx/operator_converters/target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np

from ..common._apply_operation import apply_cast, apply_concat, apply_reshape
from ..common._container import ModelComponentContainer
from ..common.data_types import (
FloatTensorType,
Int64TensorType,
)
from ..common._registration import register_converter
from ..common._topology import Scope, Operator
from ..proto import onnx_proto


def convert_sklearn_target_encoder(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
op = operator.raw_operator
result = []
input_idx = 0
dimension_idx = 0

# NotImplementedError( # TODO: assert that we have binary output
if (op.target_type_ == "multiclass") or (
isinstance(op.classes_.dtype, np.int64) and (len(op.classes_) > 2)
):
raise NotImplementedError("multiclass TargetEncoder is not supported")
for categories, encodings in zip(op.categories_, op.encodings_):
if len(categories) == 0:
continue

current_input = operator.inputs[input_idx]
if current_input.get_second_dimension() == 1:
feature_column = current_input
input_idx += 1
else:
index_name = scope.get_unique_variable_name("index")
container.add_initializer(
index_name, onnx_proto.TensorProto.INT64, [], [dimension_idx]
)

feature_column = scope.declare_local_variable(
"feature_column",
current_input.type.__class__([current_input.get_first_dimension(), 1]),
)

container.add_node(
"ArrayFeatureExtractor",
[current_input.onnx_name, index_name],
feature_column.onnx_name,
op_domain="ai.onnx.ml",
name=scope.get_unique_operator_name("ArrayFeatureExtractor"),
)

dimension_idx += 1
if dimension_idx == current_input.get_second_dimension():
dimension_idx = 0
input_idx += 1

attrs = {"name": scope.get_unique_operator_name("LabelEncoder")}
if isinstance(feature_column.type, FloatTensorType):
attrs["keys_floats"] = np.array([float(s) for s in categories], dtype=np.float32)
elif isinstance(feature_column.type, Int64TensorType):
attrs["keys_int64s"] = np.array([int(s) for s in categories], dtype=np.int64)
else:
attrs["keys_strings"] = np.array([str(s).encode("utf-8") for s in categories])
attrs["values_floats"] = encodings
attrs["default_float"] = op.target_mean_

result.append(scope.get_unique_variable_name("ordinal_output"))
label_encoder_output = scope.get_unique_variable_name("label_encoder")

container.add_node(
"LabelEncoder",
feature_column.onnx_name,
label_encoder_output,
op_domain="ai.onnx.ml",
op_version=2,
**attrs,
)
apply_reshape(
scope,
label_encoder_output,
result[-1],
container,
desired_shape=(-1, 1),
)

concat_result_name = scope.get_unique_variable_name("concat_result")
apply_concat(scope, result, concat_result_name, container, axis=1)
apply_cast(
scope,
concat_result_name,
operator.output_full_names,
container,
to=onnx_proto.TensorProto.FLOAT,
)


register_converter("SklearnTargetEncoder", convert_sklearn_target_encoder)
2 changes: 2 additions & 0 deletions skl2onnx/shape_calculators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from . import sgd_oneclass_svm
from . import svd
from . import support_vector_machines
from . import target_encoder
from . import text_vectorizer
from . import tuned_threshold_classifier
from . import tfidf_transformer
Expand Down Expand Up @@ -99,6 +100,7 @@
sgd_oneclass_svm,
svd,
support_vector_machines,
target_encoder,
text_vectorizer,
tfidf_transformer,
tuned_threshold_classifier,
Expand Down
30 changes: 30 additions & 0 deletions skl2onnx/shape_calculators/target_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: Apache-2.0


import copy
from ..common._registration import register_shape_calculator
from ..common.data_types import FloatTensorType
from ..common.data_types import Int64TensorType, StringTensorType
from ..common.utils import check_input_and_output_numbers
from ..common.utils import check_input_and_output_types


def calculate_sklearn_target_encoder_output_shapes(operator):
"""
This function just copy the input shape to the output because target
encoder only alters input features' values, not their shape.
"""
check_input_and_output_numbers(operator, output_count_range=1)
check_input_and_output_types(
operator, good_input_types=[FloatTensorType, Int64TensorType, StringTensorType]
)

N = operator.inputs[0].get_first_dimension()
shape = [N, len(operator.raw_operator.categories_)]

operator.outputs[0].type = FloatTensorType(shape=shape)


register_shape_calculator(
"SklearnTargetEncoder", calculate_sklearn_target_encoder_output_shapes
)
Loading

0 comments on commit 95111ff

Please sign in to comment.