Skip to content

Commit

Permalink
refactor tpp on x64 and aarch64
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 21, 2025
1 parent c8808a2 commit 6b55e68
Show file tree
Hide file tree
Showing 38 changed files with 244 additions and 103 deletions.
8 changes: 6 additions & 2 deletions src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ if(ENABLE_CPU_DEBUG_CAPS)
add_definitions(-DCPU_DEBUG_CAPS)
endif()

if(AARCH64 OR ARM)
if(AARCH64 AND (NOT ANDROID))
set(ENABLE_SNIPPETS_LIBXSMM_TPP ON)
endif()

Expand Down Expand Up @@ -202,7 +202,9 @@ if(NOT X86_64)
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/snippets/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*)
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/x64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/x64/*)
endif()

if (AARCH64)
Expand All @@ -212,7 +214,9 @@ endif()

if(NOT (AARCH64 OR ARM))
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/cpu_opset/arm/*
${CMAKE_CURRENT_SOURCE_DIR}/src/transformations/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/tpp/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/executors/aarch64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/aarch64/*)
endif()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include "emitters/plugin/aarch64/jit_conversion_emitters.hpp"
#include "emitters/plugin/aarch64/jit_eltwise_emitters.hpp"
#include "emitters/snippets/aarch64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/aarch64/jit_fill_emitter.hpp"
#include "emitters/snippets/aarch64/jit_kernel_emitter.hpp"
#include "emitters/snippets/aarch64/jit_loop_emitters.hpp"
Expand All @@ -25,7 +24,11 @@
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

#ifdef SNIPPETS_LIBXSMM_TPP
# include "emitters/tpp/aarch64/jit_brgemm_emitter.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
#endif

namespace ov {

Expand Down Expand Up @@ -205,9 +208,11 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

#ifdef SNIPPETS_LIBXSMM_TPP
// brgemm
jitters[ov::intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(jit_brgemm_emitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
#endif

// control flow
jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(jit_kernel_static_emitter);
Expand Down Expand Up @@ -250,11 +255,12 @@ size_t CPUTargetMachine::get_lanes() const {
}

std::vector<snippets::Reg> CPUTargetMachine::get_abi_arg_regs() const {
using namespace dnnl::impl::cpu::aarch64;
// using namespace dnnl::impl::cpu::aarch64;
std::vector<snippets::Reg> res;
for (const auto& r :
{abi_param1, abi_param2, abi_param3, abi_param4, abi_param5, abi_param6, abi_param7, abi_param8})
res.emplace_back(snippets::RegType::gpr, r.getIdx());
// // for (const auto& r :
// for (const Xbyak_aarch64::XReg& r :
// {abi_param1, abi_param2, abi_param3, abi_param4, abi_param5, abi_param6, abi_param7, abi_param8})
// res.emplace_back(snippets::RegType::gpr, r.getIdx());
return res;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2024-2025 Intel Corporation
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
}

Expand Down Expand Up @@ -327,7 +328,6 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel(
brgemm_p.do_post_ops = with_comp;
brgemm_p.do_apply_comp = with_comp;
brgemm_p.skip_accm = 0;

brgemm_p.BS = 1; // default value
OV_CPU_JIT_EMITTER_ASSERT(kernel, "has nullptr Brgemm kernel");
(*kernel)(&brgemm_p);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@
# include "emitters/tpp/x64/jit_eltwise_emitters.hpp"
# include "emitters/tpp/x64/jit_equation_emitter.hpp"
# include "emitters/tpp/x64/jit_scalar_emitter.hpp"
# include "transformations/tpp/x64/op/brgemm.hpp"
# include "transformations/tpp/common/op/brgemm.hpp"
# include "transformations/tpp/common/op/modifiers.hpp"
# include "transformations/tpp/x64/op/eltwise.hpp"
# include "transformations/tpp/x64/op/equation.hpp"
# include "transformations/tpp/x64/op/modifiers.hpp"
# include "transformations/tpp/x64/op/reduce.hpp"
# include "transformations/tpp/x64/op/scalar.hpp"
// Note: for reference implementations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "jit_brgemm_emitter.hpp"

#include "snippets/utils/utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak_aarch64;

Expand Down Expand Up @@ -42,7 +42,10 @@ void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
void jit_brgemm_emitter::emit_code(const std::vector<size_t>& in,
const std::vector<size_t>& out,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const {
validate_arguments(in, out);
emit_impl(in, out);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#pragma once

#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/snippets/aarch64/kernel_executors/brgemm.hpp"
#include "emitters/tpp/aarch64/kernel_executors/brgemm.hpp"

namespace ov {
namespace intel_cpu {
Expand All @@ -26,7 +26,10 @@ class jit_brgemm_emitter : public jit_emitter {
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

void emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const;
void emit_code(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs = {},
const std::vector<size_t>& pool_gpr_idxs = {}) const override;

private:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

#include "brgemm.hpp"

#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -34,8 +35,8 @@ BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_in0 = tpp::ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = tpp::ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
Expand Down
36 changes: 36 additions & 0 deletions src/plugins/intel_cpu/src/emitters/tpp/common/utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <string>

#include "emitters/utils.hpp"
#include "libxsmm.h"

namespace ov {
namespace intel_cpu {
namespace tpp {

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace tpp
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#include "jit_brgemm_emitter.hpp"

#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "transformations/tpp/common/op/brgemm.hpp"

using jit_generator = dnnl::impl::cpu::x64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t;
Expand Down Expand Up @@ -33,8 +34,8 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
auto in_0_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = tpp::ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? LIBXSMM_DATATYPE_I32
: LIBXSMM_DATATYPE_F32;
auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? LIBXSMM_DATATYPE_I32 : LIBXSMM_DATATYPE_F32;
Expand Down
21 changes: 3 additions & 18 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "jit_tpp_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/tpp/common/utils.hpp"
#include "snippets/lowered/port_descriptor.hpp"
#include "transformations/tpp/x64/op/eltwise.hpp"

Expand Down Expand Up @@ -56,7 +57,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
};

for (size_t i = 0; i < num_ins; i++) {
io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i));
io_dtypes[i] = tpp::ov_to_xsmm_dtype(node->get_input_element_type(i));
io_offsets[i] = tpp_mod->get_input_offset(i);
io_strides[i] =
replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back());
Expand All @@ -65,7 +66,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i));
io_dtypes[i_off] = tpp::ov_to_xsmm_dtype(node->get_output_element_type(i));
io_offsets[i_off] = tpp_mod->get_output_offset(i);
io_strides[i_off] =
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back());
Expand Down Expand Up @@ -121,21 +122,5 @@ void TppEmitter::emit_impl(const std::vector<size_t>& in, const std::vector<size
spill.postamble();
}

libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class TppEmitter : public jit_binary_call_emitter {
dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);
void emit_code(const std::vector<size_t>& in, const std::vector<size_t>& out) const;
static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);

protected:
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
Expand Down
19 changes: 0 additions & 19 deletions src/plugins/intel_cpu/src/emitters/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <string>

#include "libxsmm.h"
#include "openvino/core/except.hpp"
#include "openvino/core/type/element_type.hpp"

Expand All @@ -24,23 +23,5 @@ std::string jit_emitter_pretty_name(const std::string& pretty_func);
#define OV_CPU_JIT_EMITTER_THROW(...) OPENVINO_THROW(OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)
#define OV_CPU_JIT_EMITTER_ASSERT(cond, ...) OPENVINO_ASSERT((cond), OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__)

inline libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32:
return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16:
return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::f16:
return LIBXSMM_DATATYPE_F16;
case ov::element::Type_t::i8:
return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8:
return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 6b55e68

Please sign in to comment.