Skip to content

Commit

Permalink
tpp backend executor for matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 2, 2025
1 parent 200b67a commit 3091e05
Show file tree
Hide file tree
Showing 11 changed files with 340 additions and 111 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
#endif

#ifdef SNIPPETS_LIBXSMM_TPP
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter);
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
if (brgemm_node && with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
Expand Down
89 changes: 17 additions & 72 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//

#include "jit_brgemm_emitter.hpp"

#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "utils.hpp"

using jit_generator = dnnl::impl::cpu::x64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t;
Expand All @@ -19,94 +21,37 @@ void BrgemmTppEmitter::validate_subtensors(const VectorDims& in_0, const VectorD
OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors");
}

BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: TppEmitter(h, isa, expr) {
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node && !brgemm_node->is_dynamic(), "Invoked with invalid node type");

const auto& input_0_desc = expr->get_input_port_descriptor(0);
const auto& input_1_desc = expr->get_input_port_descriptor(1);
const auto& output_desc = expr->get_output_port_descriptor(0);

std::vector<size_t> leading_dimensions {brgemm_node->get_input_stride(0),
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;

const auto beta = brgemm_node->get_beta();
OV_CPU_JIT_EMITTER_ASSERT(beta == 0 || beta == 1, "Detected unsupported beta value: " + std::to_string(beta));

const auto& subtensor_in0 = input_0_desc->get_subtensor();
const auto& subtensor_in1 = input_1_desc->get_subtensor();
const auto& subtensor_out0 = output_desc->get_subtensor();
validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0);

const auto K = static_cast<libxsmm_blasint>(*subtensor_in0.rbegin());
const auto M = static_cast<libxsmm_blasint>(*++subtensor_in0.rbegin());
const auto N = static_cast<libxsmm_blasint>(*subtensor_in1.rbegin());

const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32;
const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8;
OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm ||
(is_bf16_gemm && K % 2 == 0) ||
(is_i8_gemm && K % 4 == 0),
"Unsupported parameter combination for kernel configuration");

m_compile_flags = is_f32_gemm ?
LIBXSMM_GEMM_FLAGS('N', 'N') :
LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') |
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG |
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;

if (beta == 0)
m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;

if (in_0_prec == LIBXSMM_DATATYPE_U8) {
in_0_prec = LIBXSMM_DATATYPE_I8;
m_compile_flags |= LIBXSMM_GEMM_FLAG_A_UNSIGNED;
}
if (in_1_prec == LIBXSMM_DATATYPE_U8) {
in_1_prec = LIBXSMM_DATATYPE_I8;
m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED;
}

m_shape = libxsmm_create_gemm_shape(N, M, K,
io_strides[1], io_strides[0], io_strides[2],
in_1_prec, in_0_prec, out_0_prec,
exec_dtype);
m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE;
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Invoked with invalid node type");
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
BrgemmTppKernelConfig kernel_config(brg0Prc, brg1Prc, LIBXSMM_GEMM_PREFETCH_NONE, isa);
m_kernel_executor =
kernel_table->register_kernel<BrgemmTppKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}

std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
// Note: BrgemmTpp currently supports only fp32
return {{element::f32, element::f32}};
}

void BrgemmTppEmitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
void BrgemmTppEmitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
return COMPILE_TPP_KERNEL(libxsmm_dispatch_gemm(m_shape, m_compile_flags, m_prefetching_flags));
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get());
}

void BrgemmTppEmitter::execute_brgemm_kernel(libxsmm_gemmfunction brg_kernel, void *in0, void *in1, void *out0) {
libxsmm_gemm_param gemm_p;
gemm_p.a.primary = in1;
gemm_p.b.primary = in0;
gemm_p.c.primary = out0;
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "Invalid brgemm kernel pointer");
brg_kernel(&gemm_p);
const uintptr_t BrgemmTppEmitter::get_execute_function_ptr() const {
return reinterpret_cast<const uintptr_t>(BrgemmTppKernelExecutor::execute);
}

} // namespace intel_cpu
Expand Down
30 changes: 17 additions & 13 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#pragma once
#include "emitters/tpp/x64/kernel_executors/brgemm.hpp"
#include "jit_tpp_emitter.hpp"

