Skip to content

Commit

Permalink
refactor tpp on x64 and aarch64
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 14, 2025
1 parent 782f191 commit 1d94800
Show file tree
Hide file tree
Showing 38 changed files with 240 additions and 98 deletions.
8 changes: 6 additions & 2 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if(AARCH64 OR ARM)
if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

Expand Down Expand Up @@ -202,7 +202,9 @@ if(NOT X86_64)
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*)
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*)
endif()

if (AARCH64)
Expand All @@ -212,7 +214,9 @@ endif()

if(NOT (AARCH64 OR ARM))
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "emitters/plugin/aarch64/jit_conversion_emitters.hpp"
#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp"
#include "emitters/snippets/aarch64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/aarch64/jit_fill_emitter.hpp"
#include "emitters/snippets/aarch64/jit_kernel_emitter.hpp"
#include "emitters/snippets/aarch64/jit_loop_emitters.hpp"
Expand All @@ -25,7 +24,11 @@
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

#ifdef SNIPPETS_LIBXSMM_TPP
# include "emitters/tpp/aarch64/jit_brgemm_emitter.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
#endif

namespace ov {

Expand Down Expand Up @@ -205,9 +208,11 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

#ifdef SNIPPETS_LIBXSMM_TPP
// brgemm
jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
#endif

// control flow
jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024-2025 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

Expand Down Expand Up @@ -327,7 +328,6 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;

brgemm_p.BS = 1; // default value
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
# include "emitters/tpp/x64/jit_eltwise_emitters.hpp"
# include "emitters/tpp/x64/jit_equation_emitter.hpp"
# include "emitters/tpp/x64/jit_scalar_emitter.hpp"
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/common/op/modifiers.hpp"
# include "transformations/tpp/x64/op/eltwise.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/modifiers.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
// Note: for reference implementations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "jit_brgemm_emitter.hpp"

#include "snippets/utils/utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak_aarch64;

Expand Down Expand Up @@ -42,7 +42,10 @@ void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in,
const std::vector<size_t>& out,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const {
validate_arguments(in, out);
emit_impl(in, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/snippets/aarch64/kernel_executors/brgemm.hpp"
#include "emitters/tpp/aarch64/kernel_executors/brgemm.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -26,7 +26,10 @@ class jit_brgemm_emitter : public jit_emitter {
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

void emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const;
void emit_code(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs = {},
const std::vector<size_t>& pool_gpr_idxs = {}) const override;

private:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "brgemm.hpp"

#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -34,8 +35,8 @@ BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_in0 = tpp::ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = tpp::ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
Expand Down
36 changes: 36 additions & 0 deletions src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>

#include "emitters/utils.hpp"
#include "libxsmm.h"

namespace ov {
namespace intel_cpu {
namespace tpp {

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace tpp
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using jit_generator = dnnl::impl::cpu::x64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t;
Expand Down Expand Up @@ -32,8 +33,8 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
auto in_0_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
Expand Down
17 changes: 3 additions & 14 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "snippets/lowered/port_descriptor.hpp"
#include "transformations/tpp/x64/op/eltwise.hpp"
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/tpp/common/utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -55,7 +56,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
};

for (size_t i = 0; i < num_ins; i++) {
io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i));
io_dtypes[i] = tpp::ov_to_xsmm_dtype(node->get_input_element_type(i));
io_offsets[i] = tpp_mod->get_input_offset(i);
io_strides[i] = replace_full_dim(tpp_mod->get_input_stride(i),
expr->get_input_port_descriptor(i)->get_shape().back());
Expand All @@ -64,7 +65,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i));
io_dtypes[i_off] = tpp::ov_to_xsmm_dtype(node->get_output_element_type(i));
io_offsets[i_off] = tpp_mod->get_output_offset(i);
io_strides[i_off] = replace_full_dim(tpp_mod->get_output_stride(i),
expr->get_output_port_descriptor(i)->get_shape().back());
Expand Down Expand Up @@ -114,17 +115,5 @@ void TppEmitter::emit_impl(const std::vector<size_t>& in, const std::vector<size
spill.postamble();
}

libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32 : return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16 : return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::i8 : return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8 : return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class TppEmitter : public jit_emitter {
const ov::snippets::lowered::ExpressionPtr& expr);
void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;
static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);

protected:
void emit_impl(const std::vector<size_t>& in,
Expand Down
19 changes: 0 additions & 19 deletions src/plugins/intel_cpu/src/emitters/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <string>

#include "libxsmm.h"
#include "openvino/core/except.hpp"
#include "openvino/core/type/element_type.hpp"

Expand All @@ -24,23 +23,5 @@ std::string jit_emitter_pretty_name(const std::string& pretty_func);
#define OV_CPU_JIT_EMITTER_THROW(...) OPENVINO_THROW(OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)
#define OV_CPU_JIT_EMITTER_ASSERT(cond, ...) OPENVINO_ASSERT((cond), OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
25 changes: 14 additions & 11 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@

#ifdef SNIPPETS_LIBXSMM_TPP
# include "snippets/lowered/pass/optimize_domain.hpp"
# include "transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp"
# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp"
# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp"
# include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp"
# include "transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.hpp"
# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp"
# include "transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp"
# include "transformations/tpp/common/pass/lowered/set_tpp_leading_dim.hpp"
# if defined(OPENVINO_ARCH_ARM64)
# include "transformations/tpp/aarch64/pass/lowered/brgemm_tpp_blocking.hpp"
# else
# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp"
# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp"
# include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp"
# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp"
# endif
#endif

namespace ov {
Expand Down Expand Up @@ -522,11 +526,10 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {

Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {
ControlFlowPasses backend_passes;
#if defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64)
using PassPosition = ov::snippets::pass::PassPosition;
using Place = PassPosition::Place;
#define SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \
std::make_shared<PASS>(__VA_ARGS__))
#endif

#if defined(OPENVINO_ARCH_X86_64)
# define SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(PASS_PLACE, TARGET_PASS, PASS, ...) \
Expand Down Expand Up @@ -561,15 +564,15 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {

SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
ov::intel_cpu::tpp::pass::aarch64::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After,
ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);

#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::intel_cpu::pass::BrgemmCPUBlocking,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
ov::intel_cpu::tpp::pass::x64::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::intel_cpu::pass::FuseLoadStoreConvert,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

namespace ov {
namespace snippets {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/modifiers.hpp"
#include "transformations/tpp/common/op/modifiers.hpp"
#include "utils/general_utils.h"

namespace ov {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#ifdef SNIPPETS_LIBXSMM_TPP
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
Expand Down
Loading

0 comments on commit 1d94800

Please sign in to comment.