Skip to content

Commit

Permalink
brgemm emitter and executor
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 8, 2025
1 parent acae7cc commit 7f0a72b
Show file tree
Hide file tree
Showing 5 changed files with 424 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm.hpp"
#include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp"
#include "snippets/utils/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

using namespace Xbyak_aarch64;

namespace ov {
namespace intel_cpu {
namespace aarch64 {

using jit_generator = dnnl::impl::cpu::aarch64::jit_generator;
using cpu_isa_t = dnnl::impl::cpu::aarch64::cpu_isa_t;
using ExpressionPtr = ov::snippets::lowered::ExpressionPtr;

jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h,
cpu_isa_t isa,
const ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto& brgemm_node = as_type_ptr<intel_cpu::tpp::op::BrgemmTPP>(expr->get_node());
const auto& brg0Prc = brgemm_node->get_input_element_type(0);
const auto& brg1Prc = brgemm_node->get_input_element_type(1);
BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, isa);
m_kernel_executor =
kernel_table->register_kernel<BrgemmKernelExecutor>(expr, compiled_kernel_cache, kernel_config);
}

std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
// Note: Brgemm currently supports only fp32
return {{element::f32, element::f32}};
}

void jit_brgemm_emitter::validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 2, "Expects 2 input regs, got" + std::to_string(in.size()));
OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Expects 1 output reg, got" + std::to_string(out.size()));
}

void jit_brgemm_emitter::emit_code(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
validate_arguments(in, out);
emit_impl(in, out);
}

void jit_brgemm_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);
std::unordered_set<size_t> exclude = {};
store_context(exclude);

Xbyak_aarch64::XReg func_reg(9);
h->mov(func_reg, get_execute_function_ptr());

Xbyak_aarch64::XReg x0(0);
Xbyak_aarch64::XReg x1(1);
Xbyak_aarch64::XReg x2(2);
Xbyak_aarch64::XReg x3(3);

const auto& compiled_kernel = get_compiled_kernel_ptr();
h->mov(x0, compiled_kernel);
h->mov(x1, Xbyak_aarch64::XReg(in[0]));
h->mov(x2, Xbyak_aarch64::XReg(in[1]));
h->mov(x3, Xbyak_aarch64::XReg(out[0]));
h->blr(func_reg);

restore_context(exclude);
}

const uintptr_t jit_brgemm_emitter::get_compiled_kernel_ptr() const {
return reinterpret_cast<const uintptr_t>(m_kernel_executor.get());
}

