Skip to content

Commit

Permalink
update arm passes, test enable
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 9, 2025
1 parent 6ca4f1b commit 6e4e86e
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if (ENABLE_SNIPPETS_LIBXSMM_TPP)
if (ENABLE_SNIPPETS_LIBXSMM_TPP OR AARCH64 OR ARM)
# Note: LIBXSMM_DEFAULT_CONFIG needed so libxsmm_config can be included without issues
add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG)
endif()
Expand Down
62 changes: 40 additions & 22 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs,
broadcastable_inputs);
#if defined(OPENVINO_ARCH_ARM64)
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::Before,
ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::tpp::pass::BrgemmToBrgemmTPP);
#endif
if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
Expand Down Expand Up @@ -505,39 +510,52 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {

Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {
ControlFlowPasses backend_passes;

#if defined(OPENVINO_ARCH_X86_64)
using PassPosition = ov::snippets::pass::PassPosition;
using Place = PassPosition::Place;
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...) \
#define SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \
std::make_shared<PASS>(__VA_ARGS__))
#if defined(OPENVINO_ARCH_X86_64)
# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \
std::make_shared<PASS>(__VA_ARGS__))
#else
# define SNIPPETS_REGISTER_PASS_RELATIVE(PASS_PLACE, TARGET_PASS, PASS, ...)
# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...)
#endif // OPENVINO_ARCH_X86_64

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::pass::BrgemmCPUBlocking);
#if defined(OPENVINO_ARCH_ARM64)
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After,
ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
#endif
#endif

SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::pass::BrgemmCPUBlocking);

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::InitLoops,
ov::intel_cpu::pass::AdjustBrgemmCopyBLoopPorts);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::lowered::pass::InitLoops,
ov::intel_cpu::pass::AdjustBrgemmCopyBLoopPorts);

SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::pass::FuseLoadStoreConvert);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before,
ov::snippets::lowered::pass::InsertBuffers,
ov::intel_cpu::pass::InsertBrgemmCopyBuffers);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::pass::FuseLoadStoreConvert);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::snippets::lowered::pass::InsertBuffers,
ov::intel_cpu::pass::InsertBrgemmCopyBuffers);

#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE(Place::Before,
ov::intel_cpu::pass::BrgemmCPUBlocking,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE(Place::After,
ov::intel_cpu::pass::FuseLoadStoreConvert,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::intel_cpu::pass::BrgemmCPUBlocking,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::intel_cpu::pass::FuseLoadStoreConvert,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
#endif

#undef SNIPPETS_REGISTER_PASS_RELATIVE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

namespace ov {
namespace snippets {
Expand Down Expand Up @@ -42,6 +43,7 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov
const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry{
SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::SwishNode, PassThroughShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::tpp::op::BrgemmTPP, BrgemmShapeInfer),
};
#undef SHAPE_INFER_OP_SPECIFIC
#undef SHAPE_INFER_PREDEFINED
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1106,7 +1106,7 @@ void Transformations::MainSnippets(void) {
};
#endif // OPENVINO_ARCH_X86_64

auto is_supported_op = [](const std::shared_ptr<const ov::Node>& n) -> bool {
auto is_supported_op = [ignoreCallback](const std::shared_ptr<const ov::Node>& n) -> bool {
#if defined(OPENVINO_ARCH_ARM64)
return (ov::is_type<ov::op::v0::Abs>(n) || ov::is_type<ov::op::v1::Add>(n) ||
ov::is_type<ov::op::v0::Clamp>(n) || ov::is_type<ov::op::v0::Ceiling>(n) ||
Expand All @@ -1122,7 +1122,8 @@ void Transformations::MainSnippets(void) {
ov::is_type<ov::op::v0::PRelu>(n) || ov::is_type<ov::op::v0::Relu>(n) ||
ov::is_type<ov::op::v5::Round>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
ov::is_type<ov::op::v0::Sqrt>(n) || ov::is_type<ov::op::v1::Subtract>(n) ||
ov::is_type<ov::op::v4::Swish>(n) || ov::is_type<ov::op::v0::Tanh>(n));
ov::is_type<ov::op::v4::Swish>(n) || ov::is_type<ov::op::v0::Tanh>(n) ||
(ov::is_type<ov::op::v0::MatMul>(n) && ignoreCallback));
#else
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant,
// and CPU Plugin does not support Mish for x64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,9 @@ std::vector<std::string> disabledTestPatterns() {
// Issue: 126738
retVector.emplace_back(R"(smoke_Snippets.*\[.*\?.*\].*)");
retVector.emplace_back(R"(smoke_Snippets_Eltwise.*\[1.1..10.1..8.1..4\].*)");
// smoke_Snippets test cases are not supported on arm64 platforms, except for smoke_Snippets_Eltwise
retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise|_Convert).*)");
// smoke_Snippets test cases are not supported on arm64 platforms,
// except for smoke_Snippets_Eltwise and smoke_Snippets_MatMul(t)
retVector.emplace_back(R"(smoke_Snippets(?!_Eltwise|_Convert|_MatMul/|_MatMult/).*)");
// arm snippets doesn't support sve_128 that required by dnnl injector jit_uni_eltwise_injector_f32 yet
retVector.emplace_back(R"(smoke_Snippets_Eltwise_TwoResults.*)");
retVector.emplace_back(R"(smoke_Snippets_Eltwise/TwoInputsAndOutputs.*)");
Expand Down Expand Up @@ -525,13 +526,13 @@ std::vector<std::string> disabledTestPatterns() {
retVector.emplace_back(R"(.*smoke_LPT/RecurrentCellTransformation.CompareWithRefImpl/f32_\[1,1,3\]_CPU_f32FQ_X_level=256_.*_FQ_W_level=255.*)");
retVector.emplace_back(R"(.*smoke_static/ConvertFqRnnToQuantizedRnn.CompareWithRefs/Type=GRUSequence.*2.5.10.*2.1.4.*2.1.4.*)");
}
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
if (!ov::with_cpu_x86_avx2()) {
// MatMul in Snippets uses BRGEMM that is supported only on AVX2 (and newer) platforms
// Disabled Snippets MHA tests as well because MHA pattern contains MatMul
retVector.emplace_back(R"(.*Snippets.*MHA.*)");
retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)");
}
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
if (!ov::with_cpu_x86_avx512_core_fp16()) {
// Skip fp16 tests for paltforms that don't support fp16 precision
retVector.emplace_back(R"(.*INFERENCE_PRECISION_HINT=(F|f)16.*)");
Expand Down

0 comments on commit 6e4e86e

Please sign in to comment.