Skip to content

Commit

Permalink
refactor tpp on x64 and aarch64
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 13, 2025
1 parent 782f191 commit ddc1f03
Show file tree
Hide file tree
Showing 34 changed files with 233 additions and 90 deletions.
8 changes: 6 additions & 2 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,13 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if(AARCH64 OR ARM)
if(AARCH64)
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

if (ENABLE_SNIPPETS_LIBXSMM_TPP)
# Note: LIBXSMM_DEFAULT_CONFIG needed so libxsmm_config can be included without issues
ov_add_compiler_flags(-Wno-missing-declarations)
add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG)
endif()

Expand Down Expand Up @@ -202,7 +203,9 @@ if(NOT X86_64)
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*)
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*)
endif()

if (AARCH64)
Expand All @@ -212,6 +215,7 @@ endif()

if(NOT (AARCH64 OR ARM))
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

namespace ov {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "jit_brgemm_emitter.hpp"

#include "snippets/utils/utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak_aarch64;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "brgemm.hpp"

#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -34,8 +35,8 @@ BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_in0 = tpp::ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = tpp::ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
# include "emitters/tpp/x64/jit_eltwise_emitters.hpp"
# include "emitters/tpp/x64/jit_equation_emitter.hpp"
# include "emitters/tpp/x64/jit_scalar_emitter.hpp"
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/common/op/modifiers.hpp"
# include "transformations/tpp/x64/op/eltwise.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/modifiers.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
// Note: for reference implementations
Expand Down
36 changes: 36 additions & 0 deletions src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>

#include "emitters/utils.hpp"
#include "libxsmm.h"

namespace ov {
namespace intel_cpu {
namespace tpp {

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace tpp
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using jit_generator = dnnl::impl::cpu::x64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t;
Expand Down Expand Up @@ -32,8 +33,8 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
auto in_0_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
Expand Down
28 changes: 14 additions & 14 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "snippets/lowered/port_descriptor.hpp"
#include "transformations/tpp/x64/op/eltwise.hpp"
#include "emitters/plugin/x64/utils.hpp"
#include "emitters/tpp/common/utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -55,7 +56,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
};

for (size_t i = 0; i < num_ins; i++) {
io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i));
io_dtypes[i] = tpp::ov_to_xsmm_dtype(node->get_input_element_type(i));
io_offsets[i] = tpp_mod->get_input_offset(i);
io_strides[i] = replace_full_dim(tpp_mod->get_input_stride(i),
expr->get_input_port_descriptor(i)->get_shape().back());
Expand All @@ -64,7 +65,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i));
io_dtypes[i_off] = tpp::ov_to_xsmm_dtype(node->get_output_element_type(i));
io_offsets[i_off] = tpp_mod->get_output_offset(i);
io_strides[i_off] = replace_full_dim(tpp_mod->get_output_stride(i),
expr->get_output_port_descriptor(i)->get_shape().back());
Expand Down Expand Up @@ -113,18 +114,17 @@ void TppEmitter::emit_impl(const std::vector<size_t>& in, const std::vector<size

spill.postamble();
}

libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32 : return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16 : return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::i8 : return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8 : return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}
// libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
// switch (elemet_type) {
// case ov::element::Type_t::f32 : return LIBXSMM_DATATYPE_F32;
// case ov::element::Type_t::bf16 : return LIBXSMM_DATATYPE_BF16;
// case ov::element::Type_t::i8 : return LIBXSMM_DATATYPE_I8;
// case ov::element::Type_t::u8 : return LIBXSMM_DATATYPE_U8;
// default:
// OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
// return LIBXSMM_DATATYPE_IMPLICIT;
// }
// }

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TppEmitter : public jit_emitter {
const ov::snippets::lowered::ExpressionPtr& expr);
void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;
static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);
// static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);

