Skip to content

Commit

Permalink
tpp backend executor for matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Dec 30, 2024
1 parent 13d60b1 commit 707cbbb
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
#endif

#ifdef SNIPPETS_LIBXSMM_TPP
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter);
jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] =
CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache);
jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres
auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1)));

const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config");
// In case of data repacking LDB is chosen in accordance with repacking buffer size
if (with_repacking(brgemm_node->get_type()))
if (brgemm_node && with_repacking(brgemm_node->get_type()))
LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1)));

config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta);
Expand Down
62 changes: 27 additions & 35 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "jit_brgemm_emitter.hpp"

#include "emitters/snippets/x64/jit_snippets_emitters.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

Expand All @@ -19,27 +20,33 @@ void BrgemmTppEmitter::validate_subtensors(const VectorDims& in_0, const VectorD
OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors");
}

BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: TppEmitter(h, isa, expr) {
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_node && !brgemm_node->is_dynamic(), "Invoked with invalid node type");
libxsmm_bitfield compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
if (brgemm_node->get_beta() == 0) {
compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;
}
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(0);
BrgemmTppKernelConfig kernel_config(brg0Prc, brg1Prc, compile_flags, LIBXSMM_GEMM_PREFETCH_NONE);
m_kernel_executor =
kernel_table->register_kernel<BrgemmTppKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

const auto& input_0_desc = expr->get_input_port_descriptor(0);
const auto& input_1_desc = expr->get_input_port_descriptor(1);
const auto& output_desc = expr->get_output_port_descriptor(0);

std::vector<size_t> leading_dimensions {brgemm_node->get_input_stride(0),
brgemm_node->get_input_stride(1),
brgemm_node->get_output_stride(0)};

auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0));
auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1));
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ?
LIBXSMM_DATATYPE_I32 :
LIBXSMM_DATATYPE_F32;
exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? LIBXSMM_DATATYPE_I32
: LIBXSMM_DATATYPE_F32;
auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? LIBXSMM_DATATYPE_I32 : LIBXSMM_DATATYPE_F32;

const auto beta = brgemm_node->get_beta();
OV_CPU_JIT_EMITTER_ASSERT(beta == 0 || beta == 1, "Detected unsupported beta value: " + std::to_string(beta));
Expand All @@ -54,18 +61,14 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
const auto N = static_cast<libxsmm_blasint>(*subtensor_in1.rbegin());

const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32;
const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16;
const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8;
OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm ||
(is_bf16_gemm && K % 2 == 0) ||
(is_i8_gemm && K % 4 == 0),
OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm || (is_bf16_gemm && K % 2 == 0) || (is_i8_gemm && K % 4 == 0),
"Unsupported parameter combination for kernel configuration");

m_compile_flags = is_f32_gemm ?
LIBXSMM_GEMM_FLAGS('N', 'N') :
LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') |
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG |
LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;
m_compile_flags = is_f32_gemm ? LIBXSMM_GEMM_FLAGS('N', 'N')
: LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') |
LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG;

if (beta == 0)
m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;
Expand All @@ -78,35 +81,24 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres
in_1_prec = LIBXSMM_DATATYPE_I8;
m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED;
}

m_shape = libxsmm_create_gemm_shape(N, M, K,
io_strides[1], io_strides[0], io_strides[2],
in_1_prec, in_0_prec, out_0_prec,
exec_dtype);
m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE;
}

std::set<std::vector<element::Type>> BrgemmTppEmitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
// Note: BrgemmTpp currently supports only fp32
return {{element::f32, element::f32}};
}

void BrgemmTppEmitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
void BrgemmTppEmitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const {
return COMPILE_TPP_KERNEL(libxsmm_dispatch_gemm(m_shape, m_compile_flags, m_prefetching_flags));
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get());
}

void BrgemmTppEmitter::execute_brgemm_kernel(libxsmm_gemmfunction brg_kernel, void *in0, void *in1, void *out0) {
libxsmm_gemm_param gemm_p;
gemm_p.a.primary = in1;
gemm_p.b.primary = in0;
gemm_p.c.primary = out0;
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "Invalid brgemm kernel pointer");
brg_kernel(&gemm_p);
const uintptr_t BrgemmTppEmitter::get_execute_function_ptr() const {
return reinterpret_cast<const uintptr_t>(BrgemmTppKernelExecutor::execute);
}

} // namespace intel_cpu
Expand Down
30 changes: 17 additions & 13 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#pragma once
#include "emitters/tpp/x64/kernel_executors/brgemm.hpp"
#include "jit_tpp_emitter.hpp"

namespace ov {
Expand All @@ -12,22 +13,25 @@ class BrgemmTppEmitter : public TppEmitter {
public:
BrgemmTppEmitter(dnnl::impl::cpu::x64::jit_generator* h,
dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr);

size_t get_inputs_num() const override { return 2; }
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

static void execute_brgemm_kernel(libxsmm_gemmfunction brgemm_kernel, void *in0, void *in1, void *out0);

const uintptr_t get_execute_function_ptr() const override { return reinterpret_cast<const uintptr_t>(execute_brgemm_kernel); }
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);
size_t get_inputs_num() const override {
return 2;
}
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

const uintptr_t get_execute_function_ptr() const override;
const uintptr_t get_compiled_kernel_ptr() const override;

protected:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
static void validate_subtensors(const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0);
libxsmm_gemm_shape m_shape;
libxsmm_bitfield m_prefetching_flags {0};

private:
std::shared_ptr<BrgemmTppKernelExecutor> m_kernel_executor = nullptr;
};

} // namespace intel_cpu
} // namespace ov
} // namespace intel_cpu
} // namespace ov
164 changes: 164 additions & 0 deletions src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright (C) 2020-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm.hpp"