const uintptr_t jit_brgemm_emitter::get_execute_function_ptr() const {
return reinterpret_cast<const uintptr_t>(BrgemmKernelExecutor::execute);
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/aarch64/jit_emitter.hpp"
#include "emitters/snippets/aarch64/kernel_executors/brgemm.hpp"

namespace ov {
namespace intel_cpu {
namespace aarch64 {

class jit_brgemm_emitter : public jit_emitter {
public:
jit_brgemm_emitter(dnnl::impl::cpu::aarch64::jit_generator* h,
dnnl::impl::cpu::aarch64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_count() const override {
return m_memory_offsets.size() - 1;
}

static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

void emit_code(const std::vector<size_t> &in,
const std::vector<size_t> &out) const;

private:
void validate_arguments(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

const uintptr_t get_execute_function_ptr() const;
const uintptr_t get_compiled_kernel_ptr() const;

// Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in
// runtime
std::vector<size_t> m_memory_offsets{};
// Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmKernelExecutor> m_kernel_executor = nullptr;
};

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "brgemm.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"

using namespace Xbyak;
using namespace dnnl::impl;
using namespace dnnl::impl::cpu::x64;

#define HASH(X) seed = hash_combine(seed, X)

namespace ov {
namespace intel_cpu {
namespace aarch64 {
#define COMPILE_BRGEMM_TPP_KERNEL(...) \
[&]() { \
setenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX", "1", 1); \
setenv("LIBXSMM_GEMM_K_A_PF_DIST", "4", 1); \
auto res = reinterpret_cast<const libxsmm_gemmfunction>(__VA_ARGS__); \
unsetenv("LIBXSMM_X86_HINT_USE_HIGH_PREC_ELTWISE_APPROX"); \
unsetenv("LIBXSMM_GEMM_K_A_PF_DIST"); \
return res; \
}()

BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: BrgemmBaseKernelConfig(),
m_static_params(
std::make_shared<StaticParams>(in0_dtype, in1_dtype, primitive_isa)) {
m_hash = compute_hash();
}

BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype,
const element::Type& in1_dtype,
dnnl::impl::cpu::aarch64::cpu_isa_t primitive_isa)
: StaticBaseParams(in0_dtype, in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t::isa_undef, compute_hash(primitive_isa)),
m_prefetching_flags(false),
isa(primitive_isa) {
m_type_in0 = ov_to_xsmm_dtype(in0_dtype);
m_type_in1 = ov_to_xsmm_dtype(in1_dtype);
m_type_exec = LIBXSMM_DATATYPE_F32;
m_type_out0 = LIBXSMM_DATATYPE_F32;
m_compile_flags = LIBXSMM_GEMM_FLAGS('N', 'N');
}

size_t BrgemmKernelConfig::StaticParams::compute_hash(dnnl::impl::cpu::aarch64::cpu_isa_t aarch_isa) {
return hash_combine(0, aarch_isa);
}

bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const {
return StaticBaseParams::operator==(rhs) &&
isa == rhs.isa &&
m_type_in0 == rhs.m_type_in0 &&
m_type_in1 == rhs.m_type_in1 &&
m_type_exec == rhs.m_type_exec &&
m_type_out0 == rhs.m_type_out0 &&
m_compile_flags == rhs.m_compile_flags &&
m_prefetching_flags == rhs.m_prefetching_flags;
}

BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache,
BrgemmKernelConfig config)
: CPUKernelExecutor<BrgemmKernelConfig, BrgemmTppCompiledKernel>(std::move(kernel_cache), std::move(config)) {}

std::shared_ptr<BrgemmTppCompiledKernel> BrgemmKernelExecutor::compile_kernel(
const BrgemmKernelConfig& config) const {
std::shared_ptr<BrgemmTppCompiledKernel> compiled_kernel = std::make_shared<BrgemmTppCompiledKernel>();

// Brgemm is not executable - nothing to compile
if (config.is_empty())
return compiled_kernel;

libxsmm_gemm_shape m_shape = libxsmm_create_gemm_shape(config.get_N(),
config.get_M(),
config.get_K(),
config.get_LDB(),
config.get_LDA(),
config.get_LDC(),
config.get_type_in0(),
config.get_type_in1(),
config.get_type_out0(),
config.get_type_exec());
const auto& compile_flag = config.get_compile_flags();
auto refreshed_compile_flag = config.get_beta() == 0 ? config.get_compile_flags() | LIBXSMM_GEMM_FLAG_BETA_0 :
compile_flag;
compiled_kernel->brgemm_kernel = std::make_shared<libxsmm_gemmfunction>(COMPILE_BRGEMM_TPP_KERNEL(
libxsmm_dispatch_gemm(m_shape, refreshed_compile_flag, config.get_prefetching_flags())));

return compiled_kernel;
}

void BrgemmKernelExecutor::update_config(const ov::snippets::lowered::ExpressionPtr& expr,
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
BrgemmKernelConfig& config) const {
BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config);
const auto& tpp_mod = std::dynamic_pointer_cast<tpp::modifier::TensorProcessingPrimitive>(expr->get_node());
auto replace_full_dim = [](size_t dim, size_t replace_dim) {
if (ov::snippets::utils::is_full_dim_value(dim))
return replace_dim;
return dim;
};

const auto num_ins = expr->get_node()->get_input_size();
const auto num_outs = expr->get_node()->get_output_size();

size_t io_strides[num_ins + num_outs];

for (size_t i = 0; i < num_ins; i++) {
io_strides[i] =
replace_full_dim(tpp_mod->get_input_stride(i), expr->get_input_port_descriptor(i)->get_shape().back());
}

for (size_t i = 0; i < num_outs; i++) {
const auto i_off = i + num_ins;
io_strides[i_off] =
replace_full_dim(tpp_mod->get_output_stride(i), expr->get_output_port_descriptor(i)->get_shape().back());
}

config.update(config.get_M(), config.get_N(), config.get_K(), io_strides[0], io_strides[1], io_strides[2], config.get_beta());
// update compile flag, which is depend on beta. should be part of hash.
config.set_compile_flags(config.get_beta() == 0);
}

void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, void* in0, void* in1, void* out0) {
OV_CPU_JIT_EMITTER_ASSERT(executor, "has nullptr executor");
libxsmm_gemm_param gemm_p;
gemm_p.a.primary = in1;
gemm_p.b.primary = in0;
gemm_p.c.primary = out0;
auto brg_kernel = executor->get_kernel();
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel, "has nullptr compiler kernel");
OV_CPU_JIT_EMITTER_ASSERT(brg_kernel->brgemm_kernel, "has nullptr compiler brgemm_kernel");
(*(brg_kernel->brgemm_kernel))(&gemm_p);
}

} // namespace aarch64
} // namespace intel_cpu
} // namespace ov
Loading

0 comments on commit 7f0a72b

Please sign in to comment.