Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Jan 31, 2025
1 parent e27cbb9 commit 156d88b
Show file tree
Hide file tree
Showing 9 changed files with 649 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/plugins/intel_cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ if(NOT AARCH64)
endif()

if (NOT RISCV64)
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/riscv64/*)
list(APPEND EXCLUDE_PATHS ${CMAKE_CURRENT_SOURCE_DIR}/src/emitters/plugin/riscv64/*
${CMAKE_CURRENT_SOURCE_DIR}/src/nodes/kernels/riscv64/*)
endif()

if (NOT ENABLE_MLAS_FOR_CPU)
Expand Down
15 changes: 15 additions & 0 deletions src/plugins/intel_cpu/src/emitters/plugin/riscv64/jit_emitter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_emitter.hpp"

namespace ov {
namespace intel_cpu {
namespace riscv64 {



} // namespace riscv64
} // namespace intel_cpu
} // namespace ov
148 changes: 148 additions & 0 deletions src/plugins/intel_cpu/src/emitters/plugin/riscv64/jit_emitter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <node.h>

#include <set>

#include "nodes/kernels/riscv64/jit_generator.hpp"
#include "emitters/utils.hpp"
#include "snippets/generator.hpp"
#include "snippets/snippets_isa.hpp"


namespace ov {
namespace intel_cpu {
namespace riscv64 {

enum emitter_in_out_map {
vec_to_vec,
vec_to_gpr,
gpr_to_vec,
gpr_to_gpr,
};

// structure for storage of emitter parameters to hash in map
struct emitter_params {
virtual size_t hash() const = 0;
};

class jit_emitter : public ov::snippets::Emitter {
public:
jit_emitter(ov::intel_cpu::riscv64::jit_generator* host,
ov::element::Type exec_prc = ov::element::f32,
emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec);

void emit_code(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs = {},
const std::vector<size_t>& pool_gpr_idxs = {}) const override;
void emit_data() const override;

virtual size_t get_inputs_num() const = 0;
virtual size_t aux_vecs_count() const;
virtual size_t aux_gprs_count() const;
emitter_in_out_map get_in_out_type() const;

/**
* @brief Returns supported precisions.
* Precisions are ordered, the first bigger bitness precision with the same type will be selected.
* Empty collection means the emitter supports any input precisions.
*/
static std::set<std::vector<element::Type>> get_supported_precisions(
const std::shared_ptr<ov::Node>& node = nullptr);

protected:
size_t get_max_vecs_count() const;
size_t get_vec_length() const;

virtual void prepare_table();
virtual void register_table_entries() {}

//void load_table_addr() const {
// h->mov(p_table, *l_table.get());
//}

virtual void emit_impl(const std::vector<size_t>& in_idxs, const std::vector<size_t>& out_idxs) const = 0;

virtual void emitter_preamble(const std::vector<size_t>& in_idxs,
const std::vector<size_t>& out_idxs,
const std::vector<size_t>& pool_vec_idxs,
const std::vector<size_t>& pool_gpr_idxs) const;
virtual void emitter_postamble() const;

void store_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const;
void restore_context(const std::vector<size_t>& gpr_regs,
const std::vector<size_t>& vec_regs,
const std::unordered_set<size_t>& ignore_vec_regs) const;

//Xbyak_riscv64::Address table_val(const std::string& key, size_t key_off_val_shift = 0) const {
// auto off = table_off(key, key_off_val_shift);
// return h->ptr[p_table + off];
//}

// we accept only 32bit hexadecimal table values to avoid any rounding
using table_entry_val_t = uint32_t;
using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table
using table_entry_bcast_t = bool; // true => bcast value

struct table_entry_t {
table_entry_val_t val;
table_entry_bcast_t bcast;
};
struct mapped_table_entry_t {
table_entry_offset_t off;
table_entry_val_t val;
table_entry_bcast_t bcast;
};

using table_t = std::multimap<std::string, table_entry_t>;
using mapped_table_t = std::multimap<std::string, mapped_table_entry_t>;

void push_arg_entry_of(const std::string& key, const table_entry_val_t val, const bool broadcast) {
mapped_table_entry_t te{0, val, broadcast};
entry_map_.insert(std::make_pair(key, te));
}

void push_entries_of(const table_t& t) {
for (auto it = t.begin(); it != t.end(); it++) {
auto key = (*it).first;
auto te = (*it).second; // copy values from table
push_arg_entry_of(key, te.val, te.bcast);
}
}

virtual void validate_arguments(const std::vector<size_t>&, const std::vector<size_t>&) const {}

ov::intel_cpu::riscv64::jit_generator* h;
ov::element::Type exec_prc_;

mutable Xbyak_riscv::Reg p_table;
mutable std::shared_ptr<Xbyak_riscv::Label> l_table;
mutable std::vector<size_t> aux_vec_idxs;
mutable std::vector<size_t> aux_gpr_idxs;

mapped_table_t entry_map_;
emitter_in_out_map in_out_type_;

private:
mutable std::vector<size_t> preserved_vec_idxs;
mutable std::vector<size_t> preserved_gpr_idxs;

size_t table_off(const std::string& key, size_t key_off_val_shift = 0) const {
const auto it = entry_map_.find(key); // search an entry for a key
OV_CPU_JIT_EMITTER_ASSERT(it != entry_map_.end(), "Value has not been found in the table");
const auto& te = (*it).second;
const auto scale = te.bcast ? get_vec_length() : sizeof(table_entry_val_t);
return te.off + key_off_val_shift * scale;
}
};

} // namespace riscv64
} // namespace intel_cpu
} // namespace ov
8 changes: 8 additions & 0 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
# include "kernels/x64/jit_uni_eltwise_generic.hpp"
#endif

