Skip to content

Commit

Permalink
reshape
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Jan 3, 2024
1 parent d24e71d commit be9bc14
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/common/snippets/include/snippets/generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ class Generator {
* gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc)
* gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc.
* vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc.
* none->none: some operations that actually will not emit code, no registers is need: Reshape, etc.
*/
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec};
enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec, none2none};
/**
* @brief gets register type by op type
* TODO: Should be static attribute of emitters
Expand Down
2 changes: 2 additions & 0 deletions src/common/snippets/src/generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr<Node>& op)
std::dynamic_pointer_cast<op::HorizonSum>(op) ||
std::dynamic_pointer_cast<op::Fill>(op))
return vec2vec;
else if (std::dynamic_pointer_cast<ov::op::v1::Reshape>(op))
return none2none;
else
return get_specific_op_reg_type(op);
}
Expand Down
8 changes: 8 additions & 0 deletions src/common/snippets/src/lowered/pass/assign_registers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
for (auto& expr : expressions) {
auto op = expr->get_node();
auto reg_type = m_reg_type_mapper(op);
if (reg_type == Generator::opRegType::none2none)
continue;
typed_ops.emplace_back(reg_type, expr);
num_parameters += is_type<ov::op::v0::Parameter>(op);
num_results += is_type<ov::op::v0::Result>(op);
Expand Down Expand Up @@ -134,6 +136,8 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
case Generator::opRegType::vec2gpr:
enumerate_out_tensors(t_op.second, regs_gpr, manually_assigned_gprs, counter_gpr);
break;
default:
break;
}
}
// todo: make one for gpr and one for vector
Expand Down Expand Up @@ -177,6 +181,8 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
used_vec[i] = tensor2reg(used_tensors, regs_vec);
defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr);
break;
default:
break;
}
}

Expand Down Expand Up @@ -225,6 +231,8 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
case Generator::opRegType::gpr2vec:
life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end());
break;
default:
break;
}
}
}
Expand Down

0 comments on commit be9bc14

Please sign in to comment.