-
Notifications
You must be signed in to change notification settings - Fork 111
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add QuadraticDiscriminantAnalysis converter (#915)
* add QDA converter Signed-off-by: xiaowuhu <[email protected]> * fix flake8 Signed-off-by: xiaowuhu <[email protected]> * another flake8 Signed-off-by: xiaowuhu <[email protected]> * Update _supported_operators.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update quadratic_discriminant_analysis.py Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * add double dtype testing as output Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * add double output type test case Signed-off-by: xiaowuhu <[email protected]> * Update test_quadratic_discriminant_analysis_converter.py Signed-off-by: xiaowuhu <[email protected]> * change file name to standard one Signed-off-by: xiaowuhu <[email protected]> * upgrade version to 1.13 Signed-off-by: xiaowuhu <[email protected]> Signed-off-by: xiaowuhu <[email protected]>
- Loading branch information
Showing
7 changed files
with
339 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
138 changes: 138 additions & 0 deletions
138
skl2onnx/operator_converters/quadratic_discriminant_analysis.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
from ..common._apply_operation import ( | ||
apply_add, apply_argmax, apply_cast, apply_concat, apply_div, apply_exp, | ||
apply_log, apply_matmul, apply_mul, apply_pow, | ||
apply_reducesum, apply_reshape, apply_sub, apply_transpose) | ||
from ..common.data_types import ( | ||
BooleanTensorType, Int64TensorType, guess_proto_type) | ||
from ..common._registration import register_converter | ||
from ..common._topology import Scope, Operator | ||
from ..common._container import ModelComponentContainer | ||
from ..proto import onnx_proto | ||
|
||
|
||
def convert_quadratic_discriminant_analysis_classifier( | ||
scope: Scope, operator: Operator, container: ModelComponentContainer): | ||
|
||
input_name = operator.inputs[0].full_name | ||
model = operator.raw_operator | ||
|
||
n_classes = len(model.classes_) | ||
|
||
proto_dtype = guess_proto_type(operator.inputs[0].type) | ||
if proto_dtype != onnx_proto.TensorProto.DOUBLE: | ||
proto_dtype = onnx_proto.TensorProto.FLOAT | ||
|
||
if isinstance(operator.inputs[0].type, | ||
(BooleanTensorType, Int64TensorType)): | ||
cast_input_name = scope.get_unique_variable_name('cast_input') | ||
apply_cast(scope, operator.input_full_names, cast_input_name, | ||
container, to=proto_dtype) | ||
input_name = cast_input_name | ||
|
||
norm_array_name = [] | ||
sum_array_name = [] | ||
|
||
container.add_initializer('const_n05', proto_dtype, [], [-0.5]) | ||
container.add_initializer('const_p2', proto_dtype, [], [2]) | ||
|
||
for i in range(n_classes): | ||
R = model.rotations_[i] | ||
rotation_name = scope.get_unique_variable_name('rotations') | ||
container.add_initializer(rotation_name, proto_dtype, | ||
[R.shape[0], R.shape[1]], R) | ||
|
||
S = model.scalings_[i] | ||
scaling_name = scope.get_unique_variable_name('scalings') | ||
container.add_initializer( | ||
scaling_name, proto_dtype, [S.shape[0], ], S) | ||
|
||
mean = model.means_[i] | ||
mean_name = scope.get_unique_variable_name('means') | ||
container.add_initializer(mean_name, proto_dtype, mean.shape, mean) | ||
|
||
Xm_name = scope.get_unique_variable_name('Xm') | ||
apply_sub(scope, [input_name, mean_name], [Xm_name], container) | ||
|
||
s_pow_name = scope.get_unique_variable_name('s_pow_n05') | ||
apply_pow(scope, [scaling_name, 'const_n05'], [s_pow_name], container) | ||
|
||
mul_name = scope.get_unique_variable_name('mul') | ||
apply_mul(scope, [rotation_name, s_pow_name], [mul_name], container) | ||
|
||
x2_name = scope.get_unique_variable_name('matmul') | ||
apply_matmul(scope, [Xm_name, mul_name], [x2_name], container) | ||
|
||
pow_x2_name = scope.get_unique_variable_name('pow_x2') | ||
apply_pow(scope, [x2_name, 'const_p2'], [pow_x2_name], container) | ||
|
||
sum_name = scope.get_unique_variable_name('sum') | ||
apply_reducesum(scope, [pow_x2_name], [sum_name], | ||
container, axes=[1], keepdims=1) | ||
norm_array_name.append(sum_name) | ||
|
||
log_name = scope.get_unique_variable_name('log') | ||
apply_log(scope, [scaling_name], [log_name], container) | ||
|
||
sum_log_name = scope.get_unique_variable_name('sum_log') | ||
apply_reducesum( | ||
scope, [log_name], [sum_log_name], container, keepdims=1) | ||
sum_array_name.append(sum_log_name) | ||
|
||
concat_norm_name = scope.get_unique_variable_name('concat_norm') | ||
apply_concat(scope, norm_array_name, [concat_norm_name], container) | ||
|
||
reshape_norm_name = scope.get_unique_variable_name('reshape_concat_norm') | ||
apply_reshape(scope, [concat_norm_name], [reshape_norm_name], | ||
container, desired_shape=[n_classes, -1]) | ||
|
||
transpose_norm_name = scope.get_unique_variable_name('transpose_norm') | ||
apply_transpose(scope, [reshape_norm_name], [transpose_norm_name], | ||
container, perm=(1, 0)) | ||
|
||
apply_concat(scope, sum_array_name, ['concat_logsum'], container) | ||
|
||
add_norm2_u_name = scope.get_unique_variable_name('add_norm2_u') | ||
apply_add(scope, [transpose_norm_name, 'concat_logsum'], | ||
[add_norm2_u_name], container) | ||
|
||
norm2_u_n05_name = scope.get_unique_variable_name('norm2_u_n05') | ||
apply_mul( | ||
scope, ['const_n05', add_norm2_u_name], [norm2_u_n05_name], container) | ||
|
||
container.add_initializer( | ||
'priors', proto_dtype, [n_classes, ], model.priors_) | ||
apply_log(scope, ['priors'], ['log_p'], container) | ||
|
||
apply_add(scope, [norm2_u_n05_name, 'log_p'], ['decision_fun'], container) | ||
|
||
apply_argmax(scope, ['decision_fun'], ['argmax_out'], container, axis=1) | ||
|
||
container.add_initializer( | ||
'classes', onnx_proto.TensorProto.INT64, [n_classes], model.classes_) | ||
|
||
container.add_node( | ||
'ArrayFeatureExtractor', | ||
['classes', 'argmax_out'], | ||
[operator.outputs[0].full_name], | ||
op_domain='ai.onnx.ml' | ||
) | ||
|
||
attr = {'axes': [1]} | ||
container.add_node( | ||
'ReduceMax', ['decision_fun'], ['df_max'], **attr) | ||
apply_sub(scope, ['decision_fun', 'df_max'], ['df_sub_max'], container) | ||
apply_exp(scope, ['df_sub_max'], ['likelihood'], container) | ||
apply_reducesum(scope, ['likelihood'], ['likelihood_sum'], container, | ||
axes=[1], keepdims=1) | ||
apply_div(scope, ['likelihood', 'likelihood_sum'], | ||
[operator.outputs[1].full_name], container, ) | ||
|
||
|
||
register_converter('SklearnQuadraticDiscriminantAnalysis', | ||
convert_quadratic_discriminant_analysis_classifier, | ||
options={'zipmap': [True, False, 'columns'], | ||
'nocl': [True, False], | ||
'output_class_labels': [False, True]}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
15 changes: 15 additions & 0 deletions
15
skl2onnx/shape_calculators/quadratic_discriminant_analysis.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from ..common._registration import register_shape_calculator | ||
from ..common.data_types import Int64TensorType | ||
|
||
|
||
def calculate_quadratic_discriminant_analysis_shapes(operator): | ||
N = len(operator.raw_operator.classes_) | ||
operator.outputs[0].type = Int64TensorType([1, N]) | ||
operator.outputs[1].type.shape = [None, N] | ||
|
||
|
||
register_shape_calculator( | ||
'SklearnQuadraticDiscriminantAnalysis', | ||
calculate_quadratic_discriminant_analysis_shapes) |
175 changes: 175 additions & 0 deletions
175
tests/test_sklearn_quadratic_discriminant_analysis_converter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Tests scikit-learn's SGDClassifier converter.""" | ||
|
||
import sklearn | ||
import unittest | ||
import numpy as np | ||
import packaging.version as pv | ||
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis | ||
from onnxruntime import __version__ as ort_version | ||
from onnx import __version__ as onnx_version | ||
from skl2onnx import convert_sklearn | ||
from skl2onnx.common.data_types import ( | ||
FloatTensorType, | ||
DoubleTensorType | ||
) | ||
|
||
from test_utils import ( | ||
dump_data_and_model, | ||
TARGET_OPSET | ||
) | ||
|
||
ort_version = ".".join(ort_version.split(".")[:2]) | ||
onnx_version = ".".join(onnx_version.split('.')[:2]) | ||
|
||
|
||
class TestQuadraticDiscriminantAnalysisConverter(unittest.TestCase): | ||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_2c2f_float(self): | ||
# 2 classes, 2 features | ||
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) | ||
y = np.array([1, 1, 1, 2, 2, 2]) | ||
X_test = np.array([[-0.8, -1], [0.8, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", FloatTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model, | ||
basename="SklearnQDA_2c2f_Float") | ||
|
||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_2c3f_float(self): | ||
# 2 classes, 3 features | ||
X = np.array([[-1, -1, 0], [-2, -1, 1], [-3, -2, 0], | ||
[1, 1, 0], [2, 1, 1], [3, 2, 1]]) | ||
y = np.array([1, 1, 1, 2, 2, 2]) | ||
X_test = np.array([[-0.8, -1, 0], [-1, -1.6, 0], | ||
[1, 1.5, 1], [3.1, 2.1, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", FloatTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model, | ||
basename="SklearnQDA_2c3f_Float") | ||
|
||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_3c2f_float(self): | ||
# 3 classes, 2 features | ||
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], | ||
[2, 1], [3, 2], [-1, 2], [-2, 3], [-2, 2]]) | ||
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) | ||
X_test = np.array([[-0.8, -1], [0.8, 1], [-0.8, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", FloatTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.float32), skl_model, onnx_model, | ||
basename="SklearnQDA_3c2f_Float") | ||
|
||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_2c2f_double(self): | ||
# 2 classes, 2 features | ||
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], | ||
[2, 1], [3, 2]]).astype(np.double) | ||
y = np.array([1, 1, 1, 2, 2, 2]) | ||
X_test = np.array([[-0.8, -1], [0.8, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", DoubleTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET, options={'zipmap': False}) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model, | ||
basename="SklearnQDA_2c2f_Double") | ||
|
||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_2c3f_double(self): | ||
# 2 classes, 3 features | ||
X = np.array([[-1, -1, 0], [-2, -1, 1], [-3, -2, 0], | ||
[1, 1, 0], [2, 1, 1], [3, 2, 1]]).astype(np.double) | ||
y = np.array([1, 1, 1, 2, 2, 2]) | ||
X_test = np.array([[-0.8, -1, 0], [-1, -1.6, 0], | ||
[1, 1.5, 1], [3.1, 2.1, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", DoubleTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET, options={'zipmap': False}) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model, | ||
basename="SklearnQDA_2c3f_Double") | ||
|
||
@unittest.skipIf(pv.Version(sklearn.__version__) < pv.Version('1.0'), | ||
reason="scikit-learn<1.0") | ||
@unittest.skipIf(pv.Version(onnx_version) < pv.Version('1.11'), | ||
reason="fails with onnx 1.10") | ||
def test_model_qda_3c2f_double(self): | ||
# 3 classes, 2 features | ||
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2], | ||
[-1, 2], [-2, 3], [-2, 2]]).astype(np.double) | ||
y = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3]) | ||
X_test = np.array([[-0.8, -1], [0.8, 1], [-0.8, 1]]) | ||
|
||
skl_model = QuadraticDiscriminantAnalysis() | ||
skl_model.fit(X, y) | ||
|
||
onnx_model = convert_sklearn( | ||
skl_model, | ||
"scikit-learn QDA", | ||
[("input", DoubleTensorType([None, X.shape[1]]))], | ||
target_opset=TARGET_OPSET, options={'zipmap': False}) | ||
|
||
self.assertIsNotNone(onnx_model) | ||
dump_data_and_model(X_test.astype(np.double), skl_model, onnx_model, | ||
basename="SklearnQDA_3c2f_Double") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main(verbosity=3) |