#include "common/utils.hpp"
#include "dnnl_extension_utils.h"
#include "snippets/lowered/pass/insert_specific_iterations.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;

namespace ov {
namespace intel_cpu {
#define COMPILE_BRGEMM_TPP_KERNEL(...) \
[&]() { \
setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \
setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \
auto res = reinterpret_cast<const libxsmm_gemmfunction>(__VA_ARGS__); \
unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \
unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \
return res; \
}()

BrgemmTppKernelConfig::BrgemmTppKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
libxsmm_bitfield compile_flags,
bool prefetching_flags,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
m_static_params(
std::make_shared<StaticParams>(in0_dtype, in1_dtype, compile_flags, prefetching_flags, primitive_isa)) {
m_hash = compute_hash();
}

BrgemmTppKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
libxsmm_bitfield compile_flags,
bool prefetching_flags,
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(prefetching_flags)),
m_prefetching_flags(prefetching_flags) {}

bool BrgemmTppKernelConfig::StaticParams::operator==(const StaticParams& rhs) const {
return StaticBaseParams::operator==(rhs) && m_prefetching_flags == rhs.m_prefetching_flags;
}

size_t BrgemmTppKernelConfig::StaticParams::compute_hash(bool prefetching_flags) {
return hash_combine(0, prefetching_flags);
}

#ifdef SNIPPETS_DEBUG_CAPS
std::string BrgemmTppKernelConfig::StaticParams::to_string() const {
std::stringstream ss;
ss << StaticBaseParams::to_string();
ss << "m_prefetching_flags = " << m_prefetching_flags << "\n";
return ss.str();
}
#endif

BrgemmTppKernelExecutor::BrgemmTppKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache,
BrgemmTppKernelConfig config)
: CPUKernelExecutor<BrgemmTppKernelConfig, BrgemmTppCompiledKernel>(std::move(kernel_cache), std::move(config)) {}

std::shared_ptr<BrgemmTppCompiledKernel> BrgemmTppKernelExecutor::compile_kernel(
const BrgemmTppKernelConfig& config) const {
std::shared_ptr<BrgemmTppCompiledKernel> compiled_kernel = std::make_shared<BrgemmTppCompiledKernel>();

// Brgemm is not executable - nothing to compile
if (config.is_empty())
return compiled_kernel;

libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(),
config.get_M(),
config.get_K(),
config.get_LDB(),
config.get_LDA(),
config.get_LDC(),
LIBXSMM_DATATYPE_F32,
LIBXSMM_DATATYPE_F32,
LIBXSMM_DATATYPE_F32,
LIBXSMM_DATATYPE_F32);
compiled_kernel->brgemm_kernel = std::make_shared<libxsmm_gemmfunction>(COMPILE_BRGEMM_TPP_KERNEL(
libxsmm_dispatch_gemm(m_shape, config.get_compile_flags(), config.get_prefetching_flags())));

return compiled_kernel;
}

void BrgemmTppKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmTppKernelConfig& config) const {
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
const auto beta = brgemm_node->get_beta();

const auto& input_0_desc = expr->get_input_port_descriptor(0);
const auto& input_1_desc = expr->get_input_port_descriptor(1);
const auto& output_desc = expr->get_output_port_descriptor(0);

const auto& subtensor_in0 = input_0_desc->get_subtensor();
const auto& subtensor_in1 = input_1_desc->get_subtensor();
const auto& subtensor_out0 = output_desc->get_subtensor();
auto validate_subtensors = [](const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0) {
bool subtensors_compatible = in_0.size() == in_1.size() && in_0.size() == out_0.size() && in_0.size() == 2 &&
in_0[1] == in_1[0] && in_0[0] == out_0[0] && in_1[1] == out_0[1];
OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors");
};
validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0);

const auto K = static_cast<libxsmm_blasint>(*subtensor_in0.rbegin());
const auto M = static_cast<libxsmm_blasint>(*++subtensor_in0.rbegin());
const auto N = static_cast<libxsmm_blasint>(*subtensor_in1.rbegin());

const auto& tpp_mod = std::dynamic_pointer_cast<tpp::modifier::TensorProcessingPrimitive>(expr->get_node());

auto replace_full_dim = [](size_t dim, size_t replace_dim) {
if (ov::snippets::utils::is_full_dim_value(dim))
return replace_dim;
return dim;
};

const auto num_ins = expr->get_node()->get_input_size();
const auto num_outs = expr->get_node()->get_output_size();

size_t io_strides[num_ins + num_outs];

for (size_t i = 0; i < num_ins; i++) {
io_strides[i] =
replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back());
}

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_strides[i_off] =
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back());
}

config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta);

auto compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
if (beta == 0)
compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0;

config.set_compile_flags(compile_flags);
}

void BrgemmTppKernelExecutor::execute(const BrgemmTppKernelExecutor* executor, void* in0, void* in1, void* out0) {
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor");
libxsmm_gemm_param gemm_p;
gemm_p.a.primary = in1;
gemm_p.b.primary = in0;
gemm_p.c.primary = out0;
auto brg_kernel = executor->get_kernel();
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel");
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel");
(*(brg_kernel->brgemm_kernel))(&gemm_p);
}

} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 707cbbb

Please sign in to comment.