diff --git a/src/common/snippets/include/snippets/emitter.hpp b/src/common/snippets/include/snippets/emitter.hpp index a3a92c7ca7f135..90a767aa801ffd 100644 --- a/src/common/snippets/include/snippets/emitter.hpp +++ b/src/common/snippets/include/snippets/emitter.hpp @@ -16,25 +16,28 @@ namespace snippets { * @interface RegType * @brief Register type of input and output operations */ -enum class RegType { gpr, vec, undefined }; +enum class RegType { gpr, vec, mask, undefined }; /** * @interface Reg * @brief Register representation: type of register and index */ struct Reg { + enum {UNDEFINED_IDX = std::numeric_limits::max()}; Reg() = default; Reg(RegType type_, size_t idx_) : type(type_), idx(idx_) {} - RegType type = RegType::gpr; - size_t idx = 0; + bool is_defined() const { return type != RegType::undefined && idx != UNDEFINED_IDX; } + RegType type = RegType::undefined; + size_t idx = UNDEFINED_IDX; friend bool operator==(const Reg& lhs, const Reg& rhs); + friend bool operator<(const Reg& lhs, const Reg& rhs); + friend bool operator>(const Reg& lhs, const Reg& rhs); friend bool operator!=(const Reg& lhs, const Reg& rhs); + friend std::ostream& operator<<(std::ostream& s, const Reg& r); }; using RegInfo = std::pair, std::vector>; -std::string regTypeToStr(const RegType& type); - /** * @interface Emitter * @brief Base class for all target specific code emitters used by generator. diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index 525cd7c73767af..4ea06d0358d068 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -65,7 +65,7 @@ class Schedule { * @brief Target independent code generator interface * @ingroup snippets */ -class Generator { +class Generator : public std::enable_shared_from_this{ public: /** * @brief Default constructor diff --git a/src/common/snippets/include/snippets/lowered/expression.hpp b/src/common/snippets/include/snippets/lowered/expression.hpp index 286f561b5bcb03..779ab6d7d104ca 100644 --- a/src/common/snippets/include/snippets/lowered/expression.hpp +++ b/src/common/snippets/include/snippets/lowered/expression.hpp @@ -35,6 +35,8 @@ class Expression : public std::enable_shared_from_this { RegInfo get_reg_info() const; void set_reg_info(const RegInfo& rinfo); + const std::set& get_live_regs() const {return m_live_regs; } + void set_live_regs(std::set live_regs) { m_live_regs = std::move(live_regs); } double get_exec_num() const { return m_exec_num; } @@ -130,6 +132,7 @@ class Expression : public std::enable_shared_from_this { // 2. This number can be changed and updated during whole pipeline, so its absolute values are meaningless. // 3. This number can be negative, positive and zero. double m_exec_num = 0; + std::set m_live_regs; }; } // namespace lowered diff --git a/src/common/snippets/include/snippets/lowered/expression_factory.hpp b/src/common/snippets/include/snippets/lowered/expression_factory.hpp index d617eb3d03b410..30d789380e5386 100644 --- a/src/common/snippets/include/snippets/lowered/expression_factory.hpp +++ b/src/common/snippets/include/snippets/lowered/expression_factory.hpp @@ -10,6 +10,7 @@ #include "snippets/op/loop.hpp" #include "snippets/op/buffer.hpp" #include "snippets/op/perf_count.hpp" +#include "snippets/op/reg_spill.hpp" namespace ov { namespace snippets { @@ -35,6 +36,10 @@ class ExpressionFactory { const std::shared_ptr& shape_infer_factory); static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, const std::shared_ptr& shape_infer_factory); + static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, + const std::shared_ptr& shape_infer_factory); + static ExpressionPtr create(const std::shared_ptr& n, const std::vector& inputs, + const std::shared_ptr& shape_infer_factory); // Note: PerfCountBegin nodes have a PerfCountEnd ov::Output, but corresponding expression should not have any outputs to avoid register allocation #ifdef SNIPPETS_DEBUG_CAPS diff --git a/src/common/snippets/include/snippets/lowered/pass/assign_registers.hpp b/src/common/snippets/include/snippets/lowered/pass/assign_registers.hpp index 147494a56c9904..604370e46c5f4d 100644 --- a/src/common/snippets/include/snippets/lowered/pass/assign_registers.hpp +++ b/src/common/snippets/include/snippets/lowered/pass/assign_registers.hpp @@ -6,6 +6,7 @@ #include "pass.hpp" #include "snippets/generator.hpp" +#include "snippets/lowered/reg_manager.hpp" namespace ov { namespace snippets { @@ -21,15 +22,14 @@ namespace pass { class AssignRegisters : public Pass { public: OPENVINO_RTTI("AssignRegisters", "Pass") - explicit AssignRegisters(const std::function& out)>& mapper, const size_t reg_cnt) - : m_reg_type_mapper(mapper), reg_count(reg_cnt) {} + explicit AssignRegisters(RegManager& reg_manager) : m_reg_manager(reg_manager) {} bool run(LinearIR& linear_ir) override; private: - void set_reg_types(LinearIR& linear_ir); + using RegMap = std::map; + RegMap assign_regs_manually(const LinearIR& linear_ir, std::set& gpr_pool, std::set& vec_pool); - std::function& out)> m_reg_type_mapper; - size_t reg_count; + RegManager& m_reg_manager; }; } // namespace pass diff --git a/src/common/snippets/include/snippets/lowered/pass/init_live_ranges.hpp b/src/common/snippets/include/snippets/lowered/pass/init_live_ranges.hpp new file mode 100644 index 00000000000000..04880ef8ad6862 --- /dev/null +++ b/src/common/snippets/include/snippets/lowered/pass/init_live_ranges.hpp @@ -0,0 +1,33 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "pass.hpp" +#include "snippets/generator.hpp" +#include "snippets/lowered/reg_manager.hpp" + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { + +/** + * @interface InitLiveRanges + * @brief Calculates live ranges of registers. This information will be used to assign registers and optimize ABI reg spills. + * @ingroup snippets + */ +class InitLiveRanges : public Pass { +public: + OPENVINO_RTTI("InitLiveRanges", "Pass") + explicit InitLiveRanges(RegManager& reg_manager) : m_reg_manager(reg_manager) {} + bool run(LinearIR& linear_ir) override; +private: + RegManager& m_reg_manager; +}; + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/lowered/pass/insert_reg_spills.hpp b/src/common/snippets/include/snippets/lowered/pass/insert_reg_spills.hpp new file mode 100644 index 00000000000000..33d39c9d8a8f21 --- /dev/null +++ b/src/common/snippets/include/snippets/lowered/pass/insert_reg_spills.hpp @@ -0,0 +1,32 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "pass.hpp" +#include "snippets/lowered/reg_manager.hpp" + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { + +/** + * @interface InsertRegSpills + * @brief Insert RegSpill and RegRestore operations for binary call emitters to comply with ABI conventions. + * @ingroup snippets + */ +class InsertRegSpills : public Pass { +public: + OPENVINO_RTTI("InsertRegSpills", "Pass") + explicit InsertRegSpills(RegManager& reg_manager) : m_reg_manager(reg_manager) {} + bool run(LinearIR& linear_ir) override; + + RegManager& m_reg_manager; +}; + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/lowered/reg_manager.hpp b/src/common/snippets/include/snippets/lowered/reg_manager.hpp new file mode 100644 index 00000000000000..700e3be59f43b5 --- /dev/null +++ b/src/common/snippets/include/snippets/lowered/reg_manager.hpp @@ -0,0 +1,67 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "openvino/core/node.hpp" +#include "snippets/emitter.hpp" +#include "snippets/lowered/expression.hpp" +#include "snippets/generator.hpp" +#include "snippets/op/kernel.hpp" + +/** + * @interface RegManager + * @brief The class holds supplementary info about assigned registers and live ranges + * @ingroup snippets + */ +namespace ov { +namespace snippets { +namespace lowered { + +using RegTypeMapper = std::function& out)>; +using LiveInterval = std::pair; +class RegManager { +public: + RegManager() = delete; + RegManager(const std::shared_ptr& generator) : m_generator(generator) {} + inline RegType get_reg_type(const ov::Output& out) const { return m_generator->get_op_out_reg_type(out); } + inline std::vector get_vec_reg_pool() const { return m_generator->get_target_machine()->get_vec_reg_pool(); } + + inline void set_live_range(const Reg& reg, const LiveInterval& interval, bool force = false) { + OPENVINO_ASSERT(force || m_reg_live_range.count(reg) == 0, "Live range for this reg is already set"); + m_reg_live_range[reg] = interval; + } + + inline std::vector get_kernel_call_regs(const std::shared_ptr& kernel) const { + const auto& abi_regs = m_generator->get_target_machine()->get_abi_arg_regs(); + const auto num_kernel_args = kernel->get_num_call_args(); + OPENVINO_ASSERT(abi_regs.size() > num_kernel_args, "Too many kernel args requested"); + return {abi_regs.begin(), abi_regs.begin() + static_cast(num_kernel_args)}; + } + + inline std::vector get_gp_regs_except_kernel_call(const std::shared_ptr& kernel) const { + auto res = m_generator->get_target_machine()->get_gp_reg_pool(); + std::set kernel_call; + for (auto r : get_kernel_call_regs(kernel)) + kernel_call.insert(r); + res.erase(std::remove_if(res.begin(), res.end(), [&kernel_call](const Reg& r) {return kernel_call.count(r) != 0; }), res.end()); + return res; + } + + inline const LiveInterval& get_live_range(const Reg& reg) { + OPENVINO_ASSERT(m_reg_live_range.count(reg), "Live range for this reg was not set"); + return m_reg_live_range[reg]; + } + inline std::map get_live_range_map() const { + return m_reg_live_range; + } + +private: + // Maps Register to {Start, Stop} pairs + std::map m_reg_live_range; + const std::shared_ptr m_generator; +}; + +} // namespace lowered +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/op/kernel.hpp b/src/common/snippets/include/snippets/op/kernel.hpp index 0abcbf054b76f0..ed3bed6bc1bed0 100644 --- a/src/common/snippets/include/snippets/op/kernel.hpp +++ b/src/common/snippets/include/snippets/op/kernel.hpp @@ -22,7 +22,9 @@ class Kernel : public ov::op::Op { Kernel() = default; Kernel(lowered::LinearIR region); - static std::shared_ptr make_kernel(const lowered::LinearIR& region); + template + static std::shared_ptr make_kernel(bool is_dynamic, ArgTypes&&... args); + virtual size_t get_num_call_args() const = 0; std::shared_ptr region; const void *compile_params = nullptr; @@ -33,6 +35,7 @@ class KernelStatic : public Kernel { OPENVINO_OP("KernelStatic", "SnippetsOpset", Kernel); KernelStatic() = default; KernelStatic(lowered::LinearIR region); + size_t get_num_call_args() const override { return 2; } std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; }; @@ -41,9 +44,19 @@ class KernelDynamic : public Kernel { OPENVINO_OP("KernelDynamic", "SnippetsOpset", Kernel); KernelDynamic() = default; KernelDynamic(lowered::LinearIR region); + size_t get_num_call_args() const override { return 1; } std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; }; +template +std::shared_ptr Kernel::make_kernel(bool is_dynamic, ArgTypes&&... args) { + if (is_dynamic) { + return std::make_shared(std::forward(args)...); + } else { + return std::make_shared(std::forward(args)...); + } +} + } // namespace op } // namespace snippets } // namespace ov diff --git a/src/common/snippets/include/snippets/op/reg_spill.hpp b/src/common/snippets/include/snippets/op/reg_spill.hpp new file mode 100644 index 00000000000000..f7b46ab12a32cb --- /dev/null +++ b/src/common/snippets/include/snippets/op/reg_spill.hpp @@ -0,0 +1,82 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/emitter.hpp" + +#include "openvino/op/op.hpp" +#include "snippets/shape_inference/shape_inference.hpp" + +namespace ov { +namespace snippets { +namespace op { + +/** + * @interface RegSpillBase + * @brief Base class for RegSpillBegin and RegSpillEnd ops + * @ingroup snippets + */ +class RegSpillBase : public ov::op::Op { +public: + OPENVINO_OP("RegSpillBaseBase", "SnippetsOpset"); + RegSpillBase(const std::vector>& args); + RegSpillBase() = default; + virtual std::set get_regs_to_spill() const = 0; + bool visit_attributes(AttributeVisitor& visitor) override; +protected: +}; +class RegSpillEnd; +/** + * @interface RegSpillBegin + * @brief Marks the start of the register spill region. + * @ingroup snippets + */ +class RegSpillBegin : public RegSpillBase { +public: + OPENVINO_OP("RegSpillBegin", "SnippetsOpset", RegSpillBase); + RegSpillBegin(std::set regs_to_spill); + + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; + std::shared_ptr get_reg_spill_end() const; + std::set get_regs_to_spill() const override { return m_regs_to_spill; } + + class ShapeInfer : public IShapeInferSnippets { + size_t num_out_shapes = 0; + public: + explicit ShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; + }; +protected: + void validate_and_infer_types_except_RegSpillEnd(); + std::set m_regs_to_spill = {}; +}; +/** + * @interface RegSpillEnd + * @brief Marks the end of the register spill region. + * @ingroup snippets + */ +class RegSpillEnd : public RegSpillBase { +public: + OPENVINO_OP("RegSpillEnd", "SnippetsOpset", RegSpillBase); + RegSpillEnd() = default; + RegSpillEnd(const Output& reg_spill_begin); + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const OutputVector& inputs) const override; + std::shared_ptr get_reg_spill_begin() const { + auto reg_spill_begin = ov::as_type_ptr(get_input_node_shared_ptr(0)); + OPENVINO_ASSERT(reg_spill_begin, "Can't get reg_spill_begin from reg_spill_end"); + return reg_spill_begin; + } + std::set get_regs_to_spill() const override { + return get_reg_spill_begin()->get_regs_to_spill(); + } +}; + +} // namespace op +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/include/snippets/snippets_isa.hpp b/src/common/snippets/include/snippets/snippets_isa.hpp index 08002fa38ed309..d2a7127eea51e5 100644 --- a/src/common/snippets/include/snippets/snippets_isa.hpp +++ b/src/common/snippets/include/snippets/snippets_isa.hpp @@ -28,6 +28,7 @@ #include "op/rank_normalization.hpp" #include "op/perf_count.hpp" #include "op/reduce.hpp" +#include "op/reg_spill.hpp" namespace ov { namespace snippets { diff --git a/src/common/snippets/include/snippets/target_machine.hpp b/src/common/snippets/include/snippets/target_machine.hpp index d9d89264fe1926..a5e95a70a2ab3f 100644 --- a/src/common/snippets/include/snippets/target_machine.hpp +++ b/src/common/snippets/include/snippets/target_machine.hpp @@ -57,10 +57,24 @@ class TargetMachine { virtual size_t get_lanes() const = 0; /** - * @brief gets number of registers for a target machine - * @return number of registers + * @brief Get all possible ABI argument registers. + * The number of actually used register depends on the signature of the called binary. + * @return vector os snippets::Reg */ - virtual size_t get_reg_count() const = 0; + virtual std::vector get_abi_arg_regs() const = 0; + + /** + * @brief Get all available general-purpose registers. + * Returns only registers that are not reserved for special purposes (e.g. stack pointer or instruction address). + * @return vector os snippets::Reg + */ + virtual std::vector get_gp_reg_pool() const = 0; + /** + * @brief Get all available vector registers. + * Returns only registers that are not reserved for special purposes + * @return vector os snippets::Reg + */ + virtual std::vector get_vec_reg_pool() const = 0; /** * @brief called by generator to all the emitter for a target machine diff --git a/src/common/snippets/include/snippets/utils/reg_utils.hpp b/src/common/snippets/include/snippets/utils/reg_utils.hpp new file mode 100644 index 00000000000000..04d622837aa5a9 --- /dev/null +++ b/src/common/snippets/include/snippets/utils/reg_utils.hpp @@ -0,0 +1,30 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/emitter.hpp" + +namespace ov { +namespace snippets { +namespace utils { +inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs, snippets::RegType expected_type) { + std::vector idxs; + idxs.reserve(regs.size()); + for (const auto& reg : regs) { + OPENVINO_ASSERT(expected_type == snippets::RegType::undefined || reg.type == expected_type, "Reg type mismatch during to_idxs conversion"); + idxs.emplace_back(reg.idx); + } + return idxs; +} +inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs) { + std::vector idxs; + std::transform(regs.begin(), regs.end(), std::back_inserter(idxs), [](const snippets::Reg& r) { return r.idx; }); + return idxs; +} + + +} // namespace utils +} // namespace snippets +} // namespace ov \ No newline at end of file diff --git a/src/common/snippets/src/emitter.cpp b/src/common/snippets/src/emitter.cpp index 26745531cc2ede..382927460f3d1b 100644 --- a/src/common/snippets/src/emitter.cpp +++ b/src/common/snippets/src/emitter.cpp @@ -13,16 +13,34 @@ bool operator==(const Reg& lhs, const Reg& rhs) { bool operator!=(const Reg& lhs, const Reg& rhs) { return !(lhs == rhs); } +bool operator<(const Reg& lhs, const Reg& rhs) { + return lhs.type < rhs.type || + (lhs.type == rhs.type && lhs.idx < rhs.idx); +} +bool operator>(const Reg& lhs, const Reg& rhs) { + return lhs.type > rhs.type || + (lhs.type == rhs.type && lhs.idx > rhs.idx); +} -std::string regTypeToStr(const RegType& type) { - switch (type) { - case RegType::vec: - return "vec"; - case RegType::gpr: - return "gpr"; - default: - OPENVINO_THROW("Unexpected RegType"); - } +std::ostream& operator<<(std::ostream& s, const Reg& r) { + auto regTypeToStr = [](const RegType& type) { + switch (type) { + case RegType::vec: + return "vec"; + case RegType::gpr: + return "gpr"; + case RegType::mask: + return "mask"; + case RegType::undefined: + return "undefined"; + default: + OPENVINO_THROW("Unexpected RegType"); + } + }; + s << regTypeToStr(r.type) << "[" << + (r.idx == Reg::UNDEFINED_IDX ? "undefined" : std::to_string(r.idx)) + << "]"; + return s; } } // namespace snippets diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index d059ddd94d5724..779a68fb52bd9f 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -8,6 +8,8 @@ #include "snippets/runtime_configurator.hpp" #include "snippets/lowered/linear_ir.hpp" #include "snippets/lowered/expression.hpp" +#include "snippets/lowered/reg_manager.hpp" +#include "snippets/utils/reg_utils.hpp" #include "snippets/op/kernel.hpp" #include "snippets/op/memory_access.hpp" @@ -26,13 +28,15 @@ LoweringResult Generator::generate(const lowered::LinearIRPtr& linear_ir, const linear_ir->init_emitters(target); OV_ITT_TASK_NEXT(GENERATE, "::EmitCode") - - const auto kernel_op = op::Kernel::make_kernel(*linear_ir); + const auto kernel_op = op::Kernel::make_kernel(linear_ir->is_dynamic(), *linear_ir); kernel_op->compile_params = compile_params; + const lowered::RegManager& reg_manager(shared_from_this()); const auto kernel_expr = linear_ir->get_expr_factory()->build(kernel_op, std::vector{}); const auto kernel = target->get(kernel_expr->get_node()->get_type_info())(kernel_expr); - kernel->emit_code({}, {}); + kernel->emit_code(utils::transform_snippets_regs_to_idxs(reg_manager.get_kernel_call_regs(kernel_op)), {}, + utils::transform_snippets_regs_to_idxs(reg_manager.get_vec_reg_pool()), + utils::transform_snippets_regs_to_idxs(reg_manager.get_gp_regs_except_kernel_call(kernel_op))); OV_ITT_TASK_NEXT(GENERATE, "::EmitData") for (auto& l : linear_ir->get_ops()) { @@ -69,37 +73,37 @@ RegType Generator::get_op_out_reg_type(const ov::Output& out) const { if (reg_type != RegType::undefined) return reg_type; const auto op = out.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) + if (is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) #ifdef SNIPPETS_DEBUG_CAPS - || std::dynamic_pointer_cast(op) - || std::dynamic_pointer_cast(op) + || is_type(op) + || is_type(op) #endif ) return RegType::gpr; - else if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || + else if (is_type(op) || + is_type(op) || ov::op::util::is_unary_elementwise_arithmetic(op) || ov::op::util::is_binary_elementwise_arithmetic(op) || ov::op::util::is_binary_elementwise_comparison(op) || ov::op::util::is_binary_elementwise_logical(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op) || + is_type(op)) return RegType::vec; else OPENVINO_THROW("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!"); diff --git a/src/common/snippets/src/lowered/expression.cpp b/src/common/snippets/src/lowered/expression.cpp index 1952b93017aab5..245470ae1a48af 100644 --- a/src/common/snippets/src/lowered/expression.cpp +++ b/src/common/snippets/src/lowered/expression.cpp @@ -182,8 +182,7 @@ bool Expression::visit_attributes(AttributeVisitor &visitor) { return ss.str(); }; - std::vector in_regs, out_regs; - std::vector in_reg_types, out_reg_types; + std::ostringstream in_regs, out_regs; std::vector> shapes; std::vector> subtensors; std::vector>> layouts; @@ -201,8 +200,7 @@ bool Expression::visit_attributes(AttributeVisitor &visitor) { if (!layout.empty() && !utils::is_planar_layout(layout)) layouts.emplace_back("in_layout_" + std::to_string(i), layout); - in_reg_types.emplace_back(regTypeToStr(desc->get_reg().type)); - in_regs.emplace_back(desc->get_reg().idx); + in_regs << desc->get_reg() << " "; } for (size_t i = 0; i < get_output_count(); i++) { const auto& desc = m_output_port_descriptors[i]; @@ -218,17 +216,16 @@ bool Expression::visit_attributes(AttributeVisitor &visitor) { if (!layout.empty() && !utils::is_planar_layout(layout)) layouts.emplace_back("out_layout_" + std::to_string(i), layout); - out_reg_types.emplace_back(regTypeToStr(desc->get_reg().type)); - out_regs.emplace_back(desc->get_reg().idx); + out_regs << desc->get_reg() << " "; } - if (!in_regs.empty()) { - visitor.on_attribute("in_regs", in_regs); - visitor.on_attribute("in_reg_types", in_reg_types); + if (!in_regs.str().empty()) { + std::vector tmp {in_regs.str()}; + visitor.on_attribute("in_regs", tmp); } - if (!out_regs.empty()) { - visitor.on_attribute("out_regs", out_regs); - visitor.on_attribute("out_reg_types", out_reg_types); + if (!out_regs.str().empty()) { + std::vector tmp {out_regs.str()}; + visitor.on_attribute("out_regs", tmp); } for (auto& s : shapes) visitor.on_attribute(s.first, s.second); diff --git a/src/common/snippets/src/lowered/expression_factory.cpp b/src/common/snippets/src/lowered/expression_factory.cpp index 668df3b65c415e..139abfd3cb1b55 100644 --- a/src/common/snippets/src/lowered/expression_factory.cpp +++ b/src/common/snippets/src/lowered/expression_factory.cpp @@ -20,6 +20,10 @@ std::shared_ptr ExpressionFactory::build(const std::shared_ptr return create(loop_begin, inputs, m_shape_infer_factory); } else if (const auto loop_end = ov::as_type_ptr(n)) { return create(loop_end, inputs, m_shape_infer_factory); + } else if (const auto spill_begin = ov::as_type_ptr(n)) { + return create(spill_begin, inputs, m_shape_infer_factory); + } else if (const auto spill_end = ov::as_type_ptr(n)) { + return create(spill_end, inputs, m_shape_infer_factory); } else if (const auto buffer = ov::as_type_ptr(n)) { return create(buffer, inputs, m_shape_infer_factory); #ifdef SNIPPETS_DEBUG_CAPS @@ -110,6 +114,36 @@ ExpressionPtr ExpressionFactory::create(const std::shared_ptr& n, c return expr; } +ExpressionPtr ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs, + const std::shared_ptr& shape_infer_factory) { + auto expr = std::shared_ptr(new Expression(n, shape_infer_factory, false)); + OPENVINO_ASSERT(inputs.empty(), "RegSpillBegin expression expects no inputs"); + const auto num_to_spill = n->get_regs_to_spill().size(); + expr->m_output_port_descriptors.resize(num_to_spill, nullptr); + for (size_t i = 0; i < num_to_spill; i++) + expr->m_output_port_descriptors[i] = std::make_shared(); + expr->m_output_port_connectors.resize(num_to_spill, nullptr); + for (size_t i = 0; i < num_to_spill; i++) { + const auto source = expr->get_output_port(i); + expr->m_output_port_connectors[i] = std::make_shared(source); + } + expr->validate(); + return expr; +} + +ExpressionPtr ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs, + const std::shared_ptr& shape_infer_factory) { + auto expr = std::shared_ptr(new Expression(n, shape_infer_factory, false)); + const auto spill_begin_node = n->get_reg_spill_begin(); + const auto num_to_spill = spill_begin_node->get_regs_to_spill().size(); + OPENVINO_ASSERT(inputs.size() == num_to_spill, "Invalid num inputs for RegSpillEnd expression"); + expr->m_input_port_descriptors.resize(num_to_spill, std::make_shared()); + init_expression_inputs(expr, inputs); + expr->m_output_port_descriptors.clear(); + expr->validate(); + return expr; +} + #ifdef SNIPPETS_DEBUG_CAPS ExpressionPtr ExpressionFactory::create(const std::shared_ptr& n, const std::vector& inputs, const std::shared_ptr& shape_infer_factory) { diff --git a/src/common/snippets/src/lowered/linear_ir.cpp b/src/common/snippets/src/lowered/linear_ir.cpp index 36ab2e235880af..cff3bcbe927d04 100644 --- a/src/common/snippets/src/lowered/linear_ir.cpp +++ b/src/common/snippets/src/lowered/linear_ir.cpp @@ -126,10 +126,10 @@ void LinearIR::debug_print(bool tds_as_pointers) const { auto print_rinfo = [](const RegInfo& rinfo) { std::cerr << " : {"; for (auto i : rinfo.first) - std::cerr << regTypeToStr(i.type) << "[" << i.idx << "] "; + std::cerr << i << " "; std::cerr << " => "; for (auto i : rinfo.second) - std::cerr << regTypeToStr(i.type) << "[" << i.idx << "] "; + std::cerr << i << " "; std::cerr << "}"; }; std::map td2int; diff --git a/src/common/snippets/src/lowered/pass/assign_registers.cpp b/src/common/snippets/src/lowered/pass/assign_registers.cpp index 2f921214bffed4..6b7aebbeafc204 100644 --- a/src/common/snippets/src/lowered/pass/assign_registers.cpp +++ b/src/common/snippets/src/lowered/pass/assign_registers.cpp @@ -1,4 +1,4 @@ -// Copyright (C) 2023 Intel Corporation +// Copyright (C) 2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -9,6 +9,7 @@ #include "snippets/itt.hpp" #include "snippets/utils/utils.hpp" + // This header is needed to avoid MSVC warning "C2039: 'inserter': is not a member of 'std'" #include @@ -17,273 +18,120 @@ namespace snippets { namespace lowered { namespace pass { -void AssignRegisters::set_reg_types(LinearIR& linear_ir) { - for (const auto& expr : linear_ir) { - const auto op = expr->get_node(); - if (ov::is_type(op) || - ov::is_type(op) -#ifdef SNIPPETS_DEBUG_CAPS - || ov::is_type(op) - || ov::is_type(op) -#endif - ) - continue; - - OPENVINO_ASSERT(expr->get_output_count() == op->get_output_size(), "Incorrect count of output port descriptors!"); - for (size_t i = 0; i < expr->get_output_count(); ++i) { - const auto reg_type = m_reg_type_mapper(op->output(i)); - expr->get_output_port_descriptor(i)->set_reg_type(reg_type); - // propogate to consumers - for (const auto& consumer : expr->get_output_port_connector(i)->get_consumers()) { - consumer.get_descriptor_ptr()->set_reg_type(reg_type); - } - } +AssignRegisters::RegMap AssignRegisters::assign_regs_manually(const LinearIR& linear_ir, std::set& gpr_pool, std::set& vec_pool) { + RegMap manually_assigned; + for (const auto& param : linear_ir.get_parameters()) { + manually_assigned[param->get_output_port_descriptor(0)->get_reg()] = *gpr_pool.begin(); + gpr_pool.erase(gpr_pool.begin()); } -} - -bool AssignRegisters::run(LinearIR& linear_ir) { - OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AssignRegisters") - using Reg = size_t; - using tensor = PortConnectorPtr; - - set_reg_types(linear_ir); - const auto& exprs = linear_ir.get_ops(); - const auto& params = linear_ir.get_parameters(); - const auto& results = linear_ir.get_results(); - Reg num_expressions = exprs.size(); - Reg num_parameters = params.size(); - Reg num_results = results.size(); - - size_t io_index = 0; - // Define a set of immune tensors that will be ignored by auto reg allocation => their reg allocation is done manually - std::map manually_assigned_gprs, manually_assigned_vecs; - for (const auto& param : params) { - manually_assigned_gprs[param->get_output_port_connector(0)] = io_index; - // TODO [96434]: Support shape infer ops in arbitrary place in pipeline, not just after inputs - // shape infer ops sequence after input - const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(param); - for (const auto& child_shape_infer_expr : shape_infer_consumers) { - manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = io_index; - } - io_index++; - } - for (const auto& result : results) { - manually_assigned_gprs[result->get_input_port_connector(0)] = io_index; - // shape infer ops sequence before result - const auto& shape_infer_sources = utils::get_first_parent_shape_infer_expr_seq(result); - for (const auto& parent_shape_infer_expr : shape_infer_sources) { - manually_assigned_gprs[parent_shape_infer_expr->get_input_port_connector(0)] = io_index; - } - io_index++; + for (const auto& result : linear_ir.get_results()) { + manually_assigned[result->get_input_port_descriptor(0)->get_reg()] = *gpr_pool.begin(); + gpr_pool.erase(gpr_pool.begin()); } - size_t counter_vec = 0; - size_t counter_gpr = 0; - std::map regs_vec, regs_gpr; - const auto IS_MANUALLY_ALLOCATED_REG = SIZE_MAX; - auto accumulator_reg = 0lu; - for (const auto& expr : exprs) { + long int max_buffer_group = -1; + for (const auto& expr : linear_ir) { auto op = expr->get_node(); - if (const auto& buffer_expr = ov::as_type_ptr(expr)) { - const auto reg_group = buffer_expr->get_reg_group(); + if (const auto& buffer = ov::as_type_ptr(expr)) { // All buffers have one common data pointer - const auto assigned_reg = num_results + num_parameters + reg_group; - for (const auto& input : expr->get_input_port_connectors()) { - manually_assigned_gprs[input] = static_cast(assigned_reg); - // shape infer ops in the middle of subgraph. Buffer is inserted before reshape as new loop should start. - // child shape info ops share the same memory as Buffer. - const auto& shape_infer_consumers = utils::get_first_child_shape_infer_expr_seq(expr); - for (const auto& child_shape_infer_expr : shape_infer_consumers) { - manually_assigned_gprs[child_shape_infer_expr->get_input_port_connector(0)] = - manually_assigned_gprs[child_shape_infer_expr->get_output_port_connector(0)] = - static_cast(assigned_reg); - } - } - manually_assigned_gprs[expr->get_output_port_connector(0)] = static_cast(assigned_reg); + const auto reg_group = static_cast(buffer->get_reg_group()); + max_buffer_group = std::max(max_buffer_group, reg_group); + const auto& assigned = *std::next(gpr_pool.begin(), reg_group); + const auto& out_reg = expr->get_output_port_descriptor(0)->get_reg(); + manually_assigned[out_reg] = assigned; + // Buffer abstract registers validation: + bool all_equal = true; + for (const auto& pd : expr->get_input_port_descriptors()) + all_equal &= pd->get_reg() == out_reg; + for (const auto& pd : expr->get_output_port_descriptors()) + all_equal &= pd->get_reg() == out_reg; + OPENVINO_ASSERT(all_equal, "Buffer must have same register on all inputs and outputs"); } else if (ov::is_type(op) || ov::is_type(op)) { // Only in ReduceDecomposition Reduce ops use HorizonMax/HorizonSum and VectorBuffer. // We should manually set the one vector register for VectorBuffer and Max/Sum output to simulate a accumulator // TODO [96351]: We should rewrite accumulator pattern using another way const auto& input_tensor = expr->get_input_port_connector(0); - const auto& input_expr = input_tensor->get_source().get_expr(); - const auto& input_expr_input_tensors = input_expr->get_input_port_connectors(); - for (const auto& tensor : input_expr_input_tensors) { - const auto parent_expr = tensor->get_source().get_expr(); + const auto& input = input_tensor->get_source(); + const auto& assigned = *vec_pool.begin(); + for (const auto& tensor : input.get_expr()->get_input_port_connectors()) { + const auto parent = tensor->get_source(); + const auto parent_expr = parent.get_expr(); if (ov::is_type(parent_expr->get_node())) { if (ov::is_type(parent_expr->get_input_port_connector(0)->get_source().get_expr()->get_node())) { - manually_assigned_vecs[tensor] = static_cast(accumulator_reg); - manually_assigned_vecs[parent_expr->get_input_port_connector(0)] = static_cast(accumulator_reg); + manually_assigned[parent.get_descriptor_ptr()->get_reg()] = + manually_assigned[parent_expr->get_input_port_descriptor(0)->get_reg()] = assigned; } } } - manually_assigned_vecs[input_tensor] = static_cast(accumulator_reg); - accumulator_reg++; + manually_assigned[input.get_descriptor_ptr()->get_reg()] = assigned; + vec_pool.erase(vec_pool.begin()); } } - // Note: have to specify default capture "=" due to MSVC bug (it doesn't capture const expressions implicitly) - // Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified" - // the same problem with all the other lambdas in this file - auto enumerate_out_tensor = [=] (const tensor& out_tensor, - decltype(regs_vec)& reg_map, - const std::map& manually_assigned_regs, - size_t& counter) { - // Note that some ops might have identical input&output tensors (Result and Tile* for ex.) - // so we have to check that the tensor has not been enumerated already - if (reg_map.count(out_tensor) == 0) { - reg_map[out_tensor] = manually_assigned_regs.count(out_tensor) == 0 ? counter++ : IS_MANUALLY_ALLOCATED_REG; - } - }; - for (const auto& expr : exprs) { - for (size_t i = 0; i < expr->get_output_count(); ++i) { - const auto& out = expr->get_output_port(i); - switch (out.get_descriptor_ptr()->get_reg().type) { - case RegType::vec: - enumerate_out_tensor(out.get_port_connector_ptr(), regs_vec, manually_assigned_vecs, counter_vec); - break; - case RegType::gpr: - enumerate_out_tensor(out.get_port_connector_ptr(), regs_gpr, manually_assigned_gprs, counter_gpr); - break; - default: - OPENVINO_THROW("Unsupported reg type detected"); - } - } - } - // todo: make one for gpr and one for vector - std::vector> used_gpr, used_vec; // used = used as an input - std::vector> defined_gpr, defined_vec; // defined = used as output - used_gpr.reserve(num_expressions); - used_vec.reserve(num_expressions); - defined_gpr.reserve(num_expressions); - defined_vec.reserve(num_expressions); + gpr_pool.erase(gpr_pool.begin(), std::next(gpr_pool.begin(), max_buffer_group + 1)); + return manually_assigned; +} - auto tensor2reg = [=] (const std::vector& tensors, const std::map& reg_map) { - std::set result; - for (const auto& t : tensors) { - if (reg_map.count(t) == 0) - OPENVINO_THROW("Assign registers: attempt to access not enumerated tensor"); - Reg reg_id = reg_map.at(t); - if (reg_id != IS_MANUALLY_ALLOCATED_REG) - result.insert(reg_id); - } - return result; +bool AssignRegisters::run(LinearIR& linear_ir) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::AssignRegisters") + + const auto& exprs = linear_ir.get_ops(); + + const auto& kernel = snippets::op::Kernel::make_kernel(linear_ir.is_dynamic()); + auto vec2set = [](std::vector&& v){ + std::set res; + std::copy(v.begin(), v.end(), std::inserter(res, res.begin())); + return res; }; - for (const auto& expr : exprs) { - std::vector used_gpr_tensors, used_vec_tensors, defined_gpr_tensors, defined_vec_tensors; - for (size_t i = 0; i < expr->get_input_count(); ++i) { - const auto& in = expr->get_input_port(i); - switch (in.get_descriptor_ptr()->get_reg().type) { - case RegType::vec: - used_vec_tensors.push_back(in.get_port_connector_ptr()); - break; - case RegType::gpr: - used_gpr_tensors.push_back(in.get_port_connector_ptr()); - break; - default: - OPENVINO_THROW("Unsupported reg type detected"); - } - } - for (size_t i = 0; i < expr->get_output_count(); ++i) { - const auto& out = expr->get_output_port(i); - switch (out.get_descriptor_ptr()->get_reg().type) { - case RegType::vec: - defined_vec_tensors.push_back(out.get_port_connector_ptr()); - break; - case RegType::gpr: - defined_gpr_tensors.push_back(out.get_port_connector_ptr()); - break; - default: - OPENVINO_THROW("Unsupported reg type detected"); - } - } - used_vec.emplace_back(tensor2reg(used_vec_tensors, regs_vec)); - used_gpr.emplace_back(tensor2reg(used_gpr_tensors, regs_gpr)); - defined_vec.emplace_back(tensor2reg(defined_vec_tensors, regs_vec)); - defined_gpr.emplace_back(tensor2reg(defined_gpr_tensors, regs_gpr)); - } - // define life intervals - // liveOut[i] - regs that are live on exit from i-th (topologically ordered) operation - // liveIn[i] - regs that are live on entering the i-th (topologically ordered) operation - std::vector> life_in_vec(std::move(used_vec)), - life_in_gpr(std::move(used_gpr)); - std::vector> life_out_vec(num_expressions, std::set()), - life_out_gpr(num_expressions, std::set()); + std::set global_regs = vec2set(m_reg_manager.get_kernel_call_regs(kernel)); + std::set gpr_pool = vec2set(m_reg_manager.get_gp_regs_except_kernel_call(kernel)); + std::set vec_pool = vec2set(m_reg_manager.get_vec_reg_pool()); + auto assigned_reg_map = assign_regs_manually(linear_ir, gpr_pool, vec_pool); + + for (const auto& item : assigned_reg_map) + global_regs.insert(item.second); - // todo: this part if O(N*N), so it's slow for large subgraphs. Can we simplify it? At least add an early stopping criteria - for (size_t i = 0; i < num_expressions; i++) { - for (size_t n = 0; n < num_expressions; n++) { - // Regs that are live on entering the operation = regs used by the op + (all other regs alive - regs defined by the op) - // copy regs from lifeOut to lifeIn while ignoring regs in def - std::set_difference(life_out_gpr[n].begin(), life_out_gpr[n].end(), - defined_gpr[n].begin(), defined_gpr[n].end(), - std::inserter(life_in_gpr[n], life_in_gpr[n].begin())); - std::set_difference(life_out_vec[n].begin(), life_out_vec[n].end(), - defined_vec[n].begin(), defined_vec[n].end(), - std::inserter(life_in_vec[n], life_in_vec[n].begin())); - } - size_t n = 0; - for (const auto& expr : exprs) { - if (is_type(expr->get_node())) - continue; - for (const auto& out : expr->get_output_port_connectors()) { - for (const auto& child_expr_input : out->get_consumers()) { - const auto& child_expr = child_expr_input.get_expr(); - auto child_it = linear_ir.begin(); - std::advance(child_it, n); - size_t k = n; - while (child_it != linear_ir.end() && *child_it != child_expr) { - child_it++; - k++; - } - if (k == num_expressions) - OPENVINO_THROW("assign registers can't find target op in the body"); - life_out_vec[n].insert(life_in_vec[k].begin(), life_in_vec[k].end()); - life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end()); - } - } - n++; - } - } struct by_starting { - auto operator()(const std::pair& lhs, const std::pair& rhs) const -> bool { + auto operator()(const LiveInterval& lhs, const LiveInterval& rhs) const -> bool { return lhs.first < rhs.first|| (lhs.first == rhs.first && lhs.second < rhs.second); } }; struct by_ending { - auto operator()(const std::pair& lhs, const std::pair& rhs) const -> bool { + auto operator()(const LiveInterval& lhs, const LiveInterval& rhs) const -> bool { return lhs.second < rhs.second || (lhs.second == rhs.second && lhs.first < rhs.first); } }; - // A variable live interval - is a range (start, stop) of op indexes, such that - // the variable is alive within this range (defined but not used by the last user) - std::map, Reg, by_starting> live_intervals_vec, live_intervals_gpr; - std::reverse(life_in_vec.begin(), life_in_vec.end()); - std::reverse(life_in_gpr.begin(), life_in_gpr.end()); - auto find_last_use = [](decltype(life_in_gpr) life_in, int i) -> int { - int ln = static_cast(life_in.size()) - 1; - for (auto& x : life_in) { - if (x.find(i) != x.end()) { - return ln; - } - ln--; + // A variable live LiveInterval - is a range (start, stop) of op indexes, such that + // the variable is alive within this range (defined but not used by the last user) + std::map live_intervals_vec, live_intervals_gpr; + for (const auto& regint : m_reg_manager.get_live_range_map()) { + const auto& reg = regint.first; + const auto& interval = regint.second; + // If a register is assigned manually, we should ignore it during automatic assignment + if (assigned_reg_map.count(reg)) + continue; + switch (reg.type) { + case (RegType::gpr): + OPENVINO_ASSERT(!live_intervals_gpr.count(interval), "GPR live interval is already in the map"); + live_intervals_gpr[interval] = reg; + break; + case (RegType::vec): + OPENVINO_ASSERT(!live_intervals_vec.count(interval), "VEC live interval is already in the map"); + live_intervals_vec[interval] = reg; + break; + case (RegType::undefined): + default: + OPENVINO_THROW("Unhandled register type"); } - return i; - }; - for (int i = 0; i < static_cast(num_expressions); i++) { - for (const auto& def : defined_vec[i]) - live_intervals_vec[std::make_pair(i, find_last_use(life_in_vec, static_cast(def)))] = def; - for (const auto& def : defined_gpr[i]) - live_intervals_gpr[std::make_pair(i, find_last_use(life_in_gpr, static_cast(def)))] = def; } auto linescan_assign_registers = [](const decltype(live_intervals_vec)& live_intervals, const std::set& reg_pool) { // http://web.cs.ucla.edu/~palsberg/course/cs132/linearscan.pdf - // todo: do we need multimap? <=> can an op have two inputs from the same op? - std::map, Reg, by_ending> active; + std::map active; // uniquely defined register => reused reg (reduced subset enabled by reg by reusage) std::map register_map; std::stack bank; @@ -291,7 +139,7 @@ bool AssignRegisters::run(LinearIR& linear_ir) { for (auto rit = reg_pool.crbegin(); rit != reg_pool.crend(); rit++) bank.push(*rit); - std::pair interval, active_interval; + LiveInterval interval, active_interval; Reg unique_reg, active_unique_reg; for (const auto& interval_reg : live_intervals) { std::tie(interval, unique_reg) = interval_reg; @@ -306,51 +154,29 @@ bool AssignRegisters::run(LinearIR& linear_ir) { bank.push(register_map[active_unique_reg]); } // allocate - if (active.size() == reg_pool.size()) { - // todo: if it is LoopBegin or LoopEnd that requires gpr, and we don't have any in the pool, - // then assign SIZE_MAX-1 as a flag to spill a reg inside emitter - OPENVINO_THROW("can't allocate registers for a snippet "); - } else { - register_map[unique_reg] = bank.top(); - bank.pop(); - active.insert(interval_reg); - } + OPENVINO_ASSERT(active.size() != reg_pool.size(), "Can't allocate registers for a snippet: not enough registers"); + register_map[unique_reg] = bank.top(); + bank.pop(); + active.insert(interval_reg); } return register_map; }; - // todo: vec_/gpr_pool are hardware-specific and should be provided by a backend, e.g. overloaded generator - std::set vec_pool; - for (Reg i = 0; i < reg_count; i++) - vec_pool.insert(i); - std::set gpr_pool(vec_pool); - for (const auto& t_reg : manually_assigned_vecs) - vec_pool.erase(t_reg.second); - for (const auto& t_reg : manually_assigned_gprs) - gpr_pool.erase(t_reg.second); - auto unique2reused_map_vec = linescan_assign_registers(live_intervals_vec, vec_pool); - auto unique2reused_map_gpr = linescan_assign_registers(live_intervals_gpr, gpr_pool); - std::map assigned_regs(std::move(manually_assigned_gprs)); - assigned_regs.insert(manually_assigned_vecs.begin(), manually_assigned_vecs.end()); - auto register_assigned_regs = [=, &assigned_regs](const std::map& unique_regs, const std::map& unique2reused) { - for (const auto& reg : unique_regs) { - if (reg.second == IS_MANUALLY_ALLOCATED_REG) - continue; - if (unique2reused.count(reg.second) == 0) - OPENVINO_THROW("Assign registers failed to allocate register for a tensor"); - assigned_regs[reg.first] = unique2reused.at(reg.second); - } - }; - register_assigned_regs(regs_vec, unique2reused_map_vec); - register_assigned_regs(regs_gpr, unique2reused_map_gpr); + const auto& map_vec = linescan_assign_registers(live_intervals_vec, vec_pool); + assigned_reg_map.insert(map_vec.begin(), map_vec.end()); + const auto& map_gpr = linescan_assign_registers(live_intervals_gpr, gpr_pool); + assigned_reg_map.insert(map_gpr.begin(), map_gpr.end()); for (const auto& expr : exprs) { - for (size_t i = 0; i < expr->get_input_count(); ++i) { - expr->get_input_port_descriptor(i)->set_reg_idx(assigned_regs[expr->get_input_port_connector(i)]); - } - for (size_t i = 0; i < expr->get_output_count(); ++i) { - expr->get_output_port_descriptor(i)->set_reg_idx(assigned_regs[expr->get_output_port_connector(i)]); - } + // Note: manually assigned regs are always live => add them to all expressions + std::set mapped_live_regs = global_regs; + for (const auto& live_reg : expr->get_live_regs()) + mapped_live_regs.insert(assigned_reg_map[live_reg]); + expr->set_live_regs(mapped_live_regs); + for (const auto& in : expr->get_input_port_descriptors()) + in->set_reg(assigned_reg_map[in->get_reg()]); + for (const auto& out : expr->get_output_port_descriptors()) + out->set_reg(assigned_reg_map[out->get_reg()]); } return false; } diff --git a/src/common/snippets/src/lowered/pass/init_live_ranges.cpp b/src/common/snippets/src/lowered/pass/init_live_ranges.cpp new file mode 100644 index 00000000000000..c3bf909adad8be --- /dev/null +++ b/src/common/snippets/src/lowered/pass/init_live_ranges.cpp @@ -0,0 +1,94 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/lowered/pass/init_live_ranges.hpp" +#include "snippets/itt.hpp" +#include "snippets/op/subgraph.hpp" +#include "snippets/lowered/expressions/buffer_expression.hpp" + +#include "snippets/lowered/pass/serialize_control_flow.hpp" +#include "snippets/lowered/pass/serialize_data_flow.hpp" + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { +namespace { +// Expressions that don't affect lifetime of registers, e.g. Buffer or RankNormalization +inline bool pass_through_expr(const ExpressionPtr& expr) { + const auto& node = expr->get_node(); + return op::Subgraph::is_shape_infer_op(node) +#ifdef SNIPPETS_DEBUG_CAPS + || ov::is_type(node) + || ov::is_type(node) +#endif + || ov::is_type(expr); +} + +} // namespace + +bool InitLiveRanges::run(LinearIR& linear_ir) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InitLiveRanges") + std::map reg_counter; + + // Note: map expiring time to register + std::map> regs_to_expire; + + for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) { + const auto& expr = *expr_it; + const auto op = expr->get_node(); + if (pass_through_expr(expr)) { + expr->set_live_regs(std::prev(expr_it)->get()->get_live_regs()); + continue; + } + + OPENVINO_ASSERT(expr->get_output_count() == op->get_output_size() || + ov::is_type(op) || + ov::is_type(op), "Incorrect count of output port descriptors!"); + const double start = expr->get_exec_num(); + // Remove all regs that expired before start + regs_to_expire.erase(regs_to_expire.begin(), regs_to_expire.lower_bound(start)); // remove all elements lower than start (not equal) + std::set live_regs; + for (const auto& time_reg : regs_to_expire) + live_regs.insert(time_reg.second.begin(), time_reg.second.end()); + + expr->set_live_regs(std::move(live_regs)); + + for (size_t i = 0; i < expr->get_output_count(); ++i) { + const auto& out_pd = expr->get_output_port_descriptor(i); + if (out_pd->get_reg().is_defined()) + continue; + const auto reg_type = m_reg_manager.get_reg_type(op->output(i)); + const auto& reg = Reg(reg_type, reg_counter[reg_type]++); + double stop = start; + // propagate to consumers + std::stack to_visit; + to_visit.push(expr->get_output_port_connector(i)); + while (!to_visit.empty()) { + const auto& current = to_visit.top(); + current->get_source().get_descriptor_ptr()->set_reg(reg); + to_visit.pop(); + for (const auto& consumer : current->get_consumers()) { + consumer.get_descriptor_ptr()->set_reg(reg); + const auto& consumer_expr = consumer.get_expr(); + stop = std::max(stop, consumer_expr->get_exec_num()); + if (pass_through_expr(consumer_expr)) { + for (const auto& connector : consumer_expr->get_output_port_connectors()) + to_visit.push(connector); + } + } + } + regs_to_expire[stop].insert(reg); + m_reg_manager.set_live_range(reg, std::make_pair(start, stop)); + } + } + + return false; +} + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov + diff --git a/src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp b/src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp index 541628f71928f6..c6c26cec16ef96 100644 --- a/src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp +++ b/src/common/snippets/src/lowered/pass/insert_broadcastmove.cpp @@ -75,6 +75,9 @@ bool InsertBroadcastMove::run(LinearIR& linear_ir, lowered::LinearIR::constExprI const auto broadcast = std::make_shared(node->get_input_source_output(i), broadcasted_dim); const auto broadcast_expr = *linear_ir.insert_node(broadcast, std::vector{ input }, expr->get_loop_ids(), true, expr_it, { expr->get_input_port(i) }); + // Note: We have to set live regs manually, since this transformation is applied after all register-related passes. + // Since BroadcastMove sets in_regs the same as out_regs, live regs are the same as for the child. + broadcast_expr->set_live_regs(expr->get_live_regs()); // Note that BroadcastMove modified the next expr input shape, so we need to set update // expr's input port descriptor to reflect the changes expr->get_input_port_descriptor(i)->set_shape(broadcast_expr->get_output_port_descriptor(0)->get_shape()); diff --git a/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp b/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp new file mode 100644 index 00000000000000..9c9d4e0c530fd4 --- /dev/null +++ b/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp @@ -0,0 +1,88 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/lowered/pass/insert_reg_spills.hpp" + +#include "snippets/lowered/linear_ir.hpp" +#include "snippets/op/reg_spill.hpp" +#include "snippets/op/brgemm.hpp" +#include "snippets/itt.hpp" +#include "snippets/utils/utils.hpp" + + +namespace ov { +namespace snippets { +namespace lowered { +namespace pass { + +bool InsertRegSpills::run(LinearIR& linear_ir) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::InsertRegSpills") + + auto needs_reg_spill = [](const ExpressionPtr& expr) { + return ov::is_type(expr->get_node()); + }; + bool modified = false; + for (auto it = linear_ir.begin(); it != linear_ir.end(); it++) { + const auto& expr = *it; + if (!needs_reg_spill(expr)) + continue; + auto start_it = std::prev(it); + auto stop_it = std::next(it); + while (ov::is_type(start_it->get()->get_node()) && + ov::is_type(stop_it->get()->get_node())) { + start_it--; + stop_it++; + } + // Note: we need to insert immediately before LoopBegin => increment start_it + start_it++; + const auto& loop_begin_live = start_it->get()->get_live_regs(); + std::set brgemm_used; + const auto& brgemm_reg_info = expr->get_reg_info(); + brgemm_used.insert(brgemm_reg_info.first.begin(), brgemm_reg_info.first.end()); + brgemm_used.insert(brgemm_reg_info.second.begin(), brgemm_reg_info.second.end()); + // Note: before the loop, we need to spill all live regs except for the ones used by brgemm + std::set regs_to_spill; + std::set_difference(loop_begin_live.begin(), loop_begin_live.end(), + brgemm_used.begin(), brgemm_used.end(), + std::inserter(regs_to_spill, regs_to_spill.begin())); + // we also need to keep kernel regs alive (actually only abi_param_1 is used in emitters, but save all for consistency) + for (const auto& r : m_reg_manager.get_kernel_call_regs( snippets::op::Kernel::make_kernel(linear_ir.is_dynamic()))) + regs_to_spill.erase(r); + if (regs_to_spill.empty()) + continue; + // All spilled regs are not live anymore => update live_regs for affected expressions + for (auto affected_it = start_it; affected_it != stop_it; affected_it++) { + const auto& affected_expr = *affected_it; + const auto& live_old = affected_expr->get_live_regs(); + std::set live_new; + std::set_difference(live_old.begin(), live_old.end(), + regs_to_spill.begin(), regs_to_spill.end(), + std::inserter(live_new, live_new.begin())); + affected_expr->set_live_regs(live_new); + } + + const auto begin = std::make_shared(regs_to_spill); + const auto end = std::make_shared(begin); + const auto loop_ids = start_it->get()->get_loop_ids(); + OPENVINO_ASSERT(loop_ids == std::prev(stop_it)->get()->get_loop_ids(), "Inconsistent loop ids for RegSpill expressions"); + const auto spill_begin_expr = *linear_ir.insert_node(begin, std::vector{}, loop_ids, + false, start_it, std::vector>{}); + std::vector vregs{regs_to_spill.begin(), regs_to_spill.end()}; + spill_begin_expr->set_reg_info({{}, vregs}); + spill_begin_expr->set_live_regs(std::prev(start_it, 2)->get()->get_live_regs()); + + const auto spill_end_expr = *linear_ir.insert_node(end, spill_begin_expr->get_output_port_connectors(), loop_ids, + false, stop_it, std::vector>{}); + spill_end_expr->set_reg_info({vregs, {}}); + spill_begin_expr->set_live_regs(stop_it->get()->get_live_regs()); + modified = true; + } + return modified; +} + +} // namespace pass +} // namespace lowered +} // namespace snippets +} // namespace ov + diff --git a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp index 27232f1605ea0e..741ab936055c20 100644 --- a/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/lowered/pass/load_movebroadcast_to_broadcastload.cpp @@ -47,10 +47,10 @@ bool LoadMoveBroadcastToBroadcastLoad::run(LinearIR& linear_ir, lowered::LinearI const auto& load_parent_node = load_expr->get_input_port_connector(0)->get_source().get_expr()->get_node(); const auto& outshape = move_broadcast->get_output_partial_shape(0); const auto broadcastload = std::make_shared(load_parent_node, *outshape.rbegin(), load->get_offset()); - // insert at position of load_expr. As BroadcastMove and Load will be removed, preserve expr_it. - expr_it = std::next(expr_it); - linear_ir.replace_with_node({ load_expr, expr }, broadcastload, load_expr->get_loop_ids(), linear_ir.find(load_expr)); - expr_it = std::prev(expr_it); + auto live_regs = load_expr->get_live_regs(); + expr_it = linear_ir.replace_with_node({ load_expr, expr }, broadcastload, load_expr->get_loop_ids(), linear_ir.find(load_expr)); + // Note: We have to set live regs manually, since this transformation is applied after all register-related passes. + expr_it->get()->set_live_regs(std::move(live_regs)); modified |= true; } } diff --git a/src/common/snippets/src/lowered/port_descriptor.cpp b/src/common/snippets/src/lowered/port_descriptor.cpp index f5c99c8c983a6b..7e3234d788d0e6 100644 --- a/src/common/snippets/src/lowered/port_descriptor.cpp +++ b/src/common/snippets/src/lowered/port_descriptor.cpp @@ -74,7 +74,7 @@ std::string PortDescriptor::serialize() const { ss << m_layout.size() << " "; for (auto val : m_layout) ss << val << " "; - ss << regTypeToStr(m_reg.type) << "["<< m_reg.idx << "]"; + ss << m_reg; return ss.str(); } bool operator==(const PortDescriptor& lhs, const PortDescriptor& rhs) { diff --git a/src/common/snippets/src/op/kernel.cpp b/src/common/snippets/src/op/kernel.cpp index 0eb94b426b0519..347dc083ba0ac5 100644 --- a/src/common/snippets/src/op/kernel.cpp +++ b/src/common/snippets/src/op/kernel.cpp @@ -12,14 +12,6 @@ namespace op { Kernel::Kernel(lowered::LinearIR nested) : Op(), region(std::make_shared(std::move(nested))) {} -std::shared_ptr Kernel::make_kernel(const lowered::LinearIR& region) { - if (region.is_dynamic()) { - return std::make_shared(region); - } else { - return std::make_shared(region); - } -} - KernelStatic::KernelStatic(lowered::LinearIR nested) : Kernel(std::move(nested)) {} KernelDynamic::KernelDynamic(lowered::LinearIR nested) : Kernel(std::move(nested)) {} diff --git a/src/common/snippets/src/op/reg_spill.cpp b/src/common/snippets/src/op/reg_spill.cpp new file mode 100644 index 00000000000000..0eef459a47ac62 --- /dev/null +++ b/src/common/snippets/src/op/reg_spill.cpp @@ -0,0 +1,87 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/op/reg_spill.hpp" + +#include "snippets/utils/utils.hpp" + +namespace ov { +namespace snippets { +namespace op { + +RegSpillBase::RegSpillBase(const std::vector> &args) : Op(args) {} + +bool RegSpillBase::visit_attributes(AttributeVisitor &visitor) { + std::stringstream ss; + const auto& regs_to_spill = get_regs_to_spill(); + for (auto reg_it = regs_to_spill.begin(); reg_it != regs_to_spill.end(); reg_it++) { + ss << *reg_it; + if (std::next(reg_it) != regs_to_spill.end()) + ss << ", "; + } + std::string spilled = ss.str(); + visitor.on_attribute("regs_to_spill", spilled); + return true; +} + +RegSpillBegin::RegSpillBegin(std::set regs_to_spill) : m_regs_to_spill(std::move(regs_to_spill)) { + validate_and_infer_types_except_RegSpillEnd(); +} + +void RegSpillBegin::validate_and_infer_types_except_RegSpillEnd() { + NODE_VALIDATION_CHECK(this, get_input_size() == 0, "RegSpillBegin doesn't expect any inputs"); + set_output_type(0, element::f32, ov::PartialShape{ov::Shape{}}); +} + +void RegSpillBegin::validate_and_infer_types() { + validate_and_infer_types_except_RegSpillEnd(); + OPENVINO_ASSERT(get_output_size() == 1, "RegSpillBegin must have only one output"); + const auto& last_output_inputs = get_output_target_inputs(0); + OPENVINO_ASSERT(last_output_inputs.size() == 1, "RegSpillBegin must have exactly one input attached to the last output"); + OPENVINO_ASSERT(ov::is_type(last_output_inputs.begin()->get_node()), + "RegSpillBegin must have RegSpillEnd connected to its last output"); +} + +std::shared_ptr RegSpillBegin::clone_with_new_inputs(const OutputVector& inputs) const { + OPENVINO_ASSERT(inputs.empty(), "RegSpillBegin should not contain inputs"); + return std::make_shared(m_regs_to_spill); +} + +std::shared_ptr RegSpillBegin::get_reg_spill_end() const { + const auto& last_output_inputs = get_output_target_inputs(0); + OPENVINO_ASSERT(last_output_inputs.size() == 1, "RegSpillBegin has more than one inputs attached to the last output"); + const auto& loop_end = ov::as_type_ptr(last_output_inputs.begin()->get_node()->shared_from_this()); + OPENVINO_ASSERT(loop_end != nullptr, "RegSpillBegin must have RegSpillEnd connected to its last output"); + return loop_end; +} + +RegSpillBegin::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { + auto reg_spill_begin = ov::as_type_ptr(n); + OPENVINO_ASSERT(reg_spill_begin, "Invalid node passed to RegSpillBegin::ShapeInfer"); + num_out_shapes = reg_spill_begin->get_regs_to_spill().size(); +} + +RegSpillBegin::ShapeInfer::Result RegSpillBegin::ShapeInfer::infer(const std::vector& input_shapes) { + return {std::vector(num_out_shapes, VectorDims{1}), ShapeInferStatus::success}; +} + +RegSpillEnd::RegSpillEnd(const Output& reg_spill_begin) : RegSpillBase({reg_spill_begin}) { + constructor_validate_and_infer_types(); +} + +void RegSpillEnd::validate_and_infer_types() { + NODE_VALIDATION_CHECK(this, get_input_size() == 1 && ov::is_type(get_input_node_shared_ptr(0)), + "RegSpillEnd must have one input of RegSPillBegin type"); + set_output_type(0, element::f32, ov::PartialShape{}); +} + +std::shared_ptr RegSpillEnd::clone_with_new_inputs(const OutputVector& inputs) const { + check_new_args_count(this, inputs); + return std::make_shared(inputs.at(0)); +} + + +} // namespace op +} // namespace snippets +} // namespace ov diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 98e3392a65e1e2..70202b9ddd937f 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -48,7 +48,6 @@ #include "snippets/lowered/pass/validate.hpp" #include "snippets/lowered/pass/pass_config.hpp" #include "snippets/lowered/pass/reduce_decomposition.hpp" -#include "snippets/lowered/pass/assign_registers.hpp" #include "snippets/lowered/pass/cleanup_loop_offsets.hpp" #include "snippets/lowered/pass/insert_specific_iterations.hpp" #include "snippets/lowered/pass/optimize_loop_single_evaluation.hpp" @@ -57,12 +56,16 @@ #include "snippets/lowered/pass/set_load_store_scalar.hpp" #include "snippets/lowered/pass/extract_loop_invariants.hpp" +#include "snippets/lowered/reg_manager.hpp" +#include "snippets/lowered/pass/assign_registers.hpp" +#include "snippets/lowered/pass/init_live_ranges.hpp" +#include "snippets/lowered/pass/insert_reg_spills.hpp" + #include "transformations/utils/utils.hpp" #include "snippets/pass/manager.hpp" #include "openvino/pass/constant_folding.hpp" #include "ov_ops/type_relaxed.hpp" -#include "openvino/pass/serialize.hpp" #include #include @@ -498,10 +501,6 @@ void Subgraph::control_flow_transformations(size_t min_parallel_work_amount, siz OV_ITT_TASK_NEXT(CONTROL_FLOW, "::pre_generation_pipeline") - std::function& out)> reg_type_mapper = [&](const ov::Output& out) -> RegType { - return get_generator()->get_op_out_reg_type(out); - }; - lowered::pass::PassPipeline gen_pipeline(lowered_pass_config); // Note: the order of all passes in this pipeline must not be changed since they have hard dependencies // 1. InsertSpecificIterations must be called after AssignRegisters since tail loop expressions must have the same @@ -510,7 +509,11 @@ void Subgraph::control_flow_transformations(size_t min_parallel_work_amount, siz // (this might happen if tail loop and main loop have different increments) // 3. OptimizeLoopSingleEvaluation must be called after CleanupLoopOffsets // since CleanupLoopOffsets can't handle loops with evaluate_once = true - gen_pipeline.register_pass(reg_type_mapper, get_generator()->get_target_machine()->get_reg_count()); + + lowered::RegManager reg_manager(get_generator()); + gen_pipeline.register_pass(reg_manager); + gen_pipeline.register_pass(reg_manager); + gen_pipeline.register_pass(reg_manager); gen_pipeline.register_pass(); gen_pipeline.register_pass(); gen_pipeline.register_pass(); diff --git a/src/common/snippets/src/shape_inference/shape_inference.cpp b/src/common/snippets/src/shape_inference/shape_inference.cpp index 76a4c491c66983..cddf585cbd919b 100644 --- a/src/common/snippets/src/shape_inference/shape_inference.cpp +++ b/src/common/snippets/src/shape_inference/shape_inference.cpp @@ -46,10 +46,12 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry SHAPE_INFER_PREDEFINED(op::HorizonMax, HorizonOpShapeInfer), SHAPE_INFER_PREDEFINED(op::HorizonSum, HorizonOpShapeInfer), // - SHAPE_INFER_PREDEFINED(op::LoopBegin, SingleElementShapeInfer), SHAPE_INFER_PREDEFINED(op::Scalar, SingleElementShapeInfer), SHAPE_INFER_PREDEFINED(op::VectorBuffer, SingleElementShapeInfer), + SHAPE_INFER_PREDEFINED(op::LoopBegin, SingleElementShapeInfer), SHAPE_INFER_PREDEFINED(op::LoopEnd, EmptyShapeInfer), + SHAPE_INFER_OP_SPECIFIC(op::RegSpillBegin), + SHAPE_INFER_PREDEFINED(op::RegSpillEnd, EmptyShapeInfer), #ifdef SNIPPETS_DEBUG_CAPS SHAPE_INFER_PREDEFINED(op::PerfCountBegin, EmptyShapeInfer), SHAPE_INFER_PREDEFINED(op::PerfCountEnd, EmptyShapeInfer), diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index b9ecbb5570481c..ba87b3d1a067fa 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -49,7 +49,9 @@ class DummyTargetMachine : public ov::snippets::TargetMachine { ov::snippets::CompiledSnippetPtr get_snippet() override { return std::make_shared(); } size_t get_lanes() const override { return 10; } std::shared_ptr clone() const override { return std::make_shared(); } - size_t get_reg_count() const override { return 16; } + std::vector get_abi_arg_regs() const override; + std::vector get_gp_reg_pool() const override; + std::vector get_vec_reg_pool() const override; }; class DummyGenerator : public ov::snippets::Generator { @@ -58,7 +60,6 @@ class DummyGenerator : public ov::snippets::Generator { DummyGenerator(const std::shared_ptr& t) : ov::snippets::Generator(t) {} std::shared_ptr clone() const override { return std::make_shared(target); } -protected: ov::snippets::RegType get_op_out_reg_type(const ov::Output& out) const override { return ov::snippets::RegType::vec; }; }; diff --git a/src/common/snippets/tests/src/lir_comparator.cpp b/src/common/snippets/tests/src/lir_comparator.cpp index 82d5b9dcf91441..4a0e05fe8b92d5 100644 --- a/src/common/snippets/tests/src/lir_comparator.cpp +++ b/src/common/snippets/tests/src/lir_comparator.cpp @@ -17,7 +17,9 @@ inline string to_string(const vector& vec) { } inline string to_string(const ov::snippets::Reg& reg) { - return string("Reg(type = " + ov::snippets::regTypeToStr(reg.type) + ", idx = " + to_string(reg.idx) + ")"); + stringstream ss; + ss << reg; + return ss.str(); } inline string to_string(const ov::Node::type_info_t& info) { diff --git a/src/common/snippets/tests/src/lowering_utils.cpp b/src/common/snippets/tests/src/lowering_utils.cpp index e9ed04bf8da5a4..d3b44221177253 100644 --- a/src/common/snippets/tests/src/lowering_utils.cpp +++ b/src/common/snippets/tests/src/lowering_utils.cpp @@ -36,6 +36,7 @@ DummyTargetMachine::DummyTargetMachine(const std::vector& jitters[ov::snippets::op::HorizonMax::get_type_info_static()] = dummy_functor; jitters[ov::snippets::op::HorizonSum::get_type_info_static()] = dummy_functor; jitters[ov::snippets::op::Load::get_type_info_static()] = dummy_functor; + jitters[ov::snippets::op::LoadReshape::get_type_info_static()] = dummy_functor; jitters[ov::snippets::op::BroadcastLoad::get_type_info_static()] = dummy_functor; jitters[ov::snippets::op::Store::get_type_info_static()] = dummy_functor; @@ -63,6 +64,33 @@ DummyTargetMachine::DummyTargetMachine(const std::vector& } } +std::vector DummyTargetMachine::get_abi_arg_regs() const { + const auto num_abi_regs = 4; + std::vector reg_pool; + reg_pool.reserve(num_abi_regs); + for (size_t i = 0; i < num_abi_regs; i++) + reg_pool.emplace_back(ov::snippets::RegType::gpr, i); + return reg_pool; +} + +std::vector DummyTargetMachine::get_gp_reg_pool() const { + const auto num_gp_regs = 16; + std::vector reg_pool; + reg_pool.reserve(num_gp_regs); + for (size_t i = 0; i < num_gp_regs; i++) + reg_pool.emplace_back(ov::snippets::RegType::gpr, i); + return reg_pool; +} + +std::vector DummyTargetMachine::get_vec_reg_pool() const { + const auto num_vec_regs = 16; + std::vector reg_pool; + reg_pool.reserve(num_vec_regs); + for (size_t i = 0; i < num_vec_regs; i++) + reg_pool.emplace_back(ov::snippets::RegType::vec, i); + return reg_pool; +} + LoweringTests::LoweringTests() : TransformationTestsF() { // external subgraph input shape and internal parameters shapes // might differ due to the blocked layout diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp index 534470c746f2fe..bb7f0b35a4b35d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_eltwise_emitters.cpp @@ -1769,7 +1769,7 @@ void jit_mish_emitter::emit_isa(const std::vector& in_vec_idxs, const st const TReg vmm_src(in_vec_idxs[0]); const TReg vmm_dst(out_vec_idxs[0]); const TReg vmm_aux0(aux_vec_idxs[0]); - const TReg vmm_aux2(std::max(exp_emitter->get_aux_vecs_count(), 1)); + const TReg vmm_aux2(aux_vec_idxs[std::max(exp_emitter->get_aux_vecs_count(), 1)]); h->ld1r(vmm_aux0.s, table_val2("fwd_mish_max_x_for_equation_f")); h->fminnm(vmm_aux2.s, vmm_src.s, vmm_aux0.s); diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp index 4c0b0f95f783c2..5033f645413557 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.cpp @@ -72,6 +72,10 @@ void jit_emitter::emit_data() const { } } +emitter_in_out_map jit_emitter::get_in_out_type() const { + return in_out_type_; +} + std::set> jit_emitter::get_supported_precisions(const std::shared_ptr& node) { return {}; } diff --git a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp index 9ce8203afe7783..c0bfb4114f9c17 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/aarch64/jit_emitter.hpp @@ -60,6 +60,7 @@ class jit_emitter : public ov::snippets::Emitter { virtual size_t get_inputs_count() const = 0; virtual size_t get_aux_vecs_count() const; virtual size_t get_aux_gprs_count() const; + emitter_in_out_map get_in_out_type() const; /** * @brief Returns supported precisions. diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp index 420e9691ebc73c..533918d0e04fbc 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp @@ -5,6 +5,7 @@ #include "utils.hpp" #include "emitters/utils.hpp" +#include "snippets/utils/utils.hpp" namespace ov { namespace intel_cpu { @@ -12,102 +13,138 @@ namespace intel_cpu { using namespace Xbyak; using namespace dnnl::impl::cpu::x64; -EmitABIRegSpills::EmitABIRegSpills(jit_generator* h) : h(h), isa(get_isa()) {} +namespace { +inline snippets::Reg Xbyak2SnippetsReg(const Xbyak::Reg& xb_reg) { + auto get_reg_type = [](const Xbyak::Reg& xb_reg) { + switch (xb_reg.getKind()) { + case Xbyak::Reg::REG: + return snippets::RegType::gpr; + case Xbyak::Reg::XMM: + case Xbyak::Reg::YMM: + case Xbyak::Reg::ZMM: + return snippets::RegType::vec; + case Xbyak::Reg::OPMASK: + return snippets::RegType::mask; + default: + OPENVINO_THROW("Unhandled Xbyak reg type in conversion to snippets reg type"); + } + }; + return {get_reg_type(xb_reg), static_cast(xb_reg.getIdx())}; +} + +template ::type = true> +struct regs_to_spill { + static std::vector get(const std::set& live_regs) { + std::vector regs_to_spill; + auto push_if_live = [&live_regs, ®s_to_spill](Xbyak::Reg&& reg) { + if (live_regs.empty() || live_regs.count(Xbyak2SnippetsReg(reg))) + regs_to_spill.emplace_back(reg); + }; + for (int i = 0; i < 16; i++) { + // do not spill rsp; + if (i != 4) + push_if_live(Reg64(i)); + } + + for (int i = 0; i < cpu_isa_traits::n_vregs; ++i) + push_if_live(typename cpu_isa_traits::Vmm(i)); + + const int num_k_mask = isa == cpu_isa_t::avx512_core ? 8 : 0; + for (int i = 0; i < num_k_mask; ++i) + push_if_live(Xbyak::Opmask(i)); + return regs_to_spill; + } +}; + +std::vector get_regs_to_spill(cpu_isa_t isa, const std::set& live_regs) { + switch (isa) { + case sse41: + return regs_to_spill::get(live_regs); + case avx2: + return regs_to_spill::get(live_regs); + case avx512_core: + return regs_to_spill::get(live_regs); + default: + OPENVINO_THROW("Unhandled isa in get_regs_to_spill"); + } +} +} // namespace + +EmitABIRegSpills::EmitABIRegSpills(jit_generator* h_arg) : h(h_arg), isa(get_isa()) {} EmitABIRegSpills::~EmitABIRegSpills() { OPENVINO_ASSERT(spill_status, "postamble or preamble is missed"); OPENVINO_ASSERT(rsp_status, "rsp_align or rsp_restore is missed"); } -void EmitABIRegSpills::preamble() { - // gprs - Xbyak::Operand gprs_to_save[] = {h->r8, - h->r9, - h->r10, - h->r11, - h->r12, - h->r13, - h->r14, - h->r15, - h->rax, - h->rbx, - h->rcx, - h->rdx, - h->rdi, - h->rsi, - h->rbp}; - size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); - - h->sub(h->rsp, n_gprs_to_save * gpr_size); - for (size_t i = 0; i < n_gprs_to_save; ++i) - h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); - - if (isa == avx512_core) { - h->sub(h->rsp, k_mask_num * k_mask_size); - for (size_t i = 0; i < k_mask_num; ++i) { - h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast(i))); - } +void EmitABIRegSpills::preamble(const std::set& live_regs) { + OPENVINO_ASSERT(spill_status, "Attempt to spill ABI registers twice in a row"); + // all regs to spill according to ABI + m_regs_to_spill = get_regs_to_spill(isa, live_regs); + for (const auto& reg : m_regs_to_spill) { + const auto reg_bit_size = reg.getBit(); + OPENVINO_ASSERT(reg_bit_size % 8 == 0, "Unexpected reg bit size"); + m_bytes_to_spill += reg_bit_size / 8; } - - h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); - for (size_t i = 0; i < get_max_vecs_count(); ++i) { - const auto addr = h->ptr[h->rsp + i * get_vec_length()]; - if (isa == sse41) { - h->uni_vmovups(addr, Xmm(i)); - } else if (isa == avx2) { - h->uni_vmovups(addr, Ymm(i)); - } else { - h->uni_vmovups(addr, Zmm(i)); + h->sub(h->rsp, m_bytes_to_spill); + uint32_t byte_stack_offset = 0; + for (const auto& reg : m_regs_to_spill) { + Xbyak::Address addr = h->ptr[h->rsp + byte_stack_offset]; + byte_stack_offset += reg.getBit() / 8; + switch (reg.getKind()) { + case Xbyak::Reg::REG: + h->mov(addr, reg); + break; + case Xbyak::Reg::XMM: + h->uni_vmovups(addr, Xmm(reg.getIdx())); + break; + case Xbyak::Reg::YMM: + h->uni_vmovups(addr, Ymm(reg.getIdx())); + break; + case Xbyak::Reg::ZMM: + h->uni_vmovups(addr, Zmm(reg.getIdx())); + break; + case Xbyak::Reg::OPMASK: + h->kmovq(addr, Opmask(reg.getIdx())); + break; + default: + OPENVINO_THROW("Unhandled Xbyak reg type in conversion"); } } - // Update the status spill_status = false; } void EmitABIRegSpills::postamble() { - // restore vector registers - for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { - const auto addr = h->ptr[h->rsp + i * get_vec_length()]; - if (isa == sse41) { - h->uni_vmovups(Xmm(i), addr); - } else if (isa == avx2) { - h->uni_vmovups(Ymm(i), addr); - } else { - h->uni_vmovups(Zmm(i), addr); + OPENVINO_ASSERT(!spill_status, "Attempt to restore ABI registers that were not spilled"); + uint32_t byte_stack_offset = m_bytes_to_spill; + for (size_t i = m_regs_to_spill.size(); i > 0; i--) { + const auto& reg = m_regs_to_spill[i - 1]; + byte_stack_offset -= reg.getBit() / 8; + Xbyak::Address addr = h->ptr[h->rsp + byte_stack_offset]; + switch (reg.getKind()) { + case Xbyak::Reg::REG: + h->mov(reg, addr); + break; + case Xbyak::Reg::XMM: + h->uni_vmovups(Xmm(reg.getIdx()), addr); + break; + case Xbyak::Reg::YMM: + h->uni_vmovups(Ymm(reg.getIdx()), addr); + break; + case Xbyak::Reg::ZMM: + h->uni_vmovups(Zmm(reg.getIdx()), addr); + break; + case Xbyak::Reg::OPMASK: + h->kmovq(Xbyak::Opmask(reg.getIdx()), addr); + break; + default: + OPENVINO_THROW("Unhandled Xbyak reg type in conversion"); } } - h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); - - // restore k reg - if (isa == avx512_core) { - for (int i = k_mask_num - 1; i >= 0; --i) { - h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - } - h->add(h->rsp, k_mask_num * k_mask_size); - } - - // restore gpr registers - Xbyak::Operand gprs_to_save[] = {h->r8, - h->r9, - h->r10, - h->r11, - h->r12, - h->r13, - h->r14, - h->r15, - h->rax, - h->rbx, - h->rcx, - h->rdx, - h->rdi, - h->rsi, - h->rbp}; - size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); - for (int i = n_gprs_to_save - 1; i >= 0; --i) - h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); - h->add(h->rsp, n_gprs_to_save * gpr_size); - + h->add(h->rsp, m_bytes_to_spill); + m_regs_to_spill.clear(); // Update the status spill_status = true; } diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp index ba956f3375f054..1cf3ce26d59cd9 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp @@ -5,6 +5,7 @@ #pragma once #include "cpu/x64/jit_generator.hpp" +#include "snippets/emitter.hpp" namespace ov { namespace intel_cpu { @@ -14,9 +15,11 @@ class EmitABIRegSpills { public: EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h); ~EmitABIRegSpills(); - + size_t get_num_spilled_regs() const { + return m_regs_to_spill.size(); + } // push (save) all registers on the stack - void preamble(); + void preamble(const std::set& live_regs = {}); // pop (take) all registers from the stack void postamble(); @@ -27,22 +30,11 @@ class EmitABIRegSpills { private: EmitABIRegSpills() = default; - static dnnl::impl::cpu::x64::cpu_isa_t get_isa(); - - inline size_t get_max_vecs_count() const { - return dnnl::impl::cpu::x64::isa_num_vregs(isa); - } - inline size_t get_vec_length() const { - return dnnl::impl::cpu::x64::isa_max_vlen(isa); - } - dnnl::impl::cpu::x64::jit_generator* h{nullptr}; const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::cpu_isa_t::isa_undef}; - - static constexpr int k_mask_size = 8; - static constexpr int k_mask_num = 8; - static constexpr int gpr_size = 8; + std::vector m_regs_to_spill; + uint32_t m_bytes_to_spill = 0; bool spill_status = true; bool rsp_status = true; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp index 0f6b2c24c13df7..534e835aade264 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.cpp @@ -194,9 +194,40 @@ size_t CPUTargetMachine::get_lanes() const { } } -// TODO [139932]: Support separate vec_count and gpr_count -size_t CPUTargetMachine::get_reg_count() const { - return 32; +std::vector CPUTargetMachine::get_abi_arg_regs() const { + using namespace dnnl::impl::cpu::aarch64; + std::vector res; + for (const auto& r : + {abi_param1, abi_param2, abi_param3, abi_param4, abi_param5, abi_param6, abi_param7, abi_param8}) + res.emplace_back(snippets::RegType::gpr, r.getIdx()); + return res; +} + +std::vector CPUTargetMachine::get_gp_reg_pool() const { + using Xbyak_aarch64::Operand; + const auto num_gp_regs = 32; + std::vector reg_pool; + for (size_t i = 0; i < num_gp_regs; i++) { + // Note: more details on the usage of reserved registers in aarch64/jit_kernel_emitter.cpp + if (!one_of(i, Operand::SP, Operand::X18, Operand::X23, Operand::X24, Operand::X28, Operand::X29)) + reg_pool.emplace_back(snippets::RegType::gpr, i); + } + return reg_pool; +} + +std::vector CPUTargetMachine::get_vec_reg_pool() const { + const auto num_vec_regs = [this]() { + switch (isa) { + case dnnl::impl::cpu::aarch64::asimd: + return dnnl::impl::cpu::aarch64::cpu_isa_traits::n_vregs; + default: + OPENVINO_THROW("unknown isa ", isa); + } + }(); + std::vector reg_pool; + for (int i = 0; i < num_vec_regs; i++) + reg_pool.emplace_back(snippets::RegType::vec, static_cast(i)); + return reg_pool; } dnnl::impl::cpu::aarch64::cpu_isa_t CPUTargetMachine::get_isa() const { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.hpp index 4006fc01b9a1f5..d784bc3ff0cb28 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/cpu_generator.hpp @@ -30,7 +30,11 @@ class CPUTargetMachine : public snippets::TargetMachine { bool is_supported() const override; snippets::CompiledSnippetPtr get_snippet() override; size_t get_lanes() const override; - size_t get_reg_count() const override; + + std::vector get_abi_arg_regs() const override; + std::vector get_gp_reg_pool() const override; + std::vector get_vec_reg_pool() const override; + dnnl::impl::cpu::aarch64::cpu_isa_t get_isa() const; private: diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp index 32ed1a844b6724..417e33d339816c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.cpp @@ -5,6 +5,8 @@ #include "jit_kernel_emitter.hpp" #include "emitters/utils.hpp" +#include "jit_snippets_emitters.hpp" +#include "snippets/utils/reg_utils.hpp" #include "snippets/utils/utils.hpp" using namespace Xbyak_aarch64; @@ -24,55 +26,7 @@ inline static std::vector transform_idxs_to_regs(const std::vector }); return regs; } - -inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs) { - std::vector idxs(regs.size()); - std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { - return reg.idx; - }); - return idxs; -} - -jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, - cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), - reg_runtime_params_idx(Operand::X0) { - const auto kernel = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "Invoked with invalid op argument"); - OV_CPU_JIT_EMITTER_ASSERT(!kernel->region->empty(), "Invoked with empty body"); - body = kernel->region; - jcp = *reinterpret_cast(kernel->compile_params); - const auto& parameters = body->get_parameters(); - const auto& results = body->get_results(); - const auto& buffers = body->get_buffers(); - num_inputs = parameters.size(); - num_outputs = results.size(); - for (const auto& param : parameters) - mem_access_exprs.push_back(param); - for (const auto& result : results) - mem_access_exprs.push_back(result); - - std::set unique_buffers; - for (const auto& buffer_expr : buffers) { - const auto buffer_reg_group = buffer_expr->get_reg_group(); - if (unique_buffers.count(buffer_reg_group) == 0) { - mem_access_exprs.push_back(buffer_expr); - unique_buffers.insert(buffer_reg_group); - } - } - - using ExprSet = std::unordered_set; - const ExprSet params_set(parameters.cbegin(), parameters.cend()); - const ExprSet results_set(results.cbegin(), results.cend()); - const ExprSet buffers_set(buffers.cbegin(), buffers.cend()); - for (const auto& expr : *body) { - if (params_set.count(expr) == 0 && results_set.count(expr) == 0 && buffers_set.count(expr) == 0) - general_exprs.emplace_back(expr); - } - num_unique_buffers = unique_buffers.size(); -} - +// Useful register mapping info: //==================================================================================== // GPR | Description | Usage | Purpose // =================================================================================== @@ -112,31 +66,39 @@ jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, // Note that 2 of the 25 marked Data pointer registers will be used as work_amounts in // two-level loops, so the actual number of Data pointer register is 23. //==================================================================================== -void jit_kernel_emitter::init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist) { - gp_regs_pool.resize(32); - vec_regs_pool.resize(32); - // It's easier to remove the last item during mapping, so fill descending to map ascending - for (size_t i = 0; i < 32; i++) - gp_regs_pool[i] = vec_regs_pool[i] = 31 - i; - auto remove_regs_from_pool = [](std::vector& pool, const std::set& to_remove) { - // It's important to keep the order of other elements - pool.erase(std::remove_if(pool.begin(), - pool.end(), - [&](size_t x) { - return to_remove.count(x) != 0; - }), - pool.end()); - }; - std::set gprs_blacklist_extended{Operand::X18, - Operand::X23, - Operand::X24, - Operand::X28, - Operand::X29, - Operand::SP}; - gprs_blacklist_extended.insert(gpr_blacklist.begin(), gpr_blacklist.end()); - // Reserve reg_indexes_idx and reg_runtime_params_idx, since they'll be used to pass runtime call args to kernel - remove_regs_from_pool(gp_regs_pool, gprs_blacklist_extended); - remove_regs_from_pool(vec_regs_pool, vec_blacklist); + +jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, + cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr) + : jit_emitter(h, isa) { + const auto kernel = ov::as_type_ptr(expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "Invoked with invalid op argument"); + OV_CPU_JIT_EMITTER_ASSERT(!kernel->region->empty(), "Invoked with empty body"); + body = kernel->region; + jcp = *reinterpret_cast(kernel->compile_params); + const auto& parameters = body->get_parameters(); + const auto& results = body->get_results(); + const auto& buffers = body->get_buffers(); + num_inputs = parameters.size(); + num_outputs = results.size(); + std::vector data_ptr_regs; + data_ptr_regs.reserve(num_inputs + num_outputs); + for (const auto& param : parameters) + data_ptr_regs.push_back(param->get_output_port_descriptor(0)->get_reg()); + for (const auto& result : results) + data_ptr_regs.push_back(result->get_input_port_descriptor(0)->get_reg()); + + std::set unique_buffers; + for (const auto& buffer_expr : buffers) { + const auto buffer_reg_group = buffer_expr->get_reg_group(); + if (unique_buffers.count(buffer_reg_group) == 0) { + data_ptr_regs.push_back(buffer_expr->get_output_port_descriptor(0)->get_reg()); + unique_buffers.insert(buffer_reg_group); + } + } + + num_unique_buffers = unique_buffers.size(); + data_ptr_regs_idx = snippets::utils::transform_snippets_regs_to_idxs(data_ptr_regs, snippets::RegType::gpr); } void jit_kernel_emitter::emit_code(const std::vector& in, @@ -144,11 +106,14 @@ void jit_kernel_emitter::emit_code(const std::vector& in, const std::vector& pool_vec_idxs, const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); + aux_vec_idxs = pool_vec_idxs; + aux_gpr_idxs = pool_gpr_idxs; emit_impl(in, out); } void jit_kernel_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { - OV_CPU_JIT_EMITTER_ASSERT(in.empty() && out.empty(), ": Expects 0 registers on input and output"); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == get_inputs_count() && out.empty(), + "Unexpected number of input/output arguments"); const auto num_params = num_inputs + num_outputs + num_unique_buffers; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount OV_CPU_JIT_EMITTER_ASSERT(data_ptr_regs_idx.size() == num_params, @@ -158,40 +123,78 @@ void jit_kernel_emitter::validate_arguments(const std::vector& in, const data_ptr_regs_idx.size()); } -void jit_kernel_emitter::init_body_regs(const std::set& kernel_regs, - const std::vector& pool_vec_idxs, - const std::vector& pool_gpr_idxs) { - // Initialize pools of gp and vec registers - // Reserve kernel regs (reg_indexes_idx and, if there is, reg_runtime_params_idx), since they'll be used to pass - // runtime call args to kernel - init_reg_pools(kernel_regs, {}); - - mapping_info gpr_map_pool({}, gp_regs_pool); - mapping_info vec_map_pool({}, vec_regs_pool); - - // Note that we can't use kernel_regs to store data pointers because - // these regs are used to calculate offsets for the data pointers - map_abstract_registers(gpr_map_pool, vec_map_pool, mem_access_exprs); - for (const auto& abstract_to_physical : gpr_map_pool.first) - data_ptr_regs_idx.push_back(abstract_to_physical.second); - - vec_map_pool.second.insert(vec_map_pool.second.end(), pool_vec_idxs.cbegin(), pool_vec_idxs.cend()); - gpr_map_pool.second.insert(gpr_map_pool.second.end(), pool_gpr_idxs.cbegin(), pool_gpr_idxs.cend()); - map_abstract_registers(gpr_map_pool, vec_map_pool, general_exprs); -} - void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vector& out) const { h->preamble(); + std::set available_gpr; + std::set available_vec; + auto reg_type = snippets::RegType::gpr; + auto convert = [®_type](size_t i) -> snippets::Reg { + return {reg_type, i}; + }; + std::transform(aux_gpr_idxs.begin(), + aux_gpr_idxs.end(), + std::inserter(available_gpr, available_gpr.begin()), + convert); + // Note: data_ptr regs are globally live, so it makes no sense to keep them in the pool + for (auto idx : data_ptr_regs_idx) + available_gpr.erase({snippets::RegType::gpr, idx}); + reg_type = snippets::RegType::vec; + std::transform(aux_vec_idxs.begin(), + aux_vec_idxs.end(), + std::inserter(available_vec, available_vec.begin()), + convert); + auto data_ptr_regs = transform_idxs_to_regs(data_ptr_regs_idx); - init_data_pointers(data_ptr_regs); + auto get_expected_reg_types = + [](const std::shared_ptr& emitter) -> std::pair { + switch (emitter->get_in_out_type()) { + case emitter_in_out_map::gpr_to_vec: + return {snippets::RegType::gpr, snippets::RegType::vec}; + case emitter_in_out_map::gpr_to_gpr: + return {snippets::RegType::gpr, snippets::RegType::gpr}; + case emitter_in_out_map::vec_to_gpr: + return {snippets::RegType::vec, snippets::RegType::gpr}; + case emitter_in_out_map::vec_to_vec: + return {snippets::RegType::vec, snippets::RegType::vec}; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported emitter_in_ou_map instance"); + } + }; + init_data_pointers(transform_idxs_to_regs(in), data_ptr_regs); for (const auto& expression : *body) { const auto reg_info = expression->get_reg_info(); - auto in_regs = transform_snippets_regs_to_idxs(reg_info.first); - auto out_regs = transform_snippets_regs_to_idxs(reg_info.second); - const auto& emitter = expression->get_emitter(); - emitter->emit_code(in_regs, out_regs, vec_regs_pool, gp_regs_pool); + const auto& emitter = std::dynamic_pointer_cast(expression->get_emitter()); + OV_CPU_JIT_EMITTER_ASSERT(emitter, "Unexpected emitter type"); + auto expected_in_type = snippets::RegType::undefined; + auto expected_out_type = snippets::RegType::undefined; + const auto& node = expression->get_node(); + // Note: currently only a few operations are allowed to have mixed in/out register types => skip validation here + if (!ov::is_type(node) && !ov::is_type(node) && + !std::dynamic_pointer_cast(emitter)) + std::tie(expected_in_type, expected_out_type) = get_expected_reg_types(emitter); + // Note: live regs = regs live on input of the expression. We also need to exclude output regs from the pool + auto live_regs = expression->get_live_regs(); + for (auto r : reg_info.second) + live_regs.insert(r); + std::vector pool_gp_reg; + std::vector pool_vec_reg; + std::set_difference(available_gpr.begin(), + available_gpr.end(), + live_regs.begin(), + live_regs.end(), + std::back_inserter(pool_gp_reg)); + std::set_difference(available_vec.begin(), + available_vec.end(), + live_regs.begin(), + live_regs.end(), + std::back_inserter(pool_vec_reg)); + auto in_regs = snippets::utils::transform_snippets_regs_to_idxs(reg_info.first, expected_in_type); + auto out_regs = snippets::utils::transform_snippets_regs_to_idxs(reg_info.second, expected_out_type); + auto gpr_pool = snippets::utils::transform_snippets_regs_to_idxs(pool_gp_reg); + auto vec_pool = snippets::utils::transform_snippets_regs_to_idxs(pool_vec_reg); + emitter->emit_code(in_regs, out_regs, vec_pool, gpr_pool); } h->postamble(); @@ -200,25 +203,22 @@ void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vec jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_kernel_emitter(h, isa, expr), - reg_indexes_idx(Operand::X1) { + : jit_kernel_emitter(h, isa, expr) { const auto kernel = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "Expectes KernelStatic expression"); + OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "expects KernelStatic expression"); jcp = *reinterpret_cast(kernel->compile_params); master_shape = jcp.exec_domain; data_offsets = jcp.data_offsets; OV_CPU_JIT_EMITTER_ASSERT(data_offsets.size() == num_inputs + num_outputs, "Incompatible count of data offsets!"); OV_CPU_JIT_EMITTER_ASSERT(data_offsets.front().size() == master_shape.size(), "Incompatible rank of data offsets!"); - - // - Reserve reg_indexes_idx and reg_runtime_params_idx, since they'll be used to pass runtime call args to kernel - // - However we can use reg_indexes_idx for non memory access operations - // since we won't need them after offsets calculation - init_body_regs({reg_indexes_idx, reg_runtime_params_idx}, {}, {reg_indexes_idx}); } -void jit_kernel_static_emitter::init_data_pointers(const std::vector& data_ptr_regs) const { - XReg reg_indexes = XReg(static_cast(reg_indexes_idx)); - XReg reg_runtime_params = XReg(static_cast(reg_runtime_params_idx)); +void jit_kernel_static_emitter::init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs) const { + OV_CPU_JIT_EMITTER_ASSERT(arg_regs.size() == 2, "Invalid arg regs size"); + XReg reg_runtime_params = arg_regs[0]; + XReg reg_indexes = arg_regs[1]; + XReg reg_tmp = XReg(h->X_TMP_0); XReg reg_aux = XReg(h->X_TMP_1); @@ -261,18 +261,14 @@ jit_kernel_dynamic_emitter::jit_kernel_dynamic_emitter(dnnl::impl::cpu::aarch64: dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) : jit_kernel_emitter(h, isa, expr) { - const auto kernel = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(kernel, "Expectes KernelDynamic expression"); - - // - Reserve reg_runtime_params_idx, since it wll be used to pass runtime call args to all dynamic emitters that - // needs runtime args - // - We cannot assign this register to the body emitters since runtime params MUST be valid during whole execution - // for all dynamic emitters - init_body_regs({reg_runtime_params_idx}); + OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(expr->get_node()), + "expects KernelDynamic expression"); } -void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector& data_ptr_regs) const { - XReg reg_runtime_params = XReg(static_cast(reg_runtime_params_idx)); +void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs) const { + OV_CPU_JIT_EMITTER_ASSERT(arg_regs.size() == 1, "Invalid arg regs size"); + XReg reg_runtime_params = arg_regs[0]; const auto num_params = num_inputs + num_outputs; for (size_t i = 0; i < num_unique_buffers; ++i) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp index 0ede91f100f110..c4844bc2545f82 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_kernel_emitter.hpp @@ -5,7 +5,6 @@ #pragma once #include "emitters/plugin/aarch64/jit_emitter.hpp" -#include "emitters/snippets/jit_container_emitter.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" namespace ov { @@ -30,7 +29,7 @@ namespace aarch64 { /// Note that Kernel doesn't accept any input arguments. /// -class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { +class jit_kernel_emitter : public jit_emitter { public: jit_kernel_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, @@ -41,14 +40,11 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { } void emit_code(const std::vector& in_idxs, const std::vector& out_idxs, - const std::vector& pool_vec_idxs = {}, - const std::vector& pool_gpr_idxs = {}) const override; + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; protected: void validate_arguments(const std::vector& in, const std::vector& out) const override; - void init_body_regs(const std::set& kernel_regs, - const std::vector& pool_vec_idxs = {}, - const std::vector& pool_gpr_idxs = {}); /** * @brief populates physical registers pools for x86 (both vec and gp). * Skips stack-related gprs and extra gprs passed as arguments. @@ -57,7 +53,8 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { */ void init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist); - virtual void init_data_pointers(const std::vector& data_ptr_regs) const = 0; + virtual void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs) const = 0; void emit_impl(const std::vector& in, const std::vector& out) const override; @@ -70,11 +67,6 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { size_t num_outputs = 0; size_t num_unique_buffers = 0; - snippets::lowered::LinearIR::container mem_access_exprs; - snippets::lowered::LinearIR::container general_exprs; - - const size_t reg_runtime_params_idx{0}; - std::shared_ptr body; #ifdef SNIPPETS_DEBUG_CAPS @@ -87,11 +79,14 @@ class jit_kernel_static_emitter : public jit_kernel_emitter { jit_kernel_static_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + size_t get_inputs_count() const override { + return 2; + } private: - void init_data_pointers(const std::vector& data_ptr_regs) const override; + void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs) const override; - const size_t reg_indexes_idx{1}; std::vector master_shape; std::vector> data_offsets; @@ -105,9 +100,13 @@ class jit_kernel_dynamic_emitter : public jit_kernel_emitter { jit_kernel_dynamic_emitter(dnnl::impl::cpu::aarch64::jit_generator* h, dnnl::impl::cpu::aarch64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + size_t get_inputs_count() const override { + return 1; + } private: - void init_data_pointers(const std::vector& data_ptr_regs) const override; + void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs) const override; #ifdef SNIPPETS_DEBUG_CAPS friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp deleted file mode 100644 index ceee57f3c0cd28..00000000000000 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#include "jit_container_emitter.hpp" - -#include "emitters/utils.hpp" -#include "utils/general_utils.h" - -namespace ov { -namespace intel_cpu { - -void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, - mapping_info& vec_map_pool, - snippets::lowered::LinearIR::container& expressions) const { - OV_CPU_JIT_EMITTER_ASSERT(!expressions.empty(), - "Cannot map registers when there is no allocated_emitters provided"); - - auto map_regs = [&](const std::vector& abstract_regs) { - std::vector physical_regs = abstract_regs; - for (size_t i = 0; i < abstract_regs.size(); ++i) { - const auto& abstract_reg = abstract_regs[i]; - const auto& type = abstract_reg.type; - const auto& abstract = abstract_reg.idx; - OV_CPU_JIT_EMITTER_ASSERT(one_of(type, snippets::RegType::gpr, snippets::RegType::vec), - "Incorrect reg type detected!"); - auto& mapping = type == snippets::RegType::gpr ? gpr_map_pool : vec_map_pool; - auto& abstract_to_physical = mapping.first; - auto& regs_pool = mapping.second; - auto& physical = physical_regs[i]; - if (abstract_to_physical.count(abstract) == 0) { - OV_CPU_JIT_EMITTER_ASSERT( - !regs_pool.empty(), - "Cannot map registers for jit_container_emitter: not enough regs in the pool"); - physical.idx = regs_pool.back(); - regs_pool.pop_back(); - abstract_to_physical[abstract] = physical.idx; - } else { - physical.idx = abstract_to_physical[abstract]; - } - } - return physical_regs; - }; - - for (const auto& expression : expressions) { - std::vector in_physical_regs, out_physical_regs; - std::vector in_abstract_regs, out_abstract_regs; - std::tie(in_abstract_regs, out_abstract_regs) = expression->get_reg_info(); - in_physical_regs = map_regs(in_abstract_regs); - out_physical_regs = map_regs(out_abstract_regs); - expression->set_reg_info({in_physical_regs, out_physical_regs}); - if (auto container = std::dynamic_pointer_cast(expression->get_emitter())) - container->map_abstract_registers(gpr_map_pool, vec_map_pool, expressions); - } -} - -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp deleted file mode 100644 index 7737e7e1150926..00000000000000 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (C) 2024 Intel Corporation -// SPDX-License-Identifier: Apache-2.0 -// - -#pragma once - -#include "snippets/lowered/linear_ir.hpp" - -namespace ov { -namespace intel_cpu { - -/// -/// \brief jit_container_emitter designed provide common interface for register mapping -/// (abstract to physical) and nested code access. -/// -class jit_container_emitter { -public: - // mapping info contains abstract_to_physical map + regs_pool - using mapping_info = std::pair, std::vector&>; - -protected: - // maps gpr and vec abstract registers to physical ones. - void map_abstract_registers(mapping_info& gpr_map_pool, - mapping_info& vec_map_pool, - snippets::lowered::LinearIR::container& expressions) const; -}; - -} // namespace intel_cpu -} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index 39e384837856a1..58201b035bf040 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -7,10 +7,8 @@ #include #include "emitters/plugin/x64/jit_conversion_emitters.hpp" -#include "emitters/plugin/x64/jit_dnnl_emitters.hpp" #include "emitters/plugin/x64/jit_dnnl_ext_emitters.hpp" #include "emitters/plugin/x64/jit_eltwise_emitters.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" #include "emitters/snippets/cpu_runtime_configurator.hpp" #include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp" #include "emitters/snippets/x64/jit_brgemm_emitter.hpp" @@ -19,6 +17,7 @@ #include "emitters/snippets/x64/jit_kernel_emitter.hpp" #include "emitters/snippets/x64/jit_loop_emitters.hpp" #include "emitters/snippets/x64/jit_memory_emitters.hpp" +#include "emitters/snippets/x64/jit_reg_spill_emitters.hpp" #include "emitters/snippets/x64/jit_snippets_emitters.hpp" #include "snippets/snippets_isa.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" @@ -258,13 +257,6 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter); jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter); - jitters[snippets::op::KernelStatic::get_type_info_static()] = - CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_static_emitter); - jitters[snippets::op::KernelDynamic::get_type_info_static()] = - CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_dynamic_emitter); - jitters[snippets::op::LoopBegin::get_type_info_static()] = - CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_begin_emitter); - jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_end_emitter); // Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes // additional arguments jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = @@ -277,6 +269,18 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho compiled_kernel_cache); jitters[snippets::op::ReduceMax::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); jitters[snippets::op::ReduceSum::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); + // Service + jitters[snippets::op::KernelStatic::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_static_emitter); + jitters[snippets::op::KernelDynamic::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_dynamic_emitter); + jitters[snippets::op::LoopBegin::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_begin_emitter); + jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_end_emitter); + jitters[snippets::op::RegSpillBegin::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_reg_spill_begin_emitter); + jitters[snippets::op::RegSpillEnd::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_reg_spill_end_emitter); #ifdef SNIPPETS_DEBUG_CAPS jitters[snippets::op::PerfCountBegin::get_type_info_static()] = @@ -336,8 +340,41 @@ size_t intel_cpu::CPUTargetMachine::get_lanes() const { } } -size_t intel_cpu::CPUTargetMachine::get_reg_count() const { - return 16; +std::vector intel_cpu::CPUTargetMachine::get_abi_arg_regs() const { + const auto& abi_regs = dnnl::impl::cpu::x64::abi_param_regs; + std::vector res; + for (const auto& r : abi_regs) + res.emplace_back(snippets::RegType::gpr, r); + return res; +} + +std::vector intel_cpu::CPUTargetMachine::get_gp_reg_pool() const { + const auto num_gp_regs = 16; + std::vector reg_pool; + for (size_t i = 0; i < num_gp_regs; i++) { + if (!one_of(i, Xbyak::Operand::RSP, Xbyak::Operand::RBP)) + reg_pool.emplace_back(snippets::RegType::gpr, i); + } + return reg_pool; +} + +std::vector intel_cpu::CPUTargetMachine::get_vec_reg_pool() const { + const auto num_vec_regs = [this]() { + switch (isa) { + case dnnl::impl::cpu::x64::avx2: + return dnnl::impl::cpu::x64::cpu_isa_traits::n_vregs; + case dnnl::impl::cpu::x64::sse41: + return dnnl::impl::cpu::x64::cpu_isa_traits::n_vregs; + case dnnl::impl::cpu::x64::avx512_core: + return dnnl::impl::cpu::x64::cpu_isa_traits::n_vregs; + default: + OPENVINO_THROW("unknown isa ", isa); + } + }(); + std::vector reg_pool; + for (int i = 0; i < num_vec_regs; i++) + reg_pool.emplace_back(snippets::RegType::vec, static_cast(i)); + return reg_pool; } dnnl::impl::cpu::x64::cpu_isa_t intel_cpu::CPUTargetMachine::get_isa() const { @@ -389,15 +426,14 @@ std::shared_ptr intel_cpu::CPUGenerator::clone() const { ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(const ov::Output& out) const { const auto op = out.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(op) || + if (is_type(op) || #ifdef SNIPPETS_LIBXSMM_TPP std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || + is_type(op) || #endif - std::dynamic_pointer_cast(op)) + is_type(op)) return ov::snippets::RegType::gpr; - else if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) + else if (is_type(op) || is_type(op)) return ov::snippets::RegType::vec; else return ov::snippets::RegType::undefined; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.hpp index a86f3050580ed4..c4476d38088c97 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.hpp @@ -35,7 +35,11 @@ class CPUTargetMachine : public snippets::TargetMachine { bool is_supported() const override; snippets::CompiledSnippetPtr get_snippet() override; size_t get_lanes() const override; - size_t get_reg_count() const override; + + std::vector get_abi_arg_regs() const override; + std::vector get_gp_reg_pool() const override; + std::vector get_vec_reg_pool() const override; + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const; #ifdef SNIPPETS_DEBUG_CAPS SnippetsDebugCapsConfig debug_config; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 8d343cec908732..6470fc2f50b0f5 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -58,6 +58,7 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, m_memory_offsets.push_back(brgemm_node->get_offset_scratch()); m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_input_port(2))); } + m_live_regs = expr->get_live_regs(); } std::set> jit_brgemm_emitter::get_supported_precisions( @@ -106,18 +107,31 @@ void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vec template ::value, bool>::type> void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) const { + std::set regs_to_spill = m_live_regs; + // Note: these 3 registers will be corrupted by the caller during the ABI call + regs_to_spill.emplace(snippets::RegType::gpr, abi_param1.getIdx()); + regs_to_spill.emplace(snippets::RegType::gpr, abi_param2.getIdx()); + regs_to_spill.emplace(snippets::RegType::gpr, h->rbp.getIdx()); + const bool is_dynamic_case = + std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + // Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly + Xbyak::Reg64 aux_reg = abi_param1; + if (std::is_same() || is_dynamic_case) { + if (!aux_gpr_idxs.empty()) { + aux_reg = Xbyak::Reg64(static_cast(aux_gpr_idxs[0])); + } else { + aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs); + regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx()); + } + } EmitABIRegSpills spill(h); - spill.preamble(); + spill.preamble(regs_to_spill); h->mov(h->rbp, reinterpret_cast(T::execute)); auto reserved_stack_size = sizeof(typename T::call_args); // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); - const bool is_dynamic_case = - std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); - Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); - #define GET_OFF_CALL_ARGS(field) offsetof(typename T::call_args, field) const std::vector brgemm_args_offsets = {GET_OFF_CALL_ARGS(A), GET_OFF_CALL_ARGS(B), @@ -127,14 +141,16 @@ void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) con const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs); for (size_t i = 0; i < mem_ptrs.size(); i++) { - if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) + if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) { + OV_CPU_JIT_EMITTER_ASSERT(aux_reg != abi_param1, "Aux reg is needed, but wasn't allocated"); utils::push_ptr_with_runtime_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], aux_reg, GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); - else + } else { utils::push_ptr_with_static_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); + } } // No scratchpad => need to write nullptr manually @@ -143,8 +159,9 @@ void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) con // abi_param1 always contains jit_snippets_call_args which has amx tile config for each thread if (std::is_same()) { - h->lea(h->r10, h->ptr[abi_param1 + GET_OFF(amx_tile_config)]); - h->mov(h->qword[h->rsp + GET_OFF_BRGEMM_AMX_ARGS(amx_tile_config)], h->r10); + OV_CPU_JIT_EMITTER_ASSERT(aux_reg != abi_param1, "Aux reg is needed, but wasn't allocated"); + h->lea(aux_reg, h->ptr[abi_param1 + GET_OFF(amx_tile_config)]); + h->mov(h->qword[h->rsp + GET_OFF_BRGEMM_AMX_ARGS(amx_tile_config)], aux_reg); } h->mov(abi_param1, reinterpret_cast(m_kernel_executor.get())); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index 9d072065c0fe52..20b969fcc3fc2c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -38,6 +38,7 @@ class jit_brgemm_emitter : public jit_emitter { // Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer std::vector m_buffer_ids{}; std::shared_ptr m_kernel_executor = nullptr; + std::set m_live_regs{}; #ifdef SNIPPETS_DEBUG_CAPS friend std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter* emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp index bd5a3227e1e125..5e56199b5ddbe7 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp @@ -4,7 +4,8 @@ #include "jit_kernel_emitter.hpp" -#include "snippets/utils/utils.hpp" +#include "jit_snippets_emitters.hpp" +#include "snippets/utils/reg_utils.hpp" #include "utils.hpp" using namespace Xbyak; @@ -17,8 +18,7 @@ namespace intel_cpu { jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), - reg_runtime_params_idx(abi_param1.getIdx()) { + : jit_emitter(h, isa) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "invoked with invalid op argument"); OV_CPU_JIT_EMITTER_ASSERT(!kernel->region->empty(), "invoked with empty body"); @@ -29,52 +29,23 @@ jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, const auto& buffers = body->get_buffers(); num_inputs = parameters.size(); num_outputs = results.size(); + std::vector data_ptr_regs; + data_ptr_regs.reserve(num_inputs + num_outputs); for (const auto& param : parameters) - mem_access_exprs.push_back(param); + data_ptr_regs.push_back(param->get_output_port_descriptor(0)->get_reg()); for (const auto& result : results) - mem_access_exprs.push_back(result); + data_ptr_regs.push_back(result->get_input_port_descriptor(0)->get_reg()); std::set unique_buffers; for (const auto& buffer_expr : buffers) { const auto buffer_reg_group = buffer_expr->get_reg_group(); if (unique_buffers.count(buffer_reg_group) == 0) { - mem_access_exprs.push_back(buffer_expr); + data_ptr_regs.push_back(buffer_expr->get_output_port_descriptor(0)->get_reg()); unique_buffers.insert(buffer_reg_group); } } - - using ExprSet = std::unordered_set; - const ExprSet params_set(parameters.cbegin(), parameters.cend()); - const ExprSet results_set(results.cbegin(), results.cend()); - const ExprSet buffers_set(buffers.cbegin(), buffers.cend()); - for (const auto& expr : *body) { - if (params_set.count(expr) == 0 && results_set.count(expr) == 0 && buffers_set.count(expr) == 0) - general_exprs.emplace_back(expr); - } num_unique_buffers = unique_buffers.size(); -} - -void jit_kernel_emitter::init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist) { - gp_regs_pool.resize(16); - vec_regs_pool.resize(16); - // It's easier to remove the last item during mapping, so fill descending to map ascending - for (size_t i = 0; i < 16; i++) - gp_regs_pool[i] = vec_regs_pool[i] = 15 - i; - auto remove_regs_from_pool = [](std::vector& pool, const std::set& to_remove) { - // It's important to keep the order of other elements - pool.erase(std::remove_if(pool.begin(), - pool.end(), - [&](size_t x) { - return to_remove.count(x) != 0; - }), - pool.end()); - }; - // Reserve stack base and pointer for push(...) and pop(...) operations - std::set gprs_blacklist_extended{Xbyak::Operand::RSP, Xbyak::Operand::RBP}; - gprs_blacklist_extended.insert(gpr_blacklist.begin(), gpr_blacklist.end()); - // Reserve abi_param1 and abi_param2, since they'll be used to pass runtime call args to kernel - remove_regs_from_pool(gp_regs_pool, gprs_blacklist_extended); - remove_regs_from_pool(vec_regs_pool, vec_blacklist); + data_ptr_regs_idx = snippets::utils::transform_snippets_regs_to_idxs(data_ptr_regs, snippets::RegType::gpr); } void jit_kernel_emitter::emit_code(const std::vector& in, @@ -82,54 +53,98 @@ void jit_kernel_emitter::emit_code(const std::vector& in, const std::vector& pool_vec_idxs, const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); + aux_vec_idxs = pool_vec_idxs; + aux_gpr_idxs = pool_gpr_idxs; emit_impl(in, out); } void jit_kernel_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { - OV_CPU_JIT_EMITTER_ASSERT(in.empty() && out.empty(), ": expects 0 registers on input and output"); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == get_inputs_num() && out.empty(), + "Unexpected number of input/output arguments"); const auto num_params = num_inputs + num_outputs + num_unique_buffers; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount OV_CPU_JIT_EMITTER_ASSERT(data_ptr_regs_idx.size() == num_params, - "number of inputs and outputs is inconsistent with the number of allocated registers ", + "Number of inputs and outputs is inconsistent with the number of allocated registers ", num_params, " data_ptr_regs_idx.size() = ", data_ptr_regs_idx.size()); } -void jit_kernel_emitter::init_body_regs(const std::set& kernel_regs, - const std::vector& pool_vec_idxs, - const std::vector& pool_gpr_idxs) { - // Initialize pools of gp and vec registers - // Reserve kernel regs (abi_param1 and, if there is, abi_param2), since they'll be used to pass runtime call args to - // kernel - init_reg_pools(kernel_regs, {}); - - mapping_info gpr_map_pool({}, gp_regs_pool); - mapping_info vec_map_pool({}, vec_regs_pool); - - // Note that we can't use kernel_regs to store data pointers because - // these regs are used to calculate offsets for the data pointers - map_abstract_registers(gpr_map_pool, vec_map_pool, mem_access_exprs); - for (const auto& abstract_to_physical : gpr_map_pool.first) - data_ptr_regs_idx.push_back(abstract_to_physical.second); - - gpr_map_pool.second.insert(gpr_map_pool.second.end(), pool_gpr_idxs.cbegin(), pool_gpr_idxs.cend()); - vec_map_pool.second.insert(vec_map_pool.second.end(), pool_vec_idxs.cbegin(), pool_vec_idxs.cend()); - map_abstract_registers(gpr_map_pool, vec_map_pool, general_exprs); -} - void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vector& out) const { h->preamble(); + std::set available_gpr; + std::set available_vec; + auto reg_type = snippets::RegType::gpr; + auto convert = [®_type](size_t i) -> snippets::Reg { + return {reg_type, i}; + }; + std::transform(aux_gpr_idxs.begin(), + aux_gpr_idxs.end(), + std::inserter(available_gpr, available_gpr.begin()), + convert); + // Note: data_ptr regs are globally live, so it makes no sense to keep them in the pool + for (auto idx : data_ptr_regs_idx) + available_gpr.erase({snippets::RegType::gpr, idx}); + reg_type = snippets::RegType::vec; + std::transform(aux_vec_idxs.begin(), + aux_vec_idxs.end(), + std::inserter(available_vec, available_vec.begin()), + convert); + auto data_ptr_regs = utils::transform_idxs_to_regs(data_ptr_regs_idx); - init_data_pointers(data_ptr_regs); + auto get_expected_reg_types = + [](const std::shared_ptr& emitter) -> std::pair { + switch (emitter->get_in_out_type()) { + case emitter_in_out_map::gpr_to_vec: + return {snippets::RegType::gpr, snippets::RegType::vec}; + case emitter_in_out_map::gpr_to_gpr: + return {snippets::RegType::gpr, snippets::RegType::gpr}; + case emitter_in_out_map::vec_to_gpr: + return {snippets::RegType::vec, snippets::RegType::gpr}; + case emitter_in_out_map::vec_to_vec: + return {snippets::RegType::vec, snippets::RegType::vec}; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported emitter_in_ou_map instance"); + } + }; + std::vector aux_tmp_regs{}; + if (!available_gpr.empty()) + aux_tmp_regs.emplace_back(available_gpr.begin()->idx); + init_data_pointers(utils::transform_idxs_to_regs(in), data_ptr_regs, aux_tmp_regs); for (const auto& expression : *body) { const auto reg_info = expression->get_reg_info(); - auto in_regs = utils::transform_snippets_regs_to_idxs(reg_info.first); - auto out_regs = utils::transform_snippets_regs_to_idxs(reg_info.second); - const auto& emitter = expression->get_emitter(); - emitter->emit_code(in_regs, out_regs, vec_regs_pool, gp_regs_pool); + const auto& emitter = std::dynamic_pointer_cast(expression->get_emitter()); + OV_CPU_JIT_EMITTER_ASSERT(emitter, "Unexpected emitter type"); + auto expected_in_type = snippets::RegType::undefined; + auto expected_out_type = snippets::RegType::undefined; + const auto& node = expression->get_node(); + // Note: currently only a few operations are allowed to have mixed in/out register types => skip validation here + if (!ov::is_type(node) && !ov::is_type(node) && + !std::dynamic_pointer_cast(emitter)) + std::tie(expected_in_type, expected_out_type) = get_expected_reg_types(emitter); + // Note: live regs = regs live on input of the expression. We also need to exclude output regs from the pool + auto live_regs = expression->get_live_regs(); + for (auto r : reg_info.second) + live_regs.insert(r); + std::vector pool_gp_reg; + std::vector pool_vec_reg; + std::set_difference(available_gpr.begin(), + available_gpr.end(), + live_regs.begin(), + live_regs.end(), + std::back_inserter(pool_gp_reg)); + std::set_difference(available_vec.begin(), + available_vec.end(), + live_regs.begin(), + live_regs.end(), + std::back_inserter(pool_vec_reg)); + auto in_regs = snippets::utils::transform_snippets_regs_to_idxs(reg_info.first, expected_in_type); + auto out_regs = snippets::utils::transform_snippets_regs_to_idxs(reg_info.second, expected_out_type); + auto gpr_pool = snippets::utils::transform_snippets_regs_to_idxs(pool_gp_reg); + auto vec_pool = snippets::utils::transform_snippets_regs_to_idxs(pool_vec_reg); + emitter->emit_code(in_regs, out_regs, vec_pool, gpr_pool); } h->postamble(); @@ -138,25 +153,22 @@ void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vec jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_kernel_emitter(h, isa, expr), - reg_indexes_idx(abi_param2.getIdx()) { + : jit_kernel_emitter(h, isa, expr) { const auto kernel = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "expectes KernelStatic expression"); + OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "expects KernelStatic expression"); jcp = *reinterpret_cast(kernel->compile_params); master_shape = jcp.exec_domain; data_offsets = jcp.data_offsets; OV_CPU_JIT_EMITTER_ASSERT(data_offsets.size() == num_inputs + num_outputs, "Incompatible count of data offsets!"); OV_CPU_JIT_EMITTER_ASSERT(data_offsets.front().size() == master_shape.size(), "Incompatible rank of data offsets!"); - - // - Reserve abi_param1 and abi_param2, since they'll be used to pass runtime call args to kernel - // - However we can use reg_indexes_idx for non memory access operations - // since we won't need them after offsets calculation - init_body_regs({reg_indexes_idx, reg_runtime_params_idx}, {}, {reg_indexes_idx}); } -void jit_kernel_static_emitter::init_data_pointers(const std::vector& data_ptr_regs) const { - Xbyak::Reg64 reg_indexes = Xbyak::Reg64(static_cast(reg_indexes_idx)); - Xbyak::Reg64 reg_runtime_params = Xbyak::Reg64(static_cast(reg_runtime_params_idx)); +void jit_kernel_static_emitter::init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs, + const std::vector& aux_gprs) const { + OV_CPU_JIT_EMITTER_ASSERT(arg_regs.size() == 2, "Invalid arg regs size"); + Xbyak::Reg64 reg_runtime_params = arg_regs[0]; + Xbyak::Reg64 reg_indexes = arg_regs[1]; const auto num_params = num_inputs + num_outputs; // Note that we don't need offset for the last dim, since it's handled directly by Tile emitter @@ -173,12 +185,8 @@ void jit_kernel_static_emitter::init_data_pointers(const std::vector(*spare_corruptable_gpr)); + const bool last_iter_explicitly = aux_gprs.empty(); + Reg64 reg_tmp = last_iter_explicitly ? data_ptr_regs[num_params - 1] : *aux_gprs.begin(); // Vector "data_ptr_regs" is sorted by abstract regs. // It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer] // So we can initialize buffer register firstly as last value of vector "data_ptr_regs" @@ -212,18 +220,15 @@ jit_kernel_dynamic_emitter::jit_kernel_dynamic_emitter(dnnl::impl::cpu::x64::jit dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) : jit_kernel_emitter(h, isa, expr) { - const auto kernel = ov::as_type_ptr(expr->get_node()); - OV_CPU_JIT_EMITTER_ASSERT(kernel, "expectes KernelDynamic expression"); - - // - Reserve abi_param1, since it wll be used to pass runtime call args to all dynamic emitters that needs runtime - // args - // - We cannot assign this register to the body emitters since runtime params MUST be valid during whole execution - // for all dynamic emitters - init_body_regs({reg_runtime_params_idx}); + OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(expr->get_node()), + "expects KernelDynamic expression"); } -void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector& data_ptr_regs) const { - Xbyak::Reg64 reg_runtime_params = Xbyak::Reg64(static_cast(reg_runtime_params_idx)); +void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs, + const std::vector& aux_gprs) const { + OV_CPU_JIT_EMITTER_ASSERT(arg_regs.size() == 1, "Invalid arg regs size"); + Xbyak::Reg64 reg_runtime_params = arg_regs[0]; const auto num_params = num_inputs + num_outputs; for (size_t i = 0; i < num_unique_buffers; ++i) { diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.hpp index 68ea6684cbcd17..995dad15c734d1 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.hpp @@ -5,7 +5,6 @@ #pragma once #include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/jit_container_emitter.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" namespace ov { @@ -29,7 +28,7 @@ namespace intel_cpu { /// Note that Kernel doesn't accept any input arguments. /// -class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { +class jit_kernel_emitter : public jit_emitter { public: jit_kernel_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, @@ -40,40 +39,25 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { } void emit_code(const std::vector& in_idxs, const std::vector& out_idxs, - const std::vector& pool_vec_idxs = {}, - const std::vector& pool_gpr_idxs = {}) const override; + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; protected: void validate_arguments(const std::vector& in, const std::vector& out) const override; - void init_body_regs(const std::set& kernel_regs, - const std::vector& pool_vec_idxs = {}, - const std::vector& pool_gpr_idxs = {}); - /** - * @brief populates physical registers pools for x86 (both vec and gp). - * Skips stack-related gprs and extra gprs passed as arguments. - * @arg gpr_blacklist - set of gp registers that should not be added to register pool - * @arg vec_blacklist - set of vec registers should not be added to register pool - */ - void init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist); - - virtual void init_data_pointers(const std::vector& data_ptr_regs) const = 0; + + virtual void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs, + const std::vector& aux_gprs) const = 0; void emit_impl(const std::vector& in, const std::vector& out) const override; jit_snippets_compile_args jcp; // gpr's used to store data pointers, track them to apply offsets in Kernel std::vector data_ptr_regs_idx; - std::vector vec_regs_pool; - std::vector gp_regs_pool; size_t num_inputs = 0; size_t num_outputs = 0; size_t num_unique_buffers = 0; - snippets::lowered::LinearIR::container mem_access_exprs; - snippets::lowered::LinearIR::container general_exprs; - - const size_t reg_runtime_params_idx{0}; - std::shared_ptr body; #ifdef SNIPPETS_DEBUG_CAPS @@ -86,11 +70,15 @@ class jit_kernel_static_emitter : public jit_kernel_emitter { jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + size_t get_inputs_num() const override { + return 2; + } private: - void init_data_pointers(const std::vector& data_ptr_regs) const override; + void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs, + const std::vector& aux_gprs) const override; - const size_t reg_indexes_idx{1}; std::vector master_shape; std::vector> data_offsets; @@ -104,9 +92,14 @@ class jit_kernel_dynamic_emitter : public jit_kernel_emitter { jit_kernel_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + size_t get_inputs_num() const override { + return 1; + } private: - void init_data_pointers(const std::vector& data_ptr_regs) const override; + void init_data_pointers(const std::vector& arg_regs, + const std::vector& data_ptr_regs, + const std::vector& aux_gprs) const override; #ifdef SNIPPETS_DEBUG_CAPS friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp new file mode 100644 index 00000000000000..9303ab765eb4b5 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.cpp @@ -0,0 +1,92 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "jit_reg_spill_emitters.hpp" + +#include "emitters/plugin/x64/utils.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "emitters/snippets/x64/utils.hpp" +#include "snippets/utils/utils.hpp" + +using namespace Xbyak; +using namespace dnnl::impl; +using namespace dnnl::impl::cpu::x64; + +namespace ov { +namespace intel_cpu { + +/* ================== jit_reg_spill_begin_emitters ====================== */ + +jit_reg_spill_begin_emitter::jit_reg_spill_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr) + : jit_emitter(h, isa), + m_reg_spill_begin_expr(expr) { + const auto& reg_spill_node = ov::as_type_ptr(m_reg_spill_begin_expr->get_node()); + OV_CPU_JIT_EMITTER_ASSERT(reg_spill_node, "expects RegSpillBegin expression"); + m_num_spilled = reg_spill_node->get_regs_to_spill().size(); + // todo: abstract registers were not mapped on physical onces at this point (Kernel emitter does that), + // so we have to declare reg_spiller mutable + m_abi_reg_spiller = std::make_shared(h); + in_out_type_ = emitter_in_out_map::gpr_to_gpr; +} + +void jit_reg_spill_begin_emitter::validate_arguments(const std::vector& in, + const std::vector& out) const { + OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "In regs should be empty for reg_spill_begin emitter"); + OV_CPU_JIT_EMITTER_ASSERT(out.size() == m_num_spilled, "Invalid number of out regs for reg_spill_begin emitter"); +} + +void jit_reg_spill_begin_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { + validate_arguments(in, out); + emit_impl(in, out); +} + +void jit_reg_spill_begin_emitter::emit_impl(const std::vector& in, const std::vector& out) const { + const auto& reg_info = m_reg_spill_begin_expr->get_reg_info(); + m_abi_reg_spiller->preamble({reg_info.second.begin(), reg_info.second.end()}); +} + +/* ============================================================== */ + +/* ================== jit_reg_spill_end_emitter ====================== */ + +jit_reg_spill_end_emitter::jit_reg_spill_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr) + : jit_emitter(h, isa) { + in_out_type_ = emitter_in_out_map::gpr_to_gpr; + OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(expr->get_node()) && expr->get_input_count() > 0, + "Invalid expression in RegSpillEnd emitter"); + const auto& parent_expr = expr->get_input_port_connector(0)->get_source().get_expr(); + const auto& reg_spill_begin_emitter = + std::dynamic_pointer_cast(parent_expr->get_emitter()); + OV_CPU_JIT_EMITTER_ASSERT(reg_spill_begin_emitter, "Failed to obtain reg_spill_begin emitter"); + m_abi_reg_spiller = reg_spill_begin_emitter->m_abi_reg_spiller; +} + +void jit_reg_spill_end_emitter::validate_arguments(const std::vector& in, + const std::vector& out) const { + OV_CPU_JIT_EMITTER_ASSERT(out.empty(), "Out regs should be empty for reg_spill_end emitter"); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == m_abi_reg_spiller->get_num_spilled_regs(), + "Invalid number of in regs for reg_spill_end emitter"); +} + +void jit_reg_spill_end_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { + validate_arguments(in, out); + emit_impl(in, out); +} + +void jit_reg_spill_end_emitter::emit_impl(const std::vector& in, const std::vector& out) const { + m_abi_reg_spiller->postamble(); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.hpp new file mode 100644 index 00000000000000..ab475c692e0410 --- /dev/null +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_reg_spill_emitters.hpp @@ -0,0 +1,72 @@ +// Copyright (C) 2020-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "emitters/plugin/x64/jit_emitter.hpp" + +namespace ov { +namespace intel_cpu { + +/* ================== jit_reg_spill_begin_emitters ====================== */ +class EmitABIRegSpills; +class jit_reg_spill_end_emitter; +class jit_reg_spill_begin_emitter : public jit_emitter { + friend jit_reg_spill_end_emitter; + +public: + jit_reg_spill_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); + + size_t get_inputs_num() const override { + return 0; + } + + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; + +protected: + void validate_arguments(const std::vector& in, const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; + const ov::snippets::lowered::ExpressionPtr m_reg_spill_begin_expr; + std::shared_ptr m_abi_reg_spiller; + size_t m_num_spilled = SIZE_MAX; +}; + +/* ============================================================== */ + +/* ================== jit_reg_spill_end_emitter ====================== */ + +class jit_reg_spill_end_emitter : public jit_emitter { +public: + jit_reg_spill_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); + + size_t get_inputs_num() const override { + return 0; + } + + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; + +protected: + void validate_arguments(const std::vector& in, const std::vector& out) const override; + void emit_impl(const std::vector& in, const std::vector& out) const override; + + size_t aux_gprs_count() const override { + return 0; + } + std::shared_ptr m_abi_reg_spiller; +}; + +/* ============================================================== */ + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp index 3d8026ea33c750..d8967ca684fa64 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp @@ -19,14 +19,6 @@ inline static std::vector transform_idxs_to_regs(const std::vector return regs; } -inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs) { - std::vector idxs(regs.size()); - std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { - return reg.idx; - }); - return idxs; -} - /** * @brief If the passed `port` is connected to a Buffer, return its cluster ID. * Otherwise returns SIZE_MAX diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp index 9ac7f0d5cd0ffc..4338000d6982fa 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp @@ -107,13 +107,9 @@ std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* std::stringstream ss; ss << "Emitter_type_name:jit_kernel_static_emitter" << " jcp.exec_domain:" << vector_to_string(emitter->jcp.exec_domain) - << " gp_regs_pool:" << vector_to_string(emitter->gp_regs_pool) << " master_shape:" << vector_to_string(emitter->master_shape) << " num_inputs:" << emitter->num_inputs << " num_outputs:" << emitter->num_outputs << " num_unique_buffers:" << emitter->num_unique_buffers - << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx) - << " vec_regs_pool:" << vector_to_string(emitter->vec_regs_pool) - << " reg_indexes_idx:" << emitter->reg_indexes_idx - << " reg_runtime_params_idx:" << emitter->reg_runtime_params_idx; + << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx); for (size_t i = 0; i < emitter->data_offsets.size(); ++i) ss << " data_offsets for " << i << " is:" << vector_to_string(emitter->data_offsets[i]); return ss.str(); @@ -122,11 +118,9 @@ std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter) { std::stringstream ss; ss << "Emitter_type_name:jit_kernel_dynamic_emitter" - << " gp_regs_pool:" << vector_to_string(emitter->gp_regs_pool) << " num_inputs:" << emitter->num_inputs - << " num_outputs:" << emitter->num_outputs << " num_unique_buffers:" << emitter->num_unique_buffers - << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx) - << " vec_regs_pool:" << vector_to_string(emitter->vec_regs_pool) - << " reg_runtime_params_idx:" << emitter->reg_runtime_params_idx; + << " num_inputs:" << emitter->num_inputs << " num_outputs:" << emitter->num_outputs + << " num_unique_buffers:" << emitter->num_unique_buffers + << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx); return ss.str(); }