#if defined(OPENVINO_ARCH_RISCV64)
# include "kernels/riscv64/jit_uni_eltwise_generic.hpp"
#endif

using namespace dnnl::impl::utils;
using namespace dnnl::impl::cpu;

Expand Down Expand Up @@ -688,6 +692,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
}
#endif // OPENVINO_ARCH_ARM64

_pKernel.reset(new ov::intel_cpu::riscv64::jit_uni_eltwise_generic(jep, eltwise_data));
if (_pKernel) {
_pKernel->create_ker();
}
Expand Down Expand Up @@ -1466,6 +1471,9 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
OPENVINO_THROW("Unknow CPU architecture");
#endif

const bool useJit = getAlgorithm() == Algorithm::EltwiseAdd;
implType = useJit ? EltwiseImplType::optimized : EltwiseImplType::reference;

#if defined(OV_CPU_WITH_ACL)
auto filterPrecision = [&](const ov::element::Type& prc, const ov::element::Type& forcedPrec) {
if (isBitwise(algorithm)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,37 @@

#include "jit_generator.hpp"

#include "utils/general_utils.h"

namespace ov {
namespace intel_cpu {
namespace riscv64 {

using namespace Xbyak_riscv;

void jit_generator::preamble() {
const int frame_size = rnd_up((num_abi_save_gpr_regs + 1) * xlen, sp_aligment);
// TODO: FP gpr ?
addi(sp, sp, -num_abi_save_gpr_regs * xlen);
addi(sp, sp, -frame_size);
int imm = 0;
for (const auto& gpr : abi_save_gpr_regs) {
sw(gpr, sp, imm);
imm += 4;
sd(gpr, sp, imm);
imm += xlen;
}
sd(ra, sp, imm);
}

void jit_generator::postamble() {
const int frame_size = rnd_up((num_abi_save_gpr_regs + 1) * xlen, sp_aligment);
int imm = 0;
for (const auto& gpr : abi_save_gpr_regs) {
lw(gpr, sp, imm);
imm += 4;
ld(gpr, sp, imm);
imm += xlen;
}
addi(sp, sp, num_abi_save_gpr_regs * xlen);
ld(ra, sp, imm);

addi(sp, sp, frame_size);

ret();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@

#include "openvino/core/except.hpp"


namespace ov {
namespace intel_cpu {
namespace riscv64 {

#define DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_name) \
const char *name() const override { return #jit_name; } \
const char *source_file() const override { return __FILE__; }

// RISCV-64 specific registers mapping
// reg | ABI Name | descripion | saved by
// =====================================================
Expand All @@ -39,15 +44,6 @@ namespace riscv64 {
// f18-27 | fs2-11 | FP Saved registers | Callee
// f28-31 | ft8-11 | FP Temporaries | Caller

// Callee-saved registers
constexpr Xbyak_riscv::Reg abi_save_gpr_regs[] = {Xbyak_riscv::s0, Xbyak_riscv::s1, Xbyak_riscv::s2, Xbyak_riscv::s3,
Xbyak_riscv::s4, Xbyak_riscv::s5, Xbyak_riscv::s6, Xbyak_riscv::s7,
Xbyak_riscv::s8, Xbyak_riscv::s9, Xbyak_riscv::s10, Xbyak_riscv::s11};

static const Xbyak_riscv::Reg abi_param1(Xbyak_riscv::a0), abi_param2(Xbyak_riscv::a1), abi_param3(Xbyak_riscv::a2),
abi_param4(Xbyak_riscv::a3), abi_param5(Xbyak_riscv::a4), abi_param6(Xbyak_riscv::a5),
abi_param7(Xbyak_riscv::a6), abi_param8(Xbyak_riscv::a7);

class jit_generator : public Xbyak_riscv::CodeGenerator {
public:
jit_generator(size_t maxSize = Xbyak_riscv::DEFAULT_MAX_CODE_SIZE,
Expand Down Expand Up @@ -79,15 +75,33 @@ class jit_generator : public Xbyak_riscv::CodeGenerator {

// Disallow char-based labels completely
void L(const char *label) = delete;
void L(Xbyak_riscv::Label &label) {
Xbyak_riscv::CodeGenerator::L(label);
}

jit_generator(const jit_generator &) = delete;
jit_generator &operator=(const jit_generator &) = delete;

virtual const char *name() const = 0;
virtual const char *source_file() const = 0;

// Callee-saved registers
static constexpr Xbyak_riscv::Reg abi_save_gpr_regs[] = {Xbyak_riscv::s0, Xbyak_riscv::s1, Xbyak_riscv::s2, Xbyak_riscv::s3,
Xbyak_riscv::s4, Xbyak_riscv::s5, Xbyak_riscv::s6, Xbyak_riscv::s7,
Xbyak_riscv::s8, Xbyak_riscv::s9, Xbyak_riscv::s10, Xbyak_riscv::s11};
// ABI-arguments registers
static constexpr Xbyak_riscv::Reg abi_param_regs[] = {Xbyak_riscv::a0, Xbyak_riscv::a1, Xbyak_riscv::a2, Xbyak_riscv::a3,
Xbyak_riscv::a4, Xbyak_riscv::a5, Xbyak_riscv::a6, Xbyak_riscv::a7};

protected:
virtual void generate() = 0;

const uint8_t *jit_ker_ = nullptr;

// In the standard RISC-V calling convention, the stack pointer is always kept 16-byte aligned
const size_t sp_aligment = 16;
// Vector register count
const size_t vec_count = 32;
// integer gpr byte size
const size_t xlen = Xbyak_riscv::CPU().getXlen() / 8;
// fp gpr byte size
Expand Down
Loading

0 comments on commit 156d88b

Please sign in to comment.