Skip to content

Commit

Permalink
test update
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Nov 13, 2024
1 parent 807ae4c commit a80aea6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,9 @@ TEST_P(MHATest, CompareWithRefs) {
if (inputPrecisions[0] == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
GTEST_SKIP();

if (inputPrecisions[0] == ElementType::f16 && !ov::with_cpu_x86_avx512_core_amx_fp16())
GTEST_SKIP();

if (!ov::with_cpu_x86_avx512_core())
GTEST_SKIP();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,19 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16_4D,
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFP16_4D,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
::testing::ValuesIn(precision_fp16_if_supported(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(MHA::default_thread_count),
::testing::Values(1), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(1), // MHA + 5 Converts on inputs and output
::testing::Values(ov::test::utils::DEVICE_CPU),
::testing::Values(CPUTestUtils::empty_plugin_config)),
MHA::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16,
MHA,
::testing::Combine(::testing::ValuesIn(transposedShape_4D()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ static inline bool is_bf16_supported_by_brgemm() {
return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16();
}

static inline bool is_fp16_supported_by_brgemm() {
return ov::with_cpu_x86_avx512_core_amx_fp16();
}

static inline bool is_i8_supported_by_brgemm() {
return ov::with_cpu_x86_avx512_core_vnni() || ov::with_cpu_x86_avx512_core_amx_int8();
}
Expand All @@ -33,6 +37,13 @@ static inline std::vector<std::vector<element::Type>> precision_bf16_if_supporte
return prc;
}

static inline std::vector<std::vector<element::Type>> precision_fp16_if_supported(size_t count) {
std::vector<std::vector<element::Type>> prc;
if (is_fp16_supported_by_brgemm())
prc.emplace_back(std::vector<element::Type>(count, element::f16));
return prc;
}

static inline std::vector<std::vector<element::Type>> quantized_precisions_if_supported() {
std::vector<std::vector<element::Type>> prc = {};
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
Expand Down

0 comments on commit a80aea6

Please sign in to comment.