namespace ov {
Expand All @@ -12,22 +13,25 @@ class BrgemmTppEmitter : public TppEmitter {
public:
BrgemmTppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);

size_t get_inputs_num() const override { return 2; }
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

static void execute_brgemm_kernel(libxsmm_gemmfunction brgemm_kernel, void *in0, void *in1, void *out0);

const uintptr_t get_execute_function_ptr() const override { return reinterpret_cast<const uintptr_t>(execute_brgemm_kernel); }
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
size_t get_inputs_num() const override {
return 2;
}
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

const uintptr_t get_execute_function_ptr() const override;
const uintptr_t get_compiled_kernel_ptr() const override;

protected:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
static void validate_subtensors(const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0);
libxsmm_gemm_shape m_shape;
libxsmm_bitfield m_prefetching_flags {0};

private:
std::shared_ptr<BrgemmTppKernelExecutor> m_kernel_executor = nullptr;
};

} // namespace intel_cpu
} // namespace ov
} // namespace intel_cpu
} // namespace ov
17 changes: 3 additions & 14 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_tpp_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "snippets/lowered/port_descriptor.hpp"
#include "transformations/tpp/x64/op/eltwise.hpp"
#include "emitters/plugin/x64/utils.hpp"
#include "utils.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
Expand Down Expand Up @@ -55,7 +56,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
};

for (size_t i = 0; i < num_ins; i++) {
io_dtypes[i] = ov_to_xsmm_dtype(node->get_input_element_type(i));
io_dtypes[i] = tpp::utils::ov_to_xsmm_dtype(node->get_input_element_type(i));
io_offsets[i] = tpp_mod->get_input_offset(i);
io_strides[i] = replace_full_dim(tpp_mod->get_input_stride(i),
expr->get_input_port_descriptor(i)->get_shape().back());
Expand All @@ -64,7 +65,7 @@ TppEmitter::TppEmitter(dnnl::impl::cpu::x64::jit_generator* h,

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_dtypes[i_off] = ov_to_xsmm_dtype(node->get_output_element_type(i));
io_dtypes[i_off] = tpp::utils::ov_to_xsmm_dtype(node->get_output_element_type(i));
io_offsets[i_off] = tpp_mod->get_output_offset(i);
io_strides[i_off] = replace_full_dim(tpp_mod->get_output_stride(i),
expr->get_output_port_descriptor(i)->get_shape().back());
Expand Down Expand Up @@ -114,17 +115,5 @@ void TppEmitter::emit_impl(const std::vector<size_t>& in, const std::vector<size
spill.postamble();
}

libxsmm_datatype TppEmitter::ov_to_xsmm_dtype(ov::element::Type_t elemet_type) {
switch (elemet_type) {
case ov::element::Type_t::f32 : return LIBXSMM_DATATYPE_F32;
case ov::element::Type_t::bf16 : return LIBXSMM_DATATYPE_BF16;
case ov::element::Type_t::i8 : return LIBXSMM_DATATYPE_I8;
case ov::element::Type_t::u8 : return LIBXSMM_DATATYPE_U8;
default:
OV_CPU_JIT_EMITTER_THROW("Attempt to convert unsupported ov data type");
return LIBXSMM_DATATYPE_IMPLICIT;
}
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class TppEmitter : public jit_emitter {
const ov::snippets::lowered::ExpressionPtr& expr);
void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;
static libxsmm_datatype ov_to_xsmm_dtype(ov::element::Type_t elemet_type);

protected:
void emit_impl(const std::vector<size_t>& in,
Expand Down
Loading

0 comments on commit 3091e05

Please sign in to comment.