diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp index 1e9881785d9107..ff23d2ca9056a4 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_binary.cpp @@ -6,6 +6,7 @@ #include "itt.hpp" #include "openvino/op/constant.hpp" +#include "openvino/op/fake_quantize.hpp" #include "openvino/op/prelu.hpp" #include "openvino/op/transpose.hpp" #include "openvino/op/util/op_types.hpp" @@ -25,7 +26,8 @@ TSBinaryForward::TSBinaryForward() { auto main_node_label = wrap_type([](const Output& output) -> bool { + ov::op::v0::PRelu, + ov::op::v0::FakeQuantize>([](const Output& output) -> bool { return has_static_rank()(output) && IfNodeHasTransposeInputs(output); }); @@ -62,7 +64,8 @@ TSBinaryBackward::TSBinaryBackward() { auto main_node_label = wrap_type([](const Output& output) -> bool { + ov::op::v0::PRelu, + ov::op::v0::FakeQuantize>([](const Output& output) -> bool { return has_static_rank()(output) && HasSameOutputTransposeNodes(output); }); diff --git a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp index c5092f2a562400..f3e360d8793a42 100644 --- a/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp +++ b/src/common/transformations/src/transformations/transpose_sinking/ts_utils.cpp @@ -346,6 +346,7 @@ bool CanPropagateForwardThrough(Node* node) { CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Reshape, node) CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::Unsqueeze, node) CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v1::Transpose, node) + CHECK_TRANSPOSE_SINKING_SUPPORTED(ov::op::v0::FakeQuantize, node) return false; } diff --git a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp index 03d99f2ff8e507..7a7a483c1d58e5 100644 --- a/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp +++ b/src/common/transformations/tests/transpose_sinking/ts_common_test.cpp @@ -217,6 +217,22 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) { return std::make_shared(type_name); } +class FakeQuantizeFactory : public IFactory { +public: + explicit FakeQuantizeFactory(const std::string& type_name) : IFactory(type_name) {} + NodePtr create(const OutputVector& parent_nodes) const override { + return std::make_shared(parent_nodes[0], + parent_nodes[1], + parent_nodes[2], + parent_nodes[3], + parent_nodes[4], + 128); + } +}; + +FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) { + return std::make_shared(type_name); +} // ---------------------------------------------------------------------------- #undef CREATE_UNARY_FACTORY @@ -255,6 +271,9 @@ FactoryPtr CreateReshapeFactory(const std::string& type_name) { #undef CREATE_RESHAPE_FACTORY #define CREATE_RESHAPE_FACTORY(type_name) CreateReshapeFactory(#type_name) +#undef CREATE_FQ_FACTORY +#define CREATE_FQ_FACTORY(type_name) common::CreateFakeQuantizeFactory(#type_name) + // ---------------------------------------------------------------------------- vector unary_factories = { @@ -393,6 +412,42 @@ auto test_forward_binary = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TSTestFixture, test_forward_binary()); +auto test_forward_fq = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSBinaryForward); + test_case.num_main_ops = {1, 10}; + test_case.inputs_to_main = { + parameter(element::f32, {1, 96, 55, 55}), + parameter(element::f32, {55, 55, 96, 1}), + parameter(element::f32, {1}), + parameter(element::f32, {55, 1, 1, 1}), + parameter(element::f32, {55, 55, 1, 1}), + }; + + // Test model description: + test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}}; + test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)}; + test_case.model.model_template = create_model; + + // Reference model description: + auto set_unsqueeze_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector new_out_vec = out_vec; + auto indices = make_shared(element::i64, Shape{3}, std::vector{0, 1, 2}); + new_out_vec[2] = make_shared(out_vec[2], indices); + return new_out_vec; + }; + test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {1, 2, 3, 4}}}; + test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)}; + test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQForward, TSTestFixture, test_forward_fq()); + auto test_forward_concat = []() { TestCase test_case; @@ -867,6 +922,42 @@ auto test_backward_binary = []() { INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryBackward, TSTestFixture, test_backward_binary()); +auto test_backward_fq = []() { + TestCase test_case; + + // Initialize common attributes + test_case.transformation = CREATE_PASS_FACTORY(TSBinaryBackward); + test_case.num_main_ops = {1, 10}; + test_case.inputs_to_main = { + parameter(element::f32, {1, 96, 55, 55}), + parameter(element::f32, {1, 96, 55, 55}), + parameter(element::f32, {1}), + parameter(element::f32, {1, 96, 55, 1}), + parameter(element::f32, {1, 96, 1, 1}), + }; + + // Test model description: + test_case.model.main_op = {CREATE_FQ_FACTORY(FakeQuantize)}; + test_case.model.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}}; + test_case.model.model_template = create_model; + + auto set_unsqueeze_for = [](const vector& idxs, const OutputVector& out_vec) -> OutputVector { + OutputVector new_out_vec = out_vec; + auto indices = make_shared(element::i64, Shape{3}, std::vector{0, 1, 2}); + new_out_vec[2] = make_shared(out_vec[2], indices); + return new_out_vec; + }; + + // Reference model description: + test_case.model_ref.preprocess_inputs_to_main = {{set_unsqueeze_for, set_transpose_for}, {{2}, {0, 1, 2, 3, 4}}}; + test_case.model_ref.main_op = {CREATE_FQ_FACTORY(FakeQuantize)}; + test_case.model_ref.model_template = create_model; + + return wrapper(test_case); +}; + +INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonFQBackward, TSTestFixture, test_backward_fq()); + auto test_backward_concat = []() { TestCase test_case;