forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
13d60b1
commit 707cbbb
Showing
7 changed files
with
306 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
164 changes: 164 additions & 0 deletions
164
src/plugins/intel_cpu/src/emitters/tpp/x64/kernel_executors/brgemm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
// Copyright (C) 2020-2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "brgemm.hpp" | ||
|
||
#include "common/utils.hpp" | ||
#include "dnnl_extension_utils.h" | ||
#include "snippets/lowered/pass/insert_specific_iterations.hpp" | ||
#include "transformations/snippets/x64/op/brgemm_cpu.hpp" | ||
#include "transformations/snippets/x64/op/brgemm_utils.hpp" | ||
#include "transformations/tpp/x64/op/brgemm.hpp" | ||
|
||
using namespace Xbyak; | ||
using namespace dnnl::impl; | ||
using namespace dnnl::impl::cpu::x64; | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
#define COMPILE_BRGEMM_TPP_KERNEL(...) \ | ||
[&]() { \ | ||
setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \ | ||
setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \ | ||
auto res = reinterpret_cast<const libxsmm_gemmfunction>(__VA_ARGS__); \ | ||
unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \ | ||
unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \ | ||
return res; \ | ||
}() | ||
|
||
BrgemmTppKernelConfig::BrgemmTppKernelConfig(const element::Type& in0_dtype, | ||
const element::Type& in1_dtype, | ||
libxsmm_bitfield compile_flags, | ||
bool prefetching_flags, | ||
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) | ||
: BrgemmBaseKernelConfig(), | ||
m_static_params( | ||
std::make_shared<StaticParams>(in0_dtype, in1_dtype, compile_flags, prefetching_flags, primitive_isa)) { | ||
m_hash = compute_hash(); | ||
} | ||
|
||
BrgemmTppKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, | ||
const element::Type& in1_dtype, | ||
libxsmm_bitfield compile_flags, | ||
bool prefetching_flags, | ||
dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) | ||
: StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(prefetching_flags)), | ||
m_prefetching_flags(prefetching_flags) {} | ||
|
||
bool BrgemmTppKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { | ||
return StaticBaseParams::operator==(rhs) && m_prefetching_flags == rhs.m_prefetching_flags; | ||
} | ||
|
||
size_t BrgemmTppKernelConfig::StaticParams::compute_hash(bool prefetching_flags) { | ||
return hash_combine(0, prefetching_flags); | ||
} | ||
|
||
#ifdef SNIPPETS_DEBUG_CAPS | ||
std::string BrgemmTppKernelConfig::StaticParams::to_string() const { | ||
std::stringstream ss; | ||
ss << StaticBaseParams::to_string(); | ||
ss << "m_prefetching_flags = " << m_prefetching_flags << "\n"; | ||
return ss.str(); | ||
} | ||
#endif | ||
|
||
BrgemmTppKernelExecutor::BrgemmTppKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, | ||
BrgemmTppKernelConfig config) | ||
: CPUKernelExecutor<BrgemmTppKernelConfig, BrgemmTppCompiledKernel>(std::move(kernel_cache), std::move(config)) {} | ||
|
||
std::shared_ptr<BrgemmTppCompiledKernel> BrgemmTppKernelExecutor::compile_kernel( | ||
const BrgemmTppKernelConfig& config) const { | ||
std::shared_ptr<BrgemmTppCompiledKernel> compiled_kernel = std::make_shared<BrgemmTppCompiledKernel>(); | ||
|
||
// Brgemm is not executable - nothing to compile | ||
if (config.is_empty()) | ||
return compiled_kernel; | ||
|
||
libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(), | ||
config.get_M(), | ||
config.get_K(), | ||
config.get_LDB(), | ||
config.get_LDA(), | ||
config.get_LDC(), | ||
LIBXSMM_DATATYPE_F32, | ||
LIBXSMM_DATATYPE_F32, | ||
LIBXSMM_DATATYPE_F32, | ||
LIBXSMM_DATATYPE_F32); | ||
compiled_kernel->brgemm_kernel = std::make_shared<libxsmm_gemmfunction>(COMPILE_BRGEMM_TPP_KERNEL( | ||
libxsmm_dispatch_gemm(m_shape, config.get_compile_flags(), config.get_prefetching_flags()))); | ||
|
||
return compiled_kernel; | ||
} | ||
|
||
void BrgemmTppKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, | ||
const ov::snippets::lowered::LinearIRCPtr& linear_ir, | ||
BrgemmTppKernelConfig& config) const { | ||
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node()); | ||
const auto beta = brgemm_node->get_beta(); | ||
|
||
const auto& input_0_desc = expr->get_input_port_descriptor(0); | ||
const auto& input_1_desc = expr->get_input_port_descriptor(1); | ||
const auto& output_desc = expr->get_output_port_descriptor(0); | ||
|
||
const auto& subtensor_in0 = input_0_desc->get_subtensor(); | ||
const auto& subtensor_in1 = input_1_desc->get_subtensor(); | ||
const auto& subtensor_out0 = output_desc->get_subtensor(); | ||
auto validate_subtensors = [](const VectorDims& in_0, const VectorDims& in_1, const VectorDims& out_0) { | ||
bool subtensors_compatible = in_0.size() == in_1.size() && in_0.size() == out_0.size() && in_0.size() == 2 && | ||
in_0[1] == in_1[0] && in_0[0] == out_0[0] && in_1[1] == out_0[1]; | ||
OV_CPU_JIT_EMITTER_ASSERT(subtensors_compatible, "Incompatible subtensors"); | ||
}; | ||
validate_subtensors(subtensor_in0, subtensor_in1, subtensor_out0); | ||
|
||
const auto K = static_cast<libxsmm_blasint>(*subtensor_in0.rbegin()); | ||
const auto M = static_cast<libxsmm_blasint>(*++subtensor_in0.rbegin()); | ||
const auto N = static_cast<libxsmm_blasint>(*subtensor_in1.rbegin()); | ||
|
||
const auto& tpp_mod = std::dynamic_pointer_cast<tpp::modifier::TensorProcessingPrimitive>(expr->get_node()); | ||
|
||
auto replace_full_dim = [](size_t dim, size_t replace_dim) { | ||
if (ov::snippets::utils::is_full_dim_value(dim)) | ||
return replace_dim; | ||
return dim; | ||
}; | ||
|
||
const auto num_ins = expr->get_node()->get_input_size(); | ||
const auto num_outs = expr->get_node()->get_output_size(); | ||
|
||
size_t io_strides[num_ins + num_outs]; | ||
|
||
for (size_t i = 0; i < num_ins; i++) { | ||
io_strides[i] = | ||
replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back()); | ||
} | ||
|
||
for (size_t i = 0; i < num_outs; i++) { | ||
const auto i_off = i + num_ins; | ||
io_strides[i_off] = | ||
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back()); | ||
} | ||
|
||
config.update(M, N, K, io_strides[0], io_strides[1], io_strides[2], beta); | ||
|
||
auto compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); | ||
if (beta == 0) | ||
compile_flags |= LIBXSMM_GEMM_FLAG_BETA_0; | ||
|
||
config.set_compile_flags(compile_flags); | ||
} | ||
|
||
void BrgemmTppKernelExecutor::execute(const BrgemmTppKernelExecutor* executor, void* in0, void* in1, void* out0) { | ||
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); | ||
libxsmm_gemm_param gemm_p; | ||
gemm_p.a.primary = in1; | ||
gemm_p.b.primary = in0; | ||
gemm_p.c.primary = out0; | ||
auto brg_kernel = executor->get_kernel(); | ||
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel"); | ||
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel"); | ||
(*(brg_kernel->brgemm_kernel))(&gemm_p); | ||
} | ||
|
||
} // namespace intel_cpu | ||
} // namespace ov |
Oops, something went wrong.