protected:
void emit_impl(const std::vector<size_t>& in,
Expand Down
19 changes: 0 additions & 19 deletions src/plugins/intel_cpu/src/emitters/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <string>

#include "libxsmm.h"
#include "openvino/core/except.hpp"
#include "openvino/core/type/element_type.hpp"

Expand All @@ -24,23 +23,5 @@ std::string jit_emitter_pretty_name(const std::string& pretty_func);
#define OV_CPU_JIT_EMITTER_THROW(...) OPENVINO_THROW(OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)
#define OV_CPU_JIT_EMITTER_ASSERT(cond, ...) OPENVINO_ASSERT((cond), OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
23 changes: 15 additions & 8 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,16 @@

#ifdef SNIPPETS_LIBXSMM_TPP
# include "snippets/lowered/pass/optimize_domain.hpp"
# include "transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp"
# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp"
# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp"
# include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp"
# include "transformations/tpp/x64/pass/lowered/set_tpp_leading_dim.hpp"
# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp"
# include "transformations/tpp/common/pass/brgemm_to_brgemm_tpp.hpp"
# include "transformations/tpp/common/pass/lowered/set_tpp_leading_dim.hpp"
# if defined(OPENVINO_ARCH_ARM64)
# include "transformations/tpp/aarch64/pass/lowered/brgemm_tpp_blocking.hpp"
# else
# include "transformations/tpp/x64/pass/eltwise_to_eltwise_tpp.hpp"
# include "transformations/tpp/x64/pass/fuse_tpp_to_equations.hpp"
# include "transformations/tpp/x64/pass/lowered/brgemm_tpp_blocking.hpp"
# include "transformations/tpp/x64/pass/scalar_to_scalar_tpp.hpp"
# endif
#endif

namespace ov {
Expand Down Expand Up @@ -523,7 +527,10 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {
ControlFlowPasses backend_passes;
using PassPosition = ov::snippets::pass::PassPosition;
#if defined(OPENVINO_ARCH_X86_64) || defined(OPENVINO_ARCH_ARM64)
using Place = PassPosition::Place;
#endif

#define SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(PASS_PLACE, TARGET_PASS, PASS, ...) \
backend_passes.emplace_back(PassPosition(PASS_PLACE, TARGET_PASS::get_type_info_static()), \
std::make_shared<PASS>(__VA_ARGS__))
Expand Down Expand Up @@ -561,15 +568,15 @@ Subgraph::ControlFlowPasses Subgraph::getControlFlowPasses() const {

SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After,
ov::snippets::lowered::pass::MarkLoops,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
ov::intel_cpu::tpp::pass::aarch64::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE_ARM64(Place::After,
ov::snippets::lowered::pass::InsertLoops,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);

#ifdef SNIPPETS_LIBXSMM_TPP
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before,
ov::intel_cpu::pass::BrgemmCPUBlocking,
ov::intel_cpu::tpp::pass::BrgemmTPPBlocking);
ov::intel_cpu::tpp::pass::x64::BrgemmTPPBlocking);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After,
ov::intel_cpu::pass::FuseLoadStoreConvert,
ov::intel_cpu::tpp::pass::SetTPPLeadingDim);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

namespace ov {
namespace snippets {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/modifiers.hpp"
#include "transformations/tpp/common/op/modifiers.hpp"
#include "utils/general_utils.h"

namespace ov {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#ifdef SNIPPETS_LIBXSMM_TPP
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm_tpp_blocking.hpp"

#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils/utils.hpp"

namespace ov {
namespace intel_cpu {
namespace tpp {
namespace pass {
namespace aarch64 {
using namespace ov::snippets::utils;

bool BrgemmTPPBlocking::SetBrgemmBeta::run(ov::snippets::lowered::LinearIR& linear_ir,
ov::snippets::lowered::LinearIR::constExprIt begin,
ov::snippets::lowered::LinearIR::constExprIt end) {
for (auto expr_it = begin; expr_it != end; ++expr_it) {
if (const auto brgemm = ov::as_type_ptr<ov::intel_cpu::tpp::op::BrgemmTPP>(expr_it->get()->get_node()))
brgemm->set_beta(0);
}
return true;
}

std::shared_ptr<snippets::lowered::pass::PassBase> BrgemmTPPBlocking::SetBrgemmBeta::merge(
const std::shared_ptr<snippets::lowered::pass::PassBase>& other) {
return !other || ov::is_type<SetBrgemmBeta>(other) ? std::make_shared<SetBrgemmBeta>() : nullptr;
}

std::tuple<size_t, size_t, size_t> BrgemmTPPBlocking::get_blocking_params(
const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const {
size_t m, n, k;
std::tie(m, n, k) = get_brgemm_dimensions(brgemm_expr);
OPENVINO_ASSERT(!is_dynamic_value(m) && !is_dynamic_value(n) && !is_dynamic_value(n),
"BrgemmTPP doesn't support dynamic shapes");

size_t m_blk, n_blk, k_blk;
std::tie(m_blk, n_blk, k_blk) = BrgemmBlockingBase::get_blocking_params(brgemm_expr);

auto get_projected_blk = [](const size_t dim, const size_t blk) {
return ov::snippets::utils::is_full_dim_value(blk) ? dim : blk;
};
return std::make_tuple(get_projected_blk(m, m_blk), get_projected_blk(n, n_blk), get_projected_blk(k, k_blk));
}

ov::snippets::lowered::SpecificIterationHandlers BrgemmTPPBlocking::get_k_loop_handlers(size_t work_amount,
size_t block_size) const {
ov::snippets::lowered::SpecificIterationHandlers handlers =
ov::snippets::lowered::pass::BrgemmBlockingBase::get_k_loop_handlers(work_amount, block_size);
handlers.register_pass<ov::snippets::lowered::SpecificLoopIterType::FIRST_ITER, SetBrgemmBeta>();
return handlers;
}

} // namespace aarch64
} // namespace pass
} // namespace tpp
} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit ddc1f03

Please sign in to comment.