Skip to content

Commit

Permalink
Apply Alexandra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Dec 1, 2024
1 parent ab62d9e commit 41a7933
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 25 deletions.
15 changes: 4 additions & 11 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,16 +450,10 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
auto precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::bf16 &&
const auto precision = ((originalInputPrecision == ov::element::f32) &&
one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::bf16) :
originalInputPrecision;
precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::f16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::f16) :
precision;
context->getConfig().inferencePrecision : originalInputPrecision;
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -644,8 +638,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs);
if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16)
&& subgraph_attrs->snippet->has_domain_sensitive_ops()) {
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,12 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr<ov::Model>& f) {

std::set<std::vector<ov::element::Type>> EnforcePrecision::get_supported_precisions_default(
const std::shared_ptr<ov::Node>&op) noexcept {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && ov::is_type<snippets::op::Brgemm>(op)) {
return {{element::bf16, element::bf16}};
std::set<std::vector<ov::element::Type>> types;
if (ov::is_type<snippets::op::Brgemm>(op)) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16))
types.insert({element::f16, element::f16});
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16))
types.insert({element::bf16, element::bf16});
}
return {};
return types;
}
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,10 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*BF16.*/MHA.*IS\[0\]=\[\?.\?.\?\].*)");
retVector.emplace_back(R"(.*smoke_Snippets_(MHAINT8MatMul|MHAQuantMatMul0|MHAFQAfterMatMul_4D|smoke_Snippets_MHAFQ).*IS\[0\]=\[\?.\?.\?\.\?].*)");
}
// MHA FP16 precision is only supported on amx_fp16 platform
if (!ov::with_cpu_x86_avx512_core_amx_fp16()) {
retVector.emplace_back(R"(.*smoke_Snippets_MHA.*FP16.*)");
}
#ifdef SNIPPETS_LIBXSMM_TPP
// GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234)
retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,30 +124,43 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(6),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f32),
::testing::Values(ov::element::f16),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(1), // MHA + 5 Converts on inputs and output
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({false}),
::testing::Values(ov::element::f16),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(7),
::testing::Values(6),
::testing::Values(2),
::testing::Values(1),
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)),
::testing::Values(CPUTestUtils::cpu_f16_plugin_config)),
MHA::getTestCaseName);

} // namespace
Expand Down

0 comments on commit 41a7933

Please sign in to comment.