From 0e5be1e3c0e88c485f9221e4080548f371f494b3 Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Fri, 15 Nov 2024 02:20:16 -0500 Subject: [PATCH] Apply Alexandra comments --- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 15 +++----- .../snippets/x64/pass/enforce_precision.cpp | 10 ++++-- .../skip_tests_config.cpp | 4 +++ .../shared_tests_instances/snippets/mha.cpp | 35 +++++++++++++------ 4 files changed, 39 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 328045ad4ca7f3..203cc6fea0e7e2 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -449,16 +449,10 @@ void Subgraph::initSupportedPrimitiveDescriptors() { config.inConfs.resize(inputShapes.size()); for (size_t i = 0; i < inputShapes.size(); i++) { const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i); - auto precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::bf16 && + const auto precision = ((originalInputPrecision == ov::element::f32) && + one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) ? - static_cast(ov::element::bf16) : - originalInputPrecision; - precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::f16 && - subgraph_attrs->snippet->has_domain_sensitive_ops()) ? - static_cast(ov::element::f16) : - precision; + context->getConfig().inferencePrecision : originalInputPrecision; if (supportedPrecisions.count(precision) == 0) OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision."); @@ -643,8 +637,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU); SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); - if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16) - && subgraph_attrs->snippet->has_domain_sensitive_ops()) { + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations // MatMul has to be decomposed to Brgemm operations before enforcement // Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp index b90b35f9359aa4..80385c53fcc20d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp @@ -122,8 +122,12 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr& f) { std::set> EnforcePrecision::get_supported_precisions_default( const std::shared_ptr&op) noexcept { - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && ov::is_type(op)) { - return {{element::bf16, element::bf16}}; + std::set> types; + if (ov::is_type(op)) { + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + types.insert({element::f16, element::f16}); + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) + types.insert({element::bf16, element::bf16}); } - return {}; + return types; } diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 90820d550df179..c229bbaf6a7d13 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -577,6 +577,10 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[\?.\?.\?\].*)"); retVector.emplace_back(R"(.*smoke_Snippets_(MHAINT8MatMul|MHAQuantMatMul0|MHAFQAfterMatMul_4D|smoke_Snippets_MHAFQ).*IS\[0\]=\[\?.\?.\?\.\?].*)"); } + // MHA FP16 precision is only supported on amx_fp16 platform + if (!ov::with_cpu_x86_avx512_core_amx_fp16()) { + retVector.emplace_back(R"(.*smoke_Snippets_MHA.*FP16.*)"); + } #ifdef SNIPPETS_LIBXSMM_TPP // GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234) retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)"); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 93e4e3df4e856b..f37ec2400d41d7 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -124,30 +124,43 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D, ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, MHA, ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::bf16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(7), + ::testing::Values(6), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), ::testing::ValuesIn(precision_fp16_if_supported(4)), - ::testing::Values(ov::element::f32), + ::testing::Values(ov::element::f16), ::testing::ValuesIn({false, true}), ::testing::Values(MHA::default_thread_count), - ::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output - ::testing::Values(1), // MHA + 5 Converts on inputs and output + ::testing::Values(2), + ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), ::testing::Values(CPUTestUtils::empty_plugin_config)), MHA::getTestCaseName); -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), ::testing::ValuesIn(precision_f32(4)), - ::testing::Values(ov::element::bf16), - ::testing::ValuesIn({false}), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({false, true}), ::testing::Values(MHA::default_thread_count), - ::testing::Values(7), - ::testing::Values(6), + ::testing::Values(2), + ::testing::Values(1), ::testing::Values(ov::test::utils::DEVICE_CPU), - ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), MHA::getTestCaseName); } // namespace