Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Matmul emitters with TPP backend executor #28229

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
#endif

#ifdef SNIPPETS_LIBXSMM_TPP
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter);
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "dnnl_extension_utils.h"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

#define DIM_CAST(X) static_cast<dnnl_dim_t>(X)
#define DTYPE_CAST(X) static_cast<dnnl_data_type_t>(DnnlExtensionUtils::ElementTypeToDataType(X))
Expand Down Expand Up @@ -253,9 +254,10 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
const auto& brgemm_tpp_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node || brgemm_tpp_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
if (brgemm_node && with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
Expand Down
86 changes: 15 additions & 71 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "utils.hpp"

using jit_generator = dnnl::impl::cpu::x64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t;
Expand All @@ -19,94 +20,37 @@ void BrgemmTppEmitter::validate_subtensors(const VectorDims& in_0, const VectorD
OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors");
}

BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: TppEmitter(h, isa, expr) {
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node && !brgemm_node->is_dynamic(), "Invoked with invalid node type");

const auto& input_0_desc = expr->get_input_port_descriptor(0);
const auto& input_1_desc = expr->get_input_port_descriptor(1);
const auto& output_desc = expr->get_output_port_descriptor(0);

std::vector<size_t> leading_dimensions {brgemm_node->get_input_stride(0),
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;

const auto beta = brgemm_node->get_beta();
OV_CPU_JIT_EMITTER_ASSERT(beta == 0 || beta == 1, "Detected unsupported beta value: " + std::to_string(beta));

const auto& subtensor_in0 = input_0_desc->get_subtensor();
const auto& subtensor_in1 = input_1_desc->get_subtensor();
const auto& subtensor_out0 = output_desc->get_subtensor();
validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0);

const auto K = static_cast<libxsmm_blasint>(*subtensor_in0.rbegin());
const auto M = static_cast<libxsmm_blasint>(*++subtensor_in0.rbegin());
const auto N = static_cast<libxsmm_blasint>(*subtensor_in1.rbegin());

const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32;
const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8;
OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm ||
(is_bf16_gemm && K % 2 == 0) ||
(is_i8_gemm && K % 4 == 0),
"Unsupported parameter combination for kernel configuration");

m_compile_flags = is_f32_gemm ?
LIBXSMM_GEMM_FLAGS('N', 'N') :
LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') |
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG |
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;

if (beta == 0)
m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;

if (in_0_prec == LIBXSMM_DATATYPE_U8) {
in_0_prec = LIBXSMM_DATATYPE_I8;
m_compile_flags |= LIBXSMM_GEMM_FLAG_A_UNSIGNED;
}
if (in_1_prec == LIBXSMM_DATATYPE_U8) {
in_1_prec = LIBXSMM_DATATYPE_I8;
m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED;
}

m_shape = libxsmm_create_gemm_shape(N, M, K,
io_strides[1], io_strides[0], io_strides[2],
in_1_prec, in_0_prec, out_0_prec,
exec_dtype);
m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE;
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
BrgemmTppKernelConfig kernel_config(brg0Prc, brg1Prc, LIBXSMM_GEMM_PREFETCH_NONE, isa);
m_kernel_executor =
kernel_table->register_kernel<BrgemmTppKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}

std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
// Note: BrgemmTpp currently supports only fp32
return {{element::f32, element::f32}};
}

void BrgemmTppEmitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
void BrgemmTppEmitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
return COMPILE_TPP_KERNEL(libxsmm_dispatch_gemm(m_shape, m_compile_flags, m_prefetching_flags));
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get());
}

void BrgemmTppEmitter::execute_brgemm_kernel(libxsmm_gemmfunction brg_kernel, void *in0, void *in1, void *out0) {
libxsmm_gemm_param gemm_p;
gemm_p.a.primary = in1;
gemm_p.b.primary = in0;
gemm_p.c.primary = out0;
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "Invalid brgemm kernel pointer");
brg_kernel(&gemm_p);
const uintptr_t BrgemmTppEmitter::get_execute_function_ptr() const {
return reinterpret_cast<const uintptr_t>(BrgemmTppKernelExecutor::execute);
}

} // namespace intel_cpu
Expand Down
30 changes: 17 additions & 13 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#pragma once
#include "emitters/tpp/x64/kernel_executors/brgemm.hpp"
#include "jit_tpp_emitter.hpp"

namespace ov {
Expand All @@ -12,22 +13,25 @@ class BrgemmTppEmitter : public TppEmitter {
public:
BrgemmTppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);

size_t get_inputs_num() const override { return 2; }
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

static void execute_brgemm_kernel(libxsmm_gemmfunction brgemm_kernel, void *in0, void *in1, void *out0);

const uintptr_t get_execute_function_ptr() const override { return reinterpret_cast<const uintptr_t>(execute_brgemm_kernel); }
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
size_t get_inputs_num() const override {
return 2;
}
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

const uintptr_t get_execute_function_ptr() const override;
const uintptr_t get_compiled_kernel_ptr() const override;

protected:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
static void validate_subtensors(const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0);
libxsmm_gemm_shape m_shape;
libxsmm_bitfield m_prefetching_flags {0};

private:
std::shared_ptr<BrgemmTppKernelExecutor> m_kernel_executor = nullptr;
};

} // namespace intel_cpu
} // namespace ov
} // namespace intel_cpu
} // namespace ov
17 changes: 3 additions & 14 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "snippets/lowered/port_descriptor.hpp"
#include "transformations/tpp/x64/op/eltwise.hpp"
#include "emitters/plugin/x64/utils.hpp"
#include "utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -55,7 +56,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
};

for (size_t i = 0; i < num_ins; i++) {
io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i));
io_dtypes[i] = tpp::utils::ov_to_xsmm_dtype(node->get_input_element_type(i));
io_offsets[i] = tpp_mod->get_input_offset(i);
io_strides[i] = replace_full_dim(tpp_mod->get_input_stride(i),
expr->get_input_port_descriptor(i)->get_shape().back());
Expand All @@ -64,7 +65,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i));
io_dtypes[i_off] = tpp::utils::ov_to_xsmm_dtype(node->get_output_element_type(i));
io_offsets[i_off] = tpp_mod->get_output_offset(i);
io_strides[i_off] = replace_full_dim(tpp_mod->get_output_stride(i),
expr->get_output_port_descriptor(i)->get_shape().back());
Expand Down Expand Up @@ -114,17 +115,5 @@ void TppEmitter::emit_impl(const std::vector<size_t>& in, const std::vector<size
spill.postamble();
}

libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32 : return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16 : return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::i8 : return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8 : return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class TppEmitter : public jit_emitter {
const ov::snippets::lowered::ExpressionPtr& expr);
void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;
static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);

protected:
void emit_impl(const std::vector<size_t>& in,
Expand Down
Loading
Loading