diff --git a/src/common/snippets/src/lowered/pass/init_live_ranges.cpp b/src/common/snippets/src/lowered/pass/init_live_ranges.cpp index ba6cf0b7ca3694..00e19d83666a7e 100644 --- a/src/common/snippets/src/lowered/pass/init_live_ranges.cpp +++ b/src/common/snippets/src/lowered/pass/init_live_ranges.cpp @@ -39,10 +39,6 @@ bool InitLiveRanges::run(LinearIR& linear_ir) { expr->set_live_regs(std::prev(expr_it)->get()->get_live_regs()); continue; } - - OPENVINO_ASSERT(expr->get_output_count() == op->get_output_size() || - ov::is_type(op) || - ov::is_type(op), "Incorrect count of output port descriptors!"); const double start = expr->get_exec_num(); // Remove all regs that expired before start regs_to_expire.erase(regs_to_expire.begin(), regs_to_expire.lower_bound(start)); // remove all elements lower than start (not equal) diff --git a/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp b/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp index 9c9d4e0c530fd4..3d7c4517545ffa 100644 --- a/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp +++ b/src/common/snippets/src/lowered/pass/insert_reg_spills.cpp @@ -37,14 +37,14 @@ bool InsertRegSpills::run(LinearIR& linear_ir) { // Note: we need to insert immediately before LoopBegin => increment start_it start_it++; const auto& loop_begin_live = start_it->get()->get_live_regs(); - std::set brgemm_used; - const auto& brgemm_reg_info = expr->get_reg_info(); - brgemm_used.insert(brgemm_reg_info.first.begin(), brgemm_reg_info.first.end()); - brgemm_used.insert(brgemm_reg_info.second.begin(), brgemm_reg_info.second.end()); - // Note: before the loop, we need to spill all live regs except for the ones used by brgemm + std::set used; + const auto& reg_info = expr->get_reg_info(); + used.insert(reg_info.first.begin(), reg_info.first.end()); + used.insert(reg_info.second.begin(), reg_info.second.end()); + // Note: before the loop, we need to spill all live regs except for the ones used by the target expression std::set regs_to_spill; std::set_difference(loop_begin_live.begin(), loop_begin_live.end(), - brgemm_used.begin(), brgemm_used.end(), + used.begin(), used.end(), std::inserter(regs_to_spill, regs_to_spill.begin())); // we also need to keep kernel regs alive (actually only abi_param_1 is used in emitters, but save all for consistency) for (const auto& r : m_reg_manager.get_kernel_call_regs( snippets::op::Kernel::make_kernel(linear_ir.is_dynamic()))) diff --git a/src/common/snippets/src/lowered/pass/validate.cpp b/src/common/snippets/src/lowered/pass/validate.cpp index 2e9e5813c03264..94df5f0e8f5597 100644 --- a/src/common/snippets/src/lowered/pass/validate.cpp +++ b/src/common/snippets/src/lowered/pass/validate.cpp @@ -97,8 +97,13 @@ void validate_buffer(const ExpressionPtr& expr, const LinearIR& linear_ir) { void validate_loop_end(const ExpressionPtr& expr, const LinearIR& linear_ir) { const auto loop_end = ov::as_type_ptr(expr->get_node()); OPENVINO_ASSERT(loop_end, "LoopEnd validation expects LoopEnd op"); - OPENVINO_ASSERT(loop_end->get_loop_begin() != nullptr, + const auto& loop_begin = loop_end->get_loop_begin(); + OPENVINO_ASSERT(loop_begin != nullptr, "LoopEnd must be connected to the LoopBegin"); + const auto num_inputs = expr->get_input_count(); + OPENVINO_ASSERT(num_inputs >= 1, "LoopEnd expression must have at least 1 input"); + OPENVINO_ASSERT(expr->get_input_port_connector(num_inputs - 1)->get_source().get_expr()->get_node() == loop_begin, + "LopEnd expression must have LoopBegin attached to the last connector"); const auto& loop_manager = linear_ir.get_loop_manager(); const auto& loop_info = loop_manager->get_loop_info(loop_end->get_id()); @@ -148,6 +153,9 @@ bool Validate::run(LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lo if (found != m_validation_map.cend()) { (found->second)(expr, linear_ir); } + OPENVINO_ASSERT(expr->get_output_count() == node->get_output_size() || + ov::is_type(node) || + ov::is_type(node), "Incorrect count of output port descriptors!"); expr->validate(); // Loop expr doesn't have shapes and layouts if (!ov::is_type(node)) diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index 6df658d8d72d0c..3add94baa93939 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -66,6 +66,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, m_memory_offsets.push_back(brgemm_repack->get_offset_compensations()); m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_output_port(1))); } + m_live_regs = expr->get_live_regs(); } void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector& in, @@ -81,31 +82,38 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector& in, const s if (out.size() > 1) mem_ptrs_idxs.emplace_back(out[1]); + std::set regs_to_spill = m_live_regs; + // Note: these 3 registers will be corrupted by the caller during the ABI call + regs_to_spill.emplace(snippets::RegType::gpr, abi_param1.getIdx()); + regs_to_spill.emplace(snippets::RegType::gpr, abi_param2.getIdx()); + regs_to_spill.emplace(snippets::RegType::gpr, h->rbp.getIdx()); + // Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly + Xbyak::Reg64 aux_reg = abi_param1; + utils::init_memory_access_aux_gpr(mem_ptrs_idxs, m_memory_offsets, aux_gpr_idxs, regs_to_spill, aux_reg); + EmitABIRegSpills spill(h); - spill.preamble(); + spill.preamble(regs_to_spill); h->mov(h->rbp, reinterpret_cast(BrgemmCopyBKernelExecutor::execute)); auto reserved_stack_size = sizeof(BrgemmCopyBKernel::call_args); // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); - const bool is_dynamic_case = - std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); - Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); - const std::vector args_offsets{GET_OFF_BRGEMM_COPY_B_ARGS(src), GET_OFF_BRGEMM_COPY_B_ARGS(tr_src), GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)}; const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs); for (size_t i = 0; i < mem_ptrs.size(); i++) { - if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) + if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) { + OV_CPU_JIT_EMITTER_ASSERT(aux_reg != abi_param1, "Aux reg is needed, but wasn't allocated"); utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg, GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); - else + } else { utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); + } } // No scratchpad => need to write nullptr manually diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index d937e646b603da..ffcc842f0084bb 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -32,6 +32,7 @@ class jit_brgemm_copy_b_emitter : public jit_emitter { std::vector m_memory_offsets{}; std::vector m_buffer_ids{}; std::shared_ptr m_kernel_executor{nullptr}; + std::set m_live_regs{}; bool m_with_comp{false}; #ifdef SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 6470fc2f50b0f5..a30b3ac80f75f2 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -112,18 +112,9 @@ void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) con regs_to_spill.emplace(snippets::RegType::gpr, abi_param1.getIdx()); regs_to_spill.emplace(snippets::RegType::gpr, abi_param2.getIdx()); regs_to_spill.emplace(snippets::RegType::gpr, h->rbp.getIdx()); - const bool is_dynamic_case = - std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); // Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly Xbyak::Reg64 aux_reg = abi_param1; - if (std::is_same() || is_dynamic_case) { - if (!aux_gpr_idxs.empty()) { - aux_reg = Xbyak::Reg64(static_cast(aux_gpr_idxs[0])); - } else { - aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs); - regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx()); - } - } + utils::init_memory_access_aux_gpr(mem_ptrs_idxs, m_memory_offsets, aux_gpr_idxs, regs_to_spill, aux_reg); EmitABIRegSpills spill(h); spill.preamble(regs_to_spill); diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp index 0604792fc22573..98da2565c1368e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.cpp @@ -59,6 +59,24 @@ Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs) { OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR"); } +void init_memory_access_aux_gpr(const std::vector& mem_ptr_reg_idxs, + const std::vector& memory_offsets, + const std::vector& aux_gpr_idxs, + std::set& regs_to_spill, + Xbyak::Reg64& aux_reg) { + const bool is_dynamic_case = + std::any_of(memory_offsets.cbegin(), memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + // Note: abi_param_1 is a default invalid value to check later that the aux reg was allocated properly + if (is_dynamic_case) { + if (!aux_gpr_idxs.empty()) { + aux_reg = Xbyak::Reg64(static_cast(aux_gpr_idxs[0])); + } else { + aux_reg = ov::intel_cpu::utils::get_aux_gpr(mem_ptr_reg_idxs); + regs_to_spill.emplace(snippets::RegType::gpr, aux_reg.getIdx()); + } + } +}; + void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, Xbyak::Reg64 ptr_reg, diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp index d8967ca684fa64..2b61ace0fad499 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp @@ -34,6 +34,22 @@ size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port); */ Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs); +/** + * @brief Initializes aux gpr register for dynamic memory access emitters. If any of the `memory_offsets` is dynamic, + * then try to assign `aux_reg` a register from `aux_gpr_idxs`. If `aux_gpr_idxs` is empty, then choose a register that + * is not in `mem_ptr_reg_idxs` and add it to `regs_to_spill`. + * @param mem_ptr_reg_idxs register indexes reserved to store memory pointers in this emitter + * @param memory_offsets memory offsets, could be dynamic + * @param aux_gpr_idxs pool of available gp register indexes + * @param regs_to_spill set of live registers to be spilled before ABI call + * @param aux_reg auxiliary register that should be initialized + */ +void init_memory_access_aux_gpr(const std::vector& mem_ptr_reg_idxs, + const std::vector& memory_offsets, + const std::vector& aux_gpr_idxs, + std::set& regs_to_spill, + Xbyak::Reg64& aux_reg); + /** * @brief Push data pointer on stack adding offset. The offset is taken from runtime params `abi_param1` * @param h generator