Skip to content

Commit

Permalink
[Snippets] Assign registers and ABI call optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanNovoselov committed Dec 31, 2024
1 parent a091846 commit cb519eb
Show file tree
Hide file tree
Showing 52 changed files with 1,493 additions and 848 deletions.
13 changes: 8 additions & 5 deletions src/common/snippets/include/snippets/emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,28 @@ namespace snippets {
* @interface RegType
* @brief Register type of input and output operations
*/
enum class RegType { gpr, vec, undefined };
enum class RegType { gpr, vec, mask, undefined };
/**
* @interface Reg
* @brief Register representation: type of register and index
*/
struct Reg {
enum {UNDEFINED_IDX = std::numeric_limits<size_t>::max()};
Reg() = default;
Reg(RegType type_, size_t idx_) : type(type_), idx(idx_) {}

RegType type = RegType::gpr;
size_t idx = 0;
bool is_defined() const { return type != RegType::undefined && idx != UNDEFINED_IDX; }
RegType type = RegType::undefined;
size_t idx = UNDEFINED_IDX;

friend bool operator==(const Reg& lhs, const Reg& rhs);
friend bool operator<(const Reg& lhs, const Reg& rhs);
friend bool operator>(const Reg& lhs, const Reg& rhs);
friend bool operator!=(const Reg& lhs, const Reg& rhs);
friend std::ostream& operator<<(std::ostream& s, const Reg& r);
};
using RegInfo = std::pair<std::vector<Reg>, std::vector<Reg>>;

std::string regTypeToStr(const RegType& type);

/**
* @interface Emitter
* @brief Base class for all target specific code emitters used by generator.
Expand Down
2 changes: 1 addition & 1 deletion src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class Schedule {
* @brief Target independent code generator interface
* @ingroup snippets
*/
class Generator {
class Generator : public std::enable_shared_from_this<Generator>{
public:
/**
* @brief Default constructor
Expand Down
3 changes: 3 additions & 0 deletions src/common/snippets/include/snippets/lowered/expression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Expression : public std::enable_shared_from_this<Expression> {

RegInfo get_reg_info() const;
void set_reg_info(const RegInfo& rinfo);
const std::set<Reg>& get_live_regs() const {return m_live_regs; }
void set_live_regs(std::set<Reg> live_regs) { m_live_regs = std::move(live_regs); }

double get_exec_num() const { return m_exec_num; }

Expand Down Expand Up @@ -130,6 +132,7 @@ class Expression : public std::enable_shared_from_this<Expression> {
// 2. This number can be changed and updated during whole pipeline, so its absolute values are meaningless.
// 3. This number can be negative, positive and zero.
double m_exec_num = 0;
std::set<Reg> m_live_regs;
};

} // namespace lowered
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "snippets/op/loop.hpp"
#include "snippets/op/buffer.hpp"
#include "snippets/op/perf_count.hpp"
#include "snippets/op/reg_spill.hpp"

namespace ov {
namespace snippets {
Expand All @@ -35,6 +36,10 @@ class ExpressionFactory {
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::LoopEnd>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::RegSpillBegin>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);
static ExpressionPtr create(const std::shared_ptr<op::RegSpillEnd>& n, const std::vector<PortConnectorPtr>& inputs,
const std::shared_ptr<IShapeInferSnippetsFactory>& shape_infer_factory);

// Note: PerfCountBegin nodes have a PerfCountEnd ov::Output, but corresponding expression should not have any outputs to avoid register allocation
#ifdef SNIPPETS_DEBUG_CAPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "pass.hpp"
#include "snippets/generator.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
Expand All @@ -21,15 +22,14 @@ namespace pass {
class AssignRegisters : public Pass {
public:
OPENVINO_RTTI("AssignRegisters", "Pass")
explicit AssignRegisters(const std::function<RegType(const ov::Output<Node>& out)>& mapper, const size_t reg_cnt)
: m_reg_type_mapper(mapper), reg_count(reg_cnt) {}
explicit AssignRegisters(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;

private:
void set_reg_types(LinearIR& linear_ir);
using RegMap = std::map<Reg, Reg>;
RegMap assign_regs_manually(const LinearIR& linear_ir, std::set<Reg>& gpr_pool, std::set<Reg>& vec_pool);

std::function<RegType(const ov::Output<Node>& out)> m_reg_type_mapper;
size_t reg_count;
RegManager& m_reg_manager;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/generator.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface InitLiveRanges
* @brief Calculates live ranges of registers. This information will be used to assign registers and optimize ABI reg spills.
* @ingroup snippets
*/
class InitLiveRanges : public Pass {
public:
OPENVINO_RTTI("InitLiveRanges", "Pass")
explicit InitLiveRanges(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;
private:
RegManager& m_reg_manager;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"
#include "snippets/lowered/reg_manager.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface InsertRegSpills
* @brief Insert RegSpill and RegRestore operations for binary call emitters to comply with ABI conventions.
* @ingroup snippets
*/
class InsertRegSpills : public Pass {
public:
OPENVINO_RTTI("InsertRegSpills", "Pass")
explicit InsertRegSpills(RegManager& reg_manager) : m_reg_manager(reg_manager) {}
bool run(LinearIR& linear_ir) override;

RegManager& m_reg_manager;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
67 changes: 67 additions & 0 deletions src/common/snippets/include/snippets/lowered/reg_manager.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include "openvino/core/node.hpp"
#include "snippets/emitter.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/generator.hpp"
#include "snippets/op/kernel.hpp"

/**
* @interface RegManager
* @brief The class holds supplementary info about assigned registers and live ranges
* @ingroup snippets
*/
namespace ov {
namespace snippets {
namespace lowered {

using RegTypeMapper = std::function<RegType(const ov::Output<Node>& out)>;
using LiveInterval = std::pair<double, double>;
class RegManager {
public:
RegManager() = delete;
RegManager(const std::shared_ptr<const Generator>& generator) : m_generator(generator) {}
inline RegType get_reg_type(const ov::Output<Node>& out) const { return m_generator->get_op_out_reg_type(out); }
inline std::vector<Reg> get_vec_reg_pool() const { return m_generator->get_target_machine()->get_vec_reg_pool(); }

inline void set_live_range(const Reg& reg, const LiveInterval& interval, bool force = false) {
OPENVINO_ASSERT(force || m_reg_live_range.count(reg) == 0, "Live range for this reg is already set");
m_reg_live_range[reg] = interval;
}

inline std::vector<Reg> get_kernel_call_regs(const std::shared_ptr<snippets::op::Kernel>& kernel) const {
const auto& abi_regs = m_generator->get_target_machine()->get_abi_arg_regs();
const auto num_kernel_args = kernel->get_num_call_args();
OPENVINO_ASSERT(abi_regs.size() > num_kernel_args, "Too many kernel args requested");
return {abi_regs.begin(), abi_regs.begin() + static_cast<int64_t>(num_kernel_args)};
}

inline std::vector<Reg> get_gp_regs_except_kernel_call(const std::shared_ptr<snippets::op::Kernel>& kernel) const {
auto res = m_generator->get_target_machine()->get_gp_reg_pool();
std::set<Reg> kernel_call;
for (auto r : get_kernel_call_regs(kernel))
kernel_call.insert(r);
res.erase(std::remove_if(res.begin(), res.end(), [&kernel_call](const Reg& r) {return kernel_call.count(r) != 0; }), res.end());
return res;
}

inline const LiveInterval& get_live_range(const Reg& reg) {
OPENVINO_ASSERT(m_reg_live_range.count(reg), "Live range for this reg was not set");
return m_reg_live_range[reg];
}
inline std::map<Reg, LiveInterval> get_live_range_map() const {
return m_reg_live_range;
}

private:
// Maps Register to {Start, Stop} pairs
std::map<Reg, LiveInterval> m_reg_live_range;
const std::shared_ptr<const Generator> m_generator;
};

} // namespace lowered
} // namespace snippets
} // namespace ov
15 changes: 14 additions & 1 deletion src/common/snippets/include/snippets/op/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class Kernel : public ov::op::Op {
Kernel() = default;
Kernel(lowered::LinearIR region);

static std::shared_ptr<Kernel> make_kernel(const lowered::LinearIR& region);
template<typename ... ArgTypes>
static std::shared_ptr<Kernel> make_kernel(bool is_dynamic, ArgTypes&&... args);
virtual size_t get_num_call_args() const = 0;

std::shared_ptr<lowered::LinearIR> region;
const void *compile_params = nullptr;
Expand All @@ -33,6 +35,7 @@ class KernelStatic : public Kernel {
OPENVINO_OP("KernelStatic", "SnippetsOpset", Kernel);
KernelStatic() = default;
KernelStatic(lowered::LinearIR region);
size_t get_num_call_args() const override { return 2; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
};

Expand All @@ -41,9 +44,19 @@ class KernelDynamic : public Kernel {
OPENVINO_OP("KernelDynamic", "SnippetsOpset", Kernel);
KernelDynamic() = default;
KernelDynamic(lowered::LinearIR region);
size_t get_num_call_args() const override { return 1; }
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
};

template<typename ... ArgTypes>
std::shared_ptr<Kernel> Kernel::make_kernel(bool is_dynamic, ArgTypes&&... args) {
if (is_dynamic) {
return std::make_shared<KernelDynamic>(std::forward<ArgTypes>(args)...);
} else {
return std::make_shared<KernelStatic>(std::forward<ArgTypes>(args)...);
}
}

} // namespace op
} // namespace snippets
} // namespace ov
82 changes: 82 additions & 0 deletions src/common/snippets/include/snippets/op/reg_spill.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/emitter.hpp"

#include "openvino/op/op.hpp"
#include "snippets/shape_inference/shape_inference.hpp"

namespace ov {
namespace snippets {
namespace op {

/**
* @interface RegSpillBase
* @brief Base class for RegSpillBegin and RegSpillEnd ops
* @ingroup snippets
*/
class RegSpillBase : public ov::op::Op {
public:
OPENVINO_OP("RegSpillBaseBase", "SnippetsOpset");
RegSpillBase(const std::vector<Output<Node>>& args);
RegSpillBase() = default;
virtual std::set<Reg> get_regs_to_spill() const = 0;
bool visit_attributes(AttributeVisitor& visitor) override;
protected:
};
class RegSpillEnd;
/**
* @interface RegSpillBegin
* @brief Marks the start of the register spill region.
* @ingroup snippets
*/
class RegSpillBegin : public RegSpillBase {
public:
OPENVINO_OP("RegSpillBegin", "SnippetsOpset", RegSpillBase);
RegSpillBegin(std::set<Reg> regs_to_spill);

void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
std::shared_ptr<RegSpillEnd> get_reg_spill_end() const;
std::set<Reg> get_regs_to_spill() const override { return m_regs_to_spill; }

class ShapeInfer : public IShapeInferSnippets {
size_t num_out_shapes = 0;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
protected:
void validate_and_infer_types_except_RegSpillEnd();
std::set<Reg> m_regs_to_spill = {};
};
/**
* @interface RegSpillEnd
* @brief Marks the end of the register spill region.
* @ingroup snippets
*/
class RegSpillEnd : public RegSpillBase {
public:
OPENVINO_OP("RegSpillEnd", "SnippetsOpset", RegSpillBase);
RegSpillEnd() = default;
RegSpillEnd(const Output<Node>& reg_spill_begin);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const override;
std::shared_ptr<RegSpillBegin> get_reg_spill_begin() const {
auto reg_spill_begin = ov::as_type_ptr<RegSpillBegin>(get_input_node_shared_ptr(0));
OPENVINO_ASSERT(reg_spill_begin, "Can't get reg_spill_begin from reg_spill_end");
return reg_spill_begin;
}
std::set<Reg> get_regs_to_spill() const override {
return get_reg_spill_begin()->get_regs_to_spill();
}
};

} // namespace op
} // namespace snippets
} // namespace ov
1 change: 1 addition & 0 deletions src/common/snippets/include/snippets/snippets_isa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "op/rank_normalization.hpp"
#include "op/perf_count.hpp"
#include "op/reduce.hpp"
#include "op/reg_spill.hpp"

namespace ov {
namespace snippets {
Expand Down
20 changes: 17 additions & 3 deletions src/common/snippets/include/snippets/target_machine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,24 @@ class TargetMachine {
virtual size_t get_lanes() const = 0;

/**
* @brief gets number of registers for a target machine
* @return number of registers
* @brief Get all possible ABI argument registers.
* The number of actually used register depends on the signature of the called binary.
* @return vector os snippets::Reg
*/
virtual size_t get_reg_count() const = 0;
virtual std::vector<snippets::Reg> get_abi_arg_regs() const = 0;

/**
* @brief Get all available general-purpose registers.
* Returns only registers that are not reserved for special purposes (e.g. stack pointer or instruction address).
* @return vector os snippets::Reg
*/
virtual std::vector<snippets::Reg> get_gp_reg_pool() const = 0;
/**
* @brief Get all available vector registers.
* Returns only registers that are not reserved for special purposes
* @return vector os snippets::Reg
*/
virtual std::vector<snippets::Reg> get_vec_reg_pool() const = 0;

/**
* @brief called by generator to all the emitter for a target machine
Expand Down
Loading

0 comments on commit cb519eb

Please sign in to comment.