forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
acae7cc
commit 7f0a72b
Showing
5 changed files
with
424 additions
and
0 deletions.
There are no files selected for viewing
89 changes: 89 additions & 0 deletions
89
src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_brgemm_emitter.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "jit_brgemm_emitter.hpp" | ||
|
||
#include "emitters/plugin/x64/utils.hpp" | ||
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp" | ||
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp" | ||
#include "snippets/utils/utils.hpp" | ||
#include "transformations/snippets/x64/op/brgemm_cpu.hpp" | ||
#include "transformations/snippets/x64/op/brgemm_utils.hpp" | ||
#include "transformations/tpp/x64/op/brgemm.hpp" | ||
|
||
using namespace Xbyak_aarch64; | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace aarch64 { | ||
|
||
using jit_generator = dnnl::impl::cpu::aarch64::jit_generator; | ||
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t; | ||
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; | ||
|
||
jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, | ||
cpu_isa_t isa, | ||
const ExpressionPtr& expr, | ||
const snippets::KernelExecutorTablePtr& kernel_table, | ||
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) | ||
: jit_emitter(h, isa) { | ||
in_out_type_ = emitter_in_out_map::gpr_to_gpr; | ||
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node()); | ||
const auto& brg0Prc = brgemm_node->get_input_element_type(0); | ||
const auto& brg1Prc = brgemm_node->get_input_element_type(1); | ||
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa); | ||
m_kernel_executor = | ||
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config); | ||
} | ||
|
||
std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions( | ||
const std::shared_ptr<ov::Node>& node) { | ||
// Note: Brgemm currently supports only fp32 | ||
return {{element::f32, element::f32}}; | ||
} | ||
|
||
void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const { | ||
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size())); | ||
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size())); | ||
} | ||
|
||
void jit_brgemm_emitter::emit_code(const std::vector<size_t> &in, const std::vector<size_t> &out) const { | ||
validate_arguments(in, out); | ||
emit_impl(in, out); | ||
} | ||
|
||
void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const { | ||
validate_arguments(in, out); | ||
std::unordered_set<size_t> exclude = {}; | ||
store_context(exclude); | ||
|
||
Xbyak_aarch64::XReg func_reg(9); | ||
h->mov(func_reg, get_execute_function_ptr()); | ||
|
||
Xbyak_aarch64::XReg x0(0); | ||
Xbyak_aarch64::XReg x1(1); | ||
Xbyak_aarch64::XReg x2(2); | ||
Xbyak_aarch64::XReg x3(3); | ||
|
||
const auto& compiled_kernel = get_compiled_kernel_ptr(); | ||
h->mov(x0, compiled_kernel); | ||
h->mov(x1, Xbyak_aarch64::XReg(in[0])); | ||
h->mov(x2, Xbyak_aarch64::XReg(in[1])); | ||
h->mov(x3, Xbyak_aarch64::XReg(out[0])); | ||
h->blr(func_reg); | ||
|
||
restore_context(exclude); | ||
} | ||
|
||
const uintptr_t jit_brgemm_emitter::get_compiled_kernel_ptr() const { | ||
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get()); | ||
} | ||
|
||
const uintptr_t jit_brgemm_emitter::get_execute_function_ptr() const { | ||
return reinterpret_cast<const uintptr_t>(BrgemmKernelExecutor::execute); | ||
} | ||
|
||
} // namespace aarch64 | ||
} // namespace intel_cpu | ||
} // namespace ov |
49 changes: 49 additions & 0 deletions
49
src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_brgemm_emitter.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "emitters/plugin/aarch64/jit_emitter.hpp" | ||
#include "emitters/snippets/aarch64/kernel_executors/brgemm.hpp" | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace aarch64 { | ||
|
||
class jit_brgemm_emitter : public jit_emitter { | ||
public: | ||
jit_brgemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, | ||
dnnl::impl::cpu::aarch64::cpu_isa_t isa, | ||
const ov::snippets::lowered::ExpressionPtr& expr, | ||
const snippets::KernelExecutorTablePtr& kernel_table, | ||
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); | ||
|
||
size_t get_inputs_count() const override { | ||
return m_memory_offsets.size() - 1; | ||
} | ||
|
||
static std::set<std::vector<element::Type>> get_supported_precisions( | ||
const std::shared_ptr<ov::Node>& node = nullptr); | ||
|
||
void emit_code(const std::vector<size_t> &in, | ||
const std::vector<size_t> &out) const; | ||
|
||
private: | ||
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override; | ||
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override; | ||
|
||
const uintptr_t get_execute_function_ptr() const; | ||
const uintptr_t get_compiled_kernel_ptr() const; | ||
|
||
// Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in | ||
// runtime | ||
std::vector<size_t> m_memory_offsets{}; | ||
// Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer | ||
std::vector<size_t> m_buffer_ids{}; | ||
std::shared_ptr<BrgemmKernelExecutor> m_kernel_executor = nullptr; | ||
}; | ||
|
||
} // namespace aarch64 | ||
} // namespace intel_cpu | ||
} // namespace ov |
141 changes: 141 additions & 0 deletions
141
src/plugins/intel_cpu/src/emitters/snippets/aarch64/kernel_executors/brgemm.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
// Copyright (C) 2025 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "brgemm.hpp" | ||
#include "transformations/tpp/x64/op/brgemm.hpp" | ||
|
||
using namespace Xbyak; | ||
using namespace dnnl::impl; | ||
using namespace dnnl::impl::cpu::x64; | ||
|
||
#define HASH(X) seed = hash_combine(seed, X) | ||
|
||
namespace ov { | ||
namespace intel_cpu { | ||
namespace aarch64 { | ||
#define COMPILE_BRGEMM_TPP_KERNEL(...) \ | ||
[&]() { \ | ||
setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \ | ||
setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \ | ||
auto res = reinterpret_cast<const libxsmm_gemmfunction>(__VA_ARGS__); \ | ||
unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \ | ||
unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \ | ||
return res; \ | ||
}() | ||
|
||
BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, | ||
const element::Type& in1_dtype, | ||
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa) | ||
: BrgemmBaseKernelConfig(), | ||
m_static_params( | ||
std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) { | ||
m_hash = compute_hash(); | ||
} | ||
|
||
BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, | ||
const element::Type& in1_dtype, | ||
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa) | ||
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)), | ||
m_prefetching_flags(false), | ||
isa(primitive_isa) { | ||
m_type_in0 = ov_to_xsmm_dtype(in0_dtype); | ||
m_type_in1 = ov_to_xsmm_dtype(in1_dtype); | ||
m_type_exec = LIBXSMM_DATATYPE_F32; | ||
m_type_out0 = LIBXSMM_DATATYPE_F32; | ||
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N'); | ||
} | ||
|
||
size_t BrgemmKernelConfig::StaticParams::compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa) { | ||
return hash_combine(0, aarch_isa); | ||
} | ||
|
||
bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { | ||
return StaticBaseParams::operator==(rhs) && | ||
isa == rhs.isa && | ||
m_type_in0 == rhs.m_type_in0 && | ||
m_type_in1 == rhs.m_type_in1 && | ||
m_type_exec == rhs.m_type_exec && | ||
m_type_out0 == rhs.m_type_out0 && | ||
m_compile_flags == rhs.m_compile_flags && | ||
m_prefetching_flags == rhs.m_prefetching_flags; | ||
} | ||
|
||
BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, | ||
BrgemmKernelConfig config) | ||
: CPUKernelExecutor<BrgemmKernelConfig, BrgemmTppCompiledKernel>(std::move(kernel_cache), std::move(config)) {} | ||
|
||
std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel( | ||
const BrgemmKernelConfig& config) const { | ||
std::shared_ptr<BrgemmTppCompiledKernel> compiled_kernel = std::make_shared<BrgemmTppCompiledKernel>(); | ||
|
||
// Brgemm is not executable - nothing to compile | ||
if (config.is_empty()) | ||
return compiled_kernel; | ||
|
||
libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(), | ||
config.get_M(), | ||
config.get_K(), | ||
config.get_LDB(), | ||
config.get_LDA(), | ||
config.get_LDC(), | ||
config.get_type_in0(), | ||
config.get_type_in1(), | ||
config.get_type_out0(), | ||
config.get_type_exec()); | ||
const auto& compile_flag = config.get_compile_flags(); | ||
auto refreshed_compile_flag = config.get_beta() == 0 ? config.get_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0 : | ||
compile_flag; | ||
compiled_kernel->brgemm_kernel = std::make_shared<libxsmm_gemmfunction>(COMPILE_BRGEMM_TPP_KERNEL( | ||
libxsmm_dispatch_gemm(m_shape, refreshed_compile_flag, config.get_prefetching_flags()))); | ||
|
||
return compiled_kernel; | ||
} | ||
|
||
void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr, | ||
const ov::snippets::lowered::LinearIRCPtr& linear_ir, | ||
BrgemmKernelConfig& config) const { | ||
BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); | ||
const auto& tpp_mod = std::dynamic_pointer_cast<tpp::modifier::TensorProcessingPrimitive>(expr->get_node()); | ||
auto replace_full_dim = [](size_t dim, size_t replace_dim) { | ||
if (ov::snippets::utils::is_full_dim_value(dim)) | ||
return replace_dim; | ||
return dim; | ||
}; | ||
|
||
const auto num_ins = expr->get_node()->get_input_size(); | ||
const auto num_outs = expr->get_node()->get_output_size(); | ||
|
||
size_t io_strides[num_ins + num_outs]; | ||
|
||
for (size_t i = 0; i < num_ins; i++) { | ||
io_strides[i] = | ||
replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back()); | ||
} | ||
|
||
for (size_t i = 0; i < num_outs; i++) { | ||
const auto i_off = i + num_ins; | ||
io_strides[i_off] = | ||
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back()); | ||
} | ||
|
||
config.update(config.get_M(), config.get_N(), config.get_K(), io_strides[0], io_strides[1], io_strides[2], config.get_beta()); | ||
// update compile flag, which is depend on beta. should be part of hash. | ||
config.set_compile_flags(config.get_beta() == 0); | ||
} | ||
|
||
void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) { | ||
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor"); | ||
libxsmm_gemm_param gemm_p; | ||
gemm_p.a.primary = in1; | ||
gemm_p.b.primary = in0; | ||
gemm_p.c.primary = out0; | ||
auto brg_kernel = executor->get_kernel(); | ||
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel"); | ||
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel"); | ||
(*(brg_kernel->brgemm_kernel))(&gemm_p); | ||
} | ||
|
||
} // namespace aarch64 | ||
} // namespace intel_cpu | ||
} // namespace ov |
Oops, something went wrong.