From 707cbbb0f58e30bd1bd47d0035e38578233b428b Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Mon, 30 Dec 2024 16:52:55 +0800 Subject: [PATCH] tpp backend executor for matmul --- .../emitters/snippets/x64/cpu_generator.cpp | 3 +- .../x64/kernel_executors/brgemm_base.cpp | 4 +- .../emitters/tpp/x64/jit_brgemm_emitter.cpp | 62 +++---- .../emitters/tpp/x64/jit_brgemm_emitter.hpp | 30 ++-- .../tpp/x64/kernel_executors/brgemm.cpp | 164 ++++++++++++++++++ .../tpp/x64/kernel_executors/brgemm.hpp | 93 ++++++++++ .../tpp/x64/pass/brgemm_to_brgemm_tpp.cpp | 1 + 7 files changed, 306 insertions(+), 51 deletions(-) create mode 100644 src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp create mode 100644 src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.hpp diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index 39e384837856a1..de3b4f9bd4dd15 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -290,7 +290,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho #endif #ifdef SNIPPETS_LIBXSMM_TPP - jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter); + jitters[intel_cpu::tpp::op::BrgemmTPP::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(BrgemmTppEmitter, configurator->get_kernel_executor_table(), compiled_kernel_cache); jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); jitters[intel_cpu::tpp::op::Subtract::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); jitters[intel_cpu::tpp::op::Multiply::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(BinaryEltwiseTppEmitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp index 008237780de3f6..126571c8498842 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp @@ -257,9 +257,9 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres auto LDB = DIM_CAST(snippets::utils::get_dim_stride(expr->get_input_port(1))); const auto& brgemm_node = as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); + // OV_CPU_JIT_EMITTER_ASSERT(brgemm_node, "Got invalid node type in update_config"); // In case of data repacking LDB is chosen in accordance with repacking buffer size - if (with_repacking(brgemm_node->get_type())) + if (brgemm_node && with_repacking(brgemm_node->get_type())) LDB = DIM_CAST(brgemm_utils::repacking::compute_LDB(LDB, brgemm_node->get_input_element_type(1))); config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp index 0fcb394a8a5bde..b278b1725ac9ef 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.cpp @@ -3,6 +3,7 @@ // #include "jit_brgemm_emitter.hpp" + #include "emitters/snippets/x64/jit_snippets_emitters.hpp" #include "transformations/tpp/x64/op/brgemm.hpp" @@ -19,27 +20,33 @@ void BrgemmTppEmitter::validate_subtensors(const VectorDims& in_0, const VectorD OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors"); } -BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) +BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, + cpu_isa_t isa, + const ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) : TppEmitter(h, isa, expr) { const auto& brgemm_node = as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(brgemm_node && !brgemm_node->is_dynamic(), "Invoked with invalid node type"); + libxsmm_bitfield compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + if (brgemm_node->get_beta() == 0) { + compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0; + } + const auto& brg0Prc = brgemm_node->get_input_element_type(0); + const auto& brg1Prc = brgemm_node->get_input_element_type(0); + BrgemmTppKernelConfig kernel_config(brg0Prc, brg1Prc, compile_flags, LIBXSMM_GEMM_PREFETCH_NONE); + m_kernel_executor = + kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); const auto& input_0_desc = expr->get_input_port_descriptor(0); const auto& input_1_desc = expr->get_input_port_descriptor(1); const auto& output_desc = expr->get_output_port_descriptor(0); - std::vector leading_dimensions {brgemm_node->get_input_stride(0), - brgemm_node->get_input_stride(1), - brgemm_node->get_output_stride(0)}; - auto in_0_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(0)); auto in_1_prec = ov_to_xsmm_dtype(brgemm_node->get_input_element_type(1)); - exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? - LIBXSMM_DATATYPE_I32 : - LIBXSMM_DATATYPE_F32; - auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? - LIBXSMM_DATATYPE_I32 : - LIBXSMM_DATATYPE_F32; + exec_dtype = in_0_prec == LIBXSMM_DATATYPE_I8 || in_0_prec == LIBXSMM_DATATYPE_U8 ? LIBXSMM_DATATYPE_I32 + : LIBXSMM_DATATYPE_F32; + auto out_0_prec = exec_dtype == LIBXSMM_DATATYPE_I32 ? LIBXSMM_DATATYPE_I32 : LIBXSMM_DATATYPE_F32; const auto beta = brgemm_node->get_beta(); OV_CPU_JIT_EMITTER_ASSERT(beta == 0 || beta == 1, "Detected unsupported beta value: " + std::to_string(beta)); @@ -54,18 +61,14 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres const auto N = static_cast(*subtensor_in1.rbegin()); const bool is_f32_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_F32; - const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16; + const bool is_bf16_gemm = in_0_prec == in_1_prec && in_0_prec == LIBXSMM_DATATYPE_BF16; const bool is_i8_gemm = in_0_prec == LIBXSMM_DATATYPE_U8 || in_0_prec == LIBXSMM_DATATYPE_I8; - OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm || - (is_bf16_gemm && K % 2 == 0) || - (is_i8_gemm && K % 4 == 0), + OV_CPU_JIT_EMITTER_ASSERT(is_f32_gemm || (is_bf16_gemm && K % 2 == 0) || (is_i8_gemm && K % 4 == 0), "Unsupported parameter combination for kernel configuration"); - m_compile_flags = is_f32_gemm ? - LIBXSMM_GEMM_FLAGS('N', 'N') : - LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') | - LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | - LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG; + m_compile_flags = is_f32_gemm ? LIBXSMM_GEMM_FLAGS('N', 'N') + : LIBXSMM_GEMM_VNNI_FLAGS('N', 'N', 'V', 'N') | + LIBXSMM_GEMM_FLAG_NO_SETUP_TILECONFIG | LIBXSMM_GEMM_FLAG_NO_RESET_TILECONFIG; if (beta == 0) m_compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0; @@ -78,12 +81,6 @@ BrgemmTppEmitter::BrgemmTppEmitter(jit_generator* h, cpu_isa_t isa, const Expres in_1_prec = LIBXSMM_DATATYPE_I8; m_compile_flags |= LIBXSMM_GEMM_FLAG_B_UNSIGNED; } - - m_shape = libxsmm_create_gemm_shape(N, M, K, - io_strides[1], io_strides[0], io_strides[2], - in_1_prec, in_0_prec, out_0_prec, - exec_dtype); - m_prefetching_flags = LIBXSMM_GEMM_PREFETCH_NONE; } std::set> BrgemmTppEmitter::get_supported_precisions(const std::shared_ptr& node) { @@ -91,22 +88,17 @@ std::set> BrgemmTppEmitter::get_supported_precisions( return {{element::f32, element::f32}}; } -void BrgemmTppEmitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void BrgemmTppEmitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); } const uintptr_t BrgemmTppEmitter::get_compiled_kernel_ptr() const { - return COMPILE_TPP_KERNEL(libxsmm_dispatch_gemm(m_shape, m_compile_flags, m_prefetching_flags)); + return reinterpret_cast(m_kernel_executor.get()); } -void BrgemmTppEmitter::execute_brgemm_kernel(libxsmm_gemmfunction brg_kernel, void *in0, void *in1, void *out0) { - libxsmm_gemm_param gemm_p; - gemm_p.a.primary = in1; - gemm_p.b.primary = in0; - gemm_p.c.primary = out0; - OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "Invalid brgemm kernel pointer"); - brg_kernel(&gemm_p); +const uintptr_t BrgemmTppEmitter::get_execute_function_ptr() const { + return reinterpret_cast(BrgemmTppKernelExecutor::execute); } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp index b8e13ab10b5f3a..0fe8bd60dd5b9b 100644 --- a/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/jit_brgemm_emitter.hpp @@ -3,6 +3,7 @@ // #pragma once +#include "emitters/tpp/x64/kernel_executors/brgemm.hpp" #include "jit_tpp_emitter.hpp" namespace ov { @@ -12,22 +13,25 @@ class BrgemmTppEmitter : public TppEmitter { public: BrgemmTppEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); - - size_t get_inputs_num() const override { return 2; } - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); - - static void execute_brgemm_kernel(libxsmm_gemmfunction brgemm_kernel, void *in0, void *in1, void *out0); - - const uintptr_t get_execute_function_ptr() const override { return reinterpret_cast(execute_brgemm_kernel); } + const ov::snippets::lowered::ExpressionPtr& expr, + const snippets::KernelExecutorTablePtr& kernel_table, + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); + size_t get_inputs_num() const override { + return 2; + } + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); + + const uintptr_t get_execute_function_ptr() const override; const uintptr_t get_compiled_kernel_ptr() const override; protected: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; static void validate_subtensors(const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0); - libxsmm_gemm_shape m_shape; - libxsmm_bitfield m_prefetching_flags {0}; + +private: + std::shared_ptr m_kernel_executor = nullptr; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp new file mode 100644 index 00000000000000..f062d0f27b9d72 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp @@ -0,0 +1,164 @@ +// Copyright (C) 2020-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "brgemm.hpp" + +#include "common/utils.hpp" +#include "dnnl_extension_utils.h" +#include "snippets/lowered/pass/insert_specific_iterations.hpp" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" +#include "transformations/tpp/x64/op/brgemm.hpp" + +using namespace Xbyak; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { +#define COMPILE_BRGEMM_TPP_KERNEL(...) \ + [&]() { \ + setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \ + setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \ + auto res = reinterpret_cast(__VA_ARGS__); \ + unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \ + unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \ + return res; \ + }() + +BrgemmTppKernelConfig::BrgemmTppKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + libxsmm_bitfield compile_flags, + bool prefetching_flags, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : BrgemmBaseKernelConfig(), + m_static_params( + std::make_shared(in0_dtype, in1_dtype, compile_flags, prefetching_flags, primitive_isa)) { + m_hash = compute_hash(); +} + +BrgemmTppKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + libxsmm_bitfield compile_flags, + bool prefetching_flags, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(prefetching_flags)), + m_prefetching_flags(prefetching_flags) {} + +bool BrgemmTppKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { + return StaticBaseParams::operator==(rhs) && m_prefetching_flags == rhs.m_prefetching_flags; +} + +size_t BrgemmTppKernelConfig::StaticParams::compute_hash(bool prefetching_flags) { + return hash_combine(0, prefetching_flags); +} + +#ifdef SNIPPETS_DEBUG_CAPS +std::string BrgemmTppKernelConfig::StaticParams::to_string() const { + std::stringstream ss; + ss << StaticBaseParams::to_string(); + ss << "m_prefetching_flags = " << m_prefetching_flags << "\n"; + return ss.str(); +} +#endif + +BrgemmTppKernelExecutor::BrgemmTppKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + BrgemmTppKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} + +std::shared_ptr BrgemmTppKernelExecutor::compile_kernel( + const BrgemmTppKernelConfig& config) const { + std::shared_ptr compiled_kernel = std::make_shared(); + + // Brgemm is not executable - nothing to compile + if (config.is_empty()) + return compiled_kernel; + + libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(), + config.get_M(), + config.get_K(), + config.get_LDB(), + config.get_LDA(), + config.get_LDC(), + LIBXSMM_DATATYPE_F32, + LIBXSMM_DATATYPE_F32, + LIBXSMM_DATATYPE_F32, + LIBXSMM_DATATYPE_F32); + compiled_kernel->brgemm_kernel = std::make_shared(COMPILE_BRGEMM_TPP_KERNEL( + libxsmm_dispatch_gemm(m_shape, config.get_compile_flags(), config.get_prefetching_flags()))); + + return compiled_kernel; +} + +void BrgemmTppKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmTppKernelConfig& config) const { + const auto& brgemm_node = as_type_ptr(expr->get_node()); + const auto beta = brgemm_node->get_beta(); + + const auto& input_0_desc = expr->get_input_port_descriptor(0); + const auto& input_1_desc = expr->get_input_port_descriptor(1); + const auto& output_desc = expr->get_output_port_descriptor(0); + + const auto& subtensor_in0 = input_0_desc->get_subtensor(); + const auto& subtensor_in1 = input_1_desc->get_subtensor(); + const auto& subtensor_out0 = output_desc->get_subtensor(); + auto validate_subtensors = [](const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0) { + bool subtensors_compatible = in_0.size() == in_1.size() && in_0.size() == out_0.size() && in_0.size() == 2 && + in_0[1] == in_1[0] && in_0[0] == out_0[0] && in_1[1] == out_0[1]; + OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors"); + }; + validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0); + + const auto K = static_cast(*subtensor_in0.rbegin()); + const auto M = static_cast(*++subtensor_in0.rbegin()); + const auto N = static_cast(*subtensor_in1.rbegin()); + + const auto& tpp_mod = std::dynamic_pointer_cast(expr->get_node()); + + auto replace_full_dim = [](size_t dim, size_t replace_dim) { + if (ov::snippets::utils::is_full_dim_value(dim)) + return replace_dim; + return dim; + }; + + const auto num_ins = expr->get_node()->get_input_size(); + const auto num_outs = expr->get_node()->get_output_size(); + + size_t io_strides[num_ins + num_outs]; + + for (size_t i = 0; i < num_ins; i++) { + io_strides[i] = + replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back()); + } + + for (size_t i = 0; i < num_outs; i++) { + const auto i_off = i + num_ins; + io_strides[i_off] = + replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back()); + } + + config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta); + + auto compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); + if (beta == 0) + compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0; + + config.set_compile_flags(compile_flags); +} + +void BrgemmTppKernelExecutor::execute(const BrgemmTppKernelExecutor* executor, void* in0, void* in1, void* out0) { + OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); + libxsmm_gemm_param gemm_p; + gemm_p.a.primary = in1; + gemm_p.b.primary = in0; + gemm_p.c.primary = out0; + auto brg_kernel = executor->get_kernel(); + OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel"); + OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel"); + (*(brg_kernel->brgemm_kernel))(&gemm_p); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.hpp new file mode 100644 index 00000000000000..b8530253eb92a8 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.hpp @@ -0,0 +1,93 @@ +// Copyright (C) 2020-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "emitters/snippets/x64/kernel_executors/brgemm_base.hpp" +#include "libxsmm.h" + +namespace ov { +namespace intel_cpu { + +struct BrgemmTppKernelConfig : public BrgemmBaseKernelConfig { +public: + BrgemmTppKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + libxsmm_bitfield compile_flags, + bool m_prefetching_flags, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef); + BrgemmTppKernelConfig() = delete; + + std::unique_ptr get_clone_ptr() const override { + return std::unique_ptr(new BrgemmTppKernelConfig(*this)); + } + + libxsmm_bitfield get_compile_flags() const { + return m_compile_flags; + } + + void set_compile_flags(const libxsmm_bitfield& flag) { + m_compile_flags = flag; + } + + bool get_prefetching_flags() const { + return m_static_params->m_prefetching_flags; + } + +private: + struct StaticParams : StaticBaseParams { + StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + libxsmm_bitfield compile_flags, + bool prefetching_flags, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + + bool operator==(const StaticParams& rhs) const; + bool operator!=(const StaticParams& rhs) const { + return !(*this == rhs); + } + const bool m_prefetching_flags{false}; +#ifdef SNIPPETS_DEBUG_CAPS + std::string to_string() const; +#endif + private: + static size_t compute_hash(bool prefetching_flags); + }; + + std::shared_ptr get_static_params() const override { + return m_static_params; + } + + libxsmm_bitfield m_compile_flags{0}; + + std::shared_ptr m_static_params{nullptr}; +}; + +// The `update_kernel` method verifies that a compiled kernel is not nullptr. +// However, the compiled kernel might be empty in cases if nothing is to be compiled (`Config.is_empty() == true`). +// To cover this case, we wrap the `libxsmm_gemmfunction` in the separate structure which may contain empty +// `libxsmm_gemmfunction` +struct BrgemmTppCompiledKernel { + std::shared_ptr brgemm_kernel = nullptr; +}; + +class BrgemmTppKernelExecutor : public BrgemmBaseKernelExecutor, + public CPUKernelExecutor { +public: + BrgemmTppKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmTppKernelConfig config); + + /** Function that will be called in runtime to execute the kernel */ + static void execute(const BrgemmTppKernelExecutor* executor, void* in0, void* in1, void* out0); + +protected: + std::shared_ptr compile_kernel(const BrgemmTppKernelConfig& c) const override; + + void update_config(const ov::snippets::lowered::ExpressionPtr& expr, + const ov::snippets::lowered::LinearIRCPtr& linear_ir, + BrgemmTppKernelConfig& config) const override; +}; +#define GET_OFF_BRGEMM_ARGS(field) offsetof(BrgemmKernelExecutor::call_args, field) + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp index 53992b1e67da9c..e69724c3ff0a52 100644 --- a/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp +++ b/src/plugins/intel_cpu/src/transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.cpp @@ -49,6 +49,7 @@ BrgemmToBrgemmTPP::BrgemmToBrgemmTPP() { if (!brgemm || ov::as_type_ptr(node)) OPENVINO_THROW("BrgemmCPU cannot be in body before BrgemmToBrgemmTPP pass"); + // remove this restriction after executor finished if (brgemm->is_dynamic()) { return false; }